semiconductor-pipeline / src /models /learning_module.py
Sai Kumar Taraka
feat: Add actual AI/ML capabilities with LLM, semantic embeddings, and reinforcement learning
9e8e9e2
import logging
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass, field
from collections import defaultdict
import json
import os
from datetime import datetime
logger = logging.getLogger("uvmgen.ml.learning")
@dataclass
class ValidationFeedback:
design_name: str
file_name: str
file_type: str
passed: bool
errors: List[str]
warnings: List[str]
score: float
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"design_name": self.design_name,
"file_name": self.file_name,
"file_type": self.file_type,
"passed": self.passed,
"errors": self.errors,
"warnings": self.warnings,
"score": self.score,
"timestamp": self.timestamp,
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> "ValidationFeedback":
return cls(
design_name=d.get("design_name", "unknown"),
file_name=d.get("file_name", "unknown"),
file_type=d.get("file_type", "unknown"),
passed=d.get("passed", False),
errors=d.get("errors", []),
warnings=d.get("warnings", []),
score=d.get("score", 0.0),
timestamp=d.get("timestamp", datetime.now().isoformat()),
metadata=d.get("metadata", {}),
)
@dataclass
class GenerationHistory:
design_name: str
generation_source: str
spec_hash: str
feedback_list: List[ValidationFeedback]
success_rate: float = 0.0
avg_score: float = 0.0
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict[str, Any]:
return {
"design_name": self.design_name,
"generation_source": self.generation_source,
"spec_hash": self.spec_hash,
"feedback_list": [f.to_dict() for f in self.feedback_list],
"success_rate": self.success_rate,
"avg_score": self.avg_score,
"timestamp": self.timestamp,
}
class PatternLearner:
def __init__(self):
self._error_patterns: Dict[str, int] = defaultdict(int)
self._success_patterns: Dict[str, int] = defaultdict(int)
self._file_type_stats: Dict[str, Dict[str, Any]] = defaultdict(
lambda: {"success": 0, "total": 0, "errors": defaultdict(int)}
)
self._protocol_stats: Dict[str, Dict[str, Any]] = defaultdict(
lambda: {"success": 0, "total": 0}
)
def record_error(self, error_msg: str, file_type: str = "unknown"):
patterns = self._extract_patterns(error_msg)
for p in patterns:
self._error_patterns[p] += 1
self._file_type_stats[file_type]["errors"][error_msg[:100]] += 1
def record_success(self, file_type: str = "unknown", protocol: str = "unknown"):
self._file_type_stats[file_type]["success"] += 1
self._file_type_stats[file_type]["total"] += 1
self._protocol_stats[protocol]["success"] += 1
self._protocol_stats[protocol]["total"] += 1
def record_attempt(self, file_type: str = "unknown", protocol: str = "unknown"):
self._file_type_stats[file_type]["total"] += 1
self._protocol_stats[protocol]["total"] += 1
def _extract_patterns(self, text: str) -> List[str]:
import re
patterns = []
uvm_patterns = [
(r"uvm_fatal", "uvm_fatal"),
(r"uvm_error", "uvm_error"),
(r"uvm_component_utils", "missing_uvm_macro"),
(r"uvm_object_utils", "missing_uvm_macro"),
(r"build_phase", "phase_issue"),
(r"connect_phase", "phase_issue"),
(r"run_phase", "phase_issue"),
]
for pattern, name in uvm_patterns:
if re.search(pattern, text, re.IGNORECASE):
patterns.append(name)
syntax_patterns = [
(r"missing.*semicolon", "missing_semicolon"),
(r"unbalanced.*parenthes", "unbalanced_parentheses"),
(r"unbalanced.*brace", "unbalanced_braces"),
(r"unbalanced.*bracket", "unbalanced_brackets"),
(r"mismatch.*begin", "mismatched_blocks"),
(r"syntax error", "syntax_error"),
]
for pattern, name in syntax_patterns:
if re.search(pattern, text, re.IGNORECASE):
patterns.append(name)
if not patterns:
patterns.append("unknown_error")
return patterns
def get_common_errors(self, top_n: int = 10) -> List[Tuple[str, int]]:
sorted_errors = sorted(
self._error_patterns.items(),
key=lambda x: x[1],
reverse=True,
)
return sorted_errors[:top_n]
def get_file_type_success_rate(self, file_type: str) -> float:
stats = self._file_type_stats.get(file_type, {})
total = stats.get("total", 0)
if total == 0:
return 0.5
return stats.get("success", 0) / total
def get_protocol_success_rate(self, protocol: str) -> float:
stats = self._protocol_stats.get(protocol, {})
total = stats.get("total", 0)
if total == 0:
return 0.5
return stats.get("success", 0) / total
def to_dict(self) -> Dict[str, Any]:
return {
"error_patterns": dict(self._error_patterns),
"file_type_stats": {
ft: {
"success": s["success"],
"total": s["total"],
"errors": dict(s["errors"]),
}
for ft, s in self._file_type_stats.items()
},
"protocol_stats": dict(self._protocol_stats),
}
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> "PatternLearner":
learner = cls()
learner._error_patterns = defaultdict(int, d.get("error_patterns", {}))
for ft, s in d.get("file_type_stats", {}).items():
learner._file_type_stats[ft] = {
"success": s.get("success", 0),
"total": s.get("total", 0),
"errors": defaultdict(int, s.get("errors", {})),
}
for proto, s in d.get("protocol_stats", {}).items():
learner._protocol_stats[proto] = {
"success": s.get("success", 0),
"total": s.get("total", 0),
}
return learner
class ReinforcementLearner:
def __init__(self, learning_rate: float = 0.1, discount_factor: float = 0.9):
self._learning_rate = learning_rate
self._discount_factor = discount_factor
self._q_values: Dict[str, float] = defaultdict(lambda: 0.5)
self._visit_counts: Dict[str, int] = defaultdict(int)
def _get_state_key(
self,
protocol: str,
file_type: str,
generation_source: str,
) -> str:
return f"{protocol}:{file_type}:{generation_source}"
def get_action_value(
self,
protocol: str,
file_type: str,
generation_source: str,
) -> float:
key = self._get_state_key(protocol, file_type, generation_source)
return self._q_values[key]
def update(
self,
protocol: str,
file_type: str,
generation_source: str,
reward: float,
):
key = self._get_state_key(protocol, file_type, generation_source)
old_value = self._q_values[key]
self._visit_counts[key] += 1
self._q_values[key] = (
old_value + self._learning_rate * (reward - old_value)
)
def select_best_action(
self,
protocol: str,
file_type: str,
available_sources: List[str],
epsilon: float = 0.1,
) -> Tuple[str, float]:
import random
if random.random() < epsilon and len(available_sources) > 1:
chosen = random.choice(available_sources)
return chosen, self.get_action_value(protocol, file_type, chosen)
best_source = available_sources[0]
best_value = -1.0
for source in available_sources:
value = self.get_action_value(protocol, file_type, source)
if value > best_value:
best_value = value
best_source = source
return best_source, best_value
def to_dict(self) -> Dict[str, Any]:
return {
"learning_rate": self._learning_rate,
"discount_factor": self._discount_factor,
"q_values": dict(self._q_values),
"visit_counts": dict(self._visit_counts),
}
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> "ReinforcementLearner":
learner = cls(
learning_rate=d.get("learning_rate", 0.1),
discount_factor=d.get("discount_factor", 0.9),
)
learner._q_values = defaultdict(lambda: 0.5)
learner._q_values.update(d.get("q_values", {}))
learner._visit_counts = defaultdict(int)
learner._visit_counts.update(d.get("visit_counts", {}))
return learner
class LearningModule:
def __init__(self, storage_path: Optional[str] = None):
self._storage_path = storage_path
self._pattern_learner = PatternLearner()
self._rl_learner = ReinforcementLearner()
self._history: List[GenerationHistory] = []
self._total_generations = 0
self._successful_generations = 0
if storage_path:
self._load_from_storage()
def record_feedback(
self,
design_name: str,
generation_source: str,
spec_dict: Dict[str, Any],
validation_results: Dict[str, Any],
):
import hashlib
import json
spec_str = json.dumps(spec_dict, sort_keys=True)
spec_hash = hashlib.md5(spec_str.encode()).hexdigest()[:12]
protocol = spec_dict.get("protocol", "unknown")
feedback_list = []
files_data = validation_results.get("files", [])
if isinstance(files_data, dict):
for file_name, file_info in files_data.items():
file_type = file_info.get("type", "unknown")
passed = file_info.get("passed", True)
errors = file_info.get("errors", [])
warnings = file_info.get("warnings", [])
score = file_info.get("score", 0.5)
feedback = ValidationFeedback(
design_name=design_name,
file_name=file_name,
file_type=file_type,
passed=passed,
errors=errors,
warnings=warnings,
score=score,
)
feedback_list.append(feedback)
if passed:
self._pattern_learner.record_success(file_type, protocol)
reward = 1.0
else:
for err in errors:
self._pattern_learner.record_error(err, file_type)
reward = -0.5
self._pattern_learner.record_attempt(file_type, protocol)
self._rl_learner.update(protocol, file_type, generation_source, reward)
elif isinstance(files_data, list):
for file_info in files_data:
file_name = file_info.get("filename", "unknown")
file_type = file_info.get("file_type", "unknown")
passed = file_info.get("passed", True)
issues = file_info.get("issues", [])
errors = []
warnings = []
for issue in issues:
severity = issue.get("severity", "warning")
message = issue.get("message", "")
if severity == "error":
errors.append(message)
else:
warnings.append(message)
error_count = file_info.get("error_count", 0)
warning_count = file_info.get("warning_count", 0)
if error_count > 0:
passed = False
score = 1.0 if passed else 0.3
if passed and warning_count == 0:
score = 1.0
elif passed and warning_count > 0:
score = 0.7
feedback = ValidationFeedback(
design_name=design_name,
file_name=file_name,
file_type=file_type,
passed=passed,
errors=errors,
warnings=warnings,
score=score,
)
feedback_list.append(feedback)
if passed:
self._pattern_learner.record_success(file_type, protocol)
reward = 1.0
else:
for err in errors:
self._pattern_learner.record_error(err, file_type)
reward = -0.5
self._pattern_learner.record_attempt(file_type, protocol)
self._rl_learner.update(protocol, file_type, generation_source, reward)
all_passed = all(f.passed for f in feedback_list)
avg_score = sum(f.score for f in feedback_list) / len(feedback_list) if feedback_list else 0.0
history = GenerationHistory(
design_name=design_name,
generation_source=generation_source,
spec_hash=spec_hash,
feedback_list=feedback_list,
success_rate=1.0 if all_passed else 0.0,
avg_score=avg_score,
)
self._history.append(history)
self._total_generations += 1
if all_passed:
self._successful_generations += 1
if self._storage_path:
self._save_to_storage()
def select_best_generation_strategy(
self,
spec_dict: Dict[str, Any],
file_type: str,
available_sources: List[str],
) -> Tuple[str, float]:
protocol = spec_dict.get("protocol", "unknown")
best_source, best_value = self._rl_learner.select_best_action(
protocol=protocol,
file_type=file_type,
available_sources=available_sources,
epsilon=0.05,
)
return best_source, best_value
def get_generation_hints(
self,
spec_dict: Dict[str, Any],
file_type: str,
) -> Dict[str, Any]:
protocol = spec_dict.get("protocol", "unknown")
common_errors = self._pattern_learner.get_common_errors(5)
file_success_rate = self._pattern_learner.get_file_type_success_rate(file_type)
protocol_success_rate = self._pattern_learner.get_protocol_success_rate(protocol)
return {
"common_errors": common_errors,
"file_type_success_rate": file_success_rate,
"protocol_success_rate": protocol_success_rate,
"recommendations": self._generate_recommendations(
common_errors,
file_success_rate,
protocol_success_rate,
),
}
def _generate_recommendations(
self,
common_errors: List[Tuple[str, int]],
file_success_rate: float,
protocol_success_rate: float,
) -> List[str]:
recommendations = []
for error_pattern, count in common_errors[:3]:
if count > 0:
if "semicolon" in error_pattern:
recommendations.append(
"Ensure all statements end with semicolons"
)
elif "parenthes" in error_pattern:
recommendations.append(
"Check for balanced parentheses"
)
elif "brace" in error_pattern:
recommendations.append(
"Check for balanced begin/end blocks"
)
elif "uvm_macro" in error_pattern:
recommendations.append(
"Add UVM factory registration macros (uvm_component_utils/uvm_object_utils)"
)
elif "phase" in error_pattern:
recommendations.append(
"Ensure proper UVM phase implementation"
)
if file_success_rate < 0.7:
recommendations.append(
"Consider using retrieval-based generation for this file type"
)
if protocol_success_rate < 0.7:
recommendations.append(
"Add protocol-specific templates may improve quality"
)
if not recommendations:
recommendations.append(
"No specific recommendations - generation should work well"
)
return recommendations
def get_stats(self) -> Dict[str, Any]:
return {
"total_generations": self._total_generations,
"successful_generations": self._successful_generations,
"success_rate": (
self._successful_generations / self._total_generations
if self._total_generations > 0
else 0.0
),
"history_count": len(self._history),
"pattern_stats": self._pattern_learner.to_dict(),
}
def _save_to_storage(self):
if not self._storage_path:
return
try:
os.makedirs(os.path.dirname(self._storage_path), exist_ok=True)
data = {
"pattern_learner": self._pattern_learner.to_dict(),
"rl_learner": self._rl_learner.to_dict(),
"history": [h.to_dict() for h in self._history[-100:]],
"total_generations": self._total_generations,
"successful_generations": self._successful_generations,
"saved_at": datetime.now().isoformat(),
}
with open(self._storage_path, "w") as f:
json.dump(data, f, indent=2)
logger.debug("Learning module saved to: %s", self._storage_path)
except Exception as e:
logger.warning("Could not save learning module: %s", e)
def _load_from_storage(self):
if not self._storage_path or not os.path.exists(self._storage_path):
return
try:
with open(self._storage_path, "r") as f:
data = json.load(f)
self._pattern_learner = PatternLearner.from_dict(
data.get("pattern_learner", {})
)
self._rl_learner = ReinforcementLearner.from_dict(
data.get("rl_learner", {})
)
history_list = data.get("history", [])
for h_dict in history_list:
feedback_list = [
ValidationFeedback.from_dict(f)
for f in h_dict.get("feedback_list", [])
]
history = GenerationHistory(
design_name=h_dict.get("design_name", "unknown"),
generation_source=h_dict.get("generation_source", "unknown"),
spec_hash=h_dict.get("spec_hash", ""),
feedback_list=feedback_list,
success_rate=h_dict.get("success_rate", 0.0),
avg_score=h_dict.get("avg_score", 0.0),
timestamp=h_dict.get("timestamp", datetime.now().isoformat()),
)
self._history.append(history)
self._total_generations = data.get("total_generations", 0)
self._successful_generations = data.get("successful_generations", 0)
logger.info("Learning module loaded from: %s", self._storage_path)
except Exception as e:
logger.warning("Could not load learning module: %s", e)