| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import math |
| import numpy as np |
| from typing import Dict, Optional |
| from .node import CognitiveNode |
|
|
| class DynamicCognitiveNet(nn.Module): |
| """Jaringan dengan manajemen koneksi yang robust""" |
| def __init__(self, input_size: int, output_size: int): |
| super().__init__() |
| self.input_size = input_size |
| self.output_size = output_size |
| |
| |
| self.input_nodes = nn.ModuleList([ |
| CognitiveNode(i, 1) for i in range(input_size) |
| ]) |
| self.output_nodes = nn.ModuleList([ |
| CognitiveNode(input_size + i, 1) for i in range(output_size) |
| ]) |
| |
| |
| self.connections = nn.ParameterDict() |
| self._init_base_connections() |
| |
| |
| self.emotional_state = nn.Parameter(torch.tensor(0.0)) |
| self.optimizer = optim.AdamW(self.parameters(), lr=0.001) |
| self.loss_fn = nn.MSELoss() |
|
|
| def _init_base_connections(self): |
| """Inisialisasi koneksi input-output""" |
| for in_node in self.input_nodes: |
| for out_node in self.output_nodes: |
| conn_id = f"{in_node.id}->{out_node.id}" |
| self.connections[conn_id] = nn.Parameter( |
| torch.randn(1) * 0.1 |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| activations = {} |
| for i, node in enumerate(self.input_nodes): |
| activations[node.id] = node(x[i].unsqueeze(0)) |
| |
| |
| outputs = [] |
| for out_node in self.output_nodes: |
| integrated = [] |
| for in_node in self.input_nodes: |
| conn_id = f"{in_node.id}->{out_node.id}" |
| weight = torch.sigmoid(self.connections[conn_id]) |
| integrated.append(activations[in_node.id] * weight) |
| |
| if integrated: |
| combined = sum(integrated) / math.sqrt(len(integrated)) |
| outputs.append(out_node(combined)) |
| |
| return torch.stack(outputs).squeeze() |
|
|
| def structural_update(self, global_reward: float): |
| """Update struktur dengan validasi koneksi""" |
| |
| for conn_id in list(self.connections.keys()): |
| new_weight = self.connections[conn_id] + 0.1 * global_reward |
| self.connections[conn_id].data = new_weight.clamp(-1, 1) |
| |
| |
| if global_reward < -0.5: |
| new_conn = self._find_underutilized_connection() |
| if new_conn and new_conn not in self.connections: |
| self.connections[new_conn] = nn.Parameter(torch.randn(1) * 0.1) |
|
|
| def _find_underutilized_connection(self) -> Optional[str]: |
| """Mencari koneksi input-output yang underutilized""" |
| input_act = {n.id: np.mean(n.recent_activations) |
| for n in self.input_nodes if n.recent_activations} |
| output_act = {n.id: np.mean(n.recent_activations) |
| for n in self.output_nodes if n.recent_activations} |
| |
| if not input_act or not output_act: |
| return None |
| |
| src = min(input_act, key=lambda k: input_act[k]) |
| tgt = min(output_act, key=lambda k: output_act[k]) |
| return f"{src}->{tgt}" |
|
|
| def train_step(self, x: torch.Tensor, y: torch.Tensor) -> float: |
| """Training step dengan error handling""" |
| self.optimizer.zero_grad() |
| |
| try: |
| pred = self(x) |
| loss = self.loss_fn(pred, y) |
| except RuntimeError as e: |
| print(f"Error selama forward pass: {e}") |
| return float('nan') |
| |
| |
| reg_loss = sum(p.abs().mean() for p in self.connections.values()) |
| total_loss = loss + 0.01 * reg_loss |
| |
| try: |
| total_loss.backward() |
| self.optimizer.step() |
| except RuntimeError as e: |
| print(f"Error selama backpropagation: {e}") |
| return float('nan') |
| |
| |
| self.emotional_state.data = torch.sigmoid( |
| self.emotional_state + (0.5 - loss.item()) * 0.1 |
| ) |
| |
| |
| self.structural_update(0.5 - loss.item()) |
| |
| return total_loss.item() |