granite-code-3b / shared /agent /core /meta_learner.py
AjinkyaPagare's picture
ADAM v2.0: Advanced Agentic Mesh — DAG orchestrator, cognition, knowledge web, forge tools, runtime optimization
ba2ada2
"""
Meta-Learner — optimizes agent strategies from execution history.
Ultra-lightweight: uses compressed execution signatures and simple
statistical learning (no heavy ML frameworks). Tracks which strategies
work best for which goal types and automatically tunes parameters.
"""
import os
import json
import time
import hashlib
from typing import Optional
from dataclasses import dataclass
_STORE_PATH = os.getenv("ADAM_LEARNER_PATH", "/tmp/adam_learner.json")
_MAX_RECORDS = int(os.getenv("ADAM_LEARNER_MAX", "1000"))
@dataclass
class ExecutionRecord:
goal_hash: str
goal_type: str
node_count: int
success_count: int
fail_count: int
total_latency_ms: int
strategies_used: list[str]
timestamp: float
class MetaLearner:
"""
Learns from past agent executions to optimize future performance.
Features:
- Goal type classification → optimal strategy mapping
- Failure pattern analysis → prevention
- Latency optimization → faster execution paths
- Parameter tuning → better LLM prompts
"""
def __init__(self):
self._records: list[dict] = []
self._strategy_scores: dict[str, dict] = {}
self._failure_patterns: dict[str, int] = {}
self._load()
def _load(self):
"""Load persisted learning data."""
try:
if os.path.exists(_STORE_PATH):
with open(_STORE_PATH, "r") as f:
data = json.load(f)
self._records = data.get("records", [])
self._strategy_scores = data.get("scores", {})
self._failure_patterns = data.get("failures", {})
except Exception:
pass
def _save(self):
"""Persist learning data."""
try:
os.makedirs(os.path.dirname(_STORE_PATH) or ".", exist_ok=True)
with open(_STORE_PATH, "w") as f:
json.dump({
"records": self._records[-_MAX_RECORDS:],
"scores": self._strategy_scores,
"failures": dict(sorted(self._failure_patterns.items(),
key=lambda x: x[1], reverse=True)[:50]),
}, f)
except Exception:
pass
async def record_execution(self, goal: str, success_count: int,
fail_count: int, latency_ms: int):
"""Record a completed execution for learning."""
goal_hash = hashlib.md5(goal.encode()).hexdigest()[:12]
goal_type = self._classify_goal(goal)
self._records.append({
"gh": goal_hash,
"gt": goal_type,
"ok": success_count,
"fail": fail_count,
"ms": latency_ms,
"ts": time.time(),
})
# Update strategy scores
if success_count > fail_count:
self._update_strategy_score(goal_type, 1.0, latency_ms)
else:
self._update_strategy_score(goal_type, 0.0, latency_ms)
self._save()
async def record_failure(self, node_label: str, error: str):
"""Record a specific failure for pattern analysis."""
# Extract failure pattern
pattern = self._extract_failure_pattern(error)
if pattern:
self._failure_patterns[pattern] = self._failure_patterns.get(pattern, 0) + 1
self._save()
def get_optimal_strategy(self, goal: str) -> Optional[str]:
"""Get the optimal strategy for a goal based on past learning."""
goal_type = self._classify_goal(goal)
if goal_type in self._strategy_scores:
scores = self._strategy_scores[goal_type]
if scores.get("count", 0) > 3:
return scores.get("best_strategy")
return None
def get_common_failures(self, top_n: int = 5) -> list[tuple[str, int]]:
"""Get most common failure patterns."""
return sorted(self._failure_patterns.items(), key=lambda x: x[1], reverse=True)[:top_n]
def _classify_goal(self, goal: str) -> str:
"""Classify a goal into a type for strategy selection."""
goal_lower = goal.lower()
if any(w in goal_lower for w in ["search", "find", "look up", "research", "google"]):
return "research"
elif any(w in goal_lower for w in ["code", "write", "implement", "function", "script", "program"]):
return "coding"
elif any(w in goal_lower for w in ["analyze", "analyze", "compare", "evaluate"]):
return "analysis"
elif any(w in goal_lower for w in ["create", "generate", "build", "make", "design"]):
return "creation"
elif any(w in goal_lower for w in ["explain", "what", "how", "why", "describe"]):
return "explanation"
elif any(w in goal_lower for w in ["fix", "debug", "error", "issue", "bug"]):
return "debugging"
elif any(w in goal_lower for w in ["data", "file", "read", "process"]):
return "data_processing"
else:
return "general"
def _update_strategy_score(self, goal_type: str, success_ratio: float, latency_ms: int):
"""Update the score for a strategy on a goal type."""
if goal_type not in self._strategy_scores:
self._strategy_scores[goal_type] = {
"best_strategy": "code_forge",
"best_score": 0.0,
"count": 0,
"avg_latency": 0.0,
}
score = self._strategy_scores[goal_type]
score["count"] += 1
score["avg_latency"] = (score["avg_latency"] * 0.7 + latency_ms * 0.3)
# Weight: success ratio * 0.7 + speed factor * 0.3
latency_factor = max(0.0, 1.0 - (latency_ms / 30000.0))
weighted = success_ratio * 0.7 + latency_factor * 0.3
if weighted > score["best_score"]:
score["best_score"] = weighted
def _extract_failure_pattern(self, error: str) -> Optional[str]:
"""Extract a normalized failure pattern from an error message."""
if not error:
return None
error_lower = error.lower()
patterns = [
("timeout", ["timeout", "timed out"]),
("import_error", ["import", "module not found", "no module"]),
("syntax_error", ["syntaxerror", "invalid syntax"]),
("permission", ["permission denied", "access denied"]),
("connection", ["connection", "network", "refused"]),
("not_found", ["not found", "no such file", "does not exist"]),
("memory", ["memory", "oom", "out of memory"]),
("rate_limit", ["rate limit", "too many requests", "429"]),
]
for pattern, keywords in patterns:
if any(k in error_lower for k in keywords):
return pattern
return "unknown_error"
def get_stats(self) -> dict:
"""Get learning statistics."""
return {
"total_executions": len(self._records),
"strategies_tracked": len(self._strategy_scores),
"failure_patterns": len(self._failure_patterns),
"common_failures": self.get_common_failures(3),
"goal_types": list(self._strategy_scores.keys()),
}