""" Agente del Nucleo Logico ======================== Agente de RL que aprende a: 1. Seleccionar skills relevantes 2. Reorganizar el grafo 3. Asistir con pruebas Lean Arquitectura: - Encoder de grafo (GNN con GATConv) - Encoder de query (bag-of-keywords) - Encoder de goal (hash determinista) - Fusion (concat + Linear) - Actor-Critic (PPO con GAE) """ from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Optional, List, Dict, Any, Tuple import numpy as np import random import os from nucleo.types import State, Action, ActionType from nucleo.graph.category import SkillCategory from nucleo.rl.mdp import MDP, Transition, ExperienceBuffer from nucleo.rl.rewards import RewardFunction, RewardConfig # Tipos de accion en orden fijo (indices para la red neuronal) ACTION_TYPES = [ActionType.RESPONSE, ActionType.REORGANIZE, ActionType.ASSIST] @dataclass class AgentConfig: """Configuracion del agente.""" # Arquitectura (debe coincidir con los checkpoints entrenados: num_heads=4) hidden_dim: int = 256 num_heads: int = 4 num_layers: int = 3 # Entrenamiento learning_rate: float = 3e-4 gamma: float = 0.99 batch_size: int = 64 # PPO clip_range: float = 0.2 value_coef: float = 0.5 entropy_coef: float = 0.01 gae_lambda: float = 0.95 n_epochs: int = 4 # Exploracion epsilon_start: float = 1.0 epsilon_end: float = 0.1 epsilon_decay: float = 0.995 class BaseAgent(ABC): """Clase base para agentes.""" @abstractmethod def select_action(self, state: State) -> Action: """Seleccionar accion dado estado.""" pass @abstractmethod def update(self, transitions: List[Transition]) -> Dict[str, float]: """Actualizar politica con transiciones.""" pass class RandomAgent(BaseAgent): """ Agente aleatorio (baseline). Selecciona acciones uniformemente al azar. """ def __init__(self, action_space: List[ActionType]): self.action_space = action_space def select_action(self, state: State) -> Action: """Seleccionar accion aleatoria.""" action_type = random.choice(self.action_space) if action_type == ActionType.RESPONSE: return Action.response("Random response") elif action_type == ActionType.REORGANIZE: op = random.choice(["add_node", "reweight", "merge"]) return Action.reorganize(op) else: return Action.assist(tactic="simp", goal="?") def update(self, transitions: List[Transition]) -> Dict[str, float]: """Agente aleatorio no aprende.""" return {"loss": 0.0} class HeuristicAgent(BaseAgent): """ Agente heuristico. Usa reglas simples para seleccionar acciones. Sirve como baseline mejorado. """ def __init__(self, graph: SkillCategory): self.graph = graph def select_action(self, state: State) -> Action: """Seleccionar accion usando heuristicas.""" # Si hay goal de Lean, intentar asistir if state.has_active_goal: tactic = self._suggest_tactic(state.lean_goal) return Action.assist(tactic=tactic, goal=state.lean_goal) # Si el grafo esta desbalanceado, reorganizar if self._should_reorganize(): return self._suggest_reorganization() # Por defecto, responder return Action.response("Heuristic response") def _suggest_tactic(self, goal: str) -> str: """Sugerir tactica basada en patron del goal.""" if not goal: return "sorry" # Heuristicas simples if "∀" in goal or "→" in goal: return "intro" if "∧" in goal: return "constructor" if "=" in goal: return "rfl" if "∨" in goal: return "left" # o right return "simp" def _should_reorganize(self) -> bool: """Determinar si el grafo necesita reorganizacion.""" stats = self.graph.stats # Reorganizar si hay muchos skills desconectados if not self.graph.is_connected(): return True # Reorganizar si los pesos estan muy desbalanceados avg_weight = stats.get("avg_weight", 1.0) if avg_weight < 0.5 or avg_weight > 2.0: return True return False def _suggest_reorganization(self) -> Action: """Sugerir operacion de reorganizacion.""" if not self.graph.is_connected(): return Action.reorganize("add_edge") return Action.reorganize("reweight") def update(self, transitions: List[Transition]) -> Dict[str, float]: """Agente heuristico no aprende.""" return {"loss": 0.0} class NucleoAgent(BaseAgent): """ Agente principal del Nucleo Logico. Combina: - Exploracion epsilon-greedy - Politica aprendida via PPO con GNN (cuando use_neural=True) - Heuristicas como fallback """ def __init__( self, graph: SkillCategory, config: Optional[AgentConfig] = None, use_neural: bool = False, ): self.graph = graph self.config = config or AgentConfig() # Estado de entrenamiento self.epsilon = self.config.epsilon_start self.total_steps = 0 self.training = True # Agentes auxiliares self._heuristic = HeuristicAgent(graph) self._random = RandomAgent([ ActionType.RESPONSE, ActionType.REORGANIZE, ActionType.ASSIST ]) # Buffer de experiencia self.buffer = ExperienceBuffer(capacity=10000) # Metricas self.metrics = { "episodes": 0, "total_reward": 0.0, "avg_reward": 0.0, "epsilon": self.epsilon, } # Red neuronal (opt-in) self._use_neural = use_neural self._network = None self._optimizer = None self._procedural_memory = None # Set externally for memory-guided decisions if use_neural: try: self._init_network() except ImportError: self._use_neural = False def _init_network(self) -> None: """Inicializar red neuronal actor-critic.""" import torch from nucleo.rl.networks import ActorCriticNetwork, TORCH_AVAILABLE if not TORCH_AVAILABLE: raise ImportError("torch no disponible") self._network = ActorCriticNetwork( hidden_dim=self.config.hidden_dim, gnn_num_layers=self.config.num_layers, gnn_num_heads=self.config.num_heads, ) self._optimizer = torch.optim.Adam( self._network.parameters(), lr=self.config.learning_rate, ) @property def has_network(self) -> bool: """Verificar si la red neuronal esta disponible.""" return self._network is not None def select_action(self, state: State) -> Action: """ Seleccionar accion. Usa epsilon-greedy durante entrenamiento. Si hay red neuronal, la usa para explotacion. """ self.total_steps += 1 # Exploracion epsilon-greedy if self.training and random.random() < self.epsilon: return self._random.select_action(state) # Explotacion con red neuronal if self._network is not None: return self._select_neural(state) # Fallback a heuristica return self._heuristic.select_action(state) def _select_neural(self, state: State) -> Action: """Seleccionar accion usando memoria de patrones o red neuronal.""" # Check procedural memory first for proven patterns if self._procedural_memory is not None: query_text = state.lean_goal or "" best_proc = self._procedural_memory.get_best_for_query(query_text) if best_proc is not None and best_proc.success_rate >= 0.8: # Use proven action from memory action_name = best_proc.action_sequence[0] if best_proc.action_sequence else "RESPONSE" try: chosen = ActionType[action_name] except KeyError: chosen = ActionType.RESPONSE best_proc.invoke() if chosen == ActionType.RESPONSE: return Action.response("Memory pattern: respond") elif chosen == ActionType.REORGANIZE: return self._heuristic._suggest_reorganization() else: tactic = best_proc.tactic_used or self._heuristic._suggest_tactic(query_text) return Action.assist(tactic=tactic, goal=query_text) # Fall back to neural network import torch from nucleo.rl.gnn import graph_to_pyg from nucleo.rl.networks import encode_query, encode_goal self._network.eval() with torch.no_grad(): device = next(self._network.parameters()).device graph_data = graph_to_pyg(self.graph).to(device) query_text = state.lean_goal or "" query_emb = encode_query(query_text).unsqueeze(0).to(device) goal_emb = None if state.lean_goal: goal_emb = encode_goal(state.lean_goal).unsqueeze(0).to(device) output = self._network(graph_data, query_emb, goal_emb=goal_emb) probs = torch.softmax(output.action_logits, dim=-1) action_idx = torch.multinomial(probs, 1).item() chosen = ACTION_TYPES[action_idx] # Delegar parametros especificos a la heuristica if chosen == ActionType.RESPONSE: return Action.response("Neural policy: respond") elif chosen == ActionType.REORGANIZE: return self._heuristic._suggest_reorganization() else: tactic = self._heuristic._suggest_tactic(state.lean_goal or "") return Action.assist(tactic=tactic, goal=state.lean_goal or "") def update(self, transitions: List[Transition]) -> Dict[str, float]: """ Actualizar politica con transiciones. Si la red neuronal esta disponible y hay suficientes datos, ejecuta actualizacion PPO. """ # Anadir al buffer for t in transitions: self.buffer.push(t) # Calcular metricas total_reward = sum(t.reward for t in transitions) self.metrics["total_reward"] += total_reward # Decay epsilon self.epsilon = max( self.config.epsilon_end, self.epsilon * self.config.epsilon_decay ) self.metrics["epsilon"] = self.epsilon # PPO update si hay red y suficientes transiciones ppo_loss = 0.0 if self._network is not None and len(transitions) >= 2: ppo_loss = self._ppo_update(transitions) return { "loss": ppo_loss, "reward": total_reward, "buffer_size": len(self.buffer), } def _ppo_update(self, transitions: List[Transition]) -> float: """ Actualizacion PPO con GAE. 1. Calcular ventajas via GAE(lambda) 2. Multiples epocas de minibatch 3. Clipped surrogate objective 4. Value loss + entropy bonus """ import torch import torch.nn.functional as F from nucleo.rl.gnn import graph_to_pyg from nucleo.rl.networks import encode_query, encode_goal self._network.train() gamma = self.config.gamma gae_lambda = self.config.gae_lambda # Detectar dispositivo del modelo device = next(self._network.parameters()).device # Preparar datos del rollout graph_data = graph_to_pyg(self.graph) rewards = torch.tensor([t.reward for t in transitions], dtype=torch.float32).to(device) dones = torch.tensor([t.done for t in transitions], dtype=torch.float32).to(device) # Codificar queries para cada transicion query_embs = [] goal_embs = [] action_indices = [] for t in transitions: q_text = t.state.lean_goal or "" query_embs.append(encode_query(q_text)) goal_embs.append(encode_goal(q_text)) action_indices.append(ACTION_TYPES.index(t.action.action_type)) query_batch = torch.stack(query_embs).to(device) goal_batch = torch.stack(goal_embs).to(device) action_batch = torch.tensor(action_indices, dtype=torch.long).to(device) # Expandir graph_data para el batch (mismo grafo para todas las transiciones) from torch_geometric.data import Batch batch_size = len(transitions) graph_batch = Batch.from_data_list([graph_data] * batch_size).to(device) # Forward pass para obtener old log_probs y values with torch.no_grad(): output = self._network(graph_batch, query_batch, goal_emb=goal_batch) old_log_probs = F.log_softmax(output.action_logits, dim=-1) old_log_probs = old_log_probs.gather(1, action_batch.unsqueeze(1)).squeeze(1) values = output.value.squeeze(1) # Calcular GAE advantages = torch.zeros_like(rewards) returns = torch.zeros_like(rewards) gae = 0.0 next_value = 0.0 for t in reversed(range(len(transitions))): delta = rewards[t] + gamma * next_value * (1 - dones[t]) - values[t] gae = delta + gamma * gae_lambda * (1 - dones[t]) * gae advantages[t] = gae returns[t] = advantages[t] + values[t] next_value = values[t] # Normalizar ventajas if advantages.std() > 0: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # PPO epochs total_loss = 0.0 for _ in range(self.config.n_epochs): output = self._network(graph_batch, query_batch, goal_emb=goal_batch) new_log_probs = F.log_softmax(output.action_logits, dim=-1) new_log_probs = new_log_probs.gather(1, action_batch.unsqueeze(1)).squeeze(1) new_values = output.value.squeeze(1) # Ratio ratio = torch.exp(new_log_probs - old_log_probs.detach()) # Clipped surrogate surr1 = ratio * advantages.detach() surr2 = torch.clamp(ratio, 1.0 - self.config.clip_range, 1.0 + self.config.clip_range) * advantages.detach() policy_loss = -torch.min(surr1, surr2).mean() # Value loss value_loss = F.mse_loss(new_values, returns.detach()) # Entropy bonus probs = torch.softmax(output.action_logits, dim=-1) entropy = -(probs * probs.clamp(min=1e-8).log()).sum(dim=-1).mean() # Loss total loss = ( policy_loss + self.config.value_coef * value_loss - self.config.entropy_coef * entropy ) self._optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self._network.parameters(), 0.5) self._optimizer.step() total_loss += loss.item() return total_loss / self.config.n_epochs def train_episode(self, mdp: MDP, max_steps: int = 100) -> Dict[str, float]: """ Entrenar un episodio completo. Args: mdp: MDP del nucleo max_steps: Pasos maximos por episodio Returns: Metricas del episodio """ state = mdp.reset() transitions = [] episode_reward = 0.0 for step in range(max_steps): # Seleccionar y ejecutar accion action = self.select_action(state) transition = mdp.step(action) transitions.append(transition) episode_reward += transition.reward if transition.done: break state = transition.next_state # Actualizar update_metrics = self.update(transitions) self.metrics["episodes"] += 1 self.metrics["avg_reward"] = ( self.metrics["total_reward"] / self.metrics["episodes"] ) return { "episode_reward": episode_reward, "episode_length": len(transitions), **update_metrics, } def eval_mode(self) -> None: """Cambiar a modo evaluacion.""" self.training = False def train_mode(self) -> None: """Cambiar a modo entrenamiento.""" self.training = True def save(self, path: str) -> None: """Guardar agente (config + metricas + pesos de red).""" import json with open(path, 'w', encoding="utf-8") as f: json.dump({ "config": self.config.__dict__, "metrics": self.metrics, "epsilon": self.epsilon, "total_steps": self.total_steps, "use_neural": self._use_neural, }, f, indent=2) # Guardar pesos de la red neuronal if self._network is not None: import torch torch.save(self._network.state_dict(), path + ".pt") @classmethod def load(cls, path: str, graph: SkillCategory) -> NucleoAgent: """Cargar agente (config + metricas + pesos de red).""" import json with open(path, encoding="utf-8") as f: data = json.load(f) use_neural = data.get("use_neural", False) config = AgentConfig(**data["config"]) agent = cls(graph, config, use_neural=use_neural) agent.metrics = data["metrics"] agent.epsilon = data["epsilon"] agent.total_steps = data["total_steps"] # Cargar pesos si existen pt_path = path + ".pt" if os.path.exists(pt_path) and agent._network is not None: import torch agent._network.load_state_dict( torch.load(pt_path, weights_only=True) ) return agent