joaquinsturtz commited on
Commit
fa4c4e1
·
verified ·
1 Parent(s): 5d61cd2

Fix: Bind to 0.0.0.0 and show_api=False to resolve runtime and schema errors

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +80 -6
  2. app.py +169 -0
  3. config.json +58 -0
  4. gfn/__init__.py +30 -0
  5. gfn/realizations/__init__.py +19 -0
  6. gfn/realizations/api.py +66 -0
  7. gfn/realizations/gssm/__init__.py +10 -0
  8. gfn/realizations/gssm/api.py +102 -0
  9. gfn/realizations/gssm/config/__init__.py +63 -0
  10. gfn/realizations/gssm/config/defaults.py +111 -0
  11. gfn/realizations/gssm/config/loader.py +163 -0
  12. gfn/realizations/gssm/config/schema.py +199 -0
  13. gfn/realizations/gssm/config/serialization.py +39 -0
  14. gfn/realizations/gssm/config/validator.py +109 -0
  15. gfn/realizations/gssm/constants.py +67 -0
  16. gfn/realizations/gssm/core/__init__.py +7 -0
  17. gfn/realizations/gssm/core/state.py +60 -0
  18. gfn/realizations/gssm/core/types.py +27 -0
  19. gfn/realizations/gssm/csrc/README.md +2 -0
  20. gfn/realizations/gssm/csrc/compile_cuda_12.9.bat +68 -0
  21. gfn/realizations/gssm/csrc/extension.cpp +87 -0
  22. gfn/realizations/gssm/csrc/geometry/low_rank.cu +160 -0
  23. gfn/realizations/gssm/csrc/integrators/integrators.cpp +252 -0
  24. gfn/realizations/gssm/csrc/integrators/integrators.h +41 -0
  25. gfn/realizations/gssm/csrc/losses/toroidal.cu +99 -0
  26. gfn/realizations/gssm/csrc/setup.py +38 -0
  27. gfn/realizations/gssm/cuda/__init__.py +11 -0
  28. gfn/realizations/gssm/cuda/autograd/__init__.py +0 -0
  29. gfn/realizations/gssm/cuda/kernels/__init__.py +0 -0
  30. gfn/realizations/gssm/cuda/kernels/geometry_kernels.py +99 -0
  31. gfn/realizations/gssm/cuda/kernels/integrator_kernels.py +73 -0
  32. gfn/realizations/gssm/cuda/ops/__init__.py +52 -0
  33. gfn/realizations/gssm/data/__init__.py +16 -0
  34. gfn/realizations/gssm/data/dataset.py +14 -0
  35. gfn/realizations/gssm/data/loader.py +53 -0
  36. gfn/realizations/gssm/data/replay.py +130 -0
  37. gfn/realizations/gssm/data/transforms.py +40 -0
  38. gfn/realizations/gssm/errors.py +23 -0
  39. gfn/realizations/gssm/geometry/__init__.py +42 -0
  40. gfn/realizations/gssm/geometry/adaptive.py +83 -0
  41. gfn/realizations/gssm/geometry/base.py +70 -0
  42. gfn/realizations/gssm/geometry/euclidean.py +20 -0
  43. gfn/realizations/gssm/geometry/factory.py +117 -0
  44. gfn/realizations/gssm/geometry/hierarchical.py +84 -0
  45. gfn/realizations/gssm/geometry/holographic.py +91 -0
  46. gfn/realizations/gssm/geometry/hyperbolic.py +97 -0
  47. gfn/realizations/gssm/geometry/low_rank.py +324 -0
  48. gfn/realizations/gssm/geometry/reactive.py +109 -0
  49. gfn/realizations/gssm/geometry/spherical.py +47 -0
  50. gfn/realizations/gssm/geometry/torus.py +274 -0
README.md CHANGED
@@ -1,12 +1,86 @@
1
  ---
2
- title: G Ssm Mniah
3
- emoji:
4
- colorFrom: blue
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 6.9.0
 
8
  app_file: app.py
9
  pinned: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: GFN MNIAH Solver
3
+ emoji: 📍
4
+ colorFrom: red
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.44.1
8
+ python_version: "3.11"
9
  app_file: app.py
10
  pinned: false
11
+ license: cc-by-nc-nd-4.0
12
+ library_name: gfn
13
+ language: en
14
+ pipeline_tag: other
15
+ tags:
16
+ - gfn
17
+ - physics-informed
18
+ - geometric-deep-learning
19
+ - g-ssm
20
+ - needle-in-a-haystack
21
+ - long-context
22
+ model-index:
23
+ - name: gfn-gssm-mniah-k2
24
+ results:
25
+ - task:
26
+ type: other
27
+ name: Multi-Needle Retrieval
28
+ dataset:
29
+ name: synthetic-needle-haystack
30
+ type: synthetic
31
+ metrics:
32
+ - name: Accuracy
33
+ type: accuracy
34
+ value: 100.0
35
  ---
36
 
37
+ # 📍 GFN MNIAH (Multi-Needle-in-a-Haystack) Solver
38
+
39
+ [![DOI: 10.5281/zenodo.19141133](https://img.shields.io/badge/DOI-10.5281/zenodo.19141133-blue.svg)](https://doi.org/10.5281/zenodo.19141133)
40
+ [![Models: Hugging Face](https://img.shields.io/badge/Models-Hugging%20Face-orange.svg)](https://huggingface.co/DepthMuun)
41
+ [![GitHub: GFN Framework](https://img.shields.io/badge/GitHub-GFN--Framework-black.svg?logo=github)](https://github.com/DepthMuun/gfn)
42
+
43
+ This repository demonstrates the **Geometric Flow Network (GFN)** framework's ability to handle extreme-length context retrieval without KV-cache. It implements the **Geodesic State Space Model (G-SSM)** specialized for the K=2 Needle-in-a-Haystack task.
44
+
45
+ ## 🚀 Technical Highlights
46
+ - **Context Length**: Validated up to **32,000 tokens** (linear complexity $O(L)$).
47
+ - **Inductive Bias**: Topologically encodes "needles" as geodetic impulses that curve the world state manifold into a target decision region.
48
+ - **Inference VRAM**: ~38MB constant (O(1) memory).
49
+
50
+ ## 🛠️ Local Installation & Usage
51
+ To run this model locally, you need the **GFN Framework** and the model assets.
52
+
53
+ ### 1. Install GFN Framework
54
+ ```bash
55
+ pip install git+https://github.com/DepthMuun/gfn.git
56
+ ```
57
+
58
+ ### 2. Clone this repository
59
+ ```bash
60
+ git lfs install
61
+ git clone https://huggingface.co/spaces/DepthMuun/gfn-gssm-mniah-k2-space
62
+ cd gfn-gssm-mniah-k2-space
63
+ ```
64
+
65
+ ### 3. Run the Interactive Demo
66
+ ```bash
67
+ python app.py
68
+ ```
69
+
70
+ ## 📜 Citation
71
+ If you use this work, please cite:
72
+ ```latex
73
+ @article{sturtz2026geometry,
74
+ title={Geometric Flow Networks: A Physics-Informed Paradigm for Sequential Intelligence},
75
+ author={Stürtz, Joaquín},
76
+ journal={Zenodo Preprints},
77
+ year={2026},
78
+ doi={10.5281/zenodo.19141133},
79
+ url={https://doi.org/10.5281/zenodo.19141133}
80
+ }
81
+ ```
82
+
83
+ ## 🔗 Resources
84
+ - **Official Checkpoint**: [GFN MNIAH Model](https://huggingface.co/DepthMuun/gfn-gssm-mniah-k2)
85
+ - **Framework Source**: [GitHub: DepthMuun/gfn](https://github.com/DepthMuun/gfn)
86
+ - **Official Paper**: [Zenodo](https://doi.org/10.5281/zenodo.19141133)
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import math
4
+ import sys
5
+ import os
6
+ import time
7
+ import json
8
+ from pathlib import Path
9
+
10
+ # Add local gfn folder to path if it exists (for HF Spaces)
11
+ script_dir = os.path.dirname(os.path.abspath(__file__))
12
+ if os.path.exists(os.path.join(script_dir, "gfn")):
13
+ sys.path.insert(0, script_dir)
14
+
15
+ import gfn
16
+
17
+ def load_model():
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+
20
+ # Load config safely using absolute path
21
+ config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.json")
22
+ with open(config_path, "r") as f:
23
+ config = json.load(f)
24
+
25
+ model = gfn.gssm.create(
26
+ vocab_size=config['architecture']['vocab_size'],
27
+ dim=config['architecture']['dim'],
28
+ depth=config['architecture']['depth'],
29
+ heads=config['architecture']['heads'],
30
+ integrator=config['architecture']['integrator'],
31
+ impulse_scale=config['architecture']['impulse_scale'],
32
+ dynamics_type=config['architecture']['dynamics_type'],
33
+ topology_type=config['architecture']['topology_type'],
34
+ physics=config['physics'],
35
+ holographic=config['architecture'].get('holographic', True),
36
+ ).to(device)
37
+ script_dir = os.path.dirname(os.path.abspath(__file__))
38
+ checkpoint_path = os.path.join(script_dir, "mniah_model_final.pt")
39
+ if not os.path.exists(checkpoint_path):
40
+ raise FileNotFoundError(f"Missing model weights: {checkpoint_path}. Please place the trained checkpoint here.")
41
+
42
+ ckpt = torch.load(checkpoint_path, map_location=device, weights_only=True)
43
+ model.load_state_dict(ckpt['model'])
44
+ model.eval()
45
+ return model, device
46
+
47
+ model, device = load_model()
48
+
49
+ import json
50
+ import tempfile
51
+
52
+ def run_mniah(seq_len, needle_pos_str, num_needles):
53
+ try:
54
+ if needle_pos_str.strip() == "":
55
+ # Random needles
56
+ lo = 1
57
+ pool = torch.randperm(seq_len)[:num_needles]
58
+ positions = sorted((pool + lo).tolist())
59
+ else:
60
+ positions = sorted([int(p.strip()) for p in needle_pos_str.split(",")])
61
+ if len(positions) != num_needles:
62
+ return f"Error: Number of positions ({len(positions)}) must match needle count ({num_needles})."
63
+ if any(p < 1 or p > seq_len for p in positions):
64
+ return f"Error: Positions must be between 1 and {seq_len}."
65
+
66
+ # Input Generation
67
+ # Format: [Haystack, Needle, ...] (No context token for k2 model)
68
+ x = torch.zeros(1, seq_len, dtype=torch.long, device=device)
69
+
70
+ # Needles are 1-indexed from the user UI, so subtract 1 for tensor indexing
71
+ for p in positions:
72
+ x[0, p - 1] = 1 # Needle token
73
+
74
+ t0 = time.time()
75
+ with torch.no_grad():
76
+ output = model(x)
77
+ x_pred = output[0] # [B, L, D] or [B, L, H, D]
78
+ if x_pred.ndim == 4:
79
+ x_pred = x_pred.mean(dim=2)
80
+
81
+ elapsed = time.time() - t0
82
+
83
+ # Binary prediction (toroidal)
84
+ PI = math.pi
85
+ TWO_PI = 2.0 * PI
86
+ half_pi = PI * 0.5
87
+ dist_pos = torch.min(
88
+ torch.abs(x_pred - half_pi) % TWO_PI,
89
+ TWO_PI - (torch.abs(x_pred - half_pi) % TWO_PI)
90
+ ).mean(dim=-1)
91
+ dist_neg = torch.min(
92
+ torch.abs(x_pred + half_pi) % TWO_PI,
93
+ TWO_PI - (torch.abs(x_pred + half_pi) % TWO_PI)
94
+ ).mean(dim=-1)
95
+
96
+ preds = (dist_pos < dist_neg).long()[0] # [L]
97
+
98
+ # Result summary
99
+ last_pos = positions[-1] - 1 # Adjust to 0-indexed
100
+ acc_after = (preds[last_pos+1:] == 1).float().mean().item() if last_pos+1 < seq_len else 1.0
101
+ acc_before = (preds[:last_pos] == 0).float().mean().item() if last_pos > 0 else 1.0
102
+ result_data = {
103
+ "Status": "Success",
104
+ "Context Length": seq_len,
105
+ "Needle Positions": positions,
106
+ "Predictions (subset)": preds.tolist()[:500],
107
+ "Acc After Last Needle": f"{acc_after:.2%}",
108
+ "Acc Before Flip": f"{acc_before:.2%}",
109
+ "Inference Time": f"{elapsed:.3f}s",
110
+ "Full Trace": preds.tolist()
111
+ }
112
+
113
+ # Save to temp file for download
114
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode='w')
115
+ json.dump(result_data, temp_file, indent=4)
116
+ temp_file.close()
117
+
118
+ # Create Markdown table for UI
119
+ md_summary = f"""
120
+ ### 📊 Invariant Evaluation Summary
121
+
122
+ | Metric | Value |
123
+ |:---|:---|
124
+ | **Status** | 🟢 {result_data['Status']} |
125
+ | **Context Length ($L$)** | {seq_len:,} tokens |
126
+ | **Needles Count ($K$)** | {num_needles} (Positions: {positions}) |
127
+ | **Accuracy (Before Critical Threshold)** | **{acc_before:.2%}** |
128
+ | **Accuracy (After Critical Threshold)** | **{acc_after:.2%}** |
129
+ | **Inference Time** | {elapsed:.3f}s |
130
+
131
+ > *Note: Model was trained explicitly on $K=2$. Extrapolating to higher $K$ demonstrates how the physical integrator accumulates forces. The 'Accuracy Before' metric measures strict adherence before the final expected needle, but geometrical flips might occur exactly when the integrated energy crosses the learned $K=2$ threshold.*
132
+ """
133
+
134
+ return md_summary, temp_file.name
135
+ except Exception as e:
136
+ return f"### ❌ Error\n{str(e)}", None
137
+
138
+ with gr.Blocks(title="G-SSM MNIAH Solver") as demo:
139
+ gr.Markdown("# 📍 G-SSM Multi-Needle-in-a-Haystack (MNIAH)")
140
+ gr.Markdown("""
141
+ This space demonstrates the **Geodesic State Space Model (G-SSM)** on the Multi-Needle task.
142
+ The model must detect and remember exactly **K needles** scattered across a long sequence
143
+ of "haystack" tokens. Its internal state only "flips" to the target configuration after **all K impulses**
144
+ have been integrated into the geodetic flow.
145
+ """)
146
+
147
+ with gr.Row():
148
+ seq_len = gr.Number(value=1000, minimum=64, maximum=10_000_000_000, precision=0, label="Sequence Length (Up to 10B, be mindful of VRAM!)")
149
+ num_k = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Number of Needles (K)")
150
+ positions = gr.Textbox(label="Manual Positions (comma separated, empty for random)", placeholder="e.g. 100, 450")
151
+
152
+ submit_btn = gr.Button("Evaluate Geometric Memory", variant="primary")
153
+
154
+ with gr.Row():
155
+ with gr.Column(scale=2):
156
+ output_md = gr.Markdown("### 📊 Results will appear here...")
157
+ with gr.Column(scale=1):
158
+ download_btn = gr.File(label="Full Trace (JSON)")
159
+
160
+ submit_btn.click(fn=run_mniah, inputs=[seq_len, positions, num_k], outputs=[output_md, download_btn])
161
+
162
+ gr.Examples(
163
+ examples=[
164
+ [1000, "100, 900", 2],
165
+ [4000, "50, 1500, 3900", 3],
166
+ [32000, "", 1]
167
+ ],
168
+ inputs=[seq_len, positions, num_k]
169
+ )
config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architecture": {
3
+ "vocab_size": 2,
4
+ "dim": 16,
5
+ "depth": 2,
6
+ "heads": 2,
7
+ "integrator": "leapfrog",
8
+ "impulse_scale": 80.0,
9
+ "dynamics_type": "direct",
10
+ "topology_type": "torus",
11
+ "holographic": true
12
+ },
13
+ "physics": {
14
+ "embedding": {
15
+ "type": "functional",
16
+ "mode": "linear",
17
+ "coord_dim": 16
18
+ },
19
+ "readout": {
20
+ "type": "implicit",
21
+ "coord_dim": 16
22
+ },
23
+ "active_inference": {
24
+ "enabled": true,
25
+ "dynamic_time": {
26
+ "enabled": true
27
+ },
28
+ "reactive_curvature": {
29
+ "enabled": true,
30
+ "plasticity": 0.2
31
+ },
32
+ "singularities": {
33
+ "enabled": true,
34
+ "strength": 20.0,
35
+ "threshold": 0.8
36
+ }
37
+ },
38
+ "fractal": {
39
+ "enabled": true,
40
+ "threshold": 0.5,
41
+ "alpha": 0.2
42
+ },
43
+ "topology": {
44
+ "type": "torus",
45
+ "riemannian_type": "low_rank"
46
+ },
47
+ "stability": {
48
+ "enable_trace_normalization": true,
49
+ "base_dt": 0.4,
50
+ "velocity_saturation": 15.0,
51
+ "friction": 2.0,
52
+ "toroidal_curvature_scale": 0.01
53
+ },
54
+ "hysteresis": {
55
+ "enabled": false
56
+ }
57
+ }
58
+ }
gfn/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GFN (Geodesic Flow Network) Package
3
+ ==================================
4
+ Unified framework for Geodesic State Space Models (G-SSM)
5
+ and Inertial State Networks (ISN).
6
+
7
+ This package implements the GFN paradigm as a platform for
8
+ physics-informed neural dynamics.
9
+ """
10
+
11
+ # ── Realizations ──────────────────────────────────────────────────────────────
12
+ from .realizations import api, gssm, isn
13
+ from .realizations.api import create, load, save
14
+
15
+ # ── Dynamic Registry
16
+ REALIZATIONS = api.list_available()
17
+
18
+ # ── Package Metadata ──────────────────────────────────────────────────────────
19
+ __version__ = "2.7.0"
20
+ __author__ = "DepthMuun"
21
+
22
+ __all__ = [
23
+ "gssm",
24
+ "isn",
25
+ "api",
26
+ "create",
27
+ "load",
28
+ "save",
29
+ "REALIZATIONS",
30
+ ]
gfn/realizations/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GFN Realizations Subpackage
3
+ ===========================
4
+ Contains specific implementations of the GFN paradigm:
5
+ - G-SSM: Geodesic State Space Model (Riemannian/Symplectic)
6
+ - ISN: Inertial State Network (Physics-Informed Interaction Engine)
7
+ """
8
+
9
+ from . import api
10
+ from .api import create, list_available
11
+
12
+ # Trigger registration of standard realizations
13
+ from . import gssm
14
+ from . import isn
15
+
16
+ # Future realizations can be added here or via external plugins
17
+ # from . import rt
18
+
19
+ __all__ = ['gssm', 'isn', 'api', 'create', 'list_available']
gfn/realizations/api.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GFN Realizations API Router (Purified Version)
3
+ ==============================================
4
+ Agnostic factory and dynamic registry for GFN architectural realizations.
5
+ Follows SOLID principles: open for extension, closed for modification.
6
+ """
7
+
8
+ import logging
9
+ from typing import List, Dict, Any, Optional, Protocol, runtime_checkable
10
+ import torch.nn as nn
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ @runtime_checkable
15
+ class RealizationProvider(Protocol):
16
+ """Protocol defining the interface any GFN realization must provide."""
17
+ def create(self, **kwargs) -> nn.Module: ...
18
+ def save(self, model: nn.Module, path: str): ...
19
+ def load(self, path: str, **kwargs) -> nn.Module: ...
20
+
21
+ # The Dynamic Registry
22
+ _REGISTRY: Dict[str, RealizationProvider] = {}
23
+
24
+ def register(name: str, provider: RealizationProvider):
25
+ """
26
+ Register a new realization architecture.
27
+
28
+ Args:
29
+ name: Unique identifier for the realization.
30
+ provider: An object or module implementing the RealizationProvider protocol.
31
+ """
32
+ name = name.lower()
33
+ if name in _REGISTRY:
34
+ logger.debug(f"Overwriting GFN realization provider: {name}")
35
+ _REGISTRY[name] = provider
36
+
37
+ def list_available() -> List[str]:
38
+ """List all dynamically registered architectural realizations."""
39
+ return list(_REGISTRY.keys())
40
+
41
+ def create(name: str, **kwargs) -> nn.Module:
42
+ """
43
+ Unified factory to create any registered GFN realization by name.
44
+ """
45
+ name = name.lower()
46
+ if name not in _REGISTRY:
47
+ raise ValueError(
48
+ f"GFN Error: Realization '{name}' is not registered. "
49
+ f"Ensure the subpackage is imported. Available: {list_available()}"
50
+ )
51
+ return _REGISTRY[name].create(**kwargs)
52
+
53
+ def save(model: nn.Module, path: str, realization: Optional[str] = None):
54
+ """Unified save interface delegate."""
55
+ if realization and realization.lower() in _REGISTRY:
56
+ _REGISTRY[realization.lower()].save(model, path)
57
+ else:
58
+ import torch
59
+ torch.save(model.state_dict(), path)
60
+
61
+ def load(path: str, realization: str, **kwargs) -> nn.Module:
62
+ """Unified load interface delegate."""
63
+ realization = realization.lower()
64
+ if realization not in _REGISTRY:
65
+ raise ValueError(f"GFN Error: Realization provider for '{realization}' not found.")
66
+ return _REGISTRY[realization].load(path, **kwargs)
gfn/realizations/gssm/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Export core API
2
+ from .api import create, save, load, Model, Manifold, loss, Trainer
3
+
4
+ # Register with central realization registry
5
+ try:
6
+ from .. import api as central_api
7
+ from . import api as gssm_api
8
+ central_api.register('gssm', gssm_api)
9
+ except ImportError:
10
+ pass # Fallback for standalone GSSM usage
gfn/realizations/gssm/api.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ gfn/api.py — GFN V5
3
+ Interfaz pública simplificada y orquestación de alto nivel.
4
+ Centraliza la creación, carga y evaluación de modelos.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from typing import Optional, Dict, Any, Union
10
+
11
+ from .models.factory import ModelFactory
12
+ from .models.manifold import ManifoldModel
13
+ from .losses.factory import LossFactory
14
+ from .training.trainer import GFNTrainer
15
+ from .training.evaluation import ManifoldMetricEvaluator
16
+
17
+ # -- Alias principales
18
+ Model = ManifoldModel
19
+ Manifold = ManifoldModel
20
+ Trainer = GFNTrainer
21
+
22
+ def create(*args, **kwargs):
23
+ """Factory para modelos Manifold (V5)."""
24
+ return ModelFactory.create(*args, **kwargs)
25
+
26
+ def loss(config, **kwargs):
27
+ """Factory para funciones de pérdida (V5)."""
28
+ return LossFactory.create(config, **kwargs)
29
+
30
+ def save(model: nn.Module, path: str):
31
+ """
32
+ Guarda el modelo y su configuración (HuggingFace Style).
33
+ """
34
+ if hasattr(model, 'save_pretrained'):
35
+ model.save_pretrained(path)
36
+ else:
37
+ # Fallback para modelos que no heredan de BaseModel
38
+ torch.save({'state_dict': model.state_dict()}, path)
39
+
40
+ def load(path: str, device: Optional[str] = None):
41
+ """
42
+ Carga un modelo guardado junto con su configuración.
43
+ Soporta directorios (HF Style) o archivos .pth/.bin legados.
44
+ """
45
+ import os
46
+ if os.path.isdir(path):
47
+ return ModelFactory.from_pretrained(path)
48
+
49
+ # Fallback para archivos aislados legados
50
+ checkpoint = torch.load(path, map_location=device or 'cpu', weights_only=True)
51
+ config = checkpoint.get('config')
52
+ if config is None:
53
+ raise ValueError(f"No se encontró configuración en el checkpoint {path}. Use directorios HF para carga completa.")
54
+
55
+ model = create(config=config)
56
+
57
+ # Robust state_dict extraction (handles different saving conventions)
58
+ state_dict = checkpoint.get('state_dict') or checkpoint.get('model') or checkpoint
59
+
60
+ # Filter state_dict against the model's actual parameters
61
+ model_state = model.state_dict()
62
+ filtered_state = {k: v for k, v in state_dict.items() if k in model_state}
63
+
64
+ # Log filtered keys for debugging (optional)
65
+ n_filtered = len(state_dict) - len(filtered_state)
66
+ if n_filtered > 0:
67
+ import logging
68
+ logging.getLogger("gssm.api").info(f"Filtered {n_filtered} unexpected keys from state_dict (legacy or auxiliary data).")
69
+
70
+ # Load with strict=False to handle potential missing non-essential parameters
71
+ model.load_state_dict(filtered_state, strict=False)
72
+ return model
73
+
74
+ def benchmark(model: nn.Module, dataloader: torch.utils.data.DataLoader,
75
+ device: Optional[str] = None) -> Dict[str, float]:
76
+ """
77
+ Ejecuta una evaluación rápida de métricas geométricas y de tarea.
78
+ """
79
+ device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
80
+ model.to(device)
81
+ model.eval()
82
+
83
+ evaluator = ManifoldMetricEvaluator(model)
84
+ all_x, all_v, all_y = [], [], []
85
+
86
+ with torch.no_grad():
87
+ for x, y in dataloader:
88
+ x, y = x.to(device), y.to(device)
89
+ logits, (xf, vf), info = model(x)
90
+
91
+ all_x.append(xf.detach().cpu())
92
+ all_v.append(vf.detach().cpu())
93
+ all_y.append(y.detach().cpu())
94
+
95
+ if not all_x:
96
+ return {}
97
+
98
+ x_total = torch.cat(all_x, dim=0)
99
+ v_total = torch.cat(all_v, dim=0)
100
+ y_total = torch.cat(all_y, dim=0)
101
+
102
+ return evaluator.full_report(x_total, v_total, y_total)
gfn/realizations/gssm/config/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ gfn/config/__init__.py
3
+ Public API for the configuration module — GFN V5
4
+ Centralized configuration system for all GFN components.
5
+ """
6
+
7
+ from .schema import (
8
+ TopologyConfig,
9
+ StabilityConfig,
10
+ DynamicTimeConfig,
11
+ HysteresisConfig,
12
+ ActiveInferenceConfig,
13
+ EmbeddingConfig,
14
+ ReadoutConfig,
15
+ MixtureConfig,
16
+ DynamicsConfig,
17
+ FractalConfig,
18
+ SingularityConfig,
19
+ PhysicsConfig,
20
+ TrainerConfig,
21
+ ManifoldConfig,
22
+ )
23
+
24
+ from .defaults import (
25
+ PHYSICS_DEFAULTS,
26
+ MODEL_DEFAULTS,
27
+ TRAINING_DEFAULTS,
28
+ LOSS_DEFAULTS,
29
+ get_default,
30
+ )
31
+
32
+ from .loader import dict_to_physics_config
33
+ from .validator import ConfigValidator, validate_manifold_config, validate_and_print
34
+
35
+ __all__ = [
36
+ # Schema classes
37
+ "TopologyConfig",
38
+ "StabilityConfig",
39
+ "DynamicTimeConfig",
40
+ "HysteresisConfig",
41
+ "ActiveInferenceConfig",
42
+ "EmbeddingConfig",
43
+ "ReadoutConfig",
44
+ "MixtureConfig",
45
+ "DynamicsConfig",
46
+ "FractalConfig",
47
+ "SingularityConfig",
48
+ "PhysicsConfig",
49
+ "TrainerConfig",
50
+ "ManifoldConfig",
51
+ # Defaults
52
+ "PHYSICS_DEFAULTS",
53
+ "MODEL_DEFAULTS",
54
+ "TRAINING_DEFAULTS",
55
+ "LOSS_DEFAULTS",
56
+ "get_default",
57
+ # Loader
58
+ "dict_to_physics_config",
59
+ # Validator
60
+ "ConfigValidator",
61
+ "validate_manifold_config",
62
+ "validate_and_print",
63
+ ]
gfn/realizations/gssm/config/defaults.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ config/defaults.py — GFN V5
3
+ Valores por defecto centralizados para todas las configuraciones.
4
+ Elimina hardcodes dispersos en implementaciones.
5
+ """
6
+
7
+ from typing import Dict, Any
8
+ from ..constants import (
9
+ DEFAULT_DT, DEFAULT_FRICTION, DEFAULT_PLASTICITY,
10
+ MAX_VELOCITY, CURVATURE_CLAMP, VELOCITY_FRICTION_SCALE,
11
+ SINGULARITY_THRESHOLD, BLACK_HOLE_STRENGTH, EPSILON_STANDARD,
12
+ TOPOLOGY_TORUS, TOPOLOGY_EUCLIDEAN
13
+ )
14
+
15
+ # ─── Física ─────────────────────────────────────────────────────────────────
16
+ PHYSICS_DEFAULTS: Dict[str, Any] = {
17
+ # Topología
18
+ 'topology_type': TOPOLOGY_EUCLIDEAN,
19
+ 'riemannian_type': 'low_rank',
20
+ 'major_radius_R': 2.0,
21
+ 'minor_radius_r': 1.0,
22
+
23
+ # Estabilidad — referencias a constants.py (una sola fuente de verdad)
24
+ 'base_dt': DEFAULT_DT,
25
+ 'adaptive_dt': True,
26
+ 'friction': DEFAULT_FRICTION,
27
+ 'velocity_clamp': MAX_VELOCITY,
28
+ 'curvature_clamp': CURVATURE_CLAMP,
29
+ 'enable_trace_normalization': True,
30
+ 'velocity_friction_scale': VELOCITY_FRICTION_SCALE,
31
+ 'integrator_type': 'leapfrog',
32
+ 'friction_mode': 'static',
33
+
34
+ # Inferencia activa
35
+ 'active_inference_enabled': True,
36
+ 'holographic_geometry': False,
37
+ 'plasticity': DEFAULT_PLASTICITY,
38
+
39
+ # Singularities
40
+ 'singularity_enabled': False,
41
+ 'singularity_threshold': SINGULARITY_THRESHOLD,
42
+ 'singularity_strength': BLACK_HOLE_STRENGTH,
43
+ 'singularity_epsilon': EPSILON_STANDARD,
44
+
45
+ # Hysteresis
46
+ 'hysteresis_enabled': False,
47
+ 'hysteresis_decay': 0.95,
48
+ 'hysteresis_ghost_force': True,
49
+
50
+ # Stochasticity / Curiosity
51
+ 'stochasticity_enabled': False,
52
+ 'stochasticity_type': 'brownian',
53
+ 'stochasticity_sigma': 0.01,
54
+ 'curiosity_enabled': False,
55
+ 'curiosity_strength': 0.1,
56
+ }
57
+
58
+ # ─── Modelo ──────────────────────────────────────────────────────────────────
59
+ MODEL_DEFAULTS: Dict[str, Any] = {
60
+ 'dim': 64,
61
+ 'heads': 4,
62
+ 'depth': 2,
63
+ 'rank': 16,
64
+ 'vocab_size': 256,
65
+ 'holographic': False,
66
+ 'pooling_type': None,
67
+ 'initial_spread': 1e-3,
68
+ 'n_trajectories': 1,
69
+ }
70
+
71
+ # ─── Entrenamiento ───────────────────────────────────────────────────────────
72
+ TRAINING_DEFAULTS: Dict[str, Any] = {
73
+ 'lr': 1e-3,
74
+ 'optimizer_type': 'adam',
75
+ 'weight_decay': 0.0,
76
+ 'grad_clip': 1.0,
77
+ 'epochs': 10,
78
+ 'batch_size': 32,
79
+ 'scheduler_type': 'cosine_warmup',
80
+ 'warmup_steps': 100,
81
+ 'min_lr': 1e-6,
82
+ 'task': 'lm',
83
+ }
84
+
85
+ # ─── Pérdidas ────────────────────────────────────────────────────────────────
86
+ LOSS_DEFAULTS: Dict[str, Any] = {
87
+ 'type': 'generative',
88
+ 'mode': 'nll',
89
+ 'entropy_coef': 0.0,
90
+ 'label_smoothing': 0.0,
91
+
92
+ # Physics-informed
93
+ 'lambda_physics': 0.01,
94
+ 'lambda_geo': 0.001,
95
+ 'lambda_ham': 0.0,
96
+ 'lambda_kin': 0.0,
97
+ }
98
+
99
+
100
+ def get_default(section: str, key: str, fallback=None):
101
+ """
102
+ Obtiene un valor por defecto desde la sección correspondiente.
103
+ Uso: get_default('physics', 'base_dt') -> 0.1
104
+ """
105
+ mapping = {
106
+ 'physics': PHYSICS_DEFAULTS,
107
+ 'model': MODEL_DEFAULTS,
108
+ 'training': TRAINING_DEFAULTS,
109
+ 'loss': LOSS_DEFAULTS,
110
+ }
111
+ return mapping.get(section, {}).get(key, fallback)
gfn/realizations/gssm/config/loader.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ config/loader.py — GFN V5
3
+ Conversión de dicts de configuración a PhysicsConfig tipado.
4
+ Soporte para overrides anidados sobre configs existentes.
5
+ """
6
+ from typing import Dict, Any, Optional
7
+ from .schema import (
8
+ PhysicsConfig, TopologyConfig, StabilityConfig, DynamicsConfig,
9
+ ActiveInferenceConfig, DynamicTimeConfig, HysteresisConfig,
10
+ EmbeddingConfig, FractalConfig, SingularityConfig,
11
+ )
12
+
13
+
14
+ def dict_to_physics_config(d: Dict[str, Any]) -> PhysicsConfig:
15
+ """
16
+ Convierte un dict anidado en un PhysicsConfig tipado.
17
+
18
+ Soporta todos los sub-campos de PhysicsConfig. Los campos no presentes
19
+ en el dict mantienen sus valores default del schema.
20
+ Si `d` ya es PhysicsConfig, lo devuelve intacto.
21
+ """
22
+ if isinstance(d, PhysicsConfig):
23
+ return d
24
+
25
+ cfg = PhysicsConfig()
26
+ _apply_dict_to_physics_config(cfg, d)
27
+ return cfg
28
+
29
+
30
+ def apply_physics_overrides(cfg: PhysicsConfig, overrides: Dict[str, Any]) -> PhysicsConfig:
31
+ """
32
+ Aplica un dict de overrides sobre un PhysicsConfig EXISTENTE (in-place).
33
+
34
+ A diferencia de dict_to_physics_config(), esta función NO parte de defaults
35
+ sino que modifica solo los campos presentes en el dict, dejando el resto intacto.
36
+ Es la función que usa ModelFactory cuando se combina preset + physics kwarg.
37
+
38
+ Args:
39
+ cfg: PhysicsConfig existente (ej. resultado de get_preset())
40
+ overrides: Dict anidado con los campos a sobreescribir
41
+
42
+ Returns:
43
+ El mismo cfg modificado in-place (también retornado para encadenamiento).
44
+ """
45
+ if not overrides:
46
+ return cfg
47
+ _apply_dict_to_physics_config(cfg, overrides)
48
+ return cfg
49
+
50
+
51
+ def _apply_dict_to_physics_config(cfg: PhysicsConfig, d: Dict[str, Any]) -> None:
52
+ """Función interna — aplica los campos del dict sobre cfg in-place."""
53
+
54
+ # ── Topology ──────────────────────────────────────────────────────────────
55
+ t_d = d.get('topology', d.get('topology_config', {}))
56
+ if isinstance(t_d, dict) and t_d:
57
+ _apply(cfg.topology, t_d, [
58
+ 'type', 'R', 'r', 'curvature',
59
+ 'riemannian_type', 'riemannian_rank', 'riemannian_class',
60
+ 'geometry_scope'
61
+ ])
62
+ if 'major_radius' in t_d: cfg.topology.R = t_d['major_radius']
63
+ if 'minor_radius' in t_d: cfg.topology.r = t_d['minor_radius']
64
+
65
+ # ── Stability ─────────────────────────────────────────────────────────────
66
+ s_d = d.get('stability', d.get('stability_config', {}))
67
+ if isinstance(s_d, dict) and s_d:
68
+ _apply(cfg.stability, s_d, [
69
+ 'base_dt', 'adaptive', 'dt_min', 'dt_max',
70
+ 'enable_trace_normalization', 'wrap_x',
71
+ 'friction', 'velocity_friction_scale',
72
+ 'curvature_clamp', 'friction_mode',
73
+ 'integrator_type',
74
+ # alias legacy
75
+ 'velocity_saturation', # → ignorado, no existe en StabilityConfig
76
+ ])
77
+ # Alias de nombres legacy
78
+ if 'toroidal_curvature_scale' in s_d:
79
+ cfg.stability.curvature_clamp = s_d['toroidal_curvature_scale']
80
+
81
+ # ── Dynamics ──────────────────────────────────────────────────────────────
82
+ dyn_d = d.get('dynamics', d.get('dynamics_config', {}))
83
+ if isinstance(dyn_d, dict) and dyn_d:
84
+ if 'type' in dyn_d:
85
+ cfg.dynamics.type = dyn_d['type']
86
+
87
+ # ── Active Inference ──────────────────────────────────────────────────────
88
+ ai_d = d.get('active_inference', d.get('active_inference_config', {}))
89
+ if isinstance(ai_d, dict) and ai_d:
90
+ _apply(cfg.active_inference, ai_d, [
91
+ 'enabled', 'holographic_geometry',
92
+ 'thermodynamic_geometry', 'plasticity',
93
+ ])
94
+ # Dynamic time
95
+ dt_d = ai_d.get('dynamic_time', {})
96
+ if isinstance(dt_d, dict) and dt_d:
97
+ _apply(cfg.active_inference.dynamic_time, dt_d, ['enabled', 'type'])
98
+ # Reactive curvature — es un dict interno
99
+ rc_d = ai_d.get('reactive_curvature', {})
100
+ if isinstance(rc_d, dict) and rc_d:
101
+ cfg.active_inference.reactive_curvature.update(rc_d)
102
+ # Stochasticity — es un dict interno
103
+ st_d = ai_d.get('stochasticity', {})
104
+ if isinstance(st_d, dict) and st_d:
105
+ cfg.active_inference.stochasticity.update(st_d)
106
+ # Curiosity — es un dict interno
107
+ cu_d = ai_d.get('curiosity', {})
108
+ if isinstance(cu_d, dict) and cu_d:
109
+ cfg.active_inference.curiosity.update(cu_d)
110
+ # ── Hysteresis (pueden estar en raíz O dentro de active_inference) ────────
111
+ hyst_src = d.get('hysteresis', ai_d.get('hysteresis', {}) if isinstance(ai_d, dict) else {})
112
+ if isinstance(hyst_src, dict) and hyst_src:
113
+ _apply(cfg.hysteresis, hyst_src, [
114
+ 'enabled', 'ghost_force', 'hyst_decay',
115
+ 'hyst_update_w', 'hyst_update_b',
116
+ 'hyst_readout_w', 'hyst_readout_b',
117
+ ])
118
+
119
+ # ── Singularities (pueden estar en raíz O dentro de active_inference) ─────
120
+ sing_src = d.get('singularities', ai_d.get('singularities', {}) if isinstance(ai_d, dict) else {})
121
+ if isinstance(sing_src, dict) and sing_src:
122
+ _apply(cfg.singularities, sing_src, [
123
+ 'enabled', 'epsilon', 'strength', 'threshold'
124
+ ])
125
+
126
+ # ── Embedding ─────────────────────────────────────────────────────────────
127
+ emb_d = d.get('embedding', d.get('embedding_config', {}))
128
+ if isinstance(emb_d, dict) and emb_d:
129
+ _apply(cfg.embedding, emb_d, [
130
+ 'type', 'mode', 'coord_dim', 'impulse_scale', 'omega_0'
131
+ ])
132
+
133
+ # ── Readout ───────────────────────────────────────────────────────────────
134
+ read_d = d.get('readout', d.get('readout_config', {}))
135
+ if isinstance(read_d, dict) and read_d:
136
+ _apply(cfg.readout, read_d, ['type'])
137
+
138
+ # ── Mixture ───────────────────────────────────────────────────────────────
139
+ mix_d = d.get('mixture', d.get('mixture_config', {}))
140
+ if isinstance(mix_d, dict) and mix_d:
141
+ _apply(cfg.mixture, mix_d, ['coupler_mode'])
142
+
143
+ # ── Fractal ───────────────────────────────────────────────────────────────
144
+ frac_d = d.get('fractal', {})
145
+ if isinstance(frac_d, dict) and frac_d:
146
+ _apply(cfg.fractal, frac_d, ['enabled', 'threshold', 'alpha'])
147
+
148
+ # ── Top-level trajectory_mode ─────────────────────────────────────────────
149
+ if 'trajectory_mode' in d:
150
+ cfg.trajectory_mode = d['trajectory_mode']
151
+
152
+ # ── Attention/mixer alias (legacy ECG configs) ────────────────────────────
153
+ # 'attention': {'mixer_type': 'low_rank'} — se ignora acá, aplica en ManifoldConfig
154
+
155
+
156
+ def _apply(target, source: dict, keys: list) -> None:
157
+ """Copia las claves presentes en source hacia target (setattr)."""
158
+ for k in keys:
159
+ if k in source:
160
+ try:
161
+ setattr(target, k, source[k])
162
+ except AttributeError:
163
+ pass # clave no existe en el dataclass — ignorar silenciosamente
gfn/realizations/gssm/config/schema.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # schema.py — GFN V5
2
+ # Definiciones de clases de configuración (Schema)
3
+ # SEPARACIÓN: Los valores por defecto van a defaults.py, las constantes físicas a constants.py
4
+
5
+ from dataclasses import dataclass, field, asdict
6
+ from typing import Dict, Any, Optional, List
7
+
8
+ # Importar constantes físicas正确adas
9
+ from ..constants import (
10
+ EPSILON_STANDARD,
11
+ TOPOLOGY_TORUS,
12
+ MIN_DT,
13
+ MAX_DT,
14
+ CURVATURE_CLAMP,
15
+ SINGULARITY_THRESHOLD,
16
+ BLACK_HOLE_STRENGTH,
17
+ DEFAULT_DT,
18
+ DEFAULT_FRICTION,
19
+ DEFAULT_PLASTICITY,
20
+ MAX_VELOCITY,
21
+ )
22
+
23
+
24
+ @dataclass
25
+ class TopologyConfig:
26
+ """Configuración de topología del manifold."""
27
+ type: str = TOPOLOGY_TORUS
28
+ R: float = 2.0 # Radio mayor del toro (default)
29
+ r: float = 1.0 # Radio menor del toro (default)
30
+ curvature: float = 0.0
31
+ riemannian_type: str = 'reactive'
32
+ riemannian_rank: int = 16
33
+ riemannian_class: Optional[str] = None
34
+ geometry_scope: str = 'local' # 'local' (dim/heads) or 'global' (full dim)
35
+ # NUEVO: Parámetros aprendibles
36
+ learnable_R: bool = True # Hacer R aprendible (como dice el paper)
37
+ learnable_r: bool = True # Hacer r aprendible (como dice el paper)
38
+
39
+
40
+ @dataclass
41
+ class StabilityConfig:
42
+ """Configuración de estabilidad numérica."""
43
+ base_dt: float = DEFAULT_DT
44
+ adaptive: bool = True
45
+ dt_min: float = MIN_DT
46
+ dt_max: float = MAX_DT
47
+ enable_trace_normalization: bool = True
48
+ wrap_x: bool = True
49
+ friction: float = DEFAULT_FRICTION
50
+ velocity_friction_scale: float = 0.0
51
+ velocity_saturation: float = 0.0 # P2.3: 0 = disabled, >0 = clamp magnitude via tanh
52
+ curvature_clamp: float = CURVATURE_CLAMP
53
+ friction_mode: str = 'static' # 'static' or 'lif'
54
+ integrator_type: str = 'leapfrog'
55
+ toroidal_curvature_scale: float = 0.01 # scale for torus Christoffel contribution
56
+
57
+
58
+ @dataclass
59
+ class DynamicTimeConfig:
60
+ enabled: bool = False
61
+ type: str = 'riemannian'
62
+
63
+
64
+ @dataclass
65
+ class HysteresisConfig:
66
+ enabled: bool = False
67
+ ghost_force: bool = True
68
+ hyst_decay: float = 0.1
69
+ hyst_update_w: float = 1.0
70
+ hyst_update_b: float = 0.0
71
+ hyst_readout_w: float = 1.0
72
+ hyst_readout_b: float = 0.0
73
+
74
+
75
+ @dataclass
76
+ class ActiveInferenceConfig:
77
+ enabled: bool = False
78
+ holographic_geometry: bool = False
79
+ thermodynamic_geometry: bool = False
80
+ plasticity: float = DEFAULT_PLASTICITY
81
+ dynamic_time: DynamicTimeConfig = field(default_factory=DynamicTimeConfig)
82
+ reactive_curvature: Dict[str, Any] = field(default_factory=lambda: {
83
+ "enabled": False,
84
+ "plasticity": 0.0
85
+ })
86
+ geodesic_lensing: Dict[str, Any] = field(default_factory=lambda: {"enabled": False})
87
+
88
+ # Exploration / Noise
89
+ stochasticity: Dict[str, Any] = field(default_factory=lambda: {
90
+ "enabled": False,
91
+ "type": "brownian",
92
+ "sigma": 0.01,
93
+ "theta": 0.15,
94
+ "mu": 0.0
95
+ })
96
+ curiosity: Dict[str, Any] = field(default_factory=lambda: {
97
+ "enabled": False,
98
+ "strength": 0.1,
99
+ "decay": 0.99
100
+ })
101
+
102
+
103
+ @dataclass
104
+ class EmbeddingConfig:
105
+ type: str = 'standard'
106
+ mode: str = 'linear'
107
+ coord_dim: int = 16
108
+ impulse_scale: float = 1.0
109
+ omega_0: float = 30.0
110
+
111
+
112
+ @dataclass
113
+ class ReadoutConfig:
114
+ type: str = 'standard'
115
+
116
+
117
+ @dataclass
118
+ class MixtureConfig:
119
+ coupler_mode: str = 'mean_field'
120
+
121
+
122
+ @dataclass
123
+ class DynamicsConfig:
124
+ type: str = 'direct'
125
+
126
+
127
+ @dataclass
128
+ class FractalConfig:
129
+ enabled: bool = False
130
+ threshold: float = 0.5
131
+ alpha: float = 0.2
132
+
133
+
134
+ @dataclass
135
+ class SingularityConfig:
136
+ enabled: bool = False
137
+ epsilon: float = EPSILON_STANDARD
138
+ strength: float = BLACK_HOLE_STRENGTH
139
+ threshold: float = SINGULARITY_THRESHOLD
140
+
141
+
142
+ @dataclass
143
+ class PhysicsConfig:
144
+ """Configuración completa de física."""
145
+ topology: TopologyConfig = field(default_factory=TopologyConfig)
146
+ stability: StabilityConfig = field(default_factory=StabilityConfig)
147
+ dynamics: DynamicsConfig = field(default_factory=DynamicsConfig)
148
+ active_inference: ActiveInferenceConfig = field(default_factory=ActiveInferenceConfig)
149
+ embedding: EmbeddingConfig = field(default_factory=EmbeddingConfig)
150
+ readout: ReadoutConfig = field(default_factory=ReadoutConfig)
151
+ mixture: MixtureConfig = field(default_factory=MixtureConfig)
152
+ fractal: FractalConfig = field(default_factory=FractalConfig)
153
+ hysteresis: HysteresisConfig = field(default_factory=HysteresisConfig)
154
+ singularities: SingularityConfig = field(default_factory=SingularityConfig)
155
+ trajectory_mode: str = 'partition'
156
+ lensing: Dict[str, Any] = field(default_factory=lambda: {'enabled': False})
157
+ checkpointing: Dict[str, Any] = field(default_factory=lambda: {'enabled': False})
158
+
159
+ def to_dict(self) -> Dict[str, Any]:
160
+ return asdict(self)
161
+
162
+
163
+ @dataclass
164
+ class TrainerConfig:
165
+ lr: float = 1e-3
166
+ optimizer: str = 'adamw'
167
+ max_lr: Optional[float] = None
168
+ total_steps: Optional[int] = None
169
+ loss_config: Dict[str, Any] = field(default_factory=lambda: {
170
+ 'lambda_g': 0.001,
171
+ 'lambda_h': 0.0,
172
+ 'geodesic_mode': 'magnitude'
173
+ })
174
+
175
+
176
+ @dataclass
177
+ class ManifoldConfig:
178
+ """Configuración principal del modelo Manifold."""
179
+ vocab_size: int
180
+ dim: int = 512
181
+ depth: int = 4
182
+ heads: int = 4
183
+ rank: int = 32
184
+ integrator: str = 'leapfrog'
185
+ physics: PhysicsConfig = field(default_factory=PhysicsConfig)
186
+ trainer: TrainerConfig = field(default_factory=TrainerConfig)
187
+ adjoint_enabled: bool = False
188
+ adjoint_rtol: float = 1e-4
189
+ adjoint_atol: float = 1e-4
190
+ holographic: bool = False
191
+ impulse_scale: float = 1.0
192
+ dynamics_type: str = 'direct'
193
+ mixer_type: str = 'low_rank'
194
+ trajectory_mode: str = 'partition'
195
+ coupler_mode: str = 'mean_field'
196
+ initial_spread: float = 1e-3
197
+
198
+ def to_dict(self) -> Dict[str, Any]:
199
+ return asdict(self)
gfn/realizations/gssm/config/serialization.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from typing import Any, Dict, Type, TypeVar, get_type_hints, get_args, get_origin, Union
3
+
4
+ T = TypeVar('T')
5
+
6
+ def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
7
+ """
8
+ Reconstructs a nested dataclass from a dictionary.
9
+ Handles nested dataclasses and basic types.
10
+ """
11
+ if not dataclasses.is_dataclass(cls):
12
+ return data
13
+
14
+ field_types = get_type_hints(cls)
15
+ kwargs = {}
16
+
17
+ for field in dataclasses.fields(cls):
18
+ if field.name in data:
19
+ value = data[field.name]
20
+ field_type = field_types[field.name]
21
+
22
+ # Handle Optional[T]
23
+ origin = get_origin(field_type)
24
+ if origin is Union:
25
+ args = get_args(field_type)
26
+ if type(None) in args:
27
+ # It's an Optional. Find the non-None type
28
+ field_type = [arg for arg in args if arg is not type(None)][0]
29
+
30
+ # Handle nested dataclasses
31
+ if dataclasses.is_dataclass(field_type):
32
+ if value is not None:
33
+ kwargs[field.name] = from_dict(field_type, value)
34
+ else:
35
+ kwargs[field.name] = None
36
+ else:
37
+ kwargs[field.name] = value
38
+
39
+ return cls(**kwargs)
gfn/realizations/gssm/config/validator.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Validación de configuraciones — GFN V5
3
+ Verifica la compatibilidad de parámetros antes de construir componentes.
4
+ Fusionado de utils/validation.py y config/validator.py original.
5
+ """
6
+
7
+ from typing import Dict, Any, List, Optional
8
+ from .schema import ManifoldConfig, PhysicsConfig
9
+ from ..constants import TOPOLOGY_TORUS, TOPOLOGY_EUCLIDEAN, TOPOLOGY_SPHERE
10
+
11
+ class ConfigValidationError(Exception):
12
+ """Error de validación de configuración crítica."""
13
+ pass
14
+
15
+ class ConfigValidator:
16
+ """Central validator for GFN configurations."""
17
+
18
+ @staticmethod
19
+ def validate_physics(cfg: PhysicsConfig, dim: Optional[int] = None, heads: Optional[int] = None):
20
+ """
21
+ Validate physical and architectural consistency of PhysicsConfig.
22
+ Raises ConfigValidationError if strict topology/stability rules are violated.
23
+ """
24
+ # 1. Topology checks
25
+ if cfg.topology.type == TOPOLOGY_TORUS:
26
+ if dim is not None and heads is not None:
27
+ head_dim = dim // heads
28
+ if head_dim % 2 != 0:
29
+ raise ConfigValidationError(
30
+ f"Toroid geometry requires head_dim (dim//heads) to be even. "
31
+ f"Found {dim}//{heads}={head_dim}"
32
+ )
33
+
34
+ if cfg.topology.type == TOPOLOGY_SPHERE and cfg.topology.curvature <= 0:
35
+ raise ConfigValidationError("Spherical topology requires positive curvature.")
36
+
37
+ # 2. Stability checks
38
+ if cfg.stability.base_dt <= 0:
39
+ raise ConfigValidationError("base_dt must be positive.")
40
+ if cfg.stability.friction < 0:
41
+ raise ConfigValidationError("friction cannot be negative.")
42
+
43
+ # 3. Mode Compatibility
44
+ if cfg.trajectory_mode == 'ensemble' and heads is not None and heads <= 1:
45
+ raise ConfigValidationError("Ensemble trajectory mode requires more than 1 head.")
46
+
47
+
48
+ def validate_manifold_config(config: ManifoldConfig) -> List[str]:
49
+ """
50
+ Valida un ManifoldConfig completo y su PhysicsConfig anidado.
51
+ Retorna lista de warnings (vacía si todo está OK).
52
+ Lanza ConfigValidationError en errores críticos o de compatibilidad.
53
+ """
54
+ warnings = []
55
+
56
+ # Validaciones críticas (Raise exceptions)
57
+ if config.dim % config.heads != 0:
58
+ raise ConfigValidationError(
59
+ f"dim={config.dim} no es divisible por heads={config.heads}. "
60
+ f"head_dim={config.dim/config.heads:.1f} no es entero."
61
+ )
62
+
63
+ if config.vocab_size <= 0:
64
+ raise ConfigValidationError(f"vocab_size={config.vocab_size} debe ser > 0.")
65
+
66
+ if config.depth <= 0:
67
+ raise ConfigValidationError(f"depth={config.depth} debe ser > 0.")
68
+
69
+ # Validate Physics properties via centralized method
70
+ ConfigValidator.validate_physics(config.physics, config.dim, config.heads)
71
+
72
+ # Validaciones suaves (Warnings)
73
+ head_dim = config.dim // config.heads
74
+ topo_type = config.physics.topology.type.lower()
75
+
76
+ if topo_type == TOPOLOGY_TORUS and head_dim % 2 != 0:
77
+ warnings.append(
78
+ f"[WARN] Para geometría toroidal, head_dim={head_dim} debería ser par "
79
+ f"para representaciones sin/cos. Considera usar heads={config.dim // (head_dim + 1)} o similar."
80
+ )
81
+
82
+ if config.rank > config.dim:
83
+ warnings.append(
84
+ f"[WARN] rank={config.rank} > dim={config.dim}. "
85
+ f"La descomposición no es de rango bajo. ¿Intencional?"
86
+ )
87
+
88
+ dt = config.physics.stability.base_dt
89
+ if dt > 1.0:
90
+ warnings.append(f"[WARN] base_dt={dt} > 1.0 puede causar inestabilidad numérica.")
91
+ if dt < 1e-5:
92
+ warnings.append(f"[WARN] base_dt={dt} < 1e-5 puede ralentizar convergencia.")
93
+
94
+ return warnings
95
+
96
+
97
+ def validate_and_print(config: ManifoldConfig) -> bool:
98
+ """
99
+ Valida la configuración e imprime warnings.
100
+ Retorna True si es válida, False si hubo errores.
101
+ """
102
+ try:
103
+ warnings = validate_manifold_config(config)
104
+ for w in warnings:
105
+ print(w)
106
+ return True
107
+ except ConfigValidationError as e:
108
+ print(f"[CONFIG ERROR] {e}")
109
+ return False
gfn/realizations/gssm/constants.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # constants.py — GFN V5
2
+ # Constantes físicas y matemáticas universales.
3
+ # NO contiene hiperparámetros de entrenamiento ni valores configurables.
4
+
5
+ import torch
6
+
7
+ # ─── Constantes Matemáticas ─────────────────────────────────────────────────
8
+ PI = 3.14159265358979
9
+ E = 2.718281828459045
10
+ SQRT_2 = 1.4142135623730951
11
+ LOG_2 = 0.6931471805599453
12
+
13
+ # ─── Estabilidad Numérica ─────────────────────────────────────────────────
14
+ EPS = 1e-8
15
+ INF = 1e12
16
+ EPSILON_STANDARD = 1e-7
17
+ EPSILON_SMOOTH = 1e-9
18
+ EPSILON_STRONG = 1e-6
19
+ CLAMP_MIN_STRONG = 1e-4
20
+
21
+ # ─── Límites Físicos ───────────────────────────────────────────────────────
22
+ MIN_DT = 0.001
23
+ MAX_DT = 1.0
24
+
25
+ # ─── Geometría / Curvatura ─────────────────────────────────────────────────
26
+ CURVATURE_CLAMP = 5.0 # Maximum absolute value of Christoffel output
27
+ FRICTION_SCALE = 0.1 # Global friction scaling factor
28
+ VELOCITY_FRICTION_SCALE = 0.01
29
+
30
+ # ─── Gate initialization constants ─────────────────────────────────────────
31
+ GATE_BIAS_OPEN = 2.0 # sigmoid(2.0) ≈ 0.88
32
+ GATE_BIAS_CLOSED = -2.0 # sigmoid(-2.0) ≈ 0.12
33
+
34
+ # ─── Singularity / Active Inference ───────────────────────────────────────
35
+ SINGULARITY_THRESHOLD = 0.5
36
+ BLACK_HOLE_STRENGTH = 3.0
37
+ SINGULARITY_GATE_SLOPE = 10.0
38
+
39
+ # ─── Torus geometry ───────────────────────────────────────────────────────
40
+ TOROIDAL_MAJOR_RADIUS = 1.0
41
+ TOROIDAL_MINOR_RADIUS = 0.3
42
+ TOROIDAL_PERIOD = 2.0 * PI
43
+ TOROIDAL_CURVATURE_SCALE = 0.1
44
+
45
+ # ─── Tipo de dato por defecto ─────────────────────────────────────────────
46
+ DTYPE = torch.float32
47
+
48
+ # ─── Topology Names ───────────────────────────────────────────────────────
49
+ TOPOLOGY_TORUS = "torus"
50
+ TOPOLOGY_SPHERE = "spherical"
51
+ TOPOLOGY_HYPERBOLIC = "hyperbolic"
52
+ TOPOLOGY_EUCLIDEAN = "euclidean"
53
+
54
+ # ─── Dynamics Modes ───────────────────────────────────────────────────────
55
+ DYNAMICS_DIRECT = "direct"
56
+ DYNAMICS_RESIDUAL = "residual"
57
+ DYNAMICS_MIX = "mix"
58
+ DYNAMICS_GATED = "gated"
59
+ DYNAMICS_STOCHASTIC = "stochastic"
60
+
61
+ # ─── Alias de compatibilidad (valores por defecto que moved a defaults.py) ─
62
+ # NOTA: Estos valores se mantienen aquí por compatibilidad pero deberían
63
+ # imports desde config/defaults.py en código nuevo
64
+ DEFAULT_FRICTION = 0.01
65
+ DEFAULT_DT = 0.1
66
+ DEFAULT_PLASTICITY = 0.05
67
+ MAX_VELOCITY = 10.0
gfn/realizations/gssm/core/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ core/__init__.py — GFN V5
3
+ """
4
+ from ..core.types import ManifoldState, Trajectory, StepResult, ModelOutput
5
+ from ..core.state import ManifoldStateManager
6
+
7
+ __all__ = ['ManifoldState', 'Trajectory', 'StepResult', 'ModelOutput', 'ManifoldStateManager']
gfn/realizations/gssm/core/state.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ core/state.py — GFN V5
3
+ Manejo de estado del manifold (posición + velocidad).
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import Optional, Tuple
9
+
10
+
11
+ class ManifoldStateManager:
12
+ """
13
+ Gestiona la inicialización y manipulación del estado (x, v).
14
+ Compatible con batches y múltiples cabezales.
15
+ """
16
+
17
+ @staticmethod
18
+ def initialize(x0: nn.Parameter, v0: nn.Parameter,
19
+ batch_size: int, n_trajectories: int = 1,
20
+ initial_spread: float = 1e-3) -> Tuple[torch.Tensor, torch.Tensor]:
21
+ """
22
+ Inicializa el estado (x, v) para un batch dado.
23
+
24
+ Args:
25
+ x0, v0: Parámetros iniciales [1, H, HD]
26
+ batch_size: Tamaño del batch
27
+ n_trajectories: Número de trayectorias paralelas
28
+ initial_spread: Ruido inicial
29
+
30
+ Returns:
31
+ (x, v) — [B, H, HD]
32
+ """
33
+ x = x0.expand(batch_size, -1, -1)
34
+ v = v0.expand(batch_size, -1, -1)
35
+
36
+ if initial_spread > 0:
37
+ x = x + torch.randn_like(x) * initial_spread
38
+
39
+ return x.contiguous(), v.contiguous()
40
+
41
+ @staticmethod
42
+ def from_tuple(state: Optional[Tuple], x0: nn.Parameter, v0: nn.Parameter,
43
+ batch_size: int, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
44
+ """
45
+ Construye (x, v) desde un estado previo o desde parámetros iniciales.
46
+ Compatible con el API de BasicModel.
47
+ """
48
+ if state is not None and isinstance(state, (tuple, list)) and len(state) == 2:
49
+ return state[0], state[1]
50
+ return ManifoldStateManager.initialize(x0, v0, batch_size, **kwargs)
51
+
52
+ @staticmethod
53
+ def wrap_torus(x: torch.Tensor) -> torch.Tensor:
54
+ """Proyecta posición al dominio toroidal [-π, π]."""
55
+ return torch.atan2(torch.sin(x), torch.cos(x))
56
+
57
+ @staticmethod
58
+ def energy(v: torch.Tensor) -> torch.Tensor:
59
+ """Energía cinética H = 0.5 * ||v||² por muestra."""
60
+ return 0.5 * (v ** 2).sum(dim=-1)
gfn/realizations/gssm/core/types.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ core/types.py — GFN V5
3
+ Tipos y type aliases del framework.
4
+ """
5
+
6
+ import torch
7
+ from typing import Dict, Any, Tuple, Optional, List, Union
8
+
9
+ # ─── State types ─────────────────────────────────────────────────────────────
10
+ # (position, velocity) pair
11
+ ManifoldState = Tuple[torch.Tensor, torch.Tensor]
12
+
13
+ # Trajectory: list of (x, v) states over time
14
+ Trajectory = List[ManifoldState]
15
+
16
+ # Force tensor (same shape as x, v)
17
+ Force = torch.Tensor
18
+
19
+ # Integration step result
20
+ StepResult = Dict[str, torch.Tensor] # {'x': ..., 'v': ...}
21
+
22
+ # ─── Config types ─────────────────────────────────────────────────────────────
23
+ ConfigDict = Dict[str, Any]
24
+
25
+ # ─── Forward pass outputs ─────────────────────────────────────────────────────
26
+ # (logits, state, info_dict)
27
+ ModelOutput = Tuple[torch.Tensor, ManifoldState, Dict[str, Any]]
gfn/realizations/gssm/csrc/README.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Native C++/CUDA Extensions
2
+ Store raw .cpp and .cu files here. Bindings remain in new_gfn/cuda/
gfn/realizations/gssm/csrc/compile_cuda_12.9.bat ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+
3
+ REM Ensure we are in the script's directory so we run the correct setup.py
4
+ cd /d "%~dp0"
5
+
6
+ echo ================================================================
7
+ echo [GFN] Custom CUDA Kernel Compilation Pipeline (VS 2022 + CUDA 12.9)
8
+ echo ================================================================
9
+
10
+ REM --- 1. Find Visual Studio 2022 Installation ---
11
+ set "VS_PATH="
12
+
13
+ if exist "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" (
14
+ set "VS_PATH=C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat"
15
+ ) else if exist "C:\Program Files\Microsoft Visual Studio\2022\Professional\VC\Auxiliary\Build\vcvars64.bat" (
16
+ set "VS_PATH=C:\Program Files\Microsoft Visual Studio\2022\Professional\VC\Auxiliary\Build\vcvars64.bat"
17
+ ) else if exist "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat" (
18
+ set "VS_PATH=C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
19
+ )
20
+
21
+ if "%VS_PATH%"=="" (
22
+ echo [ERROR] Could not find Visual Studio 2022 installation.
23
+ pause
24
+ exit /b 1
25
+ )
26
+
27
+ echo [*] Found MSVC Environment: "%VS_PATH%"
28
+ echo [*] Initializing Developer Console...
29
+ call "%VS_PATH%"
30
+
31
+ REM --- 2. Setup CUDA Environment (12.9) ---
32
+ set "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.9"
33
+ set "PATH=%CUDA_PATH%\bin;%PATH%"
34
+
35
+ echo [*] Using CUDA Path: "%CUDA_PATH%"
36
+ nvcc --version
37
+
38
+ REM --- 3. Compile Kernels ---
39
+ echo [*] Cleaning old builds...
40
+ rmdir /s /q build
41
+ rmdir /s /q gfn_cuda.egg-info
42
+ del /q *.pyc
43
+ rmdir /s /q __pycache__
44
+
45
+ echo.
46
+ echo [*] Starting Setup Compilation (In-place)...
47
+
48
+ REM Fix for "It seems that the VC environment is activated..." warning
49
+ set DISTUTILS_USE_SDK=1
50
+
51
+
52
+
53
+ python setup.py build_ext --inplace
54
+
55
+ if %errorlevel% neq 0 (
56
+ echo [ERROR] Compilation failed.
57
+ echo Ensure you have PyTorch installed for CUDA 12.x:
58
+ echo pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
59
+ pause
60
+ exit /b 1
61
+ )
62
+
63
+ echo.
64
+ echo [SUCCESS] Kernels compiled to local .pyd file!
65
+ echo Verified import:
66
+
67
+ python -c "print('SUCCESS: gfn_cuda module imported directly!')"
68
+ pause
gfn/realizations/gssm/csrc/extension.cpp ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <vector>
3
+ #include "integrators/integrators.h"
4
+
5
+ // Declaración de funciones (Toroidal Loss)
6
+ torch::Tensor toroidal_distance_loss_fwd(const torch::Tensor& y_pred, const torch::Tensor& y_true);
7
+ torch::Tensor toroidal_distance_loss_bwd(const torch::Tensor& grad_output, const torch::Tensor& y_pred, const torch::Tensor& y_true);
8
+
9
+ // Declaración de funciones (Low Rank Christoffel)
10
+ torch::Tensor low_rank_christoffel_fwd(
11
+ const torch::Tensor& v, const torch::Tensor& U, const torch::Tensor& W,
12
+ double clamp_val, bool enable_trace_norm, bool is_paper_version);
13
+
14
+ // Implementación de Backward puro en ATen C++ (Compilado por MSVC, evadiendo bug de NVCC CICC)
15
+ std::vector<torch::Tensor> low_rank_christoffel_bwd(
16
+ const torch::Tensor& grad_gamma,
17
+ const torch::Tensor& v,
18
+ const torch::Tensor& U,
19
+ const torch::Tensor& W,
20
+ const torch::Tensor& gamma_out,
21
+ double clamp_val,
22
+ bool enable_trace_norm,
23
+ bool is_paper_version)
24
+ {
25
+ // Fast pure ATen operations avoiding Python overhead
26
+ auto g_norm = gamma_out / clamp_val;
27
+ auto d_tanh = 1.0 - g_norm.pow(2);
28
+ auto grad_raw = grad_gamma * d_tanh; // [B, H, D]
29
+
30
+ if (enable_trace_norm) {
31
+ auto mean_d = grad_raw.mean(-1, /*keepdim=*/true);
32
+ grad_raw = grad_raw - mean_d;
33
+ }
34
+
35
+ // Explicit Batched Matrix Multiplication (bmm) to avoid matmul broadcast crashes
36
+ // W is [H, D, R], grad_raw is [B, H, D]
37
+ auto grad_raw_h = grad_raw.permute({1, 0, 2}); // [H, B, D]
38
+ auto d_sq_h = torch::bmm(grad_raw_h, W); // [H, B, D] @ [H, D, R] -> [H, B, R]
39
+ auto d_sq = d_sq_h.permute({1, 0, 2}); // [B, H, R]
40
+
41
+ auto v_h = v.permute({1, 0, 2}); // [H, B, D]
42
+ auto v_r_h = torch::bmm(v_h, U); // [H, B, D] @ [H, D, R] -> [H, B, R]
43
+ auto v_r = v_r_h.permute({1, 0, 2}); // [B, H, R]
44
+
45
+ torch::Tensor d_vr;
46
+ if (is_paper_version) {
47
+ auto vr_norm = torch::norm(v_r, 2, -1, true);
48
+ auto denom = 1.0 + vr_norm;
49
+
50
+ // Correct Chain Rule for normalized denominator Coupling:
51
+ // grad_vr_j = grad_phi_j * (2v_j / denom) - v_j * Sum_k(grad_phi_k * v_k^2) / (||v|| * denom^2)
52
+ auto S = (d_sq * v_r.pow(2)).sum(-1, true);
53
+ auto term1 = d_sq * (2.0 * v_r / denom);
54
+ auto term2 = (v_r * S) / (vr_norm * denom.pow(2) + 1e-8);
55
+ d_vr = term1 - term2;
56
+ } else {
57
+ d_vr = d_sq * 2.0 * v_r;
58
+ }
59
+
60
+ auto d_vr_h = d_vr.permute({1, 0, 2}); // [H, B, R]
61
+ auto U_t = U.transpose(-1, -2); // [H, R, D]
62
+ auto d_v_h = torch::bmm(d_vr_h, U_t); // [H, B, R] @ [H, R, D] -> [H, B, D]
63
+ auto d_v = d_v_h.permute({1, 0, 2}); // [B, H, D]
64
+
65
+ // We accumulate W and U gradients over Batch:
66
+ auto sq = is_paper_version ? v_r.pow(2) / (1.0 + torch::norm(v_r, 2, -1, true)) : v_r.pow(2); // [B, H, R]
67
+ auto sq_h = sq.permute({1, 0, 2}); // [H, B, R]
68
+
69
+ auto grad_raw_h_t = grad_raw_h.transpose(-1, -2); // [H, D, B]
70
+ auto d_W = torch::bmm(grad_raw_h_t, sq_h); // [H, D, B] @ [H, B, R] -> [H, D, R]
71
+
72
+ auto v_h_t = v_h.transpose(-1, -2); // [H, D, B]
73
+ auto d_U = torch::bmm(v_h_t, d_vr_h); // [H, D, B] @ [H, B, R] -> [H, D, R]
74
+
75
+ return {d_v, d_U, d_W};
76
+ }
77
+
78
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
79
+ m.def("toroidal_distance_loss_fwd", &toroidal_distance_loss_fwd, "Toroidal Distance Loss Forward (CUDA)");
80
+ m.def("toroidal_distance_loss_bwd", &toroidal_distance_loss_bwd, "Toroidal Distance Loss Backward (CUDA)");
81
+
82
+ m.def("low_rank_christoffel_fwd", &low_rank_christoffel_fwd, "Low Rank Christoffel Forward Kernel");
83
+ m.def("low_rank_christoffel_bwd", &low_rank_christoffel_bwd, "Low Rank Christoffel Backward ATen");
84
+
85
+ m.def("yoshida_fwd", &yoshida_fwd_aten, "Yoshida C++ Macro Integrator Step");
86
+ m.def("leapfrog_fwd", &leapfrog_fwd_aten, "Leapfrog C++ Macro Integrator Step");
87
+ }
gfn/realizations/gssm/csrc/geometry/low_rank.cu ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <cuda.h>
3
+ #include <cuda_runtime.h>
4
+
5
+ template <typename scalar_t>
6
+ __global__ void low_rank_christoffel_fwd_kernel(
7
+ const scalar_t* __restrict__ v, // [B, H, D]
8
+ const scalar_t* __restrict__ U, // [H, D, R]
9
+ const scalar_t* __restrict__ W, // [H, D, R]
10
+ scalar_t* __restrict__ gamma, // [B, H, D]
11
+ const int B, const int H, const int D, const int R,
12
+ const scalar_t clamp_val,
13
+ const bool enable_trace_norm,
14
+ const bool is_paper_version)
15
+ {
16
+ // A block computes the output for one (b, h) pair
17
+ int bh = blockIdx.x;
18
+ if (bh >= B * H) return;
19
+
20
+ int h = bh % H;
21
+
22
+ // Dynamic shared memory allocations
23
+ extern __shared__ char smem[];
24
+ scalar_t* v_s_d = reinterpret_cast<scalar_t*>(smem); // Size: D
25
+ scalar_t* vr_sq_s = reinterpret_cast<scalar_t*>(&v_s_d[D]); // Size: R
26
+ scalar_t* gamma_s_d = reinterpret_cast<scalar_t*>(&vr_sq_s[R]); // Size: D
27
+
28
+ const scalar_t* v_b = v + bh * D;
29
+ scalar_t* gamma_b = gamma + bh * D;
30
+
31
+ // Pointers for H offset
32
+ const scalar_t* U_h = U + h * D * R;
33
+ const scalar_t* W_h = W + h * D * R;
34
+
35
+ const int tid = threadIdx.x;
36
+ const int bdim = blockDim.x;
37
+
38
+ // 1. Load v into shared memory
39
+ for (int i = tid; i < D; i += bdim) {
40
+ v_s_d[i] = v_b[i];
41
+ }
42
+ __syncthreads();
43
+
44
+ // 2. Compute v_r = v @ U -> sq = v_r^2
45
+ for (int r = tid; r < R; r += bdim) {
46
+ scalar_t sum = 0;
47
+ for (int j = 0; j < D; ++j) {
48
+ sum += v_s_d[j] * U_h[j * R + r];
49
+ }
50
+ vr_sq_s[r] = sum * sum;
51
+ }
52
+ __syncthreads();
53
+
54
+ // 3. Optional: Paper Low Rank denominator logic
55
+ if (is_paper_version) {
56
+ __shared__ scalar_t block_sum_sq;
57
+ if (tid == 0) block_sum_sq = 0;
58
+ __syncthreads();
59
+
60
+ scalar_t local_sq = 0;
61
+ for (int r = tid; r < R; r += bdim) {
62
+ local_sq += vr_sq_s[r];
63
+ }
64
+ atomicAdd(&block_sum_sq, local_sq);
65
+ __syncthreads();
66
+
67
+ scalar_t norm_vr = sqrt(block_sum_sq);
68
+ scalar_t denom = 1.0 + norm_vr;
69
+
70
+ for (int r = tid; r < R; r += bdim) {
71
+ vr_sq_s[r] = vr_sq_s[r] / denom;
72
+ }
73
+ __syncthreads();
74
+ }
75
+
76
+ // 4. Compute gamma_raw = sq @ W.T
77
+ for (int d = tid; d < D; d += bdim) {
78
+ scalar_t sum = 0;
79
+ for (int r = 0; r < R; ++r) {
80
+ sum += vr_sq_s[r] * W_h[d * R + r];
81
+ }
82
+ gamma_s_d[d] = sum;
83
+ }
84
+ __syncthreads();
85
+
86
+ // 5. Trace normalization (mean subtraction)
87
+ scalar_t mean_val = 0;
88
+ if (enable_trace_norm) {
89
+ __shared__ scalar_t block_sum_gamma;
90
+ if (tid == 0) block_sum_gamma = 0;
91
+ __syncthreads();
92
+
93
+ scalar_t local_gamma_sum = 0;
94
+ for (int d = tid; d < D; d += bdim) {
95
+ local_gamma_sum += gamma_s_d[d];
96
+ }
97
+ atomicAdd(&block_sum_gamma, local_gamma_sum);
98
+ __syncthreads();
99
+
100
+ mean_val = block_sum_gamma / D;
101
+ }
102
+
103
+ // 6. Normalization and storage
104
+ for (int d = tid; d < D; d += bdim) {
105
+ scalar_t g = gamma_s_d[d];
106
+ if (enable_trace_norm) {
107
+ g -= mean_val;
108
+ }
109
+ g = clamp_val * tanh(g / clamp_val);
110
+ gamma_b[d] = g;
111
+ }
112
+ }
113
+
114
+ #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
115
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
116
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
117
+
118
+ torch::Tensor low_rank_christoffel_fwd(
119
+ const torch::Tensor& v,
120
+ const torch::Tensor& U,
121
+ const torch::Tensor& W,
122
+ double clamp_val,
123
+ bool enable_trace_norm,
124
+ bool is_paper_version)
125
+ {
126
+ CHECK_INPUT(v);
127
+ CHECK_INPUT(U);
128
+ CHECK_INPUT(W);
129
+
130
+ // Ensure shapes: v is [B, H, D], U is [H, D, R], W is [H, D, R]
131
+ int B = v.size(0);
132
+ int H = v.size(1);
133
+ int D = v.size(2);
134
+ int R = U.size(2);
135
+
136
+ auto gamma = torch::empty_like(v);
137
+
138
+ const int threads = 256;
139
+ const int blocks = B * H;
140
+
141
+ // Shared memory size: (D + R + D) * sizeof(float)
142
+ const int shared_mem_size = (2 * D + R) * sizeof(float);
143
+
144
+ if (v.scalar_type() == torch::kFloat32) {
145
+ low_rank_christoffel_fwd_kernel<float><<<blocks, threads, shared_mem_size>>>(
146
+ v.data_ptr<float>(),
147
+ U.data_ptr<float>(),
148
+ W.data_ptr<float>(),
149
+ gamma.data_ptr<float>(),
150
+ B, H, D, R,
151
+ static_cast<float>(clamp_val),
152
+ enable_trace_norm,
153
+ is_paper_version
154
+ );
155
+ } else {
156
+ TORCH_CHECK(false, "low_rank_christoffel_fwd only supports float32");
157
+ }
158
+
159
+ return gamma;
160
+ }
gfn/realizations/gssm/csrc/integrators/integrators.cpp ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <vector>
3
+ #define _USE_MATH_DEFINES
4
+ #include <cmath>
5
+
6
+ #ifndef M_PI
7
+ #define M_PI 3.14159265358979323846
8
+ #endif
9
+
10
+ // -------------------------------------------------------------
11
+ // Pure ATen Implementation of the Integrators Loop
12
+ // -------------------------------------------------------------
13
+ // We build this in standard C++ using ATen so it runs in a single
14
+ // Python call but is compiled by MSVC, averting NVCC OOM bugs.
15
+ // It performs `steps` loop using the exact GFN LowRank geometry.
16
+
17
+ // Helper to compute Christoffel Gamma
18
+ torch::Tensor _compute_gamma(
19
+ const torch::Tensor& v,
20
+ const torch::Tensor& U,
21
+ const torch::Tensor& W,
22
+ double clamp_val,
23
+ bool enable_trace_norm,
24
+ bool is_paper_version)
25
+ {
26
+ auto v_r = torch::matmul(v.unsqueeze(-2), U).squeeze(-2); // [..., R]
27
+ torch::Tensor sq;
28
+ if (is_paper_version) {
29
+ auto vr_norm = torch::norm(v_r, 2, -1, true);
30
+ sq = v_r.pow(2) / (1.0 + vr_norm);
31
+ } else {
32
+ sq = v_r.pow(2);
33
+ }
34
+
35
+ auto gamma = torch::matmul(sq.unsqueeze(-2), W.transpose(-1, -2)).squeeze(-2); // [..., D]
36
+
37
+ if (enable_trace_norm) {
38
+ auto mean_g = gamma.mean(-1, /*keepdim=*/true);
39
+ gamma = gamma - mean_g;
40
+ }
41
+
42
+ return clamp_val * torch::tanh(gamma / clamp_val);
43
+ }
44
+
45
+ // Helper for velocity saturation (soft-clamp)
46
+ torch::Tensor _clamp_velocity(const torch::Tensor& v, double v_sat) {
47
+ if (v_sat > 0) {
48
+ return v_sat * torch::tanh(v / v_sat);
49
+ }
50
+ return v;
51
+ }
52
+
53
+ // Yoshida 4th order coefficients
54
+ const double w1 = 1.3512071919596576;
55
+ const double w0 = -1.7024143839193153;
56
+ const double y_c1 = w1 / 2.0;
57
+ const double y_c2 = (w0 + w1) / 2.0;
58
+ const double y_c3 = y_c2;
59
+ const double y_c4 = y_c1;
60
+ const double y_d1 = w1;
61
+ const double y_d2 = w0;
62
+ const double y_d3 = w1;
63
+
64
+ // Helper for Gated Friction (Active Inference)
65
+ torch::Tensor _compute_mu(
66
+ const torch::Tensor& x,
67
+ const torch::Tensor& v,
68
+ const torch::Tensor& gate_w,
69
+ const torch::Tensor& gate_b,
70
+ double base_friction,
71
+ double vel_fric_scale)
72
+ {
73
+ const double eps = 1e-8;
74
+ const double D = x.size(-1);
75
+
76
+ // mu_base = base_friction
77
+ torch::Tensor mu = torch::full_like(x.select(-1, 0).unsqueeze(-1), base_friction);
78
+
79
+ // If gate weigths are provided, calculate learnable friction component
80
+ if (gate_w.numel() > 0) {
81
+ torch::Tensor feat;
82
+ // Check if we need Torus features [sin, cos] (gate_w dim will be 2*D)
83
+ if (gate_w.size(1) == 2 * D) {
84
+ feat = torch::cat({torch::sin(x), torch::cos(x)}, -1); // [..., 2D]
85
+ } else {
86
+ feat = x; // Euclidean / Flat
87
+ }
88
+
89
+ // Linear gate: sigmoid(feat @ w + b)
90
+ auto gate_out = torch::matmul(feat.unsqueeze(-2), gate_w).squeeze(-2); // [B, H, 1]
91
+ if (gate_b.numel() > 0) {
92
+ gate_out = gate_out + gate_b;
93
+ }
94
+ mu = mu + torch::sigmoid(gate_out);
95
+ }
96
+
97
+ // Velocity-dependent scaling: mu * (1 + scale * ||v||)
98
+ auto v_norm = torch::norm(v, 2, -1, true) / (std::sqrt(D) + eps);
99
+ mu = mu * (1.0 + vel_fric_scale * v_norm);
100
+
101
+ return mu;
102
+ }
103
+
104
+ // Helper for Singularity Damping
105
+ torch::Tensor _apply_singularity_damping(
106
+ const torch::Tensor& acc,
107
+ const torch::Tensor& v,
108
+ const torch::Tensor& U,
109
+ double sing_thresh,
110
+ double sing_strength)
111
+ {
112
+ if (sing_strength <= 1.0 || sing_thresh <= 0.0) return acc;
113
+
114
+ // Detect singularity: metrics are low near singular points.
115
+ // In LowRank, g_diag = sum(U^2).
116
+ auto g_diag = (U.pow(2)).sum(-1); // [H, D]
117
+
118
+ // Potential = sigmoid(5.0 * (g - thresh))
119
+ auto soft_mask = torch::sigmoid(5.0 * (g_diag - sing_thresh));
120
+
121
+ // Scale acceleration by soft_mask (Damping Shield)
122
+ // Near singularity: soft_mask -> 0, damping the forces.
123
+ return acc * soft_mask;
124
+ }
125
+
126
+ std::vector<torch::Tensor> yoshida_fwd_aten(
127
+ const torch::Tensor& x_init,
128
+ const torch::Tensor& v_init,
129
+ const torch::Tensor& U,
130
+ const torch::Tensor& W,
131
+ const torch::Tensor& force,
132
+ const torch::Tensor& dt,
133
+ int steps,
134
+ double clamp_val,
135
+ double friction,
136
+ double vel_fric_scale,
137
+ double vel_sat,
138
+ const torch::Tensor& gate_w,
139
+ const torch::Tensor& gate_b,
140
+ double sing_thresh,
141
+ double sing_strength,
142
+ bool enable_trace_norm,
143
+ bool is_paper_version)
144
+ {
145
+ auto x = x_init.clone();
146
+ auto v = v_init.clone();
147
+
148
+ const double eps = 1e-8;
149
+
150
+ for (int i = 0; i < steps; ++i) {
151
+ // Sub-step 1
152
+ x = x + y_c1 * dt * v;
153
+ x = torch::remainder(x + M_PI, 2 * M_PI) - M_PI; // Toroidal resolve
154
+
155
+ auto gamma1 = _compute_gamma(v, U, W, clamp_val, enable_trace_norm, is_paper_version);
156
+ auto a1_nf = force - gamma1;
157
+ a1_nf = _apply_singularity_damping(a1_nf, v, U, sing_thresh, sing_strength);
158
+
159
+ auto mu1 = _compute_mu(x, v, gate_w, gate_b, friction, vel_fric_scale);
160
+
161
+ v = (v + y_d1 * dt * a1_nf) / (1.0 + y_d1 * dt * mu1 + eps);
162
+ v = _clamp_velocity(v, vel_sat);
163
+
164
+ // Sub-step 2
165
+ x = x + y_c2 * dt * v;
166
+ x = torch::remainder(x + M_PI, 2 * M_PI) - M_PI;
167
+
168
+ auto gamma2 = _compute_gamma(v, U, W, clamp_val, enable_trace_norm, is_paper_version);
169
+ auto a2_nf = force - gamma2;
170
+ a2_nf = _apply_singularity_damping(a2_nf, v, U, sing_thresh, sing_strength);
171
+
172
+ auto mu2 = _compute_mu(x, v, gate_w, gate_b, friction, vel_fric_scale);
173
+
174
+ v = (v + y_d2 * dt * a2_nf) / (1.0 + y_d2 * dt * mu2 + eps);
175
+ v = _clamp_velocity(v, vel_sat);
176
+
177
+ // Sub-step 3
178
+ x = x + y_c3 * dt * v;
179
+ x = torch::remainder(x + M_PI, 2 * M_PI) - M_PI;
180
+
181
+ auto gamma3 = _compute_gamma(v, U, W, clamp_val, enable_trace_norm, is_paper_version);
182
+ auto a3_nf = force - gamma3;
183
+ a3_nf = _apply_singularity_damping(a3_nf, v, U, sing_thresh, sing_strength);
184
+
185
+ auto mu3 = _compute_mu(x, v, gate_w, gate_b, friction, vel_fric_scale);
186
+
187
+ v = (v + y_d3 * dt * a3_nf) / (1.0 + y_d3 * dt * mu3 + eps);
188
+ v = _clamp_velocity(v, vel_sat);
189
+
190
+ // Final drift
191
+ x = x + y_c4 * dt * v;
192
+ x = torch::remainder(x + M_PI, 2 * M_PI) - M_PI;
193
+ }
194
+
195
+ return {x, v};
196
+ }
197
+
198
+ std::vector<torch::Tensor> leapfrog_fwd_aten(
199
+ const torch::Tensor& x_init,
200
+ const torch::Tensor& v_init,
201
+ const torch::Tensor& U,
202
+ const torch::Tensor& W,
203
+ const torch::Tensor& force,
204
+ const torch::Tensor& dt,
205
+ int steps,
206
+ double clamp_val,
207
+ double friction,
208
+ double vel_fric_scale,
209
+ double vel_sat,
210
+ const torch::Tensor& gate_w,
211
+ const torch::Tensor& gate_b,
212
+ double sing_thresh,
213
+ double sing_strength,
214
+ bool enable_trace_norm,
215
+ bool is_paper_version)
216
+ {
217
+ auto x = x_init.clone();
218
+ auto v = v_init.clone();
219
+
220
+ const double eps = 1e-8;
221
+
222
+ for (int i = 0; i < steps; ++i) {
223
+ // Half-kick 1
224
+ auto gamma1 = _compute_gamma(v, U, W, clamp_val, enable_trace_norm, is_paper_version);
225
+ auto a1_nf = force - gamma1;
226
+ a1_nf = _apply_singularity_damping(a1_nf, v, U, sing_thresh, sing_strength);
227
+
228
+ auto mu1 = _compute_mu(x, v, gate_w, gate_b, friction, vel_fric_scale);
229
+
230
+ auto v_half = (v + 0.5 * dt * a1_nf) / (1.0 + 0.5 * dt * mu1 + eps);
231
+ v_half = _clamp_velocity(v_half, vel_sat);
232
+
233
+ // Drift
234
+ x = x + dt * v_half;
235
+ x = torch::remainder(x + M_PI, 2 * M_PI) - M_PI;
236
+
237
+ // Half-kick 2
238
+ auto gamma2 = _compute_gamma(v_half, U, W, clamp_val, enable_trace_norm, is_paper_version);
239
+ auto a2_nf = force - gamma2;
240
+ a2_nf = _apply_singularity_damping(a2_nf, v_half, U, sing_thresh, sing_strength);
241
+
242
+ auto mu2 = _compute_mu(x, v_half, gate_w, gate_b, friction, vel_fric_scale);
243
+
244
+ auto a_avg = (a1_nf + a2_nf) / 2.0;
245
+ auto mu_avg = (mu1 + mu2) / 2.0;
246
+
247
+ v = (v + dt * a_avg) / (1.0 + dt * mu_avg + eps);
248
+ v = _clamp_velocity(v, vel_sat);
249
+ }
250
+
251
+ return {x, v};
252
+ }
gfn/realizations/gssm/csrc/integrators/integrators.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <vector>
3
+
4
+ // Forward declarations for the integrators
5
+ std::vector<torch::Tensor> yoshida_fwd_aten(
6
+ const torch::Tensor& x_init,
7
+ const torch::Tensor& v_init,
8
+ const torch::Tensor& U,
9
+ const torch::Tensor& W,
10
+ const torch::Tensor& force,
11
+ const torch::Tensor& dt,
12
+ int steps,
13
+ double clamp_val,
14
+ double friction,
15
+ double vel_fric_scale,
16
+ double vel_sat,
17
+ const torch::Tensor& gate_w,
18
+ const torch::Tensor& gate_b,
19
+ double sing_thresh,
20
+ double sing_strength,
21
+ bool enable_trace_norm,
22
+ bool is_paper_version);
23
+
24
+ std::vector<torch::Tensor> leapfrog_fwd_aten(
25
+ const torch::Tensor& x_init,
26
+ const torch::Tensor& v_init,
27
+ const torch::Tensor& U,
28
+ const torch::Tensor& W,
29
+ const torch::Tensor& force,
30
+ const torch::Tensor& dt,
31
+ int steps,
32
+ double clamp_val,
33
+ double friction,
34
+ double vel_fric_scale,
35
+ double vel_sat,
36
+ const torch::Tensor& gate_w,
37
+ const torch::Tensor& gate_b,
38
+ double sing_thresh,
39
+ double sing_strength,
40
+ bool enable_trace_norm,
41
+ bool is_paper_version);
gfn/realizations/gssm/csrc/losses/toroidal.cu ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <cuda.h>
3
+ #include <cuda_runtime.h>
4
+ #include <math_constants.h>
5
+
6
+ // ------------------------------------------------------------------------
7
+ // Toroidal Distance Loss
8
+ // L(y_pred, y_true) = (atan2(sin(y_pred - y_true), cos(y_pred - y_true)))^2
9
+ // ------------------------------------------------------------------------
10
+
11
+ template <typename scalar_t>
12
+ __global__ void toroidal_distance_loss_fwd_kernel(
13
+ const scalar_t* __restrict__ y_pred,
14
+ const scalar_t* __restrict__ y_true,
15
+ scalar_t* __restrict__ out,
16
+ const int numel)
17
+ {
18
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
19
+ if (idx < numel) {
20
+ scalar_t diff = y_pred[idx] - y_true[idx];
21
+ scalar_t wrapped = atan2(sin(diff), cos(diff));
22
+ out[idx] = wrapped * wrapped;
23
+ }
24
+ }
25
+
26
+ template <typename scalar_t>
27
+ __global__ void toroidal_distance_loss_bwd_kernel(
28
+ const scalar_t* __restrict__ grad_output,
29
+ const scalar_t* __restrict__ y_pred,
30
+ const scalar_t* __restrict__ y_true,
31
+ scalar_t* __restrict__ grad_pred,
32
+ const int numel)
33
+ {
34
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
35
+ if (idx < numel) {
36
+ // Derivada de atan2(sin(x), cos(x))^2 respecto a x es 2 * atan2(sin(x), cos(x))
37
+ scalar_t diff = y_pred[idx] - y_true[idx];
38
+ scalar_t wrapped = atan2(sin(diff), cos(diff));
39
+ grad_pred[idx] = grad_output[idx] * 2.0 * wrapped;
40
+ }
41
+ }
42
+
43
+ // ------------------------------------------------------------------------
44
+ // Wrappers ATen
45
+ // ------------------------------------------------------------------------
46
+
47
+ #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
48
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
49
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
50
+
51
+ torch::Tensor toroidal_distance_loss_fwd(const torch::Tensor& y_pred, const torch::Tensor& y_true) {
52
+ CHECK_INPUT(y_pred);
53
+ CHECK_INPUT(y_true);
54
+
55
+ auto out = torch::empty_like(y_pred);
56
+ int numel = y_pred.numel();
57
+
58
+ const int threads = 256;
59
+ const int blocks = (numel + threads - 1) / threads;
60
+
61
+ if (y_pred.scalar_type() == torch::kFloat32) {
62
+ toroidal_distance_loss_fwd_kernel<float><<<blocks, threads>>>(
63
+ y_pred.data_ptr<float>(),
64
+ y_true.data_ptr<float>(),
65
+ out.data_ptr<float>(),
66
+ numel
67
+ );
68
+ } else {
69
+ TORCH_CHECK(false, "toroidal_distance_loss_fwd only supports float32");
70
+ }
71
+
72
+ return out;
73
+ }
74
+
75
+ torch::Tensor toroidal_distance_loss_bwd(const torch::Tensor& grad_output, const torch::Tensor& y_pred, const torch::Tensor& y_true) {
76
+ CHECK_INPUT(grad_output);
77
+ CHECK_INPUT(y_pred);
78
+ CHECK_INPUT(y_true);
79
+
80
+ auto grad_pred = torch::empty_like(y_pred);
81
+ int numel = y_pred.numel();
82
+
83
+ const int threads = 256;
84
+ const int blocks = (numel + threads - 1) / threads;
85
+
86
+ if (y_pred.scalar_type() == torch::kFloat32) {
87
+ toroidal_distance_loss_bwd_kernel<float><<<blocks, threads>>>(
88
+ grad_output.data_ptr<float>(),
89
+ y_pred.data_ptr<float>(),
90
+ y_true.data_ptr<float>(),
91
+ grad_pred.data_ptr<float>(),
92
+ numel
93
+ );
94
+ } else {
95
+ TORCH_CHECK(false, "toroidal_distance_loss_bwd only supports float32");
96
+ }
97
+
98
+ return grad_pred;
99
+ }
gfn/realizations/gssm/csrc/setup.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from setuptools import setup
3
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4
+
5
+ # Force sequential compilation to prevent NVCC Out-Of-Memory (LLVM ERROR)
6
+ os.environ["MAX_JOBS"] = "1"
7
+
8
+ # Directorio base
9
+ csrc_dir = os.path.dirname(os.path.abspath(__file__))
10
+
11
+ sources = [
12
+ os.path.join(csrc_dir, "extension.cpp"),
13
+ os.path.join(csrc_dir, "losses", "toroidal.cu"),
14
+ os.path.join(csrc_dir, "geometry", "low_rank.cu"),
15
+ os.path.join(csrc_dir, "integrators", "integrators.cpp")
16
+ ]
17
+
18
+ # Configuración específica para MSVC/Windows vs Linux
19
+ extra_compile_args = {
20
+ 'cxx': ['-O2'],
21
+ 'nvcc': ['-O2', '-allow-unsupported-compiler']
22
+ }
23
+ if os.name == 'nt':
24
+ extra_compile_args['cxx'].append('/std:c++17')
25
+
26
+ setup(
27
+ name='gfn_cuda',
28
+ ext_modules=[
29
+ CUDAExtension(
30
+ name='gfn_cuda',
31
+ sources=sources,
32
+ extra_compile_args=extra_compile_args,
33
+ )
34
+ ],
35
+ cmdclass={
36
+ 'build_ext': BuildExtension
37
+ }
38
+ )
gfn/realizations/gssm/cuda/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ gfn/cuda/__init__.py
3
+ Infraestructura CUDA para GFN V5.
4
+ """
5
+ import torch
6
+
7
+ CUDA_AVAILABLE = torch.cuda.is_available()
8
+
9
+ def is_cuda_active(tensor: torch.Tensor) -> bool:
10
+ """Verifica si CUDA está disponible y el tensor está en un dispositivo GPU."""
11
+ return CUDA_AVAILABLE and tensor.is_cuda
gfn/realizations/gssm/cuda/autograd/__init__.py ADDED
File without changes
gfn/realizations/gssm/cuda/kernels/__init__.py ADDED
File without changes
gfn/realizations/gssm/cuda/kernels/geometry_kernels.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geometry Kernels — GFN V5
3
+ Unified entry points for geometric computations with hardware dispatching.
4
+ """
5
+
6
+ import torch
7
+ from typing import Optional, Tuple, Union, Any
8
+ from ...registry import GEOMETRY_REGISTRY
9
+ from ...cuda import is_cuda_active
10
+
11
+ # Lazy import for CUDA ops to avoid loading if not available
12
+ _christoffel_cuda = None
13
+
14
+ def _get_cuda_ops():
15
+ global _christoffel_cuda
16
+ if _christoffel_cuda is None:
17
+ try:
18
+ from ...cuda.ops import christoffel_cuda_fwd
19
+ _christoffel_cuda = christoffel_cuda_fwd
20
+ except ImportError:
21
+ pass
22
+ return _christoffel_cuda
23
+
24
+ def unified_christoffel_fwd(
25
+ x: torch.Tensor,
26
+ v: torch.Tensor,
27
+ U: torch.Tensor,
28
+ W: torch.Tensor,
29
+ clamp_val: float = 5.0,
30
+ **kwargs: Any
31
+ ) -> torch.Tensor:
32
+ """
33
+ Unified forward pass for Christoffel symbols.
34
+ Dispatches to CUDA kernel if available and on GPU, otherwise falls back to PyTorch.
35
+ """
36
+ if is_cuda_active(v):
37
+ cuda_op = _get_cuda_ops()
38
+ if cuda_op is not None:
39
+ try:
40
+ return _run_cuda_christoffel(x, v, U, W, clamp_val, cuda_op, **kwargs)
41
+ except Exception as e:
42
+ # print(f"[Dispatcher] CUDA Error: {e}. Falling back.")
43
+ pass
44
+
45
+ return _run_pytorch_christoffel(x, v, U, W, clamp_val, **kwargs)
46
+
47
+ def _run_cuda_christoffel(x, v, U, W, clamp_val, cuda_op, **kwargs):
48
+ # Kernel expects Head-Aware tensors [B, H, HD]
49
+ # Check if x,v are [B, D] or [B, H, HD]
50
+ if v.dim() == 2:
51
+ x_k, v_k = x.unsqueeze(1), v.unsqueeze(1)
52
+ else:
53
+ x_k, v_k = x, v
54
+
55
+ # U, W handling (assuming LowRank format)
56
+ # This logic matches the legacy kernel expectations
57
+ # W_k handling (ensuring rank-R is preserved per output dimension)
58
+ if U.dim() == 3:
59
+ U_k = U.transpose(1, 2).contiguous() # [H, R, HD]
60
+ W_k = W.transpose(1, 2).contiguous() # [H, R, HD]
61
+ else:
62
+ U_k = U.T.unsqueeze(0).contiguous() # [1, R, HD]
63
+ W_k = W.T.unsqueeze(0).contiguous() # [1, R, HD]
64
+
65
+ # Execute CUDA kernel
66
+ gamma = cuda_op(U_k, W_k, x_k, v_k, 0, 2.0, 1.0, 0.0)
67
+
68
+ if v.dim() == 2:
69
+ gamma = gamma.squeeze(1)
70
+
71
+ return clamp_val * torch.tanh(gamma / clamp_val)
72
+
73
+ def _run_pytorch_christoffel(x, v, U, W, clamp_val, **kwargs):
74
+ # Multi-head PyTorch fallback
75
+ if v.dim() == 3:
76
+ B, H, HD = v.shape
77
+ # Flatten batch and heads to use efficient matmuls
78
+ v_flat = v.reshape(B * H, HD)
79
+ if U.dim() == 3:
80
+ # U: [H, HD, R], W: [H, R, HD]
81
+ # Need to apply per-head
82
+ proj = torch.bmm(v.transpose(0, 1), U).transpose(0, 1) # [B, H, R]
83
+ sq = proj * proj
84
+ W_t = W.transpose(-1, -2)
85
+ gamma = torch.bmm(sq.transpose(0, 1), W_t).transpose(0, 1) # [B, H, HD]
86
+ else:
87
+ # Shared U, W across heads
88
+ proj = torch.matmul(v_flat, U) # [B*H, R]
89
+ sq = proj * proj
90
+ gamma_flat = torch.matmul(sq, W.t()) # [B*H, HD]
91
+ gamma = gamma_flat.view(B, H, HD)
92
+ else:
93
+ # Single head [B, D]
94
+ proj = torch.matmul(v, U[0] if U.dim() == 3 else U)
95
+ sq = proj * proj
96
+ W_t = (W[0] if W.dim() == 3 else W).t()
97
+ gamma = torch.matmul(sq, W_t)
98
+
99
+ return clamp_val * torch.tanh(gamma / clamp_val)
gfn/realizations/gssm/cuda/kernels/integrator_kernels.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Integrator Kernels — GFN V5
3
+ Unified entry points for numerical integration with hardware dispatching.
4
+ """
5
+
6
+ import torch
7
+ from typing import Optional, Tuple, Any, Callable
8
+ from ...cuda import is_cuda_active
9
+
10
+ # Lazy imports for CUDA kernels
11
+ _euler_fused = None
12
+ _rk4_fused = None
13
+ _leapfrog_fused = None
14
+
15
+ def _get_cuda_integrators():
16
+ global _euler_fused, _rk4_fused, _leapfrog_fused
17
+ if _euler_fused is None:
18
+ try:
19
+ from ...cuda.ops import euler_fused, rk4_fused, leapfrog_fused
20
+ _euler_fused = euler_fused
21
+ _rk4_fused = rk4_fused
22
+ _leapfrog_fused = leapfrog_fused
23
+ except ImportError:
24
+ pass
25
+ return _euler_fused, _rk4_fused, _leapfrog_fused
26
+
27
+ def unified_leapfrog_step(
28
+ x: torch.Tensor,
29
+ v: torch.Tensor,
30
+ force: Optional[torch.Tensor],
31
+ U: torch.Tensor,
32
+ W: torch.Tensor,
33
+ dt: float,
34
+ steps: int = 1,
35
+ **kwargs
36
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
37
+ """Unified Leapfrog integration step."""
38
+ if is_cuda_active(v):
39
+ _, _, f_leapfrog = _get_cuda_integrators()
40
+ if f_leapfrog is not None:
41
+ try:
42
+ # Prep parameters for CUDA kernel
43
+ topo_id = kwargs.get('topology_id', 0)
44
+ R = kwargs.get('R', 2.0)
45
+ r = kwargs.get('r', 1.0)
46
+ H = x.shape[1] if x.dim() == 3 else 1
47
+
48
+ # U, W transformations
49
+ if U.dim() == 2:
50
+ # [D, R] -> [1, R, D] -> [H, R, D]
51
+ U_k = U.T.unsqueeze(0).expand(H, -1, -1).contiguous()
52
+ else:
53
+ # [H, D, R] -> [H, R, D]
54
+ U_k = U.transpose(1, 2).contiguous()
55
+
56
+ if W.dim() == 2:
57
+ # [D, R] -> [1, R] -> [H, R]
58
+ # Use mean instead of sum to preserve effective force scale
59
+ W_k = W.mean(dim=0).unsqueeze(0).expand(H, -1).contiguous()
60
+ else:
61
+ # [H, D, R] -> [H, R]
62
+ W_k = W.abs().mean(dim=1).contiguous()
63
+
64
+ cx, cv = x, v
65
+ for _ in range(steps):
66
+ cx, cv = f_leapfrog(U_k, W_k, cx, cv, force, float(dt), int(topo_id), float(R), float(r), 0.0)
67
+ return cx, cv
68
+ except Exception:
69
+ pass
70
+
71
+ # Python fallback is handled by the higher-level Integrator classes in gfn/integrators/
72
+ # This unified layer is primarily for hardware acceleration.
73
+ return None, None # Signal fallback
gfn/realizations/gssm/cuda/ops/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ gfn/cuda/ops/__init__.py
3
+ Exports all fused CUDA operations.
4
+ Gracefully returns None for any op whose kernel is not compiled.
5
+ """
6
+ from ...cuda import CUDA_AVAILABLE
7
+ import os
8
+ import sys
9
+
10
+ # Ensure gfn/csrc is in PYTHONPATH to find the compiled .pyd
11
+ _csrc_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "csrc"))
12
+ if _csrc_path not in sys.path:
13
+ sys.path.insert(0, _csrc_path)
14
+
15
+ def _get_op(module_path: str, name: str):
16
+ """Safely import a CUDA binding, returning None on failure."""
17
+ try:
18
+ import importlib
19
+ mod = importlib.import_module(module_path)
20
+ return getattr(mod, name, None)
21
+ except Exception:
22
+ return None
23
+
24
+ # ── Geometry ──────────────────────────────────────────────────────────────────
25
+ christoffel_cuda_fwd = _get_op("gfn_cuda", "compute_christoffel_symbols_fwd")
26
+ christoffel_cuda_bwd = _get_op("gfn_cuda", "compute_christoffel_symbols_bwd")
27
+ low_rank_christoffel_fwd = _get_op("gfn_cuda", "low_rank_christoffel_fwd")
28
+ low_rank_christoffel_bwd = _get_op("gfn_cuda", "low_rank_christoffel_bwd")
29
+ toroidal_christ_fwd = _get_op("gfn_cuda", "toroidal_geo_christoffel_fwd")
30
+
31
+ # ── Integrators ───────────────────────────────────────────────────────────────
32
+ heun_fused = _get_op("gfn_cuda", "heun_fwd")
33
+ leapfrog_fused = _get_op("gfn_cuda", "leapfrog_fwd")
34
+ yoshida_fused = _get_op("gfn_cuda", "yoshida_fwd")
35
+ rk4_fused = _get_op("gfn_cuda", "rk4_fwd")
36
+
37
+ # ── Loss ──────────────────────────────────────────────────────────────────────
38
+ toroidal_loss_fwd = _get_op("gfn_cuda", "toroidal_distance_loss_fwd")
39
+ toroidal_loss_bwd = _get_op("gfn_cuda", "toroidal_distance_loss_bwd")
40
+
41
+ def __getattr__(name):
42
+ if name.endswith(("_fused", "_fwd", "_bwd", "_cuda")):
43
+ return None
44
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
45
+
46
+ __all__ = [
47
+ "CUDA_AVAILABLE",
48
+ "christoffel_cuda_fwd", "christoffel_cuda_bwd",
49
+ "low_rank_christoffel_fwd", "low_rank_christoffel_bwd",
50
+ "toroidal_christ_fwd", "heun_fused", "leapfrog_fused",
51
+ "yoshida_fused", "rk4_fused", "toroidal_loss_fwd", "toroidal_loss_bwd"
52
+ ]
gfn/realizations/gssm/data/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ data/__init__.py — GFN V5
3
+ """
4
+ from ..data.dataset import SequenceDataset
5
+ from ..data.loader import create_dataloaders
6
+ from ..data.transforms import shift_targets, add_bos_token, pad_sequences
7
+ from ..data.replay import TrajectoryReplayBuffer
8
+
9
+ __all__ = [
10
+ 'SequenceDataset',
11
+ 'create_dataloaders',
12
+ 'shift_targets',
13
+ 'add_bos_token',
14
+ 'pad_sequences',
15
+ 'TrajectoryReplayBuffer'
16
+ ]
gfn/realizations/gssm/data/dataset.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+
4
+ class SequenceDataset(Dataset):
5
+ """Simple sequence dataset for (X, Y) pairs."""
6
+ def __init__(self, x: torch.Tensor, y: torch.Tensor):
7
+ self.x = x
8
+ self.y = y
9
+
10
+ def __len__(self):
11
+ return len(self.x)
12
+
13
+ def __getitem__(self, idx):
14
+ return self.x[idx], self.y[idx]
gfn/realizations/gssm/data/loader.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ data/loader.py — GFN V5
3
+ DataLoaders y datasets para tareas GFN.
4
+ WATCHOUT: El DataLoader se crea UNA vez fuera del loop de entrenamiento.
5
+ """
6
+
7
+ import torch
8
+ from torch.utils.data import DataLoader, Dataset, random_split
9
+ from typing import Tuple, Optional
10
+ from ..data.dataset import SequenceDataset
11
+
12
+
13
+ def create_dataloaders(
14
+ x: torch.Tensor,
15
+ y: torch.Tensor,
16
+ batch_size: int = 32,
17
+ val_split: float = 0.1,
18
+ shuffle: bool = True,
19
+ num_workers: int = 0,
20
+ seed: int = 42,
21
+ ) -> Tuple[DataLoader, Optional[DataLoader]]:
22
+ """
23
+ Crea train y validation DataLoaders desde tensores.
24
+ IMPORTANTE: Crear los DataLoaders UNA VEZ fuera del loop — no dentro.
25
+
26
+ Args:
27
+ x, y: Tensores de entrada y objetivo
28
+ batch_size: Tamaño de batch
29
+ val_split: Fracción de datos para validación (0 = sin validación)
30
+ shuffle: Mezclar datos de entrenamiento
31
+ num_workers: Workers para carga de datos
32
+ seed: Semilla para reproducibilidad del split
33
+
34
+ Returns:
35
+ (train_loader, val_loader) — val_loader es None si val_split=0
36
+ """
37
+ dataset = SequenceDataset(x, y)
38
+
39
+ if val_split > 0:
40
+ n_val = max(1, int(len(dataset) * val_split))
41
+ n_train = len(dataset) - n_val
42
+ generator = torch.Generator().manual_seed(seed)
43
+ train_ds, val_ds = random_split(dataset, [n_train, n_val], generator=generator)
44
+
45
+ train_loader = DataLoader(train_ds, batch_size=batch_size,
46
+ shuffle=shuffle, num_workers=num_workers)
47
+ val_loader = DataLoader(val_ds, batch_size=batch_size,
48
+ shuffle=False, num_workers=num_workers)
49
+ return train_loader, val_loader
50
+
51
+ train_loader = DataLoader(dataset, batch_size=batch_size,
52
+ shuffle=shuffle, num_workers=num_workers)
53
+ return train_loader, None
gfn/realizations/gssm/data/replay.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Replay / Trajectory Buffer — GFN V5
3
+ Maneja el almacenamiento y muestreo de estados físicos (x, v, forces)
4
+ para soporte de entrenamiento Off-Policy y exploración de GFlowNets reales.
5
+ """
6
+
7
+ import torch
8
+ from typing import Optional, Tuple
9
+
10
+ class TrajectoryReplayBuffer:
11
+ """
12
+ A persistent buffer for storing and managing manifold trajectories (x, v states).
13
+ Serves as replay memory for Hamiltonian/Geodesic flows in V5.
14
+ """
15
+ def __init__(
16
+ self,
17
+ capacity: int,
18
+ dim: int,
19
+ device: torch.device = torch.device('cpu'),
20
+ dtype: torch.dtype = torch.float32
21
+ ):
22
+ self.capacity = capacity
23
+ self.dim = dim
24
+ self.device = device
25
+ self.dtype = dtype
26
+
27
+ # Buffers for state (x), velocity (v), and optional force
28
+ # Shape: [capacity, dim] or [capacity, heads, head_dim] depending on input
29
+ # Note: we flatten the capacity dimension but keep the geometry shape.
30
+ self._initialized_shape = False
31
+
32
+ self.pointer = 0
33
+ self.size = 0
34
+ self.is_full = False
35
+
36
+ def _init_buffers(self, example_shape: torch.Size):
37
+ """Initializes the tensor buffers based on the first observed shape."""
38
+ # example_shape might be [Batch, Dim] or [Batch, Heads, HeadDim]
39
+ # We need [Capacity, *shape[1:]]
40
+ element_shape = example_shape[1:]
41
+
42
+ # Memory check safeguard
43
+ import math
44
+ bytes_per_el = 4 if self.dtype == torch.float32 else 8
45
+ total_elements = self.capacity * math.prod(element_shape)
46
+ # 3 buffers (x, v, force)
47
+ total_mb = (3 * total_elements * bytes_per_el) / (1024 ** 2)
48
+ if total_mb > 1024 and self.device.type == 'cuda':
49
+ import logging
50
+ logging.warning(f"ReplayBuffer: Allocating {total_mb:.1f} MB on CUDA. Risk of OOM.")
51
+
52
+ self.x_buffer = torch.zeros((self.capacity, *element_shape), device=self.device, dtype=self.dtype)
53
+ self.v_buffer = torch.zeros((self.capacity, *element_shape), device=self.device, dtype=self.dtype)
54
+ self.force_buffer = torch.zeros((self.capacity, *element_shape), device=self.device, dtype=self.dtype)
55
+ self._initialized_shape = True
56
+
57
+ def add(
58
+ self,
59
+ x: torch.Tensor,
60
+ v: torch.Tensor,
61
+ force: Optional[torch.Tensor] = None
62
+ ):
63
+ """
64
+ Adds a batch of transitions to the buffer.
65
+ """
66
+ batch_size = x.size(0)
67
+
68
+ if not self._initialized_shape:
69
+ self._init_buffers(x.shape)
70
+
71
+ # Handle wrap-around indexing
72
+ indices = torch.arange(self.pointer, self.pointer + batch_size) % self.capacity
73
+
74
+ self.x_buffer[indices] = x.to(self.device).detach()
75
+ self.v_buffer[indices] = v.to(self.device).detach()
76
+ if force is not None:
77
+ self.force_buffer[indices] = force.to(self.device).detach()
78
+
79
+ self.pointer = (self.pointer + batch_size) % self.capacity
80
+ self.size = min(self.size + batch_size, self.capacity)
81
+ if self.size == self.capacity:
82
+ self.is_full = True
83
+
84
+ def sample_random(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
85
+ """
86
+ Randomly samples a batch of states from the buffer.
87
+ Returns: (x, v, force)
88
+ """
89
+ if self.size == 0:
90
+ raise ValueError("Cannot sample from an empty buffer.")
91
+
92
+ indices = torch.randint(0, self.size, (batch_size,), device=self.device)
93
+
94
+ return (
95
+ self.x_buffer[indices],
96
+ self.v_buffer[indices],
97
+ self.force_buffer[indices]
98
+ )
99
+
100
+ def sample_recent(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
101
+ """Samples the most recently added transitions."""
102
+ if self.size == 0:
103
+ raise ValueError("Cannot sample from an empty buffer.")
104
+
105
+ if self.size < batch_size:
106
+ idx = torch.arange(0, self.size, device=self.device)
107
+ else:
108
+ idx = (torch.arange(self.pointer - batch_size, self.pointer, device=self.device) % self.capacity)
109
+
110
+ return (
111
+ self.x_buffer[idx],
112
+ self.v_buffer[idx],
113
+ self.force_buffer[idx]
114
+ )
115
+
116
+ def sample_with_noise(self, batch_size: int, noise_std: float = 1e-3) -> Tuple[torch.Tensor, torch.Tensor]:
117
+ """Samples with Gaussian jitter to improve robust training."""
118
+ x, v, _ = self.sample_random(batch_size)
119
+ x_noisy = x + torch.randn_like(x) * noise_std
120
+ return x_noisy, v
121
+
122
+ def clear(self):
123
+ """Resets the buffer."""
124
+ self.pointer = 0
125
+ self.size = 0
126
+ self.is_full = False
127
+ self._initialized_shape = False
128
+
129
+ def __len__(self):
130
+ return self.size
gfn/realizations/gssm/data/transforms.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ data/transforms.py — GFN V5
3
+ Transformaciones de datos para secuencias.
4
+ """
5
+
6
+ import torch
7
+ from typing import Tuple, Optional
8
+
9
+
10
+ def shift_targets(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
11
+ """
12
+ Crea pares (input, target) desplazados for language modeling.
13
+ input = x[:, :-1]
14
+ target = x[:, 1:]
15
+ """
16
+ return x[:, :-1], x[:, 1:]
17
+
18
+
19
+ def add_bos_token(x: torch.Tensor, bos_id: int = 0) -> torch.Tensor:
20
+ """Añade token BOS al inicio de cada secuencia."""
21
+ bos = torch.full((x.size(0), 1), bos_id, dtype=x.dtype, device=x.device)
22
+ return torch.cat([bos, x], dim=1)
23
+
24
+
25
+ def pad_sequences(sequences, max_len: int, pad_id: int = 0) -> torch.Tensor:
26
+ """Padea una lista de secuencias de longitud variable."""
27
+ result = torch.full((len(sequences), max_len), pad_id, dtype=torch.long)
28
+ for i, seq in enumerate(sequences):
29
+ length = min(len(seq), max_len)
30
+ result[i, :length] = torch.tensor(seq[:length])
31
+ return result
32
+
33
+
34
+ def create_attention_mask(lengths: torch.Tensor, max_len: int) -> torch.Tensor:
35
+ """
36
+ Crea attention mask desde longitudes de secuencia.
37
+ Returns: [B, max_len] con True donde hay datos válidos.
38
+ """
39
+ indices = torch.arange(max_len, device=lengths.device).unsqueeze(0)
40
+ return indices < lengths.unsqueeze(1)
gfn/realizations/gssm/errors.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class GFNError(Exception):
2
+ """Base exception for all GFN errors."""
3
+ pass
4
+
5
+ class ConfigurationError(GFNError):
6
+ """Raised when a configuration is invalid or inconsistent."""
7
+ pass
8
+
9
+ class GeometryError(GFNError):
10
+ """Raised when a geometric operation fails (e.g., out of manifold)."""
11
+ pass
12
+
13
+ class PhysicsError(GFNError):
14
+ """Raised during physics engine failures (e.g., NaN detected)."""
15
+ pass
16
+
17
+ class IntegrationError(GFNError):
18
+ """Raised during numerical integration failures."""
19
+ pass
20
+
21
+ class TrainingError(GFNError):
22
+ """Raised during model training or optimization failures."""
23
+ pass
gfn/realizations/gssm/geometry/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ gfn/geometry/__init__.py
3
+ Public API for the geometry module — GFN V5
4
+ """
5
+
6
+ # Base and factory
7
+ from ..geometry.base import BaseGeometry
8
+ from ..geometry.factory import GeometryFactory
9
+
10
+ # Concrete geometries (imports trigger @register_geometry decorators)
11
+ from ..geometry.euclidean import EuclideanGeometry
12
+ from ..geometry.torus import ToroidalRiemannianGeometry, FlatToroidalRiemannianGeometry
13
+ from ..geometry.low_rank import LowRankRiemannianGeometry, PaperLowRankRiemannianGeometry
14
+ from ..geometry.adaptive import AdaptiveRiemannianGeometry
15
+ from ..geometry.reactive import ReactiveRiemannianGeometry
16
+ from ..geometry.hyperbolic import HyperRiemannianGeometry
17
+ from ..geometry.holographic import HolographicRiemannianGeometry
18
+ from ..geometry.spherical import SphericalGeometry
19
+ from ..geometry.hierarchical import HierarchicalGeometry
20
+
21
+ # Re-export FrictionGate from unified physics.components location
22
+ from ..physics.components.friction import FrictionGate
23
+
24
+ __all__ = [
25
+ # Base
26
+ "BaseGeometry",
27
+ "GeometryFactory",
28
+ # Implementations
29
+ "EuclideanGeometry",
30
+ "ToroidalRiemannianGeometry",
31
+ "FlatToroidalRiemannianGeometry",
32
+ "LowRankRiemannianGeometry",
33
+ "PaperLowRankRiemannianGeometry",
34
+ "AdaptiveRiemannianGeometry",
35
+ "ReactiveRiemannianGeometry",
36
+ "HyperRiemannianGeometry",
37
+ "HolographicRiemannianGeometry",
38
+ "SphericalGeometry",
39
+ "HierarchicalGeometry",
40
+ # Shared components
41
+ "FrictionGate",
42
+ ]
gfn/realizations/gssm/geometry/adaptive.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AdaptiveRiemannianGeometry — GFN V5
3
+ Adaptive rank Christoffel symbol decomposition.
4
+ Migrated from gfn/geo/riemannian/adaptive_geometry.py
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from typing import Optional, Union, Tuple
10
+
11
+ from ..constants import CURVATURE_CLAMP
12
+ from ..config.schema import PhysicsConfig
13
+ from ..geometry.base import BaseGeometry
14
+ from ..registry import register_geometry
15
+
16
+
17
+ @register_geometry('adaptive')
18
+ class AdaptiveRiemannianGeometry(BaseGeometry):
19
+ """
20
+ Adjusts the effective curvature rank dynamically based on velocity complexity.
21
+
22
+ Architecture:
23
+ eff_rank = f(||v||) in [min_rank, max_rank]
24
+ Γ(v) = W[:, :eff_rank] @ (U[:, :eff_rank]^T v)^2
25
+ """
26
+
27
+ def __init__(self, dim: int, max_rank: int = 64, config: Optional[PhysicsConfig] = None):
28
+ super().__init__(config)
29
+ self.dim = dim
30
+ self.max_rank = max_rank
31
+ self.min_rank_ratio = 0.1
32
+
33
+ self.U_full = nn.Parameter(torch.randn(dim, max_rank) * 0.01)
34
+ self.W_full = nn.Parameter(torch.randn(dim, max_rank) * 0.01)
35
+
36
+ # Complexity predictor: maps v → rank_ratio ∈ [0, 1]
37
+ self.complexity_net = nn.Sequential(
38
+ nn.Linear(dim, 32),
39
+ nn.ReLU(),
40
+ nn.Linear(32, 1),
41
+ nn.Sigmoid()
42
+ )
43
+ # Initialize bias to start with a mostly-open rank to avoid vanishing gradients
44
+ nn.init.constant_(self.complexity_net[-2].bias, 1.0)
45
+
46
+ self.return_friction_separately = True
47
+
48
+ def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
49
+ force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
50
+ if v is None:
51
+ return torch.zeros_like(x)
52
+
53
+ # 1. Predict rank weight (soft-mask)
54
+ # Use complexity_net to predict a value p in [0, 1]
55
+ p = self.complexity_net(v) # [B, 1]
56
+
57
+ # Create a soft mask for the rank dimension: [B, max_rank]
58
+ # Mask[i] = sigmoid(slope * (p * max_rank - i))
59
+ # This approximates hard-slicing but is differentiable.
60
+ indices = torch.arange(self.max_rank, device=v.device).float()
61
+ slope = 10.0
62
+ soft_mask = torch.sigmoid(slope * (p * self.max_rank - indices)) # [B, max_rank]
63
+
64
+ # 2. Christoffel using all components modulated by mask
65
+ proj = torch.matmul(v, self.U_full) # [B, max_rank]
66
+ sq = proj * proj # [B, max_rank]
67
+ modulated_sq = sq * soft_mask # [B, max_rank]
68
+ gamma = torch.matmul(modulated_sq, self.W_full.t()) # [B, dim]
69
+
70
+ # 3. Friction (ensure mu is not just zero)
71
+ # Fallback to config friction or a base value
72
+ friction_base = getattr(self.config.stability, 'friction', 0.1)
73
+ mu = torch.full_like(v, friction_base)
74
+
75
+ gamma_clamped = CURVATURE_CLAMP * torch.tanh(gamma / CURVATURE_CLAMP)
76
+
77
+ if self.return_friction_separately:
78
+ return gamma_clamped, mu
79
+
80
+ return gamma_clamped + mu * v
81
+
82
+ def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
83
+ return torch.ones_like(x)
gfn/realizations/gssm/geometry/base.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional, Tuple, Union
4
+ from ..interfaces.geometry import Geometry
5
+ from ..config.schema import PhysicsConfig
6
+ from ..constants import TOPOLOGY_EUCLIDEAN
7
+
8
+ class BaseGeometry(nn.Module):
9
+ """
10
+ Base implementation for Riemannian Geometries in GFN V5.
11
+ Conforms to the Geometry protocol.
12
+ """
13
+ def __init__(self, config: Optional[PhysicsConfig] = None):
14
+ super().__init__()
15
+ self.config = config or PhysicsConfig()
16
+ self.return_friction_separately = True
17
+ self.topology_type = self.config.topology.type
18
+
19
+ def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
20
+ """Default metric is identity (Euclidean). Subclasses should override."""
21
+ return torch.ones_like(x)
22
+
23
+ def christoffel_symbols(self, x: torch.Tensor) -> torch.Tensor:
24
+ """Default Christoffel symbols are zero (Euclidean). Subclasses should override."""
25
+ return torch.zeros_like(x)
26
+
27
+ def compute_kinetic_energy(self, x: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
28
+ """
29
+ Calculates Riemannian kinetic energy: T = (1/2) Σ_i g_ii v_i²
30
+ Supports position-dependent metrics (like Torus).
31
+ """
32
+ g = self.metric_tensor(x) # [..., D]
33
+ return 0.5 * (g * v.pow(2)).sum(dim=-1)
34
+
35
+ def compute_potential_energy(self, x: torch.Tensor) -> torch.Tensor:
36
+ """
37
+ Calculates physical potential energy V(x).
38
+ Default is 0.0 unless overwritten by specific topologies or forces.
39
+ """
40
+ return torch.zeros_like(x).sum(dim=-1)
41
+
42
+ def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None, force: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
43
+ """
44
+ Computes acceleration: acc = -Gamma(v, v) + F/g
45
+ Subclasses can override for more complex physics.
46
+ """
47
+ if v is None:
48
+ return torch.zeros_like(x)
49
+
50
+ gamma = self.christoffel_symbols(x)
51
+ # Standard geodesic acceleration: -Gamma^k_ij v^i v^j
52
+ # In our simplified 1D-per-dimension metric, it's often just a point-wise product
53
+ acc = -gamma * (v**2)
54
+
55
+ if force is not None:
56
+ g = self.metric_tensor(x)
57
+ acc = acc + (force / (g + 1e-8))
58
+
59
+ if getattr(self, 'return_friction_separately', False):
60
+ return acc, torch.zeros_like(v) if v is not None else torch.zeros_like(x)
61
+
62
+ return acc
63
+
64
+ def project(self, x: torch.Tensor) -> torch.Tensor:
65
+ """Default projection is identity. Subclasses should override for periodic spaces."""
66
+ return x
67
+
68
+ def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
69
+ """Default distance is Euclidean. Subclasses should override."""
70
+ return torch.norm(x1 - x2, dim=-1)
gfn/realizations/gssm/geometry/euclidean.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .base import BaseGeometry
3
+ from ..registry import register_geometry
4
+ from ..constants import TOPOLOGY_EUCLIDEAN
5
+
6
+ @register_geometry(TOPOLOGY_EUCLIDEAN)
7
+ class EuclideanGeometry(BaseGeometry):
8
+ """Standard Euclidean Space (Flat)."""
9
+
10
+ def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
11
+ return torch.ones_like(x)
12
+
13
+ def christoffel_symbols(self, x: torch.Tensor) -> torch.Tensor:
14
+ return torch.zeros_like(x)
15
+
16
+ def project(self, x: torch.Tensor) -> torch.Tensor:
17
+ return x
18
+
19
+ def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
20
+ return torch.norm(x1 - x2, dim=-1)
gfn/realizations/gssm/geometry/factory.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GeometryFactory — GFN V5
3
+ Creates geometry instances from PhysicsConfig.
4
+ Supports: euclidean, torus, low_rank, reactive, adaptive, hyperbolic, holographic.
5
+ """
6
+
7
+ from typing import Optional
8
+ from ..config.schema import PhysicsConfig
9
+ from ..registry import GEOMETRY_REGISTRY
10
+ from ..constants import TOPOLOGY_TORUS, TOPOLOGY_EUCLIDEAN
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ _GEOMETRIES_REGISTERED = False
16
+
17
+ def _register_all_geometries():
18
+ """Importa los submódulos explícitamente para registrar las geometrías."""
19
+ global _GEOMETRIES_REGISTERED
20
+ if _GEOMETRIES_REGISTERED:
21
+ return
22
+ from . import euclidean
23
+ from . import torus
24
+ from . import low_rank
25
+ from . import adaptive
26
+ from . import reactive
27
+ from . import hyperbolic
28
+ _GEOMETRIES_REGISTERED = True
29
+
30
+ class GeometryFactory:
31
+ """
32
+ Creates manifold geometries from configuration.
33
+
34
+ Primary key: topology.type ('euclidean', 'torus', 'hyperbolic', ...)
35
+ Secondary key: topology.riemannian_type ('low_rank', 'reactive', 'adaptive', ...)
36
+
37
+ riemannian_type overrides topology.type when explicitly set and registered.
38
+ """
39
+
40
+ @staticmethod
41
+ def _lookup_key(config: PhysicsConfig) -> str:
42
+ _register_all_geometries()
43
+ topo_type = config.topology.type.lower()
44
+ riem_type = getattr(config.topology, 'riemannian_type', 'reactive').lower()
45
+ available = GEOMETRY_REGISTRY.list_keys()
46
+
47
+ # Priority Logic:
48
+ # 1. Prioritize learned Riemannian geometries (low_rank, reactive, adaptive)
49
+ # even if the topology is specialized (torus, etc.), as they handle topology via features.
50
+ learned_types = {'low_rank', 'reactive', 'adaptive', 'low_rank_paper'}
51
+ if riem_type in learned_types and riem_type in available:
52
+ return riem_type
53
+
54
+ # 2. Otherwise, if topology is specific (torus, hyperbolic, etc.), use its analytical model.
55
+ if topo_type in available and topo_type != TOPOLOGY_EUCLIDEAN:
56
+ return topo_type
57
+
58
+ # 3. Fallback to riem_type or topo_type
59
+ if riem_type in available:
60
+ return riem_type
61
+
62
+ return topo_type
63
+
64
+ @staticmethod
65
+ def create(config: PhysicsConfig):
66
+ """
67
+ Create geometry using default dim from config.
68
+ Looks for 'dim' in topology config or falls back to 64.
69
+ """
70
+ lookup_key = GeometryFactory._lookup_key(config)
71
+ available = GEOMETRY_REGISTRY.list_keys()
72
+
73
+ if lookup_key in available:
74
+ geometry_cls = GEOMETRY_REGISTRY.get(lookup_key)
75
+ try:
76
+ dim = getattr(config, 'dim', 64)
77
+ rank = getattr(config.topology, 'riemannian_rank', 16)
78
+ return geometry_cls(dim=dim, rank=rank, config=config)
79
+ except TypeError:
80
+ try:
81
+ return geometry_cls(config=config)
82
+ except TypeError:
83
+ return geometry_cls()
84
+
85
+ logger.warning(f"Geometry '{lookup_key}' not found. Using EuclideanGeometry.")
86
+ from .euclidean import EuclideanGeometry
87
+ return EuclideanGeometry(config=config)
88
+
89
+ @staticmethod
90
+ def create_with_dim(dim: int, rank: int, num_heads: int, config: PhysicsConfig):
91
+ """
92
+ Create geometry with explicit dim and rank.
93
+ Used by ModelFactory to pass head_dim (not total dim) to the geometry,
94
+ since geometry operates on per-head tensors [B, H, HD].
95
+ """
96
+ lookup_key = GeometryFactory._lookup_key(config)
97
+ available = GEOMETRY_REGISTRY.list_keys()
98
+
99
+ if lookup_key in available:
100
+ geometry_cls = GEOMETRY_REGISTRY.get(lookup_key)
101
+ try:
102
+ return geometry_cls(dim=dim, rank=rank, num_heads=num_heads, config=config)
103
+ except TypeError:
104
+ try:
105
+ return geometry_cls(dim=dim, rank=rank, config=config)
106
+ except TypeError:
107
+ try:
108
+ return geometry_cls(config=config)
109
+ except TypeError:
110
+ return geometry_cls()
111
+
112
+ logger.warning(f"Geometry '{lookup_key}' not found. Using EuclideanGeometry.")
113
+ from .euclidean import EuclideanGeometry
114
+ try:
115
+ return EuclideanGeometry(dim=dim, num_heads=num_heads, config=config)
116
+ except TypeError:
117
+ return EuclideanGeometry(config=config)
gfn/realizations/gssm/geometry/hierarchical.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import List, Optional, Union, Tuple, Any
4
+ from ..geometry.base import BaseGeometry
5
+ from ..geometry.low_rank import LowRankRiemannianGeometry
6
+ from ..registry import register_geometry
7
+
8
+ @register_geometry('hierarchical')
9
+ class HierarchicalGeometry(BaseGeometry):
10
+ """
11
+ Multi-Scale Riemannian Geometry (Christoffel Mixture).
12
+ Combines multiple geometries (typically Low-Rank) with different scales.
13
+
14
+ Migrated from gfn_old HierarchicalRiemannianGeometry.
15
+ """
16
+ def __init__(self, dim: int, rank: int = 16, ranks: Optional[List[int]] = None,
17
+ num_heads: int = 1, config: Optional[Any] = None, **kwargs):
18
+ super().__init__(config)
19
+ self.dim = dim
20
+ self.ranks = ranks if ranks is not None else [8, 16, 32]
21
+ if rank not in self.ranks:
22
+ # Optionally include the factory-suggested rank
23
+ self.ranks = sorted(list(set(self.ranks + [rank])))
24
+ self.num_heads = num_heads
25
+
26
+ # Initialize sub-geometries (defaulting to LowRank)
27
+ self.scales = nn.ModuleList([
28
+ LowRankRiemannianGeometry(dim, rank=r, num_heads=num_heads, config=config)
29
+ for r in self.ranks
30
+ ])
31
+
32
+ # Learnable mixing weights
33
+ self.scale_weights = nn.Parameter(torch.ones(len(self.ranks)) / len(self.ranks))
34
+ self.return_friction_separately = False
35
+
36
+ def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
37
+ force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
38
+
39
+ gammas = []
40
+ frictions = []
41
+
42
+ # Execute each scale
43
+ for scale in self.scales:
44
+ # Temporarily ensure consistent return mode
45
+ was_sep = getattr(scale, 'return_friction_separately', False)
46
+ scale.return_friction_separately = True
47
+
48
+ res = scale(x, v, force=force, **kwargs)
49
+ if isinstance(res, tuple):
50
+ g, f = res
51
+ else:
52
+ g, f = res, torch.zeros_like(v) if v is not None else torch.zeros_like(x)
53
+
54
+ gammas.append(g)
55
+ frictions.append(f)
56
+ scale.return_friction_separately = was_sep
57
+
58
+ # Mix using softmax weights
59
+ weights = torch.softmax(self.scale_weights, dim=0)
60
+
61
+ gamma_mixed = sum(w * g for w, g in zip(weights, gammas))
62
+ friction_mixed = sum(w * f for w, f in zip(weights, frictions))
63
+
64
+ if self.return_friction_separately:
65
+ return gamma_mixed, friction_mixed
66
+
67
+ if v is not None:
68
+ return gamma_mixed + friction_mixed * v
69
+ return gamma_mixed
70
+
71
+ def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
72
+ weights = torch.softmax(self.scale_weights, dim=0)
73
+ metrics = [scale.metric_tensor(x) for scale in self.scales]
74
+ return sum(w * m for w, m in zip(weights, metrics))
75
+
76
+ def project(self, x: torch.Tensor) -> torch.Tensor:
77
+ # Hierarchical usually just projects using the first scale (geometry consistent)
78
+ from typing import cast
79
+ return cast(BaseGeometry, self.scales[0]).project(x)
80
+
81
+ def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
82
+ weights = torch.softmax(self.scale_weights, dim=0)
83
+ dists = [scale.dist(x1, x2) for scale in self.scales]
84
+ return sum(w * d for w, d in zip(weights, dists))
gfn/realizations/gssm/geometry/holographic.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HolographicRiemannianGeometry — GFN V5
3
+ AdS/CFT-inspired holographic extensions (Paper 18).
4
+ Migrated from gfn/geo/physical/holographic_geometry.py
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from typing import Optional, Union, Tuple
10
+
11
+ from ..config.schema import PhysicsConfig
12
+ from ..geometry.base import BaseGeometry
13
+ from ..registry import register_geometry
14
+
15
+
16
+ @register_geometry('holographic')
17
+ class HolographicRiemannianGeometry(BaseGeometry):
18
+ """
19
+ Conformal manifold inspired by Bulk-Boundary (AdS/CFT) correspondence.
20
+
21
+ Lifts boundary state x → bulk (x, z) where z is the holographic radial dim.
22
+ Conformal metric: g_ij = (1/z(x)²) · δ_ij
23
+
24
+ The Christoffel correction adds an AdS-geodesic term to any base geometry.
25
+ """
26
+
27
+ def __init__(self, base_geometry: BaseGeometry, z_min: float = 0.1,
28
+ z_max: float = 10.0, config: Optional[PhysicsConfig] = None):
29
+ super().__init__(config)
30
+ self.base_geometry = base_geometry
31
+ self.dim = getattr(base_geometry, 'dim', None)
32
+ self.z_min = z_min
33
+ self.z_max = z_max
34
+
35
+ dim = self.dim or 0
36
+ if dim > 0:
37
+ self.radial_net: nn.Module = nn.Sequential(
38
+ nn.Linear(dim, dim // 2),
39
+ nn.SiLU(),
40
+ nn.Linear(dim // 2, 1),
41
+ nn.Softplus()
42
+ )
43
+ else:
44
+ self.radial_net = nn.Identity()
45
+
46
+ def get_z_and_grad(self, x: torch.Tensor):
47
+ x_req = x.detach().requires_grad_(True)
48
+ with torch.enable_grad():
49
+ z = self.radial_net(x_req) + self.z_min
50
+ z = torch.clamp(z, max=self.z_max)
51
+ grad_z = torch.autograd.grad(
52
+ z.sum(), x_req,
53
+ create_graph=self.training,
54
+ retain_graph=False
55
+ )[0]
56
+ return z, grad_z
57
+
58
+ def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
59
+ z, _ = self.get_z_and_grad(x)
60
+ return (1.0 / z.pow(2)) * torch.ones_like(x)
61
+
62
+ def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
63
+ force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
64
+ out_base = self.base_geometry(x, v, force=force, **kwargs)
65
+ if isinstance(out_base, tuple):
66
+ gamma_base, mu = out_base
67
+ else:
68
+ gamma_base, mu = out_base, torch.zeros_like(v) if v is not None else torch.zeros_like(x)
69
+
70
+ if v is None:
71
+ if self.return_friction_separately:
72
+ return gamma_base, mu
73
+ return gamma_base
74
+
75
+ z, grad_z = self.get_z_and_grad(x)
76
+ v_dot_gradz = (v * grad_z).sum(dim=-1, keepdim=True)
77
+ v_sq = (v * v).sum(dim=-1, keepdim=True)
78
+ gamma_ads = -(1.0 / z) * (2.0 * v_dot_gradz * v - v_sq * grad_z)
79
+
80
+ gamma_total = gamma_base + gamma_ads
81
+
82
+ if self.return_friction_separately:
83
+ return gamma_total, mu
84
+
85
+ return gamma_total + mu * v
86
+
87
+ def project(self, x: torch.Tensor) -> torch.Tensor:
88
+ return self.base_geometry.project(x)
89
+
90
+ def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
91
+ return self.base_geometry.dist(x1, x2)
gfn/realizations/gssm/geometry/hyperbolic.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HyperRiemannianGeometry — GFN V5
3
+ Context-dependent (gated) Christoffel symbols.
4
+ Migrated from gfn/geo/topological/hyperbolic_geometry.py
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from typing import Optional, Union, Tuple
10
+
11
+ from ..constants import CURVATURE_CLAMP, EPS, TOPOLOGY_TORUS
12
+ from ..config.schema import PhysicsConfig
13
+ from ..geometry.low_rank import LowRankRiemannianGeometry
14
+ from ..registry import register_geometry
15
+
16
+
17
+ @register_geometry('hyperbolic')
18
+ class HyperRiemannianGeometry(LowRankRiemannianGeometry):
19
+ """
20
+ Hyper-Christoffel: geometry conditioned on current position.
21
+
22
+ Architecture:
23
+ U(x) = U_static * diag(Gate_u(x)) — position-scaled basis
24
+ W(x) = W_static * diag(Gate_w(x))
25
+ Γ(v | x) = W(x) @ (U(x)^T v)²
26
+
27
+ Gates output values in [0, 2] initialized near 1.0 (identity).
28
+ """
29
+
30
+ def __init__(self, dim: int, rank: int = 16, num_heads: int = 1,
31
+ config: Optional[PhysicsConfig] = None):
32
+ super().__init__(dim, rank, num_heads=num_heads, config=config)
33
+ self.return_friction_separately = True
34
+
35
+ self.gate_u = nn.Linear(dim, rank)
36
+ self.gate_w = nn.Linear(dim, rank)
37
+ nn.init.zeros_(self.gate_u.weight); nn.init.zeros_(self.gate_u.bias)
38
+ nn.init.zeros_(self.gate_w.weight); nn.init.zeros_(self.gate_w.bias)
39
+
40
+ def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
41
+ force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
42
+ if v is None:
43
+ return torch.zeros_like(x)
44
+
45
+ original_shape = v.shape
46
+ # Handle multi-head [B, H, HD] -> [B*H, HD]
47
+ if v.dim() == 3:
48
+ B, H, HD = v.shape
49
+ v_flat = v.reshape(B * H, HD)
50
+ x_flat = x.reshape(B * H, HD)
51
+ else:
52
+ v_flat = v
53
+ x_flat = x
54
+ B, H = None, None
55
+
56
+ # Context gates in [0, 2]
57
+ g_u = torch.sigmoid(self.gate_u(x_flat)) * 2.0 # [B*H, rank]
58
+ g_w = torch.sigmoid(self.gate_w(x_flat)) * 2.0
59
+
60
+ # Modulate static basis
61
+ # self.U is [HD, rank] or [H, HD, rank]
62
+ U_eff = self.U if self.U.dim() == 2 else self.U.mean(0)
63
+ proj_static = torch.matmul(v_flat, U_eff) # [B*H, rank]
64
+ proj_dynamic = proj_static * g_u
65
+
66
+ # Soft-saturation to prevent energy explosion
67
+ sq_dynamic = (proj_dynamic * proj_dynamic) / (1.0 + torch.abs(proj_dynamic) + EPS)
68
+ sq_modulated = sq_dynamic * g_w
69
+
70
+ W_t = self.W.t() if self.W.dim() == 2 else self.W.mean(0).t()
71
+ gamma = torch.matmul(sq_modulated, W_t) # [B*H, HD]
72
+
73
+ # Restore original shape if multi-head
74
+ if B is not None:
75
+ gamma = gamma.view(original_shape)
76
+ x_flat_for_mu = x_flat
77
+ v_flat_for_mu = v_flat
78
+ else:
79
+ x_flat_for_mu = x
80
+ v_flat_for_mu = v
81
+
82
+ # Friction
83
+ x_in = torch.cat([torch.sin(x_flat), torch.cos(x_flat)], dim=-1) \
84
+ if self.topology_type == TOPOLOGY_TORUS else x_flat
85
+ mu_base = self.friction + self.friction_gate(x_in, force=force)
86
+ v_norm = torch.norm(v, dim=-1, keepdim=True) / (self.dim ** 0.5 + EPS)
87
+ mu = mu_base * (1.0 + self.velocity_friction_scale * v_norm)
88
+ if mu.shape != v.shape:
89
+ mu = mu.view_as(v) if mu.numel() == v.numel() else mu.mean(dim=-1, keepdim=True)
90
+
91
+ gamma = self._normalize(gamma)
92
+ gamma = self.clamp_val * torch.tanh(gamma / self.clamp_val)
93
+
94
+ if self.return_friction_separately:
95
+ return gamma, mu
96
+
97
+ return gamma + mu * v
gfn/realizations/gssm/geometry/low_rank.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LowRankRiemannianGeometry — GFN V5
3
+ Computes Christoffel symbols via a low-rank decomposition.
4
+ Migrated from gfn/geo/riemannian/low_rank_geometry.py
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from typing import Optional, Union, Tuple, Dict, Any
10
+
11
+ from ..constants import (
12
+ EPS, MAX_VELOCITY, TOPOLOGY_TORUS, TOPOLOGY_EUCLIDEAN,
13
+ DEFAULT_FRICTION, CURVATURE_CLAMP, GATE_BIAS_OPEN
14
+ )
15
+ from ..config.schema import PhysicsConfig
16
+ from ..geometry.base import BaseGeometry
17
+ from ..registry import register_geometry
18
+ from ..cuda.ops import CUDA_AVAILABLE, low_rank_christoffel_fwd, low_rank_christoffel_bwd
19
+
20
+ class LowRankChristoffelFunction(torch.autograd.Function):
21
+ @staticmethod
22
+ def forward(ctx, v, U, W, clamp_val, enable_trace_norm, is_paper_version=False):
23
+ v_c = v.contiguous()
24
+ U_c = U.contiguous()
25
+ W_c = W.contiguous()
26
+
27
+ gamma = low_rank_christoffel_fwd(v_c, U_c, W_c, float(clamp_val), enable_trace_norm, is_paper_version)
28
+ ctx.save_for_backward(v_c, U_c, W_c, gamma)
29
+ ctx.clamp_val = float(clamp_val)
30
+ ctx.enable_trace_norm = enable_trace_norm
31
+ ctx.is_paper_version = is_paper_version
32
+ return gamma
33
+
34
+ @staticmethod
35
+ def backward(ctx, grad_gamma):
36
+ v_c, U_c, W_c, gamma_out = ctx.saved_tensors
37
+ if grad_gamma is None:
38
+ return None, None, None, None, None, None
39
+
40
+ grad_gamma_c = grad_gamma.contiguous()
41
+ d_v, d_U, d_W = low_rank_christoffel_bwd(
42
+ grad_gamma_c, v_c, U_c, W_c, gamma_out,
43
+ ctx.clamp_val, ctx.enable_trace_norm, ctx.is_paper_version
44
+ )
45
+ return d_v, d_U, d_W, None, None, None
46
+
47
+
48
+ # Use unified FrictionGate from physics.components (no duplication)
49
+ from ..physics.components.friction import FrictionGate
50
+
51
+
52
+ @register_geometry('low_rank')
53
+ class LowRankRiemannianGeometry(BaseGeometry):
54
+ r"""
55
+ Low-rank Christoffel symbol decomposition.
56
+
57
+ Γ^k_ij ≈ Σ_r W_{rk} * (U_ir * U_jr)
58
+ This is an approximation — symmetry is preserved but Bianchi identities are not guaranteed.
59
+ Chosen for computational efficiency.
60
+
61
+ Args:
62
+ dim: Manifold dimension.
63
+ rank: Rank of the decomposition.
64
+ num_heads: Number of parallel heads.
65
+ config: PhysicsConfig instance.
66
+ """
67
+
68
+ def __init__(self, dim: int, rank: int = 16, num_heads: int = 1,
69
+ config: Optional[PhysicsConfig] = None):
70
+ super().__init__(config)
71
+ self.dim = dim
72
+ self.rank = rank
73
+ self.num_heads = num_heads
74
+
75
+ topo = self.config.topology.type.lower()
76
+ self.topology = topo
77
+ self.clamp_val = self.config.stability.curvature_clamp
78
+ self.enable_trace_normalization = self.config.stability.enable_trace_normalization
79
+ self.enable_trace_normalization = self.config.stability.enable_trace_normalization
80
+ # Friction parameters are now handled by PhysicsEngine to avoid duplication
81
+
82
+ # Feature dimension for gate input (Fourier for torus)
83
+ gate_input_dim = dim * 2 if topo == TOPOLOGY_TORUS else dim
84
+
85
+ # Low-rank basis parameters - initialized with small noise to break symmetry
86
+ if num_heads > 1:
87
+ self.U = nn.Parameter(torch.randn(num_heads, dim, rank) * 1e-4)
88
+ self.W = nn.Parameter(torch.randn(num_heads, dim, rank) * 1e-4)
89
+ else:
90
+ self.U = nn.Parameter(torch.randn(dim, rank) * 1e-4)
91
+ self.W = nn.Parameter(torch.randn(dim, rank) * 1e-4)
92
+
93
+ # Friction gate
94
+ friction_mode = getattr(self.config.stability, 'friction_mode', 'static')
95
+ self.friction_gate = FrictionGate(dim, gate_input_dim, mode=friction_mode, num_heads=num_heads)
96
+
97
+ # CONTRACT: LowRank ALWAYS returns (gamma_christoffel, mu_friction) separately.
98
+ # The physics engine is the single authority on when/how friction is applied.
99
+ # This prevents the P0.1 double-friction bug (geometry + engine both applying mu*v).
100
+ self.return_friction_separately = True
101
+
102
+ def _get_features(self, x: torch.Tensor) -> torch.Tensor:
103
+ """Convert position to gate input features."""
104
+ if self.topology == TOPOLOGY_TORUS:
105
+ return torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
106
+ return x
107
+
108
+ def connection(self, v: torch.Tensor, w: torch.Tensor,
109
+ x: Optional[torch.Tensor] = None) -> torch.Tensor:
110
+ """
111
+ Bilinear Christoffel contraction: Γ(v,w)^k
112
+ Γ^k_ij ≈ Σ_r W[r,k] * (U[i,r] * U[j,r])
113
+ """
114
+ if v.dim() == 3 and self.U.dim() == 3:
115
+ v_r = torch.einsum('bhd, hdr -> bhr', v, self.U)
116
+ w_r = torch.einsum('bhd, hdr -> bhr', w, self.U)
117
+ vw_r = v_r * w_r
118
+ gamma = torch.einsum('bhr, hdr -> bhd', vw_r, self.W)
119
+ else:
120
+ v_r = v @ self.U # [..., rank] (works for both 2D and 3D U)
121
+ w_r = w @ self.U
122
+ vw_r = v_r * w_r
123
+ W_t = self.W.transpose(-1, -2) if self.W.dim() == 3 else self.W.t()
124
+ gamma = vw_r @ W_t
125
+ return torch.clamp(gamma, -self.clamp_val, self.clamp_val)
126
+
127
+ def _normalize(self, gamma: torch.Tensor) -> torch.Tensor:
128
+ """Symmetry-preserving trace normalization."""
129
+ if gamma.dim() < 2:
130
+ return gamma
131
+
132
+ is_multi_head = (gamma.dim() == 3 and self.num_heads > 1)
133
+
134
+ # Matrix case [..., D, D]
135
+ if not is_multi_head and gamma.dim() >= 3 and gamma.shape[-1] == gamma.shape[-2]:
136
+ gamma_sym = 0.5 * (gamma + gamma.transpose(-1, -2))
137
+ if self.enable_trace_normalization:
138
+ diag_mean = torch.diagonal(gamma_sym, dim1=-1, dim2=-2).mean(dim=-1, keepdim=True)
139
+ correction = torch.diag_embed(diag_mean.expand(-1, self.dim))
140
+ return gamma_sym - correction
141
+ return gamma_sym
142
+
143
+ # Vector case [..., D]
144
+ if self.enable_trace_normalization:
145
+ return gamma - gamma.mean(dim=-1, keepdim=True)
146
+ return gamma
147
+
148
+ def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
149
+ force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
150
+ if v is None:
151
+ return torch.zeros_like(x)
152
+
153
+ original_shape = v.shape
154
+ # Handle multi-head [B, H, HD] → reshape to [B*H, HD] for matmul with U=[HD, rank]
155
+ if v.dim() == 3 and self.U.dim() == 2:
156
+ B, H, HD = v.shape
157
+ v_flat = v.reshape(B * H, HD) # [B*H, HD]
158
+ x_flat = x.reshape(B * H, HD)
159
+ else:
160
+ v_flat = v
161
+ x_flat = x
162
+ B, H, HD = None, None, v_flat.shape[-1]
163
+ R = self.rank
164
+
165
+ # Check if we can take the fast CUDA path
166
+ use_cuda_fused = (
167
+ CUDA_AVAILABLE and
168
+ low_rank_christoffel_fwd is not None and
169
+ v_flat.is_cuda and
170
+ v_flat.dtype == torch.float32 and
171
+ self.W.dim() == 3
172
+ )
173
+
174
+ if use_cuda_fused:
175
+ # Reshape [B*H, HD] -> [B, H, HD] to match what kernel expects
176
+ actual_B = original_shape[0] if v.dim() == 3 else 1
177
+ actual_H = self.num_heads
178
+
179
+ v_re = v_flat.view(actual_B, actual_H, HD)
180
+ U_re = self.U.view(actual_H, HD, R) # self.U is [H, D, R]
181
+ W_re = self.W.view(actual_H, HD, R)
182
+
183
+ gamma_re = LowRankChristoffelFunction.apply(
184
+ v_re, U_re, W_re, self.clamp_val, self.enable_trace_normalization, False
185
+ )
186
+ gamma = gamma_re.view_as(v_flat)
187
+ else:
188
+ # Christoffel symbols via self-connection (Native Fallback)
189
+ if v_flat.dim() == 3 and self.U.dim() == 3:
190
+ v_r = torch.einsum('bhd, hdr -> bhr', v_flat, self.U)
191
+ sq = v_r * v_r
192
+ gamma = torch.einsum('bhr, hdr -> bhd', sq, self.W)
193
+ else:
194
+ v_r = v_flat @ self.U # [..., rank]
195
+ sq = v_r * v_r
196
+ W_t = self.W.transpose(-1, -2) if self.W.dim() == 3 else self.W.t()
197
+ gamma = sq @ W_t # [..., HD]
198
+
199
+ # Friction coefficient (position-dependent, gated)
200
+ x_in = self._get_features(x_flat)
201
+ mu = self.friction_gate(x_in, force=force)
202
+
203
+ # Note: PhysicsEngine will add base friction and apply velocity scaling
204
+
205
+ # Friction and normalizing only if not already done in CUDA
206
+ if not use_cuda_fused:
207
+ gamma = self._normalize(gamma)
208
+ gamma = self.clamp_val * torch.tanh(gamma / self.clamp_val)
209
+
210
+ # Restore original shape if we reshaped
211
+ if B is not None:
212
+ gamma = gamma.view(original_shape)
213
+ mu = mu.view(original_shape)
214
+
215
+ # CONTRACT: always return (gamma_pure, mu) so engine has single authority over friction
216
+ return gamma, mu
217
+
218
+ def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
219
+ """
220
+ Implicit Riemannian metric from the low-rank decomposition.
221
+
222
+ The Christoffel parametrization Γ^k_ij ≈ Σ_r W_rk (U_ir U_jr) implies
223
+ an underlying metric: g_ij ≈ Σ_r U_ir * U_jr = diag(U @ Uᵀ)
224
+
225
+ Returns per-coordinate metric scale [..., D] or ones if shape unknown.
226
+ T = (1/2) Σ_i g_ii v_i² (Riemannian kinetic energy)
227
+ """
228
+ if self.U.dim() == 2:
229
+ # Single head: U is [D, rank] → g_diag is [D]
230
+ g_diag = (self.U ** 2).sum(dim=-1) # [D]
231
+ # Broadcast to x shape: handles [B, D], [B*H, D], any [..., D]
232
+ return g_diag.expand_as(x)
233
+ else:
234
+ # Multi-head: U is [H, D, rank] → g_diag is [H, D]
235
+ g_diag = (self.U ** 2).sum(dim=-1) # [H, D]
236
+ if x.dim() == 3 and x.shape[1] == self.num_heads:
237
+ # [B, H, D]: structured multi-head
238
+ return g_diag.unsqueeze(0).expand_as(x)
239
+ else:
240
+ # [B, H*D]: flat layout — expand g_diag [H,D] → [H*D], broadcast to [B, H*D]
241
+ g_flat = g_diag.reshape(-1) # [H*D]
242
+ return g_flat.expand(x.shape[0], -1) if x.dim() == 2 else g_flat.expand_as(x)
243
+
244
+ def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
245
+ if self.topology == TOPOLOGY_TORUS:
246
+ diff = x1 - x2
247
+ diff = torch.atan2(torch.sin(diff), torch.cos(diff))
248
+ return torch.norm(diff, dim=-1)
249
+ return torch.norm(x1 - x2, dim=-1)
250
+
251
+ def project(self, x: torch.Tensor) -> torch.Tensor:
252
+ if self.topology == TOPOLOGY_TORUS:
253
+ return torch.atan2(torch.sin(x), torch.cos(x))
254
+ return x
255
+
256
+
257
+ @register_geometry('low_rank_paper')
258
+ class PaperLowRankRiemannianGeometry(LowRankRiemannianGeometry):
259
+ def __init__(self, dim: int, rank: int = 16, num_heads: int = 1,
260
+ config: Optional[PhysicsConfig] = None):
261
+ super().__init__(dim, rank, num_heads=num_heads, config=config)
262
+ self.return_friction_separately = True
263
+
264
+ def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
265
+ force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
266
+ if v is None:
267
+ return torch.zeros_like(x)
268
+
269
+ original_shape = v.shape
270
+ if v.dim() == 3 and self.U.dim() == 2:
271
+ B, H, HD = v.shape
272
+ v_flat = v.reshape(B * H, HD)
273
+ x_flat = x.reshape(B * H, HD)
274
+ else:
275
+ v_flat = v
276
+ x_flat = x
277
+ B, H, HD = None, None, v_flat.shape[-1]
278
+ R = self.rank
279
+
280
+ use_cuda_fused = (
281
+ CUDA_AVAILABLE and
282
+ low_rank_christoffel_fwd is not None and
283
+ v_flat.is_cuda and
284
+ v_flat.dtype == torch.float32 and
285
+ self.W.dim() == 3
286
+ )
287
+
288
+ if use_cuda_fused:
289
+ actual_B = original_shape[0] if v.dim() == 3 else 1
290
+ actual_H = self.num_heads
291
+ v_re = v_flat.view(actual_B, actual_H, HD)
292
+ U_re = self.U.view(actual_H, HD, R)
293
+ W_re = self.W.view(actual_H, HD, R)
294
+
295
+ gamma_re = LowRankChristoffelFunction.apply(
296
+ v_re, U_re, W_re, self.clamp_val, self.enable_trace_normalization, True
297
+ )
298
+ gamma = gamma_re.view_as(v_flat)
299
+ else:
300
+ if v_flat.dim() == 3 and self.U.dim() == 3:
301
+ v_r = torch.einsum('bhd, hdr -> bhr', v_flat, self.U)
302
+ denom = 1.0 + torch.norm(v_r, dim=-1, keepdim=True)
303
+ phi = (v_r * v_r) / denom
304
+ gamma = torch.einsum('bhr, hdr -> bhd', phi, self.W)
305
+ else:
306
+ v_r = v_flat @ self.U
307
+ denom = 1.0 + torch.norm(v_r, dim=-1, keepdim=True)
308
+ phi = (v_r * v_r) / denom
309
+ W_t = self.W.transpose(-1, -2) if self.W.dim() == 3 else self.W.t()
310
+ gamma = phi @ W_t
311
+
312
+ x_in = self._get_features(x_flat)
313
+ mu = self.friction_gate(x_in, force=force)
314
+
315
+ if not use_cuda_fused:
316
+ gamma = self._normalize(gamma)
317
+ gamma = self.clamp_val * torch.tanh(gamma / self.clamp_val)
318
+
319
+ if B is not None:
320
+ gamma = gamma.view(original_shape)
321
+ mu = mu.view(original_shape)
322
+
323
+ # CONTRACT: Always return (gamma_pure, mu) for unified PhysicsEngine handling.
324
+ return gamma, mu
gfn/realizations/gssm/geometry/reactive.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ReactiveRiemannianGeometry — GFN V5
3
+ Active-inference geometry: curvature reacts to system state.
4
+ Migrated from gfn/geo/physical/reactive_field_geometry.py
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from typing import Optional, Union, Tuple, Dict, Any
10
+
11
+ from ..constants import CURVATURE_CLAMP, EPS, DEFAULT_PLASTICITY, TOPOLOGY_TORUS
12
+ from ..config.schema import PhysicsConfig
13
+ from ..geometry.low_rank import LowRankRiemannianGeometry
14
+ from ..registry import register_geometry
15
+
16
+ # Default constants
17
+ SINGULARITY_THRESHOLD = 0.5
18
+ BLACK_HOLE_STRENGTH = 3.0
19
+ SINGULARITY_GATE_SLOPE = 10.0
20
+
21
+
22
+ @register_geometry('reactive')
23
+ class ReactiveRiemannianGeometry(LowRankRiemannianGeometry):
24
+ """
25
+ Geometry that reacts to the system's own state via active inference.
26
+
27
+ Enhancements over LowRank:
28
+ 1. Plasticity: Christoffel symbols scaled by kinetic energy (curv. amplification ≈ attention).
29
+ 2. Singularities: Soft curvature amplification near semantic attractors.
30
+
31
+ Note: These are regularization/attention mechanisms, NOT physical manifold properties.
32
+ """
33
+
34
+ def __init__(self, dim: int, rank: int = 16, num_heads: int = 1,
35
+ config: Optional[PhysicsConfig] = None):
36
+ super().__init__(dim, rank, num_heads=num_heads, config=config)
37
+ self.return_friction_separately = True
38
+
39
+ self.active_cfg = self.config.active_inference
40
+ self.plasticity = getattr(self.active_cfg, 'plasticity', DEFAULT_PLASTICITY)
41
+
42
+ sing_cfg = self.config.singularities
43
+ sing_enabled = getattr(sing_cfg, 'enabled', False)
44
+
45
+ if sing_enabled:
46
+ self.semantic_certainty_threshold = getattr(sing_cfg, 'threshold', SINGULARITY_THRESHOLD)
47
+ self.curvature_amplification_factor = getattr(sing_cfg, 'strength', BLACK_HOLE_STRENGTH)
48
+ gate_input_dim = dim * 2 if self.topology == TOPOLOGY_TORUS else dim
49
+ if num_heads > 1:
50
+ self.V_weight = nn.Parameter(torch.zeros(num_heads, gate_input_dim, 1))
51
+ else:
52
+ self.V = nn.Linear(gate_input_dim, 1)
53
+ nn.init.zeros_(self.V.weight)
54
+ nn.init.constant_(self.V.bias, -2.0) # Start gate closed
55
+ else:
56
+ self.semantic_certainty_threshold = SINGULARITY_THRESHOLD
57
+ self.curvature_amplification_factor = BLACK_HOLE_STRENGTH
58
+ self.V = None
59
+
60
+ def _get_potential(self, x_in: torch.Tensor) -> Optional[torch.Tensor]:
61
+ """Compute singularity potential, returns None if disabled."""
62
+ if not getattr(self.config.singularities, 'enabled', False):
63
+ return None
64
+ if self.num_heads > 1:
65
+ return torch.sigmoid(torch.matmul(x_in.unsqueeze(-2), self.V_weight).squeeze(-2))
66
+ elif self.V is not None:
67
+ return torch.sigmoid(self.V(x_in))
68
+ return None
69
+
70
+ def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
71
+ force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
72
+ if v is None:
73
+ return torch.zeros_like(x)
74
+
75
+ # 1. Base curvature from LowRank
76
+ res = super().forward(x, v, force=force, **kwargs)
77
+ if isinstance(res, tuple):
78
+ gamma, mu = res
79
+ else:
80
+ gamma, mu = res, torch.zeros_like(v) if v is not None else torch.zeros_like(x)
81
+
82
+ if not self.active_cfg.enabled:
83
+ if self.return_friction_separately:
84
+ return gamma, mu
85
+ return gamma + mu * v if v is not None else gamma
86
+
87
+ # 2. Plasticity: scale curvature by kinetic energy
88
+ react_cfg = self.active_cfg.reactive_curvature
89
+ react_enabled = react_cfg.get('enabled', False) if isinstance(react_cfg, dict) else False
90
+ if react_enabled and self.plasticity > 0.0:
91
+ energy = torch.tanh(v.pow(2).mean(dim=-1, keepdim=True))
92
+ gamma = gamma * (1.0 + self.plasticity * energy)
93
+
94
+ # 3. Singularity amplification
95
+ if getattr(self.config.singularities, 'enabled', False) and x is not None:
96
+ x_in = torch.cat([torch.sin(x), torch.cos(x)], dim=-1) if self.topology == TOPOLOGY_TORUS else x
97
+ potential = self._get_potential(x_in)
98
+ if potential is not None:
99
+ gate_slope = getattr(self.config.singularities, 'gate_slope', SINGULARITY_GATE_SLOPE)
100
+ is_amplified = torch.sigmoid(gate_slope * (potential - self.semantic_certainty_threshold))
101
+ amp = 1.0 + is_amplified * (self.curvature_amplification_factor - 1.0)
102
+ gamma = gamma * amp
103
+ limit = self.curvature_amplification_factor * CURVATURE_CLAMP
104
+ gamma = limit * torch.tanh(gamma / limit)
105
+
106
+ if self.return_friction_separately:
107
+ return gamma, mu
108
+
109
+ return gamma + mu * v if v is not None else gamma
gfn/realizations/gssm/geometry/spherical.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional, Union, Tuple, Any
3
+ from ..geometry.base import BaseGeometry
4
+ from ..registry import register_geometry
5
+ from ..constants import EPS, TOPOLOGY_SPHERE
6
+
7
+ @register_geometry(TOPOLOGY_SPHERE)
8
+ class SphericalGeometry(BaseGeometry):
9
+ """
10
+ Spherical Geometry (Analytical).
11
+ Computes Christoffel symbols for a constant positive curvature space.
12
+ """
13
+ def __init__(self, dim: int, rank: int = 16, config: Optional[Any] = None, **kwargs):
14
+ super().__init__(config)
15
+ self.dim = dim
16
+
17
+ def christoffel_symbols(self, x: torch.Tensor) -> torch.Tensor:
18
+ """Analytical Christoffel symbols for S^n are typically zero in standard embedding or complex in other charts."""
19
+ # For our GFN purposes, we often use the simplified 'spherical_christoffel_torch'
20
+ # which acts as a centering/restoring force towards the sphere surface.
21
+ return torch.zeros_like(x)
22
+
23
+ def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
24
+ force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
25
+ if v is None:
26
+ return torch.zeros_like(x)
27
+
28
+ # Simplified analytical spherical coupling (restoring force)
29
+ xv = torch.sum(x * v, dim=-1, keepdim=True)
30
+ vv = torch.sum(v * v, dim=-1, keepdim=True)
31
+
32
+ # Gamma = -(2.0 * xv * v - vv * x)
33
+ gamma = -(2.0 * xv * v - vv * x)
34
+
35
+ # Apply standard V5 clamping
36
+ clamp_val = getattr(self, 'clamp_val', 5.0)
37
+ return clamp_val * torch.tanh(gamma / clamp_val)
38
+
39
+ def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
40
+ """Identity metric for simplified spherical chart."""
41
+ return torch.ones_like(x)
42
+
43
+ def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
44
+ """Great-circle distance approximation."""
45
+ dot = torch.sum(x1 * x2, dim=-1)
46
+ # Assuming points are on unit sphere
47
+ return torch.acos(torch.clamp(dot, -1.0 + EPS, 1.0 - EPS))
gfn/realizations/gssm/geometry/torus.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ToroidalRiemannianGeometry — GFN V5
3
+ Full analytical torus geometry (canonical implementation).
4
+ Replaces old stub. Migrated from gfn/geo/topological/toroidal_geometry.py
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from typing import Optional, Union, Tuple, Dict, Any
10
+
11
+ from ..constants import (
12
+ EPS, TOPOLOGY_TORUS, CURVATURE_CLAMP,
13
+ DEFAULT_FRICTION, GATE_BIAS_CLOSED
14
+ )
15
+ from ..config.schema import PhysicsConfig
16
+ from ..geometry.base import BaseGeometry
17
+ from ..registry import register_geometry
18
+
19
+ # Import modular friction component
20
+ from ..physics.components.friction import FrictionGate
21
+
22
+ # Torus-specific constants
23
+ TOROIDAL_CURVATURE_SCALE = 0.1
24
+ CLAMP_MIN_STRONG = 1e-4
25
+ EPSILON_SMOOTH = 1e-9
26
+
27
+ try:
28
+ import toroidal_cuda
29
+ HAS_CUDA_EXT = True
30
+ except ImportError:
31
+ HAS_CUDA_EXT = False
32
+
33
+
34
+
35
+ @register_geometry(TOPOLOGY_TORUS)
36
+ class ToroidalRiemannianGeometry(BaseGeometry):
37
+ """
38
+ Curved torus geometry with exact (analytical) Christoffel symbols.
39
+
40
+ Metric (2D torus paired dimensions):
41
+ g_theta = r²
42
+ g_phi = (R + r·cos θ)²
43
+
44
+ Generalizes to N dims by pairing (theta_0,phi_0), (theta_1,phi_1), ...
45
+ """
46
+
47
+ def __init__(self, dim: int = 64, rank: int = 16, num_heads: int = 1,
48
+ config: Optional[PhysicsConfig] = None):
49
+ super().__init__(config)
50
+ self.dim = dim
51
+ self.num_heads = num_heads
52
+
53
+ topo = self.config.topology
54
+
55
+ # CORRECCIÓN: Hacer R y r aprendibles según GFN_Paper_Complete.md Sección 5.1
56
+ # "where R_i are the (learnable) radii"
57
+ # Y 01_HYPER_TORUS.md Sección 2.2: R es radio mayor, r es radio menor
58
+ learnable_R = getattr(topo, 'learnable_R', True)
59
+ learnable_r = getattr(topo, 'learnable_r', True)
60
+
61
+ if learnable_R:
62
+ # R como parámetro aprendible
63
+ self.R = nn.Parameter(torch.tensor(topo.R, dtype=torch.float32))
64
+ else:
65
+ # R como buffer no-entrenable
66
+ self.register_buffer('R', torch.tensor(topo.R, dtype=torch.float32))
67
+
68
+ if learnable_r:
69
+ # r como parámetro aprendible
70
+ self.r = nn.Parameter(torch.tensor(topo.r, dtype=torch.float32))
71
+ else:
72
+ # r como buffer no-entrenable
73
+ self.register_buffer('r', torch.tensor(topo.r, dtype=torch.float32))
74
+
75
+ self.topology = topo.type.lower()
76
+ self.clamp_val = self.config.stability.curvature_clamp
77
+
78
+ active_cfg = self.config.active_inference
79
+ self.active_cfg = active_cfg
80
+ rc = active_cfg.reactive_curvature
81
+ self.plasticity = rc.get('plasticity', getattr(active_cfg, 'plasticity', 0.0)) \
82
+ if isinstance(rc, dict) else getattr(active_cfg, 'plasticity', 0.0)
83
+
84
+ sing_cfg = self.config.singularities
85
+ self.singularity_threshold = sing_cfg.threshold
86
+ self.black_hole_strength = sing_cfg.strength
87
+
88
+ gate_input_dim = dim * 2 # [sin(x), cos(x)]
89
+ friction_mode = getattr(self.config.stability, 'friction_mode', 'static')
90
+ self.friction_gate = FrictionGate(
91
+ dim, gate_input_dim=gate_input_dim, mode=friction_mode, num_heads=num_heads
92
+ )
93
+
94
+ # Optional singularity potential gate
95
+ if getattr(sing_cfg, 'enabled', False):
96
+ if num_heads > 1:
97
+ self.V_weight = nn.Parameter(torch.zeros(num_heads, gate_input_dim, 1))
98
+ self.V_bias = nn.Parameter(torch.full((num_heads, 1), GATE_BIAS_CLOSED))
99
+ else:
100
+ self.V = nn.Linear(gate_input_dim, 1)
101
+ nn.init.zeros_(self.V.weight)
102
+ nn.init.constant_(self.V.bias, GATE_BIAS_CLOSED)
103
+ else:
104
+ self.V = None
105
+
106
+ def validate_dimensions(self, x: torch.Tensor):
107
+ if x.shape[-1] % 2 != 0:
108
+ raise ValueError(
109
+ f"ToroidalGeometry requires even dim for (θ,φ) pairing. Got {x.shape[-1]}"
110
+ )
111
+
112
+ def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
113
+ self.validate_dimensions(x)
114
+ g = torch.ones_like(x)
115
+ for i in range(0, x.shape[-1], 2):
116
+ th = x[..., i]
117
+ g[..., i] = self.r ** 2
118
+ g[..., i + 1] = (self.R + self.r * torch.cos(th)) ** 2
119
+ return g
120
+
121
+ def connection(self, v: torch.Tensor, w: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
122
+ d = x.shape[-1]
123
+ gamma = torch.zeros_like(v)
124
+ for i in range(0, (d // 2) * 2, 2):
125
+ th = x[..., i]
126
+ v_th, v_ph = v[..., i], v[..., i + 1]
127
+ w_th, w_ph = w[..., i], w[..., i + 1]
128
+ denom = torch.clamp(self.R + self.r * torch.cos(th), min=CLAMP_MIN_STRONG)
129
+ term_th = (denom * torch.sin(th) / (self.r + EPSILON_SMOOTH)) * TOROIDAL_CURVATURE_SCALE
130
+ gamma[..., i] = term_th * (v_ph * w_ph)
131
+ term_ph = (-(self.r * torch.sin(th)) / (denom + EPSILON_SMOOTH)) * TOROIDAL_CURVATURE_SCALE
132
+ gamma[..., i + 1] = term_ph * (v_ph * w_th + v_th * w_ph)
133
+ return gamma
134
+
135
+ def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
136
+ force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
137
+ if v is None:
138
+ return torch.zeros_like(x)
139
+
140
+ self.validate_dimensions(x)
141
+
142
+ if HAS_CUDA_EXT and x.is_cuda and v is not None and v.is_cuda:
143
+ gamma = toroidal_cuda.forward(
144
+ x, v, self.R, self.r,
145
+ TOROIDAL_CURVATURE_SCALE, EPSILON_SMOOTH, CLAMP_MIN_STRONG
146
+ )
147
+ else:
148
+ is_odd = (self.dim % 2 != 0)
149
+ x_pad = torch.nn.functional.pad(x, (0, 1)) if is_odd else x
150
+ v_pad = torch.nn.functional.pad(v, (0, 1)) if is_odd else v
151
+
152
+ th = x_pad[..., 0::2]
153
+ v_th = v_pad[..., 0::2]
154
+ v_ph = v_pad[..., 1::2]
155
+
156
+ denom = torch.clamp(self.R + self.r * torch.cos(th), min=CLAMP_MIN_STRONG)
157
+ term_th = (denom * torch.sin(th) / (self.r + EPSILON_SMOOTH))
158
+ term_ph = -(self.r * torch.sin(th)) / (denom + EPSILON_SMOOTH)
159
+
160
+ gamma_th = term_th * (v_ph ** 2) * TOROIDAL_CURVATURE_SCALE
161
+ gamma_ph = 2.0 * term_ph * v_ph * v_th * TOROIDAL_CURVATURE_SCALE
162
+
163
+ half = x.shape[-1] // 2
164
+ gamma = torch.zeros_like(x)
165
+ gamma[..., 0::2] = gamma_th[..., :half + x.shape[-1] % 2]
166
+ gamma[..., 1::2] = gamma_ph[..., :half]
167
+
168
+ # Friction gate
169
+ x_in = torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
170
+ mu = self.friction_gate(x_in, force=force)
171
+ if mu.shape[-1] != v.shape[-1]:
172
+ if mu.shape[-1] == 2 * v.shape[-1]:
173
+ mu = (mu[..., :v.shape[-1]] + mu[..., v.shape[-1]:]) / 2.0
174
+
175
+ # Active inference
176
+ if self.active_cfg.enabled:
177
+ rc = self.active_cfg.reactive_curvature
178
+ react_enabled = rc.get('enabled', False) if isinstance(rc, dict) else False
179
+ if react_enabled and self.plasticity > 0.0:
180
+ energy = torch.tanh(v.pow(2).mean(dim=-1, keepdim=True))
181
+ gamma = gamma * (1.0 + self.plasticity * energy)
182
+
183
+ if getattr(self.config.singularities, 'enabled', False):
184
+ if self.num_heads > 1:
185
+ potential = torch.sigmoid(
186
+ torch.matmul(x_in.unsqueeze(-2), self.V_weight).squeeze(-2)
187
+ + self.V_bias
188
+ )
189
+ elif self.V is not None:
190
+ potential = torch.sigmoid(self.V(x_in))
191
+ else:
192
+ potential = None
193
+
194
+ if potential is not None:
195
+ soft_m = torch.sigmoid(5.0 * (potential - self.singularity_threshold))
196
+ gamma = gamma * (1.0 + soft_m * (self.black_hole_strength - 1.0))
197
+
198
+ # CONTRACT: Always return (gamma_pure, mu) — engine applies friction, not geometry.
199
+ # This unifies the contract with LowRankRiemannianGeometry (P0.2 fix).
200
+ return gamma, mu
201
+
202
+ def project(self, x: torch.Tensor) -> torch.Tensor:
203
+ return torch.atan2(torch.sin(x), torch.cos(x))
204
+
205
+ def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
206
+ diff = x1 - x2
207
+ return torch.norm(torch.atan2(torch.sin(diff), torch.cos(diff)), dim=-1)
208
+
209
+
210
+ @register_geometry('flat_torus')
211
+ class FlatToroidalRiemannianGeometry(BaseGeometry):
212
+ def __init__(self, dim: int = 64, rank: int = 16, num_heads: int = 1,
213
+ config: Optional[PhysicsConfig] = None):
214
+ super().__init__(config)
215
+ self.dim = dim
216
+ self.num_heads = num_heads
217
+ topo = self.config.topology
218
+
219
+ # CORRECCIÓN: Hacer R y r aprendibles también en FlatTorus
220
+ learnable_R = getattr(topo, 'learnable_R', True)
221
+ learnable_r = getattr(topo, 'learnable_r', True)
222
+
223
+ if learnable_R:
224
+ self.R = nn.Parameter(torch.tensor(topo.R, dtype=torch.float32))
225
+ else:
226
+ self.register_buffer('R', torch.tensor(topo.R, dtype=torch.float32))
227
+
228
+ if learnable_r:
229
+ self.r = nn.Parameter(torch.tensor(topo.r, dtype=torch.float32))
230
+ else:
231
+ self.register_buffer('r', torch.tensor(topo.r, dtype=torch.float32))
232
+
233
+ self.topology = topo.type.lower()
234
+ self.return_friction_separately = True
235
+ gate_input_dim = dim * 2
236
+ friction_mode = getattr(self.config.stability, 'friction_mode', 'static')
237
+ self.friction_gate = FrictionGate(
238
+ dim, gate_input_dim=gate_input_dim, mode=friction_mode, num_heads=num_heads
239
+ )
240
+
241
+ def validate_dimensions(self, x: torch.Tensor):
242
+ if x.shape[-1] % 2 != 0:
243
+ raise ValueError(
244
+ f"FlatToroidalGeometry requires even dim for (θ,φ) pairing. Got {x.shape[-1]}"
245
+ )
246
+
247
+ def metric_tensor(self, x: torch.Tensor) -> torch.Tensor:
248
+ self.validate_dimensions(x)
249
+ return torch.ones_like(x)
250
+
251
+ def forward(self, x: torch.Tensor, v: Optional[torch.Tensor] = None,
252
+ force: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
253
+ if v is None:
254
+ return torch.zeros_like(x)
255
+
256
+ self.validate_dimensions(x)
257
+ x_in = torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
258
+ mu = self.friction_gate(x_in, force=force)
259
+ if mu.shape[-1] != v.shape[-1]:
260
+ if mu.shape[-1] == 2 * v.shape[-1]:
261
+ mu = (mu[..., :v.shape[-1]] + mu[..., v.shape[-1]:]) / 2.0
262
+ else:
263
+ mu = mu.expand_as(v)
264
+ gamma = torch.zeros_like(v)
265
+ if self.return_friction_separately:
266
+ return gamma, mu
267
+ return gamma + mu * v
268
+
269
+ def project(self, x: torch.Tensor) -> torch.Tensor:
270
+ return torch.atan2(torch.sin(x), torch.cos(x))
271
+
272
+ def dist(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
273
+ diff = x1 - x2
274
+ return torch.norm(torch.atan2(torch.sin(diff), torch.cos(diff)), dim=-1)