File size: 5,791 Bytes
4344b33
 
 
 
 
 
 
 
 
a9127d4
4344b33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9127d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# src/features/extractors.py — Feature extraction from design specs

from __future__ import annotations

from typing import Any, Dict, List, Optional

from pydantic import BaseModel

from src.config import DesignSpec
from src.models.ml_utils import RichFeatureVector


class FeatureVector(BaseModel):
    """Numerical / categorical features extracted from a spec for downstream use."""
    interface_count: int = 0
    total_signals: int = 0
    register_count: int = 0
    total_fields: int = 0
    has_output_signals: bool = False
    has_input_signals: bool = False
    protocol_type: Optional[str] = None
    complexity_score: float = 0.0

    model_config = {"extra": "forbid"}


class SpecFeatureExtractor:
    """Extracts structured features from DesignSpec for analytics / ML."""

    PROTOCOL_SIGNATURES = {
        "uart": {"tx", "rx", "baud"},
        "i2c": {"scl", "sda"},
        "spi": {"mosi", "miso", "sclk", "ss_n"},
        "axi": {"awvalid", "awready", "arvalid", "arready", "wvalid", "wready", "rvalid", "rready", "bvalid", "bready"},
        "apb": {"psel", "penable", "paddr", "pwrite"},
    }

    def extract(self, spec: DesignSpec) -> FeatureVector:
        signals = [s for iface in spec.interfaces for s in iface.signals]
        signal_names = {s.name.lower() for s in signals}

        return FeatureVector(
            interface_count=len(spec.interfaces),
            total_signals=len(signals),
            register_count=len(spec.registers),
            total_fields=sum(len(r.fields) for r in spec.registers),
            has_output_signals=any(s.direction == "output" for s in signals),
            has_input_signals=any(s.direction == "input" for s in signals),
            protocol_type=self._detect_protocol(signal_names),
            complexity_score=self._compute_complexity(spec),
        )

    @staticmethod
    def _detect_protocol(signal_names: set) -> Optional[str]:
        for proto, sigs in SpecFeatureExtractor.PROTOCOL_SIGNATURES.items():
            if all(any(keyword in s for s in signal_names) for keyword in sigs):
                return proto
        return None

    @staticmethod
    def _compute_complexity(spec: DesignSpec) -> float:
        score = 0.0
        score += len(spec.interfaces) * 1.5
        score += sum(len(iface.signals) for iface in spec.interfaces) * 0.8
        score += len(spec.registers) * 2.0
        score += sum(len(r.fields) for r in spec.registers) * 0.5
        return round(score, 2)


class RichSpecFeatureExtractor:
    """Extracts rich features from DesignSpec for ML similarity matching."""

    PROTOCOL_SIGNATURES = {
        "uart": {"tx", "rx", "baud"},
        "i2c": {"scl", "sda"},
        "spi": {"mosi", "miso", "sclk", "ss_n", "cs_n"},
        "axi": {"awvalid", "awready", "arvalid", "arready", "wvalid", "wready", "rvalid", "rready", "bvalid", "bready"},
        "apb": {"psel", "penable", "paddr", "pwrite", "prdata", "pwdata"},
        "wishbone": {"wb_cyc", "wb_stb", "wb_ack", "wb_we", "wb_adr", "wb_dat"},
    }

    def extract(self, spec: DesignSpec) -> RichFeatureVector:
        """Extract rich feature vector from a DesignSpec."""
        signals = [s for iface in spec.interfaces for s in iface.signals]
        signal_names = {s.name.lower() for s in signals}

        signal_directions: Dict[str, str] = {}
        signal_widths: Dict[str, int] = {}
        all_signal_names: List[str] = []

        for s in signals:
            all_signal_names.append(s.name)
            signal_directions[s.name] = s.direction
            signal_widths[s.name] = s.width if s.width else 1

        register_names: List[str] = []
        register_addresses: Dict[str, str] = {}
        register_fields: Dict[str, List[str]] = {}
        register_access: Dict[str, str] = {}

        for r in spec.registers:
            register_names.append(r.name)
            register_addresses[r.name] = r.address
            register_fields[r.name] = [f.name for f in r.fields]
            register_access[r.name] = r.access or "rw"

        interface_names = [iface.name for iface in spec.interfaces]

        complexity = self._compute_complexity(spec)
        protocol = self._detect_protocol(signal_names, spec.protocol)

        return RichFeatureVector(
            interface_count=len(spec.interfaces),
            total_signals=len(signals),
            register_count=len(spec.registers),
            total_fields=sum(len(r.fields) for r in spec.registers),
            complexity_score=complexity,
            protocol_type=protocol,
            signal_names=all_signal_names,
            signal_directions=signal_directions,
            signal_widths=signal_widths,
            register_names=register_names,
            register_addresses=register_addresses,
            register_fields=register_fields,
            register_access=register_access,
            interface_names=interface_names,
            design_name=spec.design_name,
        )

    def _detect_protocol(self, signal_names: set, explicit_protocol: Optional[str]) -> Optional[str]:
        """Detect protocol with explicit override."""
        if explicit_protocol:
            return explicit_protocol

        for proto, sigs in self.PROTOCOL_SIGNATURES.items():
            match_count = sum(1 for keyword in sigs if any(keyword in s for s in signal_names))
            if match_count >= len(sigs) * 0.5:
                return proto

        return None

    @staticmethod
    def _compute_complexity(spec: DesignSpec) -> float:
        score = 0.0
        score += len(spec.interfaces) * 1.5
        score += sum(len(iface.signals) for iface in spec.interfaces) * 0.8
        score += len(spec.registers) * 2.0
        score += sum(len(r.fields) for r in spec.registers) * 0.5
        return round(score, 2)