# ============================================================ # 🚀 IcosahedralRRF - Código Unificado # ============================================================ import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader import torch_geometric from torch_geometric.data import Data as GeoData from torch_geometric.utils import to_dense_batch # ============================================================ # 1️⃣ Dataset Icosaédrico # ============================================================ class IcosahedralRRFDataset(Dataset): """ Dataset sintético icosaédrico para nodos gauge y GNN Dirac. Genera features x, edge_index y coordenadas espectrales z. """ def __init__(self, num_samples=1000, num_nodes=12, feat_dim=8, z_dim=4): self.num_samples = num_samples self.num_nodes = num_nodes self.feat_dim = feat_dim self.z_dim = z_dim def __len__(self): return self.num_samples def __getitem__(self, idx): # Features de nodos x = torch.randn(self.num_nodes, self.feat_dim) # Grafo completo (icosaédrico aproximado) edge_index = torch.combinations(torch.arange(self.num_nodes), r=2).t() # Coordenadas espectrales Dirac z = torch.randn(self.num_nodes, self.z_dim) # Label dummy (regresión o clasificación) y = torch.randn(1) return {"x": x, "edge_index": edge_index, "z": z, "y": y} # ============================================================ # 2️⃣ Transformador Dataset → Secuencias nodos gauge # ============================================================ def map_to_gauge_sequence(batch): """ Convierte batch de features nodales a formato [batch_size, input_dim, seq_len] para nodos gauge (cada nodo = secuencia de features) """ x_list = [] y_list = [] edge_index_list = [] z_list = [] for sample in batch: x_list.append(sample['x'].T) # [feat_dim, num_nodes] → seq_len = num_nodes y_list.append(sample['y']) edge_index_list.append(sample['edge_index']) z_list.append(sample['z']) x_batch = torch.stack(x_list) # [batch_size, feat_dim, seq_len] y_batch = torch.stack(y_list) return x_batch, y_batch, edge_index_list, z_list # ============================================================ # 3️⃣ Nodo Gauge # ============================================================ class SavantRRF_Gauge(nn.Module): def __init__(self, input_dim=8, hidden_dim=16, output_dim=8): super().__init__() self.conv1 = nn.Conv1d(input_dim, hidden_dim, kernel_size=1) self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1) self.conv3 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1) self.fc = nn.Linear(hidden_dim, output_dim) def forward(self, x): # x: [batch_size, input_dim, seq_len] x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = x.mean(dim=2) # Promedio sobre secuencia x = self.fc(x) return x # [batch_size, output_dim] # ============================================================ # 4️⃣ Dirac GNN # ============================================================ class DiracGraphConv(nn.Module): """ GNN simple con atención basada en correlación coseno y features z """ def __init__(self, node_dim=8, z_dim=4, hidden_dim=16, output_dim=8): super().__init__() self.fc1 = nn.Linear(node_dim + z_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, output_dim) def forward(self, x, z, edge_index): """ x: [num_nodes, node_dim] z: [num_nodes, z_dim] edge_index: [2, num_edges] """ h = torch.cat([x, z], dim=1) h = F.relu(self.fc1(h)) # Atención simple: promedio sobre vecinos row, col = edge_index agg = torch.zeros_like(h) agg.index_add_(0, row, h[col]) h = h + agg h = self.fc2(h) return h.mean(dim=0, keepdim=True) # [1, output_dim] # ============================================================ # 5️⃣ Modelo IcosahedralRRF # ============================================================ class IcosahedralRRF(nn.Module): def __init__(self, input_dim=8, hidden_dim=16, node_output_dim=8, num_gauge_nodes=12, z_dim=4, output_dim=1): super().__init__() self.num_nodes = num_gauge_nodes self.gauge_nodes = nn.ModuleList([SavantRRF_Gauge(input_dim, hidden_dim, node_output_dim) for _ in range(self.num_nodes)]) ... def forward(self, x, edge_index_list=None, z_list=None): """ x: [batch_size, input_dim, seq_len] edge_index_list, z_list: list por batch """ batch_size = x.size(0) gauge_outputs = [] for i, gauge in enumerate(self.gauge_nodes): gauge_outputs.append(gauge(x)) # [batch_size, node_output_dim] gauges_cat = torch.cat(gauge_outputs, dim=1) # [batch_size, num_nodes * node_output_dim] # Aplicar GNN si z y edge_index disponibles if edge_index_list is not None and z_list is not None: gnn_outs = [] for i in range(batch_size): gnn_outs.append(self.gnn(gauge_outputs[i].unsqueeze(0), z_list[i], edge_index_list[i])) gnn_outs = torch.stack(gnn_outs).squeeze(1) gauges_cat = gauges_cat + gnn_outs.repeat(1, self.num_nodes) out = self.fc_final(gauges_cat) return out # ============================================================ # 6️⃣ Entrenamiento / Evaluación # ============================================================ def train(model, dataloader, optimizer, criterion, device="cpu"): model.train() total_loss = 0 for batch in dataloader: x_batch, y_batch, edge_index_list, z_list = map_to_gauge_sequence(batch) x_batch = x_batch.to(device) y_batch = y_batch.to(device) optimizer.zero_grad() out = model(x_batch, edge_index_list, z_list) loss = criterion(out, y_batch) loss.backward() optimizer.step() total_loss += loss.item() * x_batch.size(0) return total_loss / len(dataloader.dataset) def evaluate(model, dataloader, criterion, device="cpu"): model.eval() total_loss = 0 with torch.no_grad(): for batch in dataloader: x_batch, y_batch, edge_index_list, z_list = map_to_gauge_sequence(batch) x_batch = x_batch.to(device) y_batch = y_batch.to(device) out = model(x_batch, edge_index_list, z_list) loss = criterion(out, y_batch) total_loss += loss.item() * x_batch.size(0) return total_loss / len(dataloader.dataset) # ============================================================ # 7️⃣ Hyperparámetros y ejecución ejemplo # ============================================================ if __name__ == "__main__": device = "cuda" if torch.cuda.is_available() else "cpu" dataset = IcosahedralRRFDataset(num_samples=200) dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=lambda x: x) model = IcosahedralRRF().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) criterion = nn.MSELoss() for epoch in range(5): loss_train = train(model, dataloader, optimizer, criterion, device) loss_eval = evaluate(model, dataloader, criterion, device) print(f"Epoch {epoch+1}: Train Loss = {loss_train:.4f}, Eval Loss = {loss_eval:.4f}")