Spaces:
Sleeping
Sleeping
File size: 5,791 Bytes
4344b33 a9127d4 4344b33 a9127d4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | # src/features/extractors.py — Feature extraction from design specs
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from src.config import DesignSpec
from src.models.ml_utils import RichFeatureVector
class FeatureVector(BaseModel):
"""Numerical / categorical features extracted from a spec for downstream use."""
interface_count: int = 0
total_signals: int = 0
register_count: int = 0
total_fields: int = 0
has_output_signals: bool = False
has_input_signals: bool = False
protocol_type: Optional[str] = None
complexity_score: float = 0.0
model_config = {"extra": "forbid"}
class SpecFeatureExtractor:
"""Extracts structured features from DesignSpec for analytics / ML."""
PROTOCOL_SIGNATURES = {
"uart": {"tx", "rx", "baud"},
"i2c": {"scl", "sda"},
"spi": {"mosi", "miso", "sclk", "ss_n"},
"axi": {"awvalid", "awready", "arvalid", "arready", "wvalid", "wready", "rvalid", "rready", "bvalid", "bready"},
"apb": {"psel", "penable", "paddr", "pwrite"},
}
def extract(self, spec: DesignSpec) -> FeatureVector:
signals = [s for iface in spec.interfaces for s in iface.signals]
signal_names = {s.name.lower() for s in signals}
return FeatureVector(
interface_count=len(spec.interfaces),
total_signals=len(signals),
register_count=len(spec.registers),
total_fields=sum(len(r.fields) for r in spec.registers),
has_output_signals=any(s.direction == "output" for s in signals),
has_input_signals=any(s.direction == "input" for s in signals),
protocol_type=self._detect_protocol(signal_names),
complexity_score=self._compute_complexity(spec),
)
@staticmethod
def _detect_protocol(signal_names: set) -> Optional[str]:
for proto, sigs in SpecFeatureExtractor.PROTOCOL_SIGNATURES.items():
if all(any(keyword in s for s in signal_names) for keyword in sigs):
return proto
return None
@staticmethod
def _compute_complexity(spec: DesignSpec) -> float:
score = 0.0
score += len(spec.interfaces) * 1.5
score += sum(len(iface.signals) for iface in spec.interfaces) * 0.8
score += len(spec.registers) * 2.0
score += sum(len(r.fields) for r in spec.registers) * 0.5
return round(score, 2)
class RichSpecFeatureExtractor:
"""Extracts rich features from DesignSpec for ML similarity matching."""
PROTOCOL_SIGNATURES = {
"uart": {"tx", "rx", "baud"},
"i2c": {"scl", "sda"},
"spi": {"mosi", "miso", "sclk", "ss_n", "cs_n"},
"axi": {"awvalid", "awready", "arvalid", "arready", "wvalid", "wready", "rvalid", "rready", "bvalid", "bready"},
"apb": {"psel", "penable", "paddr", "pwrite", "prdata", "pwdata"},
"wishbone": {"wb_cyc", "wb_stb", "wb_ack", "wb_we", "wb_adr", "wb_dat"},
}
def extract(self, spec: DesignSpec) -> RichFeatureVector:
"""Extract rich feature vector from a DesignSpec."""
signals = [s for iface in spec.interfaces for s in iface.signals]
signal_names = {s.name.lower() for s in signals}
signal_directions: Dict[str, str] = {}
signal_widths: Dict[str, int] = {}
all_signal_names: List[str] = []
for s in signals:
all_signal_names.append(s.name)
signal_directions[s.name] = s.direction
signal_widths[s.name] = s.width if s.width else 1
register_names: List[str] = []
register_addresses: Dict[str, str] = {}
register_fields: Dict[str, List[str]] = {}
register_access: Dict[str, str] = {}
for r in spec.registers:
register_names.append(r.name)
register_addresses[r.name] = r.address
register_fields[r.name] = [f.name for f in r.fields]
register_access[r.name] = r.access or "rw"
interface_names = [iface.name for iface in spec.interfaces]
complexity = self._compute_complexity(spec)
protocol = self._detect_protocol(signal_names, spec.protocol)
return RichFeatureVector(
interface_count=len(spec.interfaces),
total_signals=len(signals),
register_count=len(spec.registers),
total_fields=sum(len(r.fields) for r in spec.registers),
complexity_score=complexity,
protocol_type=protocol,
signal_names=all_signal_names,
signal_directions=signal_directions,
signal_widths=signal_widths,
register_names=register_names,
register_addresses=register_addresses,
register_fields=register_fields,
register_access=register_access,
interface_names=interface_names,
design_name=spec.design_name,
)
def _detect_protocol(self, signal_names: set, explicit_protocol: Optional[str]) -> Optional[str]:
"""Detect protocol with explicit override."""
if explicit_protocol:
return explicit_protocol
for proto, sigs in self.PROTOCOL_SIGNATURES.items():
match_count = sum(1 for keyword in sigs if any(keyword in s for s in signal_names))
if match_count >= len(sigs) * 0.5:
return proto
return None
@staticmethod
def _compute_complexity(spec: DesignSpec) -> float:
score = 0.0
score += len(spec.interfaces) * 1.5
score += sum(len(iface.signals) for iface in spec.interfaces) * 0.8
score += len(spec.registers) * 2.0
score += sum(len(r.fields) for r in spec.registers) * 0.5
return round(score, 2)
|