RRF / model.py
antonypamo's picture
Update model.py
222aa9e verified
Raw
History Blame Contribute Delete
7.66 kB
# ============================================================
# 🚀 IcosahedralRRF - Código Unificado
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch_geometric
from torch_geometric.data import Data as GeoData
from torch_geometric.utils import to_dense_batch
# ============================================================
# 1️⃣ Dataset Icosaédrico
# ============================================================
class IcosahedralRRFDataset(Dataset):
"""
Dataset sintético icosaédrico para nodos gauge y GNN Dirac.
Genera features x, edge_index y coordenadas espectrales z.
"""
def __init__(self, num_samples=1000, num_nodes=12, feat_dim=8, z_dim=4):
self.num_samples = num_samples
self.num_nodes = num_nodes
self.feat_dim = feat_dim
self.z_dim = z_dim
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Features de nodos
x = torch.randn(self.num_nodes, self.feat_dim)
# Grafo completo (icosaédrico aproximado)
edge_index = torch.combinations(torch.arange(self.num_nodes), r=2).t()
# Coordenadas espectrales Dirac
z = torch.randn(self.num_nodes, self.z_dim)
# Label dummy (regresión o clasificación)
y = torch.randn(1)
return {"x": x, "edge_index": edge_index, "z": z, "y": y}
# ============================================================
# 2️⃣ Transformador Dataset → Secuencias nodos gauge
# ============================================================
def map_to_gauge_sequence(batch):
"""
Convierte batch de features nodales a formato [batch_size, input_dim, seq_len]
para nodos gauge (cada nodo = secuencia de features)
"""
x_list = []
y_list = []
edge_index_list = []
z_list = []
for sample in batch:
x_list.append(sample['x'].T) # [feat_dim, num_nodes] → seq_len = num_nodes
y_list.append(sample['y'])
edge_index_list.append(sample['edge_index'])
z_list.append(sample['z'])
x_batch = torch.stack(x_list) # [batch_size, feat_dim, seq_len]
y_batch = torch.stack(y_list)
return x_batch, y_batch, edge_index_list, z_list
# ============================================================
# 3️⃣ Nodo Gauge
# ============================================================
class SavantRRF_Gauge(nn.Module):
def __init__(self, input_dim=8, hidden_dim=16, output_dim=8):
super().__init__()
self.conv1 = nn.Conv1d(input_dim, hidden_dim, kernel_size=1)
self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)
self.conv3 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
# x: [batch_size, input_dim, seq_len]
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = x.mean(dim=2) # Promedio sobre secuencia
x = self.fc(x)
return x # [batch_size, output_dim]
# ============================================================
# 4️⃣ Dirac GNN
# ============================================================
class DiracGraphConv(nn.Module):
"""
GNN simple con atención basada en correlación coseno y features z
"""
def __init__(self, node_dim=8, z_dim=4, hidden_dim=16, output_dim=8):
super().__init__()
self.fc1 = nn.Linear(node_dim + z_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x, z, edge_index):
"""
x: [num_nodes, node_dim]
z: [num_nodes, z_dim]
edge_index: [2, num_edges]
"""
h = torch.cat([x, z], dim=1)
h = F.relu(self.fc1(h))
# Atención simple: promedio sobre vecinos
row, col = edge_index
agg = torch.zeros_like(h)
agg.index_add_(0, row, h[col])
h = h + agg
h = self.fc2(h)
return h.mean(dim=0, keepdim=True) # [1, output_dim]
# ============================================================
# 5️⃣ Modelo IcosahedralRRF
# ============================================================
class IcosahedralRRF(nn.Module):
def __init__(self, input_dim=8, hidden_dim=16, node_output_dim=8, num_gauge_nodes=12, z_dim=4, output_dim=1):
super().__init__()
self.num_nodes = num_gauge_nodes
self.gauge_nodes = nn.ModuleList([SavantRRF_Gauge(input_dim, hidden_dim, node_output_dim) for _ in range(self.num_nodes)])
...
def forward(self, x, edge_index_list=None, z_list=None):
"""
x: [batch_size, input_dim, seq_len]
edge_index_list, z_list: list por batch
"""
batch_size = x.size(0)
gauge_outputs = []
for i, gauge in enumerate(self.gauge_nodes):
gauge_outputs.append(gauge(x)) # [batch_size, node_output_dim]
gauges_cat = torch.cat(gauge_outputs, dim=1) # [batch_size, num_nodes * node_output_dim]
# Aplicar GNN si z y edge_index disponibles
if edge_index_list is not None and z_list is not None:
gnn_outs = []
for i in range(batch_size):
gnn_outs.append(self.gnn(gauge_outputs[i].unsqueeze(0), z_list[i], edge_index_list[i]))
gnn_outs = torch.stack(gnn_outs).squeeze(1)
gauges_cat = gauges_cat + gnn_outs.repeat(1, self.num_nodes)
out = self.fc_final(gauges_cat)
return out
# ============================================================
# 6️⃣ Entrenamiento / Evaluación
# ============================================================
def train(model, dataloader, optimizer, criterion, device="cpu"):
model.train()
total_loss = 0
for batch in dataloader:
x_batch, y_batch, edge_index_list, z_list = map_to_gauge_sequence(batch)
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
optimizer.zero_grad()
out = model(x_batch, edge_index_list, z_list)
loss = criterion(out, y_batch)
loss.backward()
optimizer.step()
total_loss += loss.item() * x_batch.size(0)
return total_loss / len(dataloader.dataset)
def evaluate(model, dataloader, criterion, device="cpu"):
model.eval()
total_loss = 0
with torch.no_grad():
for batch in dataloader:
x_batch, y_batch, edge_index_list, z_list = map_to_gauge_sequence(batch)
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
out = model(x_batch, edge_index_list, z_list)
loss = criterion(out, y_batch)
total_loss += loss.item() * x_batch.size(0)
return total_loss / len(dataloader.dataset)
# ============================================================
# 7️⃣ Hyperparámetros y ejecución ejemplo
# ============================================================
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
dataset = IcosahedralRRFDataset(num_samples=200)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=lambda x: x)
model = IcosahedralRRF().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
for epoch in range(5):
loss_train = train(model, dataloader, optimizer, criterion, device)
loss_eval = evaluate(model, dataloader, criterion, device)
print(f"Epoch {epoch+1}: Train Loss = {loss_train:.4f}, Eval Loss = {loss_eval:.4f}")