Sai Kumar Taraka
feat: Add actual AI/ML capabilities with LLM, semantic embeddings, and reinforcement learning
9e8e9e2 | import logging | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from dataclasses import dataclass, field | |
| from collections import defaultdict | |
| import json | |
| import os | |
| from datetime import datetime | |
| logger = logging.getLogger("uvmgen.ml.learning") | |
| class ValidationFeedback: | |
| design_name: str | |
| file_name: str | |
| file_type: str | |
| passed: bool | |
| errors: List[str] | |
| warnings: List[str] | |
| score: float | |
| timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) | |
| metadata: Dict[str, Any] = field(default_factory=dict) | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "design_name": self.design_name, | |
| "file_name": self.file_name, | |
| "file_type": self.file_type, | |
| "passed": self.passed, | |
| "errors": self.errors, | |
| "warnings": self.warnings, | |
| "score": self.score, | |
| "timestamp": self.timestamp, | |
| "metadata": self.metadata, | |
| } | |
| def from_dict(cls, d: Dict[str, Any]) -> "ValidationFeedback": | |
| return cls( | |
| design_name=d.get("design_name", "unknown"), | |
| file_name=d.get("file_name", "unknown"), | |
| file_type=d.get("file_type", "unknown"), | |
| passed=d.get("passed", False), | |
| errors=d.get("errors", []), | |
| warnings=d.get("warnings", []), | |
| score=d.get("score", 0.0), | |
| timestamp=d.get("timestamp", datetime.now().isoformat()), | |
| metadata=d.get("metadata", {}), | |
| ) | |
| class GenerationHistory: | |
| design_name: str | |
| generation_source: str | |
| spec_hash: str | |
| feedback_list: List[ValidationFeedback] | |
| success_rate: float = 0.0 | |
| avg_score: float = 0.0 | |
| timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "design_name": self.design_name, | |
| "generation_source": self.generation_source, | |
| "spec_hash": self.spec_hash, | |
| "feedback_list": [f.to_dict() for f in self.feedback_list], | |
| "success_rate": self.success_rate, | |
| "avg_score": self.avg_score, | |
| "timestamp": self.timestamp, | |
| } | |
| class PatternLearner: | |
| def __init__(self): | |
| self._error_patterns: Dict[str, int] = defaultdict(int) | |
| self._success_patterns: Dict[str, int] = defaultdict(int) | |
| self._file_type_stats: Dict[str, Dict[str, Any]] = defaultdict( | |
| lambda: {"success": 0, "total": 0, "errors": defaultdict(int)} | |
| ) | |
| self._protocol_stats: Dict[str, Dict[str, Any]] = defaultdict( | |
| lambda: {"success": 0, "total": 0} | |
| ) | |
| def record_error(self, error_msg: str, file_type: str = "unknown"): | |
| patterns = self._extract_patterns(error_msg) | |
| for p in patterns: | |
| self._error_patterns[p] += 1 | |
| self._file_type_stats[file_type]["errors"][error_msg[:100]] += 1 | |
| def record_success(self, file_type: str = "unknown", protocol: str = "unknown"): | |
| self._file_type_stats[file_type]["success"] += 1 | |
| self._file_type_stats[file_type]["total"] += 1 | |
| self._protocol_stats[protocol]["success"] += 1 | |
| self._protocol_stats[protocol]["total"] += 1 | |
| def record_attempt(self, file_type: str = "unknown", protocol: str = "unknown"): | |
| self._file_type_stats[file_type]["total"] += 1 | |
| self._protocol_stats[protocol]["total"] += 1 | |
| def _extract_patterns(self, text: str) -> List[str]: | |
| import re | |
| patterns = [] | |
| uvm_patterns = [ | |
| (r"uvm_fatal", "uvm_fatal"), | |
| (r"uvm_error", "uvm_error"), | |
| (r"uvm_component_utils", "missing_uvm_macro"), | |
| (r"uvm_object_utils", "missing_uvm_macro"), | |
| (r"build_phase", "phase_issue"), | |
| (r"connect_phase", "phase_issue"), | |
| (r"run_phase", "phase_issue"), | |
| ] | |
| for pattern, name in uvm_patterns: | |
| if re.search(pattern, text, re.IGNORECASE): | |
| patterns.append(name) | |
| syntax_patterns = [ | |
| (r"missing.*semicolon", "missing_semicolon"), | |
| (r"unbalanced.*parenthes", "unbalanced_parentheses"), | |
| (r"unbalanced.*brace", "unbalanced_braces"), | |
| (r"unbalanced.*bracket", "unbalanced_brackets"), | |
| (r"mismatch.*begin", "mismatched_blocks"), | |
| (r"syntax error", "syntax_error"), | |
| ] | |
| for pattern, name in syntax_patterns: | |
| if re.search(pattern, text, re.IGNORECASE): | |
| patterns.append(name) | |
| if not patterns: | |
| patterns.append("unknown_error") | |
| return patterns | |
| def get_common_errors(self, top_n: int = 10) -> List[Tuple[str, int]]: | |
| sorted_errors = sorted( | |
| self._error_patterns.items(), | |
| key=lambda x: x[1], | |
| reverse=True, | |
| ) | |
| return sorted_errors[:top_n] | |
| def get_file_type_success_rate(self, file_type: str) -> float: | |
| stats = self._file_type_stats.get(file_type, {}) | |
| total = stats.get("total", 0) | |
| if total == 0: | |
| return 0.5 | |
| return stats.get("success", 0) / total | |
| def get_protocol_success_rate(self, protocol: str) -> float: | |
| stats = self._protocol_stats.get(protocol, {}) | |
| total = stats.get("total", 0) | |
| if total == 0: | |
| return 0.5 | |
| return stats.get("success", 0) / total | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "error_patterns": dict(self._error_patterns), | |
| "file_type_stats": { | |
| ft: { | |
| "success": s["success"], | |
| "total": s["total"], | |
| "errors": dict(s["errors"]), | |
| } | |
| for ft, s in self._file_type_stats.items() | |
| }, | |
| "protocol_stats": dict(self._protocol_stats), | |
| } | |
| def from_dict(cls, d: Dict[str, Any]) -> "PatternLearner": | |
| learner = cls() | |
| learner._error_patterns = defaultdict(int, d.get("error_patterns", {})) | |
| for ft, s in d.get("file_type_stats", {}).items(): | |
| learner._file_type_stats[ft] = { | |
| "success": s.get("success", 0), | |
| "total": s.get("total", 0), | |
| "errors": defaultdict(int, s.get("errors", {})), | |
| } | |
| for proto, s in d.get("protocol_stats", {}).items(): | |
| learner._protocol_stats[proto] = { | |
| "success": s.get("success", 0), | |
| "total": s.get("total", 0), | |
| } | |
| return learner | |
| class ReinforcementLearner: | |
| def __init__(self, learning_rate: float = 0.1, discount_factor: float = 0.9): | |
| self._learning_rate = learning_rate | |
| self._discount_factor = discount_factor | |
| self._q_values: Dict[str, float] = defaultdict(lambda: 0.5) | |
| self._visit_counts: Dict[str, int] = defaultdict(int) | |
| def _get_state_key( | |
| self, | |
| protocol: str, | |
| file_type: str, | |
| generation_source: str, | |
| ) -> str: | |
| return f"{protocol}:{file_type}:{generation_source}" | |
| def get_action_value( | |
| self, | |
| protocol: str, | |
| file_type: str, | |
| generation_source: str, | |
| ) -> float: | |
| key = self._get_state_key(protocol, file_type, generation_source) | |
| return self._q_values[key] | |
| def update( | |
| self, | |
| protocol: str, | |
| file_type: str, | |
| generation_source: str, | |
| reward: float, | |
| ): | |
| key = self._get_state_key(protocol, file_type, generation_source) | |
| old_value = self._q_values[key] | |
| self._visit_counts[key] += 1 | |
| self._q_values[key] = ( | |
| old_value + self._learning_rate * (reward - old_value) | |
| ) | |
| def select_best_action( | |
| self, | |
| protocol: str, | |
| file_type: str, | |
| available_sources: List[str], | |
| epsilon: float = 0.1, | |
| ) -> Tuple[str, float]: | |
| import random | |
| if random.random() < epsilon and len(available_sources) > 1: | |
| chosen = random.choice(available_sources) | |
| return chosen, self.get_action_value(protocol, file_type, chosen) | |
| best_source = available_sources[0] | |
| best_value = -1.0 | |
| for source in available_sources: | |
| value = self.get_action_value(protocol, file_type, source) | |
| if value > best_value: | |
| best_value = value | |
| best_source = source | |
| return best_source, best_value | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "learning_rate": self._learning_rate, | |
| "discount_factor": self._discount_factor, | |
| "q_values": dict(self._q_values), | |
| "visit_counts": dict(self._visit_counts), | |
| } | |
| def from_dict(cls, d: Dict[str, Any]) -> "ReinforcementLearner": | |
| learner = cls( | |
| learning_rate=d.get("learning_rate", 0.1), | |
| discount_factor=d.get("discount_factor", 0.9), | |
| ) | |
| learner._q_values = defaultdict(lambda: 0.5) | |
| learner._q_values.update(d.get("q_values", {})) | |
| learner._visit_counts = defaultdict(int) | |
| learner._visit_counts.update(d.get("visit_counts", {})) | |
| return learner | |
| class LearningModule: | |
| def __init__(self, storage_path: Optional[str] = None): | |
| self._storage_path = storage_path | |
| self._pattern_learner = PatternLearner() | |
| self._rl_learner = ReinforcementLearner() | |
| self._history: List[GenerationHistory] = [] | |
| self._total_generations = 0 | |
| self._successful_generations = 0 | |
| if storage_path: | |
| self._load_from_storage() | |
| def record_feedback( | |
| self, | |
| design_name: str, | |
| generation_source: str, | |
| spec_dict: Dict[str, Any], | |
| validation_results: Dict[str, Any], | |
| ): | |
| import hashlib | |
| import json | |
| spec_str = json.dumps(spec_dict, sort_keys=True) | |
| spec_hash = hashlib.md5(spec_str.encode()).hexdigest()[:12] | |
| protocol = spec_dict.get("protocol", "unknown") | |
| feedback_list = [] | |
| files_data = validation_results.get("files", []) | |
| if isinstance(files_data, dict): | |
| for file_name, file_info in files_data.items(): | |
| file_type = file_info.get("type", "unknown") | |
| passed = file_info.get("passed", True) | |
| errors = file_info.get("errors", []) | |
| warnings = file_info.get("warnings", []) | |
| score = file_info.get("score", 0.5) | |
| feedback = ValidationFeedback( | |
| design_name=design_name, | |
| file_name=file_name, | |
| file_type=file_type, | |
| passed=passed, | |
| errors=errors, | |
| warnings=warnings, | |
| score=score, | |
| ) | |
| feedback_list.append(feedback) | |
| if passed: | |
| self._pattern_learner.record_success(file_type, protocol) | |
| reward = 1.0 | |
| else: | |
| for err in errors: | |
| self._pattern_learner.record_error(err, file_type) | |
| reward = -0.5 | |
| self._pattern_learner.record_attempt(file_type, protocol) | |
| self._rl_learner.update(protocol, file_type, generation_source, reward) | |
| elif isinstance(files_data, list): | |
| for file_info in files_data: | |
| file_name = file_info.get("filename", "unknown") | |
| file_type = file_info.get("file_type", "unknown") | |
| passed = file_info.get("passed", True) | |
| issues = file_info.get("issues", []) | |
| errors = [] | |
| warnings = [] | |
| for issue in issues: | |
| severity = issue.get("severity", "warning") | |
| message = issue.get("message", "") | |
| if severity == "error": | |
| errors.append(message) | |
| else: | |
| warnings.append(message) | |
| error_count = file_info.get("error_count", 0) | |
| warning_count = file_info.get("warning_count", 0) | |
| if error_count > 0: | |
| passed = False | |
| score = 1.0 if passed else 0.3 | |
| if passed and warning_count == 0: | |
| score = 1.0 | |
| elif passed and warning_count > 0: | |
| score = 0.7 | |
| feedback = ValidationFeedback( | |
| design_name=design_name, | |
| file_name=file_name, | |
| file_type=file_type, | |
| passed=passed, | |
| errors=errors, | |
| warnings=warnings, | |
| score=score, | |
| ) | |
| feedback_list.append(feedback) | |
| if passed: | |
| self._pattern_learner.record_success(file_type, protocol) | |
| reward = 1.0 | |
| else: | |
| for err in errors: | |
| self._pattern_learner.record_error(err, file_type) | |
| reward = -0.5 | |
| self._pattern_learner.record_attempt(file_type, protocol) | |
| self._rl_learner.update(protocol, file_type, generation_source, reward) | |
| all_passed = all(f.passed for f in feedback_list) | |
| avg_score = sum(f.score for f in feedback_list) / len(feedback_list) if feedback_list else 0.0 | |
| history = GenerationHistory( | |
| design_name=design_name, | |
| generation_source=generation_source, | |
| spec_hash=spec_hash, | |
| feedback_list=feedback_list, | |
| success_rate=1.0 if all_passed else 0.0, | |
| avg_score=avg_score, | |
| ) | |
| self._history.append(history) | |
| self._total_generations += 1 | |
| if all_passed: | |
| self._successful_generations += 1 | |
| if self._storage_path: | |
| self._save_to_storage() | |
| def select_best_generation_strategy( | |
| self, | |
| spec_dict: Dict[str, Any], | |
| file_type: str, | |
| available_sources: List[str], | |
| ) -> Tuple[str, float]: | |
| protocol = spec_dict.get("protocol", "unknown") | |
| best_source, best_value = self._rl_learner.select_best_action( | |
| protocol=protocol, | |
| file_type=file_type, | |
| available_sources=available_sources, | |
| epsilon=0.05, | |
| ) | |
| return best_source, best_value | |
| def get_generation_hints( | |
| self, | |
| spec_dict: Dict[str, Any], | |
| file_type: str, | |
| ) -> Dict[str, Any]: | |
| protocol = spec_dict.get("protocol", "unknown") | |
| common_errors = self._pattern_learner.get_common_errors(5) | |
| file_success_rate = self._pattern_learner.get_file_type_success_rate(file_type) | |
| protocol_success_rate = self._pattern_learner.get_protocol_success_rate(protocol) | |
| return { | |
| "common_errors": common_errors, | |
| "file_type_success_rate": file_success_rate, | |
| "protocol_success_rate": protocol_success_rate, | |
| "recommendations": self._generate_recommendations( | |
| common_errors, | |
| file_success_rate, | |
| protocol_success_rate, | |
| ), | |
| } | |
| def _generate_recommendations( | |
| self, | |
| common_errors: List[Tuple[str, int]], | |
| file_success_rate: float, | |
| protocol_success_rate: float, | |
| ) -> List[str]: | |
| recommendations = [] | |
| for error_pattern, count in common_errors[:3]: | |
| if count > 0: | |
| if "semicolon" in error_pattern: | |
| recommendations.append( | |
| "Ensure all statements end with semicolons" | |
| ) | |
| elif "parenthes" in error_pattern: | |
| recommendations.append( | |
| "Check for balanced parentheses" | |
| ) | |
| elif "brace" in error_pattern: | |
| recommendations.append( | |
| "Check for balanced begin/end blocks" | |
| ) | |
| elif "uvm_macro" in error_pattern: | |
| recommendations.append( | |
| "Add UVM factory registration macros (uvm_component_utils/uvm_object_utils)" | |
| ) | |
| elif "phase" in error_pattern: | |
| recommendations.append( | |
| "Ensure proper UVM phase implementation" | |
| ) | |
| if file_success_rate < 0.7: | |
| recommendations.append( | |
| "Consider using retrieval-based generation for this file type" | |
| ) | |
| if protocol_success_rate < 0.7: | |
| recommendations.append( | |
| "Add protocol-specific templates may improve quality" | |
| ) | |
| if not recommendations: | |
| recommendations.append( | |
| "No specific recommendations - generation should work well" | |
| ) | |
| return recommendations | |
| def get_stats(self) -> Dict[str, Any]: | |
| return { | |
| "total_generations": self._total_generations, | |
| "successful_generations": self._successful_generations, | |
| "success_rate": ( | |
| self._successful_generations / self._total_generations | |
| if self._total_generations > 0 | |
| else 0.0 | |
| ), | |
| "history_count": len(self._history), | |
| "pattern_stats": self._pattern_learner.to_dict(), | |
| } | |
| def _save_to_storage(self): | |
| if not self._storage_path: | |
| return | |
| try: | |
| os.makedirs(os.path.dirname(self._storage_path), exist_ok=True) | |
| data = { | |
| "pattern_learner": self._pattern_learner.to_dict(), | |
| "rl_learner": self._rl_learner.to_dict(), | |
| "history": [h.to_dict() for h in self._history[-100:]], | |
| "total_generations": self._total_generations, | |
| "successful_generations": self._successful_generations, | |
| "saved_at": datetime.now().isoformat(), | |
| } | |
| with open(self._storage_path, "w") as f: | |
| json.dump(data, f, indent=2) | |
| logger.debug("Learning module saved to: %s", self._storage_path) | |
| except Exception as e: | |
| logger.warning("Could not save learning module: %s", e) | |
| def _load_from_storage(self): | |
| if not self._storage_path or not os.path.exists(self._storage_path): | |
| return | |
| try: | |
| with open(self._storage_path, "r") as f: | |
| data = json.load(f) | |
| self._pattern_learner = PatternLearner.from_dict( | |
| data.get("pattern_learner", {}) | |
| ) | |
| self._rl_learner = ReinforcementLearner.from_dict( | |
| data.get("rl_learner", {}) | |
| ) | |
| history_list = data.get("history", []) | |
| for h_dict in history_list: | |
| feedback_list = [ | |
| ValidationFeedback.from_dict(f) | |
| for f in h_dict.get("feedback_list", []) | |
| ] | |
| history = GenerationHistory( | |
| design_name=h_dict.get("design_name", "unknown"), | |
| generation_source=h_dict.get("generation_source", "unknown"), | |
| spec_hash=h_dict.get("spec_hash", ""), | |
| feedback_list=feedback_list, | |
| success_rate=h_dict.get("success_rate", 0.0), | |
| avg_score=h_dict.get("avg_score", 0.0), | |
| timestamp=h_dict.get("timestamp", datetime.now().isoformat()), | |
| ) | |
| self._history.append(history) | |
| self._total_generations = data.get("total_generations", 0) | |
| self._successful_generations = data.get("successful_generations", 0) | |
| logger.info("Learning module loaded from: %s", self._storage_path) | |
| except Exception as e: | |
| logger.warning("Could not load learning module: %s", e) | |