Files changed (8) hide show
  1. LICENSE +12 -0
  2. README.md +20 -3
  3. configs/config.json +7 -0
  4. demo.ipynb +42 -0
  5. examples/example_usage.py +9 -0
  6. inference.py +20 -0
  7. model.py +858 -0
  8. requirements.txt +26 -0
LICENSE ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND.
README.md CHANGED
@@ -1,3 +1,20 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Model Repo
2
+
3
+ This repository contains the exported model from `awe.ipynb`.
4
+
5
+ ## Contents
6
+ - `model.py` : extracted model class definitions
7
+ - `inference.py` : load and run the model with PyTorch
8
+ - `demo.ipynb` : quick notebook demo for inference
9
+ - `requirements.txt` : dependencies
10
+ - `configs/config.json` : hyperparameters/config template
11
+ - `examples/` : sample usage scripts
12
+ - `scripts/upload_to_hf.py` : helper to push model to the Hub
13
+
14
+ ## Usage
15
+ ```bash
16
+ pip install -r requirements.txt
17
+ python inference.py
18
+ ```
19
+
20
+ Replace `YourModelClass` in `inference.py` with the actual class defined in `model.py`.
configs/config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "YourModelClass",
3
+ "input_shape": [1, 3, 224, 224],
4
+ "learning_rate": 0.001,
5
+ "batch_size": 32,
6
+ "epochs": 10
7
+ }
demo.ipynb ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Demo Notebook for Hugging Face Spaces\n",
8
+ "\n",
9
+ "This notebook demonstrates how to load and run inference with the extracted model."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import torch\n",
19
+ "from model import * # Import your model classes\n",
20
+ "from inference import load_model\n",
21
+ "\n",
22
+ "model = load_model(\"model.pt\")\n",
23
+ "x = torch.randn(1, 3, 224, 224)\n",
24
+ "y = model(x)\n",
25
+ "print('Output:', y.shape)"
26
+ ]
27
+ }
28
+ ],
29
+ "metadata": {
30
+ "kernelspec": {
31
+ "display_name": "Python 3",
32
+ "language": "python",
33
+ "name": "python3"
34
+ },
35
+ "language_info": {
36
+ "name": "python",
37
+ "version": "3.11"
38
+ }
39
+ },
40
+ "nbformat": 4,
41
+ "nbformat_minor": 2
42
+ }
examples/example_usage.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model import *
3
+ from inference import load_model
4
+
5
+ if __name__ == "__main__":
6
+ model = load_model("model.pt")
7
+ x = torch.randn(1, 3, 224, 224)
8
+ y = model(x)
9
+ print("Example output:", y.shape)
inference.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model import * # import extracted model classes
3
+
4
+ def load_model(weights_path="model.pt", device=None):
5
+ if device is None:
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+ # Replace 'YourModelClass' with the actual class name from model.py
8
+ model = YourModelClass()
9
+ state = torch.load(weights_path, map_location=device)
10
+ model.load_state_dict(state)
11
+ model.to(device)
12
+ model.eval()
13
+ return model
14
+
15
+ if __name__ == "__main__":
16
+ model = load_model()
17
+ # Example dummy input - adjust to your model's expected input
18
+ x = torch.randn(1, 3, 224, 224)
19
+ y = model(x)
20
+ print("Output shape:", y.shape)
model.py ADDED
@@ -0,0 +1,858 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Synthesized wrapper model file (inspect and adapt before use)
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ # --- extracted class 1 ---
6
+ class LossWeights:
7
+ lambda_task: float = 1.0
8
+ lambda_res: float = 0.5
9
+ lambda_ent: float = 0.2
10
+
11
+
12
+ # --- extracted class 2 ---
13
+ class RRF_Ultra_CNN(nn.Module):
14
+ def __init__(self, input_dim=1, output_dim=1):
15
+ super(RRF_Ultra_CNN, self).__init__()
16
+ self.conv1 = nn.Conv1d(input_dim, 64, kernel_size=3, padding=1)
17
+ self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
18
+ self.fc1 = nn.Linear(128*160, 256)
19
+ self.fc2 = nn.Linear(256, output_dim)
20
+
21
+ def forward(self, x):
22
+ x = F.relu(self.conv1(x))
23
+ x = F.relu(self.conv2(x))
24
+ x = torch.flatten(x, 1)
25
+ x = F.relu(self.fc1(x))
26
+ return torch.sigmoid(self.fc2(x))
27
+
28
+
29
+ # --- extracted class 3 ---
30
+ class SavantRRF_Gauge(nn.Module):
31
+ def __init__(self, input_dim, hidden_dim, output_dim):
32
+ super(SavantRRF_Gauge, self).__init__()
33
+ self.conv1 = nn.Conv1d(input_dim, 64, kernel_size=3, padding=1)
34
+ self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
35
+ self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
36
+ self.dropout = nn.Dropout(0.25)
37
+ # The input size to fc1 is based on the output size of conv3.
38
+ # Assuming input sequence length is 160, after 3 conv layers with kernel_size 3 and padding 1,
39
+ # the sequence length remains 160. 256 channels * 160 length = 40960.
40
+ self.fc1 = nn.Linear(256*160, 512) # Corrected input size based on sequence_length=160
41
+ self.fc2 = nn.Linear(512, 256)
42
+ self.fc3 = nn.Linear(256, output_dim)
43
+
44
+ def forward(self, x):
45
+ x = F.relu(self.conv1(x))
46
+ x = F.relu(self.conv2(x))
47
+ x = F.relu(self.conv3(x))
48
+ x = torch.flatten(x, 1)
49
+ x = self.dropout(x)
50
+ x = F.relu(self.fc1(x))
51
+ x = F.relu(self.fc2(x))
52
+ return torch.sigmoid(self.fc3(x))
53
+
54
+
55
+ # --- extracted class 4 ---
56
+ class DiracGraphConv(nn.Module):
57
+ def __init__(self, in_dim: int, out_dim: int, alpha: float = 1.0, bias: bool = True):
58
+ super().__init__()
59
+ self.lin = nn.Linear(in_dim, out_dim, bias=bias)
60
+ self.alpha = nn.Parameter(torch.tensor(alpha, dtype=torch.float32))
61
+ self.bias_edge = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
62
+
63
+ @staticmethod
64
+ def cosine_corr(z_i: torch.Tensor, z_j: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
65
+ num = (z_i * z_j).sum(dim=-1)
66
+ den = torch.clamp(z_i.norm(dim=-1) * z_j.norm(dim=-1), min=eps)
67
+ return num / den
68
+
69
+ def forward(self, x: torch.Tensor, edge_index: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
70
+ N = x.size(0)
71
+ row, col = edge_index
72
+ corr = self.cosine_corr(z[row], z[col])
73
+ logits = self.alpha * corr + self.bias_edge
74
+ device = x.device
75
+ E = row.size(0)
76
+ ones = torch.ones(E, device=device)
77
+ max_per_row = torch.full((N,), -1e9, device=device)
78
+ max_per_row = max_per_row.index_put((row,), logits, accumulate=False).scatter_reduce_(0, row, logits, reduce="amax")
79
+ logits_centered = logits - max_per_row[row]
80
+ exp_logits = torch.exp(logits_centered)
81
+ denom = torch.zeros(N, device=device).index_add_(0, row, exp_logits)
82
+ attn = exp_logits / (denom[row] + 1e-9)
83
+ deg = torch.zeros(N, device=device).index_add_(0, row, ones)
84
+ norm = 1.0 / torch.clamp(deg[row], min=1.0)
85
+ msgs = norm.unsqueeze(-1) * attn.unsqueeze(-1) * x[col]
86
+ out = torch.zeros_like(x).index_add_(0, row, msgs)
87
+ return self.lin(out)
88
+
89
+
90
+ # --- extracted class 5 ---
91
+ class GNNDiracRRF(nn.Module):
92
+ def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, num_layers: int, z_dim: int,
93
+ alpha_attn: float = 1.0, dropout: float = 0.1):
94
+ super().__init__()
95
+ self.z_dim = z_dim
96
+ self.layers = nn.ModuleList()
97
+ self.layers.append(DiracGraphConv(in_dim, hidden_dim, alpha=alpha_attn))
98
+ for _ in range(num_layers - 2):
99
+ self.layers.append(DiracGraphConv(hidden_dim, hidden_dim, alpha=alpha_attn))
100
+ self.layers.append(DiracGraphConv(hidden_dim, out_dim, alpha=alpha_attn))
101
+ self.dropout = nn.Dropout(dropout)
102
+
103
+ def forward(self, x: torch.Tensor, edge_index: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
104
+ h = x
105
+ for i, layer in enumerate(self.layers):
106
+ h = layer(h, edge_index, z)
107
+ if i < len(self.layers) - 1:
108
+ h = F.gelu(h)
109
+ h = self.dropout(h)
110
+ return h
111
+
112
+
113
+ # --- extracted class 6 ---
114
+ class LossWeights:
115
+ lambda_task: float = 1.0
116
+ lambda_res: float = 0.5
117
+ lambda_ent: float = 0.2
118
+
119
+
120
+ # --- extracted class 7 ---
121
+ class IcosahedralRRF(nn.Module):
122
+ def __init__(self, input_dim, hidden_dim, output_dim, gnn_num_layers=2, gnn_z_dim=16, gnn_alpha_attn=1.0, gnn_dropout=0.1):
123
+ super(IcosahedralRRF, self).__init__()
124
+ # 12 nodos gauge
125
+ self.nodes = nn.ModuleList([
126
+ SavantRRF_Gauge(input_dim, hidden_dim, output_dim) for _ in range(12)
127
+ ])
128
+ # Núcleo ético
129
+ # The input to ethical_core is the concatenation of the outputs of the 12 gauge nodes.
130
+ # Each gauge node outputs a tensor of shape [batch_size, output_dim].
131
+ # Concatenating these along dim=1 results in a shape [batch_size, 12 * output_dim].
132
+ self.ethical_core = nn.Linear(12 * output_dim, output_dim)
133
+
134
+ # Subconsciente (dodecaedro) using GNNDiracRRF
135
+ # The input dimension (in_dim) for the GNN should match the feature dimension of its input nodes.
136
+ # There's ambiguity in the original code about what the GNN's nodes and features are.
137
+ # Interpretation 1 (based on original code passing 'regulated'): GNN operates on 'batch_size' nodes, with 'output_dim' features. in_dim = output_dim.
138
+ # Interpretation 2 (more conventional for graph on icosahedron/dodecahedron): GNN operates on 12 or 20 nodes, with features derived from gauge outputs.
139
+ # Let's assume interpretation 2, where the GNN operates on the 12 gauge nodes.
140
+ # The features for each of these 12 nodes would be the output of the corresponding gauge node, shape [batch_size, output_dim].
141
+ # For a GNN layer expecting [num_nodes, in_channels], the input should be [12, output_dim] per batch item.
142
+ # This means the GNN's in_dim should be output_dim. This matches the current GNN init below.
143
+ # The GNN's out_dim should match the desired output feature dimension per node (e.g., output_dim).
144
+ # The number of nodes for the GNN is 12 (for icosahedral).
145
+
146
+ # Let's define the memory_map GNN assuming it operates on the 12 gauge nodes.
147
+ # The input features to the GNN will be the outputs of the 12 gauge nodes.
148
+ # Each gauge node outputs a tensor of shape [batch_size, output_dim].
149
+ # We will treat output_dim as the feature dimension for the GNN nodes (the 12 gauge nodes).
150
+ # So, in_dim for GNN = output_dim.
151
+ # The GNN will output features for each of the 12 nodes. Let's assume out_dim for GNN is also output_dim.
152
+ self.memory_map = GNNDiracRRF(in_dim=output_dim, # Feature dimension for GNN nodes (output_dim of gauge nodes)
153
+ hidden_dim=hidden_dim,
154
+ out_dim=output_dim, # Output feature dimension per GNN node
155
+ num_layers=gnn_num_layers,
156
+ z_dim=gnn_z_dim,
157
+ alpha_attn=gnn_alpha_attn,
158
+ dropout=gnn_dropout)
159
+
160
+
161
+ def forward(self, x, edge_index=None, z=None):
162
+ # x is the input to the gauge nodes, shape [batch_size, input_dim, sequence_length]
163
+ outputs = [node(x) for node in self.nodes]
164
+ # outputs is a list of 12 tensors, each [batch_size, output_dim]
165
+
166
+ # Concatenate outputs for the ethical core
167
+ concat = torch.cat(outputs, dim=1) # [batch_size, 12 * output_dim]
168
+ regulated = torch.sigmoid(self.ethical_core(concat)) # [batch_size, output_dim]
169
+
170
+ # GNN operation on the 12 gauge nodes
171
+ if edge_index is not None and z is not None:
172
+ # Prepare input for the GNN: Features for the 12 nodes (the gauge node outputs).
173
+ # Stack the outputs to get [batch_size, 12, output_dim]
174
+ stacked_outputs = torch.stack(outputs, dim=1) # [batch_size, 12, output_dim]
175
+
176
+ # Reshape for GNN input: [num_nodes, in_channels] = [12, output_dim] per batch item.
177
+ # Need to process batch items. Simplest is to iterate.
178
+ # A more efficient way is to use torch_geometric.data.Batch
179
+
180
+ gnn_outputs_list = []
181
+ for i in range(stacked_outputs.size(0)):
182
+ # GNN input features for this batch item: [12, output_dim]
183
+ gnn_input_features_i = stacked_outputs[i]
184
+
185
+ # Ensure edge_index and z are on the correct device
186
+ edge_index_i = edge_index.to(x.device)
187
+ z_i = z.to(x.device)
188
+
189
+ # GNN forward pass for one batch item
190
+ gnn_output_i = self.memory_map(gnn_input_features_i, edge_index_i, z_i) # [12, output_dim]
191
+ gnn_outputs_list.append(gnn_output_i)
192
+
193
+ # Stack GNN outputs back into a batch tensor: [batch_size, 12, output_dim]
194
+ gnn_outputs_stacked = torch.stack(gnn_outputs_list, dim=0)
195
+
196
+ # Now, how to combine the GNN output [batch_size, 12, output_dim] with the 'regulated' output [batch_size, output_dim]?
197
+ # The original model returned just 'regulated'.
198
+ # A simple approach is to maybe combine them, e.g., add, concatenate, or use the GNN output as a modulation.
199
+ # Let's stick to returning the aggregated GNN output as the final output when GNN is used.
200
+ # This changes the model's behavior compared to the original.
201
+
202
+ # Alternative: The GNN output modulates the 'regulated' output.
203
+ # E.g., regulated * sigmoid(aggregated_gnn_output) or similar.
204
+ # Let's stick to returning the aggregated GNN output when edge_index and z are provided,
205
+ # and the original 'regulated' output otherwise. This seems the most direct path based on the conditional in the original forward.
206
+
207
+ # Aggregate the 12 nodes' outputs from the GNN
208
+ aggregated_gnn_output = gnn_outputs_stacked.mean(dim=1) # [batch_size, output_dim]
209
+
210
+ return aggregated_gnn_output # [batch_size, output_dim]
211
+
212
+ else:
213
+ # If edge_index and z are not provided, return the output of the ethical core as before.
214
+ return regulated
215
+
216
+
217
+ # --- extracted class 8 ---
218
+ class RRF_Ultra_CNN(nn.Module):
219
+ def __init__(self, input_dim=1, output_dim=1):
220
+ super(RRF_Ultra_CNN, self).__init__()
221
+ self.conv1 = nn.Conv1d(input_dim, 64, kernel_size=3, padding=1)
222
+ self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
223
+ self.fc1 = nn.Linear(128*160, 256)
224
+ self.fc2 = nn.Linear(256, output_dim)
225
+
226
+ def forward(self, x):
227
+ x = F.relu(self.conv1(x))
228
+ x = F.relu(self.conv2(x))
229
+ x = torch.flatten(x, 1)
230
+ x = F.relu(self.fc1(x))
231
+ return torch.sigmoid(self.fc2(x))
232
+
233
+
234
+ # --- extracted class 9 ---
235
+ class SavantRRF_Gauge(nn.Module):
236
+ def __init__(self, input_dim, hidden_dim, output_dim):
237
+ super(SavantRRF_Gauge, self).__init__()
238
+ self.conv1 = nn.Conv1d(input_dim, 64, kernel_size=3, padding=1)
239
+ self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
240
+ self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
241
+ self.dropout = nn.Dropout(0.25)
242
+ # The input size to fc1 is based on the output size of conv3.
243
+ # Assuming input sequence length is 160, after 3 conv layers with kernel_size 3 and padding 1,
244
+ # the sequence length remains 160. 256 channels * 160 length = 40960.
245
+ self.fc1 = nn.Linear(256*160, 512) # Corrected input size based on sequence_length=160
246
+ self.fc2 = nn.Linear(512, 256)
247
+ self.fc3 = nn.Linear(256, output_dim)
248
+
249
+ def forward(self, x):
250
+ x = F.relu(self.conv1(x))
251
+ x = F.relu(self.conv2(x))
252
+ x = F.relu(self.conv3(x))
253
+ x = torch.flatten(x, 1)
254
+ x = self.dropout(x)
255
+ x = F.relu(self.fc1(x))
256
+ x = F.relu(self.fc2(x))
257
+ return torch.sigmoid(self.fc3(x))
258
+
259
+
260
+ # --- extracted class 10 ---
261
+ class DiracGraphConv(nn.Module):
262
+ def __init__(self, in_dim: int, out_dim: int, alpha: float = 1.0, bias: bool = True):
263
+ super().__init__()
264
+ self.lin = nn.Linear(in_dim, out_dim, bias=bias)
265
+ self.alpha = nn.Parameter(torch.tensor(alpha, dtype=torch.float32))
266
+ self.bias_edge = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
267
+
268
+ @staticmethod
269
+ def cosine_corr(z_i: torch.Tensor, z_j: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
270
+ num = (z_i * z_j).sum(dim=-1)
271
+ den = torch.clamp(z_i.norm(dim=-1) * z_j.norm(dim=-1), min=eps)
272
+ return num / den
273
+
274
+ def forward(self, x: torch.Tensor, edge_index: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
275
+ N = x.size(0)
276
+ row, col = edge_index
277
+ corr = self.cosine_corr(z[row], z[col])
278
+ logits = self.alpha * corr + self.bias_edge
279
+ device = x.device
280
+ E = row.size(0)
281
+ ones = torch.ones(E, device=device)
282
+ max_per_row = torch.full((N,), -1e9, device=device)
283
+ max_per_row = max_per_row.index_put((row,), logits, accumulate=False).scatter_reduce_(0, row, logits, reduce="amax")
284
+ logits_centered = logits - max_per_row[row]
285
+ exp_logits = torch.exp(logits_centered)
286
+ denom = torch.zeros(N, device=device).index_add_(0, row, exp_logits)
287
+ attn = exp_logits / (denom[row] + 1e-9)
288
+ deg = torch.zeros(N, device=device).index_add_(0, row, ones)
289
+ norm = 1.0 / torch.clamp(deg[row], min=1.0)
290
+ msgs = norm.unsqueeze(-1) * attn.unsqueeze(-1) * x[col]
291
+ out = torch.zeros_like(x).index_add_(0, row, msgs)
292
+ return self.lin(out)
293
+
294
+
295
+ # --- extracted class 11 ---
296
+ class GNNDiracRRF(nn.Module):
297
+ def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, num_layers: int, z_dim: int,
298
+ alpha_attn: float = 1.0, dropout: float = 0.1):
299
+ super().__init__()
300
+ self.z_dim = z_dim
301
+ self.layers = nn.ModuleList()
302
+ self.layers.append(DiracGraphConv(in_dim, hidden_dim, alpha=alpha_attn))
303
+ for _ in range(num_layers - 2):
304
+ self.layers.append(DiracGraphConv(hidden_dim, hidden_dim, alpha=alpha_attn))
305
+ self.layers.append(DiracGraphConv(hidden_dim, out_dim, alpha=alpha_attn))
306
+ self.dropout = nn.Dropout(dropout)
307
+
308
+ def forward(self, x: torch.Tensor, edge_index: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
309
+ h = x
310
+ for i, layer in enumerate(self.layers):
311
+ h = layer(h, edge_index, z)
312
+ if i < len(self.layers) - 1:
313
+ h = F.gelu(h)
314
+ h = self.dropout(h)
315
+ return h
316
+
317
+
318
+ # --- extracted class 12 ---
319
+ class LossWeights:
320
+ lambda_task: float = 1.0
321
+ lambda_res: float = 0.5
322
+ lambda_ent: float = 0.2
323
+
324
+
325
+ # --- extracted class 13 ---
326
+ class IcosahedralRRF(nn.Module):
327
+ def __init__(self, input_dim, hidden_dim, output_dim, gnn_num_layers=2, gnn_z_dim=16, gnn_alpha_attn=1.0, gnn_dropout=0.1):
328
+ super(IcosahedralRRF, self).__init__()
329
+ # 12 nodos gauge
330
+ self.nodes = nn.ModuleList([
331
+ SavantRRF_Gauge(input_dim, hidden_dim, output_dim) for _ in range(12)
332
+ ])
333
+ # Núcleo ético
334
+ # The input to ethical_core is the concatenation of the outputs of the 12 gauge nodes.
335
+ # Each gauge node outputs a tensor of shape [batch_size, output_dim].
336
+ # Concatenating these along dim=1 results in a shape [batch_size, 12 * output_dim].
337
+ self.ethical_core = nn.Linear(12 * output_dim, output_dim)
338
+
339
+ # Subconsciente (dodecaedro) using GNNDiracRRF
340
+ # The input dimension (in_dim) for the GNN should match the feature dimension of its input nodes.
341
+ # There's ambiguity in the original code about what the GNN's nodes and features are.
342
+ # Interpretation 1 (based on original code passing 'regulated'): GNN operates on 'batch_size' nodes, with 'output_dim' features. in_dim = output_dim.
343
+ # Interpretation 2 (more conventional for graph on icosahedron/dodecahedron): GNN operates on 12 or 20 nodes, with features derived from gauge outputs.
344
+ # Let's assume interpretation 2, where the GNN operates on the 12 gauge nodes.
345
+ # The features for each of these 12 nodes would be the output of the corresponding gauge node, shape [batch_size, output_dim].
346
+ # For a GNN layer expecting [num_nodes, in_channels], the input should be [12, output_dim] per batch item.
347
+ # This means the GNN's in_dim should be output_dim. This matches the current GNN init below.
348
+ # The GNN's out_dim should match the desired output feature dimension per node (e.g., output_dim).
349
+ # The number of nodes for the GNN is 12 (for icosahedral).
350
+
351
+ # Let's define the memory_map GNN assuming it operates on the 12 gauge nodes.
352
+ # The input features to the GNN will be the outputs of the 12 gauge nodes.
353
+ # Each gauge node outputs a tensor of shape [batch_size, output_dim].
354
+ # We will treat output_dim as the feature dimension for the GNN nodes (the 12 gauge nodes).
355
+ # So, in_dim for GNN = output_dim.
356
+ # The GNN will output features for each of the 12 nodes. Let's assume out_dim for GNN is also output_dim.
357
+ self.memory_map = GNNDiracRRF(in_dim=output_dim, # Feature dimension for GNN nodes (output_dim of gauge nodes)
358
+ hidden_dim=hidden_dim,
359
+ out_dim=output_dim, # Output feature dimension per GNN node
360
+ num_layers=gnn_num_layers,
361
+ z_dim=gnn_z_dim,
362
+ alpha_attn=gnn_alpha_attn,
363
+ dropout=gnn_dropout)
364
+
365
+
366
+ def forward(self, x, edge_index=None, z=None):
367
+ # x is the input to the gauge nodes, shape [batch_size, input_dim, sequence_length]
368
+ outputs = [node(x) for node in self.nodes]
369
+ # outputs is a list of 12 tensors, each [batch_size, output_dim]
370
+
371
+ # Concatenate outputs for the ethical core
372
+ concat = torch.cat(outputs, dim=1) # [batch_size, 12 * output_dim]
373
+ regulated = torch.sigmoid(self.ethical_core(concat)) # [batch_size, output_dim]
374
+
375
+ # GNN operation on the 12 gauge nodes
376
+ if edge_index is not None and z is not None:
377
+ # Prepare input for the GNN: Features for the 12 nodes (the gauge node outputs).
378
+ # Stack the outputs to get [batch_size, 12, output_dim]
379
+ stacked_outputs = torch.stack(outputs, dim=1) # [batch_size, 12, output_dim]
380
+
381
+ # Reshape for GNN input: [num_nodes, in_channels] = [12, output_dim] per batch item.
382
+ # Need to process batch items. Simplest is to iterate.
383
+ # A more efficient way is to use torch_geometric.data.Batch
384
+
385
+ gnn_outputs_list = []
386
+ for i in range(stacked_outputs.size(0)):
387
+ # GNN input features for this batch item: [12, output_dim]
388
+ gnn_input_features_i = stacked_outputs[i]
389
+
390
+ # Ensure edge_index and z are on the correct device
391
+ edge_index_i = edge_index.to(x.device)
392
+ z_i = z.to(x.device)
393
+
394
+ # GNN forward pass for one batch item
395
+ gnn_output_i = self.memory_map(gnn_input_features_i, edge_index_i, z_i) # [12, output_dim]
396
+ gnn_outputs_list.append(gnn_output_i)
397
+
398
+ # Stack GNN outputs back into a batch tensor: [batch_size, 12, output_dim]
399
+ gnn_outputs_stacked = torch.stack(gnn_outputs_list, dim=0)
400
+
401
+ # Now, how to combine the GNN output [batch_size, 12, output_dim] with the 'regulated' output [batch_size, output_dim]?
402
+ # The original model returned just 'regulated'.
403
+ # A simple approach is to maybe combine them, e.g., add, concatenate, or use the GNN output as a modulation.
404
+ # Let's stick to returning the aggregated GNN output as the final output when GNN is used.
405
+ # This changes the model's behavior compared to the original.
406
+
407
+ # Alternative: The GNN output modulates the 'regulated' output.
408
+ # E.g., regulated * sigmoid(aggregated_gnn_output) or similar.
409
+ # Let's stick to returning the aggregated GNN output when edge_index and z are provided,
410
+ # and the original 'regulated' output otherwise. This seems the most direct path based on the conditional in the original forward.
411
+
412
+ # Aggregate the 12 nodes' outputs from the GNN
413
+ aggregated_gnn_output = gnn_outputs_stacked.mean(dim=1) # [batch_size, output_dim]
414
+
415
+ return aggregated_gnn_output # [batch_size, output_dim]
416
+
417
+ else:
418
+ # If edge_index and z are not provided, return the output of the ethical core as before.
419
+ return regulated
420
+
421
+
422
+ # --- extracted class 14 ---
423
+ class LossWeights:
424
+ lambda_task: float = 1.0
425
+ lambda_res: float = 0.5
426
+ lambda_ent: float = 0.2
427
+
428
+
429
+ # --- extracted class 15 ---
430
+ class IcosahedralRRFDataset(InMemoryDataset):
431
+ def __init__(self, num_graphs: int = 64, k_modes: int = 16, feat_dim: int = 8,
432
+ task_type: str = 'classification', split: str = 'train', transform=None, pre_transform=None):
433
+ super().__init__('.', transform, pre_transform)
434
+ self.task_type = task_type
435
+ self.num_graphs = num_graphs
436
+ self.k_modes = k_modes
437
+ self.feat_dim = feat_dim
438
+
439
+ # Generate graphs and process them
440
+ data_list = []
441
+ rng = np.random.default_rng(42 if split == 'train' else (43 if split == 'val' else 44))
442
+
443
+ for i in range(num_graphs):
444
+ G = nx.icosahedral_graph()
445
+ n_nodes = G.number_of_nodes()
446
+
447
+ # Build Dirac operator and compute spectral modes
448
+ D = build_dirac_operator(G, normalize=True)
449
+ # Use the modified dirac_eigendecomp that uses np.linalg.eigh
450
+ vals, vecs = dirac_eigendecomp(D, k=k_modes)
451
+ Z = node_spectral_coords_from_dirac(vecs, n_nodes) # N x k
452
+
453
+ # Get edge index
454
+ edge_list = list(G.edges())
455
+ edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
456
+ # Add reverse edges for undirected graph
457
+ row, col = edge_index
458
+ edge_index = torch.cat([edge_index, torch.stack([col, row], dim=0)], dim=1)
459
+
460
+ # Generate synthetic node features (x) and labels (y)
461
+ # Features: [n_nodes, feat_dim]
462
+ x = torch.randn(n_nodes, feat_dim, dtype=torch.float32)
463
+
464
+ # Labels: based on task_type
465
+ if task_type == 'classification':
466
+ # Example: Binary classification based on a simple rule, e.g., sum of features > threshold
467
+ threshold = 0.0 # Example threshold
468
+ y = (x.sum(dim=-1) > threshold).long() # [n_nodes]
469
+ elif task_type == 'regression':
470
+ # Example: Regression target based on sum of features
471
+ y = x.sum(dim=-1) # [n_nodes]
472
+ else:
473
+ raise ValueError("task_type must be 'classification' or 'regression'")
474
+
475
+
476
+ # Create Data object
477
+ # Note: The IcosahedralRRF model expects input 'x' as [batch_size, input_dim, sequence_length],
478
+ # edge_index [2, num_edges], and z [num_nodes, z_dim].
479
+ # The IcosahedralRRFDataset provides batch.x [num_nodes, feat_dim], batch.edge_index [2, num_edges], and batch.U [num_nodes, k_modes].
480
+ # There is a mismatch in the expected input format for the IcosahedralRRF model's forward pass when using the DataLoader.
481
+ # The IcosahedralRRF expects a single batch tensor `x` for the gauge nodes, and graph data (edge_index, z) for the GNN part which operates on gauge outputs.
482
+ # The IcosahedralRRFDataset provides node features `batch.x` that are intended as features *for the graph nodes themselves*, not as input to the gauge nodes.
483
+ # The current IcosahedralRRF forward pass processes a single input `x` [batch_size, input_dim, sequence_length] through all gauge nodes.
484
+ # The GNN then operates on the *outputs* of these gauge nodes, using the provided edge_index and z.
485
+
486
+ # To use the IcosahedralRRFDataset with the current IcosahedralRRF model structure,
487
+ # we need to map the dataset's structure to the model's expectations.
488
+ # The dataset provides graphs, each with nodes (typically 12 for icosahedral), node features (batch.x), edge_index, and spectral coords (batch.U).
489
+ # The IcosahedralRRF model has 12 gauge nodes, each designed to process a sequence [input_dim, sequence_length].
490
+ # It seems there is a conceptual mismatch in how the IcosahedralRRFDataset is structured (graph-centric with node features)
491
+ # and how the IcosahedralRRF model processes input (sequence-centric through gauge nodes first).
492
+
493
+ # Alternative Interpretation: The IcosahedralRRFDataset is meant to provide data where each *graph* is a sample in the batch.
494
+ # batch.x would be the concatenated node features for all graphs in the batch: [total_num_nodes_in_batch, feat_dim].
495
+ # batch.edge_index would be the block-diagonal edge indices for all graphs: [2, total_num_edges_in_batch].
496
+ # batch.U would be the concatenated spectral coordinates for all nodes: [total_num_nodes_in_batch, k_modes].
497
+ # In this case, the input to the IcosahedralRRF model's forward pass is still expected to be a single tensor `x` for the gauge nodes.
498
+ # The IcosahedralRRFDataset does *not* provide this `x` input directly in the expected format.
499
+
500
+ # There is a fundamental incompatibility in how the IcosahedralRRFDataset provides data (graph-batching)
501
+ # and how the IcosahedralRRF model expects input (single batch of sequences + graph data for GNN).
502
+
503
+ # To make this cell runnable, we need to either:
504
+ # 1. Modify the IcosahedralRRF model's forward pass to handle graph batches from DataLoader.
505
+ # 2. Modify the IcosahedralRRFDataset or create a custom Dataset/DataLoader that provides data in the format expected by the IcosahedralRRF model.
506
+ # 3. Use a simplified evaluation approach that aligns with the synthetic data generation method used in the training loop (single batch).
507
+
508
+ # Given the current structure, the simplest approach to get the cell running is to align the evaluation data generation
509
+ # with the training data generation (single synthetic batch) and evaluate on that.
510
+ # This bypasses the DataLoader incompatibility but doesn't fully test with graph batching.
511
+
512
+ # Let's revert to generating a single synthetic batch for evaluation, similar to training.
513
+ # This requires defining x_val and y_val outside the DataLoader loop.
514
+
515
+ # Reverting the evaluation loop to use the single synthetic batch approach:
516
+
517
+ # Check if x_val and y_val are defined (from previous code cell)
518
+ if 'x_val' not in locals() or 'y_val' not in locals():
519
+ # Generate synthetic validation data if not already defined
520
+ val_batch_size = 16 # Example validation batch size
521
+ x_val = torch.randn(val_batch_size, input_dim, sequence_length, dtype=torch.float32).to(device)
522
+ y_val = torch.randint(0, 2, (val_batch_size,), dtype=torch.long).to(device) # Binary labels
523
+ print("Generated synthetic validation data for evaluation.")
524
+
525
+ # Ensure z and edge_index are on the correct device
526
+ if 'z' in locals() and 'edge_index' in locals():
527
+ z = z.to(device)
528
+ edge_index = edge_index.to(device)
529
+ else:
530
+ print("⚠️ Warning: Graph data (z, edge_index) not found. Skipping evaluation.")
531
+ # Exit the evaluation block if graph data is missing
532
+ # break # This will exit the with torch.no_grad(): block - REMOVED/COMMENTED OUT DUE TO SyntaxError
533
+ pass # Use pass instead of break to avoid SyntaxError outside a loop
534
+
535
+ # Forward pass on validation data using the single batch
536
+ # Pass the validation input features (x_val), edge index, and spectral coordinates (z) through the model
537
+ val_outputs = hybrid_model(x_val, edge_index, z) # Shape: [val_batch_size, output_dim]
538
+
539
+ # Calculate the validation loss (using BCEWithLogitsLoss as corrected in training)
540
+ val_loss = F.binary_cross_entropy_with_logits(val_outputs.squeeze(-1), y_val.float())
541
+
542
+ # Calculate evaluation metrics (e.g., accuracy for binary classification)
543
+ # Convert logits to predicted class (0 or 1)
544
+ predicted_classes = (torch.sigmoid(val_outputs.squeeze(-1)) > 0.5).long()
545
+
546
+ # Calculate accuracy
547
+ correct_predictions = (predicted_classes == y_val).sum().item()
548
+ accuracy = correct_predictions / val_batch_size
549
+
550
+ print(f'Validation Loss: {val_loss.item():.4f}, Validation Accuracy: {accuracy:.4f}')
551
+
552
+
553
+ # --- extracted class 16 ---
554
+ class DiracGraphConv(nn.Module):
555
+ def __init__(self, in_dim: int, out_dim: int, alpha: float = 1.0, bias: bool = True):
556
+ super().__init__()
557
+ self.lin = nn.Linear(in_dim, out_dim, bias=bias)
558
+ self.alpha = nn.Parameter(torch.tensor(alpha, dtype=torch.float32))
559
+ self.bias_edge = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
560
+
561
+ @staticmethod
562
+ def cosine_corr(z_i: torch.Tensor, z_j: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
563
+ num = (z_i * z_j).sum(dim=-1)
564
+ den = torch.clamp(z_i.norm(dim=-1) * z_j.norm(dim=-1), min=eps)
565
+ return num / den
566
+
567
+ def forward(self, x: torch.Tensor, edge_index: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
568
+ N = x.size(0)
569
+ row, col = edge_index
570
+ corr = self.cosine_corr(z[row], z[col])
571
+ logits = self.alpha * corr + self.bias_edge
572
+ device = x.device
573
+ E = row.size(0)
574
+ ones = torch.ones(E, device=device)
575
+ max_per_row = torch.full((N,), -1e9, device=device)
576
+ max_per_row = max_per_row.index_put((row,), logits, accumulate=False).scatter_reduce_(0, row, logits, reduce="amax")
577
+ logits_centered = logits - max_per_row[row]
578
+ exp_logits = torch.exp(logits_centered)
579
+ denom = torch.zeros(N, device=device).index_add_(0, row, exp_logits)
580
+ attn = exp_logits / (denom[row] + 1e-9)
581
+ deg = torch.zeros(N, device=device).index_add_(0, row, ones)
582
+ norm = 1.0 / torch.clamp(deg[row], min=1.0)
583
+ msgs = norm.unsqueeze(-1) * attn.unsqueeze(-1) * x[col]
584
+ out = torch.zeros_like(x).index_add_(0, row, msgs)
585
+ return self.lin(out)
586
+
587
+
588
+ # --- extracted class 17 ---
589
+ class GNNDiracRRF(nn.Module):
590
+ def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, num_layers: int, z_dim: int,
591
+ alpha_attn: float = 1.0, dropout: float = 0.1):
592
+ super().__init__()
593
+ self.z_dim = z_dim
594
+ self.layers = nn.ModuleList()
595
+ self.layers.append(DiracGraphConv(in_dim, hidden_dim, alpha=alpha_attn))
596
+ for _ in range(num_layers - 2):
597
+ self.layers.append(DiracGraphConv(hidden_dim, hidden_dim, alpha=alpha_attn))
598
+ self.layers.append(DiracGraphConv(hidden_dim, out_dim, alpha=alpha_attn))
599
+ self.dropout = nn.Dropout(dropout)
600
+
601
+ def forward(self, x: torch.Tensor, edge_index: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
602
+ h = x
603
+ for i, layer in enumerate(self.layers):
604
+ h = layer(h, edge_index, z)
605
+ if i < len(self.layers) - 1:
606
+ h = F.gelu(h)
607
+ h = self.dropout(h)
608
+ return h
609
+
610
+
611
+ # --- extracted class 18 ---
612
+ class DiracGraphConv(nn.Module):
613
+ def __init__(self, in_dim: int, out_dim: int, alpha: float = 1.0, bias: bool = True):
614
+ super().__init__()
615
+ self.lin = nn.Linear(in_dim, out_dim, bias=bias)
616
+ self.alpha = nn.Parameter(torch.tensor(alpha, dtype=torch.float32))
617
+ self.bias_edge = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
618
+
619
+ @staticmethod
620
+ def cosine_corr(z_i: torch.Tensor, z_j: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
621
+ num = (z_i * z_i).sum(dim=-1) # Corrected dot product: z_i * z_j
622
+ den = torch.clamp(z_i.norm(dim=-1) * z_j.norm(dim=-1), min=eps)
623
+ return num / den
624
+
625
+ def forward(self, x: torch.Tensor, edge_index: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
626
+ N = x.size(0)
627
+ row, col = edge_index
628
+ corr = self.cosine_corr(z[row], z[col])
629
+ logits = self.alpha * corr + self.bias_edge
630
+ device = x.device
631
+ E = row.size(0)
632
+ ones = torch.ones(E, device=device)
633
+ max_per_row = torch.full((N,), -1e9, device=device)
634
+ max_per_row = max_per_row.index_put((row,), logits, accumulate=False).scatter_reduce_(0, row, logits, reduce="amax")
635
+ logits_centered = logits - max_per_row[row]
636
+ exp_logits = torch.exp(logits_centered)
637
+ denom = torch.zeros(N, device=device).index_add_(0, row, exp_logits)
638
+ attn = exp_logits / (denom[row] + 1e-9)
639
+ deg = torch.zeros(N, device=device).index_add_(0, row, ones)
640
+ norm = 1.0 / torch.clamp(deg[row], min=1.0)
641
+ msgs = norm.unsqueeze(-1) * attn.unsqueeze(-1) * x[col]
642
+ out = torch.zeros_like(x).index_add_(0, row, msgs)
643
+ return self.lin(out)
644
+
645
+
646
+ # --- extracted class 19 ---
647
+ class GNNDiracRRF(nn.Module):
648
+ def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, num_layers: int, z_dim: int,
649
+ alpha_attn: float = 1.0, dropout: float = 0.1):
650
+ super().__init__()
651
+ self.z_dim = z_dim
652
+ self.layers = nn.ModuleList()
653
+ self.layers.append(DiracGraphConv(in_dim, hidden_dim, alpha=alpha_attn))
654
+ for _ in range(num_layers - 2):
655
+ self.layers.append(DiracGraphConv(hidden_dim, hidden_dim, alpha=alpha_attn))
656
+ self.layers.append(DiracGraphConv(hidden_dim, out_dim, alpha=alpha_attn))
657
+ self.dropout = nn.Dropout(dropout)
658
+
659
+ def forward(self, x: torch.Tensor, edge_index: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
660
+ h = x
661
+ for i, layer in enumerate(self.layers):
662
+ h = layer(h, edge_index, z)
663
+ if i < len(self.layers) - 1:
664
+ h = F.gelu(h)
665
+ h = self.dropout(h)
666
+ return h
667
+
668
+
669
+ # --- extracted class 20 ---
670
+ class SavantRRF_Gauge(nn.Module):
671
+ def __init__(self, input_dim, hidden_dim, output_dim):
672
+ super(SavantRRF_Gauge, self).__init__()
673
+ self.conv1 = nn.Conv1d(input_dim, 64, kernel_size=3, padding=1)
674
+ self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
675
+ self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
676
+ self.dropout = nn.Dropout(0.25)
677
+ # Assuming input sequence length is 160
678
+ self.fc1 = nn.Linear(256*160, 512)
679
+ self.fc2 = nn.Linear(512, 256)
680
+ self.fc3 = nn.Linear(256, output_dim)
681
+
682
+ def forward(self, x):
683
+ x = F.relu(self.conv1(x))
684
+ x = F.relu(self.conv2(x))
685
+ x = F.relu(self.conv3(x))
686
+ x = torch.flatten(x, 1)
687
+ x = self.dropout(x)
688
+ x = F.relu(self.fc1(x))
689
+ x = F.relu(self.fc2(x))
690
+ return torch.sigmoid(self.fc3(x))
691
+
692
+
693
+ # --- extracted class 21 ---
694
+ class IcosahedralRRF(nn.Module):
695
+ def __init__(self, input_dim, hidden_dim, output_dim, gnn_num_layers=2, gnn_z_dim=16, gnn_alpha_attn=1.0, gnn_dropout=0.1):
696
+ super(IcosahedralRRF, self).__init__()
697
+ # 12 nodos gauge
698
+ self.nodes = nn.ModuleList([
699
+ SavantRRF_Gauge(input_dim, hidden_dim, output_dim) for _ in range(12)
700
+ ])
701
+ # Núcleo ético
702
+ self.ethical_core = nn.Linear(12 * output_dim, output_dim)
703
+
704
+ # Subconsciente (dodecaedro/icosaedro) using GNNDiracRRF
705
+ # The GNN operates on the 12 gauge node outputs.
706
+ # The input features to the GNN are the outputs of the 12 gauge nodes, shape [batch_size, output_dim].
707
+ # For GNN layer, input is [num_nodes, in_channels] = [12, output_dim] per batch item.
708
+ self.memory_map = GNNDiracRRF(in_dim=output_dim,
709
+ hidden_dim=hidden_dim,
710
+ out_dim=output_dim,
711
+ num_layers=gnn_num_layers,
712
+ z_dim=gnn_z_dim,
713
+ alpha_attn=gnn_alpha_attn,
714
+ dropout=gnn_dropout)
715
+
716
+
717
+ def forward(self, x, edge_index=None, z=None):
718
+ # x is the input to the gauge nodes, shape [batch_size, input_dim, sequence_length]
719
+ outputs = [node(x) for node in self.nodes]
720
+ # outputs is a list of 12 tensors, each [batch_size, output_dim]
721
+
722
+ # Concatenate outputs for the ethical core
723
+ concat = torch.cat(outputs, dim=1) # [batch_size, 12 * output_dim]
724
+ regulated = torch.sigmoid(self.ethical_core(concat)) # [batch_size, output_dim]
725
+
726
+ # GNN operation on the 12 gauge nodes
727
+ if edge_index is not None and z is not None:
728
+ # Prepare input for the GNN: Features for the 12 nodes (the gauge node outputs).
729
+ stacked_outputs = torch.stack(outputs, dim=1) # [batch_size, 12, output_dim]
730
+
731
+ gnn_outputs_list = []
732
+ for i in range(stacked_outputs.size(0)):
733
+ gnn_input_features_i = stacked_outputs[i]
734
+ edge_index_i = edge_index.to(x.device)
735
+ z_i = z.to(x.device)
736
+ gnn_output_i = self.memory_map(gnn_input_features_i, edge_index_i, z_i) # [12, output_dim]
737
+ gnn_outputs_list.append(gnn_output_i)
738
+
739
+ gnn_outputs_stacked = torch.stack(gnn_outputs_list, dim=0)
740
+ aggregated_gnn_output = gnn_outputs_stacked.mean(dim=1) # [batch_size, output_dim]
741
+
742
+ return aggregated_gnn_output # [batch_size, output_dim]
743
+
744
+ else:
745
+ return regulated
746
+
747
+
748
+ # --- extracted class 22 ---
749
+ class DiracGraphConv(nn.Module):
750
+ def __init__(self, in_dim: int, out_dim: int, alpha: float = 1.0, bias: bool = True):
751
+ super().__init__()
752
+ self.lin = nn.Linear(in_dim, out_dim, bias=bias)
753
+ self.alpha = nn.Parameter(torch.tensor(alpha, dtype=torch.float32))
754
+ self.bias_edge = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
755
+
756
+ @staticmethod
757
+ def cosine_corr(z_i: torch.Tensor, z_j: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
758
+ num = (z_i * z_j).sum(dim=-1)
759
+ den = torch.clamp(z_i.norm(dim=-1) * z_j.norm(dim=-1), min=eps)
760
+ return num / den
761
+
762
+ def forward(self, x: torch.Tensor, edge_index: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
763
+ N = x.size(0)
764
+ row, col = edge_index
765
+ # Ensure z has correct shape for cosine_corr
766
+ # z should have shape [num_nodes, z_dim]
767
+ # x has shape [num_nodes, in_dim]
768
+ # When called from GNNDiracRRF, num_nodes is 12 (for icosahedral)
769
+ # z[row] and z[col] should broadcast correctly with x[col]
770
+ corr = self.cosine_corr(z[row], z[col])
771
+ logits = self.alpha * corr + self.bias_edge
772
+ device = x.device
773
+ E = row.size(0)
774
+ ones = torch.ones(E, device=device)
775
+ # Use scatter_reduce_ to calculate max per row
776
+ max_per_row = torch.full((N,), -1e9, device=device)
777
+ max_per_row = max_per_row.index_put((row,), logits, accumulate=False).scatter_reduce_(0, row, logits, reduce="amax")
778
+ logits_centered = logits - max_per_row[row]
779
+ exp_logits = torch.exp(logits_centered)
780
+ denom = torch.zeros(N, device=device).index_add_(0, row, exp_logits)
781
+ attn = exp_logits / (denom[row] + 1e-9)
782
+ deg = torch.zeros(N, device=device).index_add_(0, row, ones)
783
+ norm = 1.0 / torch.clamp(deg[row], min=1.0)
784
+ msgs = norm.unsqueeze(-1) * attn.unsqueeze(-1) * x[col]
785
+ out = torch.zeros_like(x).index_add_(0, row, msgs)
786
+ return self.lin(out)
787
+
788
+
789
+ # --- extracted class 23 ---
790
+ class GNNDiracRRF(nn.Module):
791
+ def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, num_layers: int, z_dim: int,
792
+ alpha_attn: float = 1.0, dropout: float = 0.1):
793
+ super().__init__()
794
+ self.z_dim = z_dim
795
+ self.layers = nn.ModuleList()
796
+ # Ensure DiracGraphConv is defined before this line
797
+ self.layers.append(DiracGraphConv(in_dim, hidden_dim, alpha=alpha_attn))
798
+ for _ in range(num_layers - 2):
799
+ self.layers.append(DiracGraphConv(hidden_dim, hidden_dim, alpha=alpha_attn))
800
+ self.layers.append(DiracGraphConv(hidden_dim, out_dim, alpha=alpha_attn))
801
+ self.dropout = nn.Dropout(dropout)
802
+
803
+ def forward(self, x: torch.Tensor, edge_index: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
804
+ h = x
805
+ for i, layer in enumerate(self.layers):
806
+ h = layer(h, edge_index, z)
807
+ if i < len(self.layers) - 1:
808
+ h = F.gelu(h)
809
+ h = self.dropout(h)
810
+ return h
811
+
812
+
813
+ # --- extracted class 24 ---
814
+ class RRF_Dataset(Dataset):
815
+ def __init__(self, strain, weights, seq_len=160): # Use seq_len=160 to match model input
816
+ self.seq_len = seq_len
817
+ self.strain = strain
818
+ self.weights = weights
819
+ print(f"Debug: RRF_Dataset __init__ - len(strain): {len(strain)}, seq_len: {self.seq_len}") # Debug print
820
+ # Calculate n only if strain is long enough
821
+ if len(strain) >= seq_len:
822
+ self.n = len(strain) // seq_len
823
+ else:
824
+ self.n = 0 # Set n to 0 if strain is too short
825
+ print(f"Debug: RRF_Dataset __init__ - Calculated self.n: {self.n}") # New debug print
826
+ # Add a check to ensure there's at least one sequence
827
+ if self.n == 0:
828
+ raise ValueError(f"Strain data length ({len(strain)}) is less than sequence length ({seq_len}). Cannot create any samples.")
829
+
830
+
831
+ def __len__(self):
832
+ return self.n
833
+
834
+ def __getitem__(self, idx):
835
+ start = idx * self.seq_len
836
+ # Extract the strain sequence x
837
+ x = self.strain[start:start+self.seq_len] # Shape: [seq_len]
838
+
839
+ # Use the mean of the provided weights as the global resonance factor w
840
+ w = np.mean(self.weights) # global resonance factor
841
+
842
+ # Define the target label y as the mean of the strain sequence x, scaled by w
843
+ # This creates a regression target derived from the strain data.
844
+ y = np.mean(x) * w # synthetic label (proxy resonance)
845
+
846
+ # Convert x and y to PyTorch tensors with float dtype
847
+ # The model expects input x as [1, seq_len] for a single sample, so add unsqueeze(0)
848
+ return torch.tensor(x).float().unsqueeze(0), torch.tensor(y).float()
849
+
850
+
851
+
852
+ def load_model_state(path, model_instance, map_location='cpu'):
853
+ '''Helper: load state_dict from path into model_instance (PyTorch).'''
854
+ state = torch.load(path, map_location=map_location)
855
+ if isinstance(state, dict) and ('state_dict' in state and isinstance(state['state_dict'], dict)):
856
+ state = state['state_dict']
857
+ model_instance.load_state_dict(state)
858
+ return model_instance
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ IPython
2
+ astroquery
3
+ dataclasses
4
+ datasets
5
+ gnn_dirac_rrf
6
+ google
7
+ gwpy
8
+ kagglehub
9
+ matplotlib
10
+ networkx
11
+ numpy
12
+ os
13
+ pandas
14
+ pesummary
15
+ plotly
16
+ random
17
+ requests
18
+ safetensors
19
+ scikit-learn
20
+ scipy
21
+ shutil
22
+ torch
23
+ torch_geometric
24
+ torchsummary
25
+ typing
26
+ zipfile