# src/features/extractors.py — Feature extraction from design specs from __future__ import annotations from typing import Any, Dict, List, Optional from pydantic import BaseModel from src.config import DesignSpec from src.models.ml_utils import RichFeatureVector class FeatureVector(BaseModel): """Numerical / categorical features extracted from a spec for downstream use.""" interface_count: int = 0 total_signals: int = 0 register_count: int = 0 total_fields: int = 0 has_output_signals: bool = False has_input_signals: bool = False protocol_type: Optional[str] = None complexity_score: float = 0.0 model_config = {"extra": "forbid"} class SpecFeatureExtractor: """Extracts structured features from DesignSpec for analytics / ML.""" PROTOCOL_SIGNATURES = { "uart": {"tx", "rx", "baud"}, "i2c": {"scl", "sda"}, "spi": {"mosi", "miso", "sclk", "ss_n"}, "axi": {"awvalid", "awready", "arvalid", "arready", "wvalid", "wready", "rvalid", "rready", "bvalid", "bready"}, "apb": {"psel", "penable", "paddr", "pwrite"}, } def extract(self, spec: DesignSpec) -> FeatureVector: signals = [s for iface in spec.interfaces for s in iface.signals] signal_names = {s.name.lower() for s in signals} return FeatureVector( interface_count=len(spec.interfaces), total_signals=len(signals), register_count=len(spec.registers), total_fields=sum(len(r.fields) for r in spec.registers), has_output_signals=any(s.direction == "output" for s in signals), has_input_signals=any(s.direction == "input" for s in signals), protocol_type=self._detect_protocol(signal_names), complexity_score=self._compute_complexity(spec), ) @staticmethod def _detect_protocol(signal_names: set) -> Optional[str]: for proto, sigs in SpecFeatureExtractor.PROTOCOL_SIGNATURES.items(): if all(any(keyword in s for s in signal_names) for keyword in sigs): return proto return None @staticmethod def _compute_complexity(spec: DesignSpec) -> float: score = 0.0 score += len(spec.interfaces) * 1.5 score += sum(len(iface.signals) for iface in spec.interfaces) * 0.8 score += len(spec.registers) * 2.0 score += sum(len(r.fields) for r in spec.registers) * 0.5 return round(score, 2) class RichSpecFeatureExtractor: """Extracts rich features from DesignSpec for ML similarity matching.""" PROTOCOL_SIGNATURES = { "uart": {"tx", "rx", "baud"}, "i2c": {"scl", "sda"}, "spi": {"mosi", "miso", "sclk", "ss_n", "cs_n"}, "axi": {"awvalid", "awready", "arvalid", "arready", "wvalid", "wready", "rvalid", "rready", "bvalid", "bready"}, "apb": {"psel", "penable", "paddr", "pwrite", "prdata", "pwdata"}, "wishbone": {"wb_cyc", "wb_stb", "wb_ack", "wb_we", "wb_adr", "wb_dat"}, } def extract(self, spec: DesignSpec) -> RichFeatureVector: """Extract rich feature vector from a DesignSpec.""" signals = [s for iface in spec.interfaces for s in iface.signals] signal_names = {s.name.lower() for s in signals} signal_directions: Dict[str, str] = {} signal_widths: Dict[str, int] = {} all_signal_names: List[str] = [] for s in signals: all_signal_names.append(s.name) signal_directions[s.name] = s.direction signal_widths[s.name] = s.width if s.width else 1 register_names: List[str] = [] register_addresses: Dict[str, str] = {} register_fields: Dict[str, List[str]] = {} register_access: Dict[str, str] = {} for r in spec.registers: register_names.append(r.name) register_addresses[r.name] = r.address register_fields[r.name] = [f.name for f in r.fields] register_access[r.name] = r.access or "rw" interface_names = [iface.name for iface in spec.interfaces] complexity = self._compute_complexity(spec) protocol = self._detect_protocol(signal_names, spec.protocol) return RichFeatureVector( interface_count=len(spec.interfaces), total_signals=len(signals), register_count=len(spec.registers), total_fields=sum(len(r.fields) for r in spec.registers), complexity_score=complexity, protocol_type=protocol, signal_names=all_signal_names, signal_directions=signal_directions, signal_widths=signal_widths, register_names=register_names, register_addresses=register_addresses, register_fields=register_fields, register_access=register_access, interface_names=interface_names, design_name=spec.design_name, ) def _detect_protocol(self, signal_names: set, explicit_protocol: Optional[str]) -> Optional[str]: """Detect protocol with explicit override.""" if explicit_protocol: return explicit_protocol for proto, sigs in self.PROTOCOL_SIGNATURES.items(): match_count = sum(1 for keyword in sigs if any(keyword in s for s in signal_names)) if match_count >= len(sigs) * 0.5: return proto return None @staticmethod def _compute_complexity(spec: DesignSpec) -> float: score = 0.0 score += len(spec.interfaces) * 1.5 score += sum(len(iface.signals) for iface in spec.interfaces) * 0.8 score += len(spec.registers) * 2.0 score += sum(len(r.fields) for r in spec.registers) * 0.5 return round(score, 2)