feat: Advanced ML V2 Model with Reinforcement Learning
Browse files## NEW FEATURES:
- Advanced Reinforcement Learning with 4 exploration strategies:
* UCB1 (Upper Confidence Bound) - default, best for exploration/exploitation
* Softmax (Boltzmann) - probabilistic based on Q-values
* Epsilon-Greedy - simple with decaying randomness
* Thompson Sampling - Bayesian approach with Beta distributions
- Experience Replay Buffer (10,000 capacity) for stable learning
- Eligibility Traces for better credit assignment in sequential tasks
- Context-Aware Pattern Learning with N-grams and Association Rules
- Deep UVM Compliance Validation:
* Factory registration checking
* Phase implementation verification
* TLM connection completeness
* Signal-direction matching
* Register field width/access validation
* Coverage model completeness checking
## BUG FIXES:
- Fixed templates to only use YAML-declared signals (removed extra modem signals: dsr_n, ri_n, dcd_n)
- Fixed sequence.sv.j2: Removed invalid 'foreach (tx_data.size())' loop
- Fixed uart_rx_seq extending wrong class (now extends base_seq)
## UI ENHANCEMENTS:
- New ML Insights tab showing:
* Learning statistics
* RL metrics (episodes, Q-values, updates)
* Pattern analysis
* Strategy distribution
- Model selection: V2 (Recommended) option
- Exploration strategy selection (UCB, Softmax, Epsilon-Greedy, Thompson)
- Learning state import/export
## FILES ADDED:
- src/models/advanced_rl_learner.py - Advanced RL with experience replay
- src/models/advanced_pattern_learner.py - Context-aware pattern mining
- src/models/advanced_code_validator.py - Deep UVM compliance validator
- src/models/enhanced_ml_model_v2.py - Integrated V2 ML model
- tests/quick_v2_test.py - Quick smoke test
- tests/test_advanced_ml_v2.py - Comprehensive test suite
## FILES MODIFIED:
- src/config.py - Added V2 model config fields
- src/pipeline.py - V2 model selection and integration
- streamlit_app.py - UI enhancements for V2 model
- src/generation/templates/*.j2 - Bug fixes and signal alignments
- src/config.py +3 -1
- src/generation/templates/interface.sv.j2 +4 -4
- src/generation/templates/rtl/protocol_core.v.j2 +4 -5
- src/generation/templates/sequence.sv.j2 +2 -2
- src/generation/templates/test.sv.j2 +0 -3
- src/generation/templates/testbench.sv.j2 +13 -19
- src/models/advanced_code_validator.py +1294 -0
- src/models/advanced_pattern_learner.py +926 -0
- src/models/advanced_rl_learner.py +728 -0
- src/models/enhanced_ml_model_v2.py +801 -0
- src/pipeline.py +30 -13
- streamlit_app.py +455 -152
- tests/quick_v2_test.py +130 -0
- tests/test_advanced_ml_v2.py +477 -0
|
@@ -88,7 +88,7 @@ class AutoTrainConfig(BaseModel):
|
|
| 88 |
class MLConfig(BaseModel):
|
| 89 |
"""Configuration for AI/ML-augmented generation with actual learning capabilities."""
|
| 90 |
enabled: bool = False
|
| 91 |
-
model_type: str = Field(default="template", pattern=r"^(template|ml|hybrid|llm|semantic)$")
|
| 92 |
similarity_threshold: float = Field(default=0.75, ge=0.0, le=1.0)
|
| 93 |
auto_learn: bool = True
|
| 94 |
index_path: Optional[str] = None
|
|
@@ -109,6 +109,8 @@ class MLConfig(BaseModel):
|
|
| 109 |
learning_rate: float = Field(default=0.1, ge=0.001, le=1.0)
|
| 110 |
reinforcement_discount: float = Field(default=0.9, ge=0.0, le=1.0)
|
| 111 |
exploration_epsilon: float = Field(default=0.05, ge=0.0, le=0.5)
|
|
|
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
class PipelineConfig(BaseModel):
|
|
|
|
| 88 |
class MLConfig(BaseModel):
|
| 89 |
"""Configuration for AI/ML-augmented generation with actual learning capabilities."""
|
| 90 |
enabled: bool = False
|
| 91 |
+
model_type: str = Field(default="template", pattern=r"^(template|ml|hybrid|llm|semantic|v2)$")
|
| 92 |
similarity_threshold: float = Field(default=0.75, ge=0.0, le=1.0)
|
| 93 |
auto_learn: bool = True
|
| 94 |
index_path: Optional[str] = None
|
|
|
|
| 109 |
learning_rate: float = Field(default=0.1, ge=0.001, le=1.0)
|
| 110 |
reinforcement_discount: float = Field(default=0.9, ge=0.0, le=1.0)
|
| 111 |
exploration_epsilon: float = Field(default=0.05, ge=0.0, le=0.5)
|
| 112 |
+
exploration_strategy: str = Field(default="ucb", pattern=r"^(epsilon_greedy|softmax|ucb|thompson)$")
|
| 113 |
+
strict_validation: bool = False
|
| 114 |
|
| 115 |
|
| 116 |
class PipelineConfig(BaseModel):
|
|
@@ -10,16 +10,16 @@ interface {{ spec.design_name }}_intf (input logic clk, input logic rst_n);
|
|
| 10 |
logic wb_ack;
|
| 11 |
// Serial
|
| 12 |
logic uart_tx, uart_rx;
|
| 13 |
-
// Modem
|
| 14 |
-
logic cts_n, rts_n
|
| 15 |
// Interrupt
|
| 16 |
logic uart_intr;
|
| 17 |
|
| 18 |
clocking drv_cb @(posedge clk);
|
| 19 |
default input #1ns output #1ns;
|
| 20 |
output wb_cyc, wb_stb, wb_we, wb_addr, wb_data_o;
|
| 21 |
-
output uart_rx, cts_n
|
| 22 |
-
input wb_ack, wb_data_i, uart_tx, uart_intr, rts_n
|
| 23 |
endclocking
|
| 24 |
|
| 25 |
{% elif p == "spi" %}
|
|
|
|
| 10 |
logic wb_ack;
|
| 11 |
// Serial
|
| 12 |
logic uart_tx, uart_rx;
|
| 13 |
+
// Modem (as per spec)
|
| 14 |
+
logic cts_n, rts_n;
|
| 15 |
// Interrupt
|
| 16 |
logic uart_intr;
|
| 17 |
|
| 18 |
clocking drv_cb @(posedge clk);
|
| 19 |
default input #1ns output #1ns;
|
| 20 |
output wb_cyc, wb_stb, wb_we, wb_addr, wb_data_o;
|
| 21 |
+
output uart_rx, cts_n;
|
| 22 |
+
input wb_ack, wb_data_i, uart_tx, uart_intr, rts_n;
|
| 23 |
endclocking
|
| 24 |
|
| 25 |
{% elif p == "spi" %}
|
|
@@ -16,9 +16,9 @@ module {{ spec.design_name }}_core (
|
|
| 16 |
// Serial
|
| 17 |
output logic uart_tx,
|
| 18 |
input logic uart_rx,
|
| 19 |
-
// Modem
|
| 20 |
-
input logic cts_n,
|
| 21 |
-
output logic rts_n,
|
| 22 |
output logic uart_intr
|
| 23 |
);
|
| 24 |
logic [7:0] reg_lcr, reg_scr;
|
|
@@ -33,8 +33,7 @@ module {{ spec.design_name }}_core (
|
|
| 33 |
3'h7: reg_scr <= wb_data_i;
|
| 34 |
endcase
|
| 35 |
assign uart_tx = uart_rx;
|
| 36 |
-
assign rts_n = 0;
|
| 37 |
-
assign out1_n = 1; assign out2_n = 1;
|
| 38 |
assign uart_intr = 0;
|
| 39 |
{% elif p == "spi" %}
|
| 40 |
// Wishbone bus
|
|
|
|
| 16 |
// Serial
|
| 17 |
output logic uart_tx,
|
| 18 |
input logic uart_rx,
|
| 19 |
+
// Modem (as per spec)
|
| 20 |
+
input logic cts_n,
|
| 21 |
+
output logic rts_n,
|
| 22 |
output logic uart_intr
|
| 23 |
);
|
| 24 |
logic [7:0] reg_lcr, reg_scr;
|
|
|
|
| 33 |
3'h7: reg_scr <= wb_data_i;
|
| 34 |
endcase
|
| 35 |
assign uart_tx = uart_rx;
|
| 36 |
+
assign rts_n = 0;
|
|
|
|
| 37 |
assign uart_intr = 0;
|
| 38 |
{% elif p == "spi" %}
|
| 39 |
// Wishbone bus
|
|
@@ -185,7 +185,7 @@ class uart_tx_seq extends {{ spec.design_name }}_base_seq;
|
|
| 185 |
|
| 186 |
`uvm_info("UART_TX", $sformatf("Transmitting %0d bytes: %p", num_bytes, tx_data), UVM_MEDIUM)
|
| 187 |
|
| 188 |
-
|
| 189 |
wait_for_tx_empty();
|
| 190 |
|
| 191 |
if (reg_model) begin
|
|
@@ -221,7 +221,7 @@ class uart_tx_seq extends {{ spec.design_name }}_base_seq;
|
|
| 221 |
endtask
|
| 222 |
endclass
|
| 223 |
|
| 224 |
-
class uart_rx_seq extends {{ spec.design_name }}
|
| 225 |
`uvm_object_utils(uart_rx_seq)
|
| 226 |
|
| 227 |
logic [7:0] rx_data[$];
|
|
|
|
| 185 |
|
| 186 |
`uvm_info("UART_TX", $sformatf("Transmitting %0d bytes: %p", num_bytes, tx_data), UVM_MEDIUM)
|
| 187 |
|
| 188 |
+
for (int i = 0; i < num_bytes; i++) begin
|
| 189 |
wait_for_tx_empty();
|
| 190 |
|
| 191 |
if (reg_model) begin
|
|
|
|
| 221 |
endtask
|
| 222 |
endclass
|
| 223 |
|
| 224 |
+
class uart_rx_seq extends {{ spec.design_name }}_base_seq;
|
| 225 |
`uvm_object_utils(uart_rx_seq)
|
| 226 |
|
| 227 |
logic [7:0] rx_data[$];
|
|
@@ -57,9 +57,6 @@ class {{ spec.design_name }}_base_test extends uvm_test;
|
|
| 57 |
`uvm_info("TEST", "Starting test...", UVM_LOW)
|
| 58 |
vif.uart_rx <= 1'b1;
|
| 59 |
vif.cts_n <= 1'b0;
|
| 60 |
-
vif.dsr_n <= 1'b0;
|
| 61 |
-
vif.ri_n <= 1'b1;
|
| 62 |
-
vif.dcd_n <= 1'b1;
|
| 63 |
run_top_sequence();
|
| 64 |
phase.drop_objection(this);
|
| 65 |
endtask
|
|
|
|
| 57 |
`uvm_info("TEST", "Starting test...", UVM_LOW)
|
| 58 |
vif.uart_rx <= 1'b1;
|
| 59 |
vif.cts_n <= 1'b0;
|
|
|
|
|
|
|
|
|
|
| 60 |
run_top_sequence();
|
| 61 |
phase.drop_objection(this);
|
| 62 |
endtask
|
|
@@ -13,25 +13,19 @@ module testbench;
|
|
| 13 |
{{ spec.design_name }}_core dut (
|
| 14 |
.clk (clk),
|
| 15 |
.rst_n(rst_n),
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
.dtr_n (intf.dtr_n),
|
| 30 |
-
.ri_n (intf.ri_n),
|
| 31 |
-
.dcd_n (intf.dcd_n),
|
| 32 |
-
.out1_n (intf.out1_n),
|
| 33 |
-
.out2_n (intf.out2_n),
|
| 34 |
-
.uart_intr (intf.uart_intr)
|
| 35 |
{% elif p == "spi" %}
|
| 36 |
.wb_cyc (intf.wb_cyc),
|
| 37 |
.wb_stb (intf.wb_stb),
|
|
|
|
| 13 |
{{ spec.design_name }}_core dut (
|
| 14 |
.clk (clk),
|
| 15 |
.rst_n(rst_n),
|
| 16 |
+
{% if p == "uart" %}
|
| 17 |
+
.wb_cyc (intf.wb_cyc),
|
| 18 |
+
.wb_stb (intf.wb_stb),
|
| 19 |
+
.wb_we (intf.wb_we),
|
| 20 |
+
.wb_addr (intf.wb_addr),
|
| 21 |
+
.wb_data_i (intf.wb_data_o),
|
| 22 |
+
.wb_data_o (intf.wb_data_i),
|
| 23 |
+
.wb_ack (intf.wb_ack),
|
| 24 |
+
.uart_tx (intf.uart_tx),
|
| 25 |
+
.uart_rx (intf.uart_rx),
|
| 26 |
+
.cts_n (intf.cts_n),
|
| 27 |
+
.rts_n (intf.rts_n),
|
| 28 |
+
.uart_intr (intf.uart_intr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
{% elif p == "spi" %}
|
| 30 |
.wb_cyc (intf.wb_cyc),
|
| 31 |
.wb_stb (intf.wb_stb),
|
|
@@ -0,0 +1,1294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Advanced Code Validator for UVM Testbench Generation.
|
| 3 |
+
|
| 4 |
+
Key improvements for promotion:
|
| 5 |
+
1. Deep UVM compliance checking with factory registration validation
|
| 6 |
+
2. Signal-direction matching validation
|
| 7 |
+
3. Register field width and access validation
|
| 8 |
+
4. Phase implementation completeness checking
|
| 9 |
+
5. TLM connection completeness validation
|
| 10 |
+
6. Compile-ready validation with SV syntax rules
|
| 11 |
+
7. Context-aware error detection with fix suggestions
|
| 12 |
+
8. Spec compliance with hierarchical signal checking
|
| 13 |
+
9. Coverage completeness checking
|
| 14 |
+
10. Scoreboard/TLM connection validation
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
import re
|
| 21 |
+
from dataclasses import dataclass, field
|
| 22 |
+
from enum import Enum
|
| 23 |
+
from typing import Any, Dict, List, Optional, Set, Tuple, Pattern
|
| 24 |
+
from collections import defaultdict, Counter
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger("uvmgen.validator.advanced")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ValidationSeverity(Enum):
|
| 30 |
+
ERROR = "error"
|
| 31 |
+
WARNING = "warning"
|
| 32 |
+
INFO = "info"
|
| 33 |
+
STYLE = "style"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class ValidationIssue:
|
| 38 |
+
severity: ValidationSeverity
|
| 39 |
+
code: str
|
| 40 |
+
message: str
|
| 41 |
+
line_number: Optional[int] = None
|
| 42 |
+
context: Optional[str] = None
|
| 43 |
+
suggestion: Optional[str] = None
|
| 44 |
+
auto_fixable: bool = False
|
| 45 |
+
confidence: float = 1.0
|
| 46 |
+
|
| 47 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 48 |
+
return {
|
| 49 |
+
"severity": self.severity.value,
|
| 50 |
+
"code": self.code,
|
| 51 |
+
"message": self.message,
|
| 52 |
+
"line_number": self.line_number,
|
| 53 |
+
"context": self.context,
|
| 54 |
+
"suggestion": self.suggestion,
|
| 55 |
+
"auto_fixable": self.auto_fixable,
|
| 56 |
+
"confidence": self.confidence,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class FileValidationResult:
|
| 62 |
+
filename: str
|
| 63 |
+
file_type: str
|
| 64 |
+
passed: bool
|
| 65 |
+
issues: List[ValidationIssue] = field(default_factory=list)
|
| 66 |
+
checks_run: int = 0
|
| 67 |
+
checks_passed: int = 0
|
| 68 |
+
score: float = 0.0
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def error_count(self) -> int:
|
| 72 |
+
return sum(1 for i in self.issues if i.severity == ValidationSeverity.ERROR)
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def warning_count(self) -> int:
|
| 76 |
+
return sum(1 for i in self.issues if i.severity == ValidationSeverity.WARNING)
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def info_count(self) -> int:
|
| 80 |
+
return sum(1 for i in self.issues if i.severity == ValidationSeverity.INFO)
|
| 81 |
+
|
| 82 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 83 |
+
return {
|
| 84 |
+
"filename": self.filename,
|
| 85 |
+
"file_type": self.file_type,
|
| 86 |
+
"passed": self.passed,
|
| 87 |
+
"error_count": self.error_count,
|
| 88 |
+
"warning_count": self.warning_count,
|
| 89 |
+
"info_count": self.info_count,
|
| 90 |
+
"checks_run": self.checks_run,
|
| 91 |
+
"checks_passed": self.checks_passed,
|
| 92 |
+
"score": self.score,
|
| 93 |
+
"issues": [i.to_dict() for i in self.issues],
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@dataclass
|
| 98 |
+
class ValidationReport:
|
| 99 |
+
design_name: str
|
| 100 |
+
overall_passed: bool
|
| 101 |
+
files: List[FileValidationResult] = field(default_factory=list)
|
| 102 |
+
timestamp: str = ""
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def total_errors(self) -> int:
|
| 106 |
+
return sum(f.error_count for f in self.files)
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def total_warnings(self) -> int:
|
| 110 |
+
return sum(f.warning_count for f in self.files)
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def total_checks_run(self) -> int:
|
| 114 |
+
return sum(f.checks_run for f in self.files)
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def total_checks_passed(self) -> int:
|
| 118 |
+
return sum(f.checks_passed for f in self.files)
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def pass_rate(self) -> float:
|
| 122 |
+
if self.total_checks_run == 0:
|
| 123 |
+
return 1.0
|
| 124 |
+
return self.total_checks_passed / self.total_checks_run
|
| 125 |
+
|
| 126 |
+
@property
|
| 127 |
+
def avg_score(self) -> float:
|
| 128 |
+
if not self.files:
|
| 129 |
+
return 0.0
|
| 130 |
+
return sum(f.score for f in self.files) / len(self.files)
|
| 131 |
+
|
| 132 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 133 |
+
return {
|
| 134 |
+
"design_name": self.design_name,
|
| 135 |
+
"overall_passed": self.overall_passed,
|
| 136 |
+
"total_errors": self.total_errors,
|
| 137 |
+
"total_warnings": self.total_warnings,
|
| 138 |
+
"total_checks_run": self.total_checks_run,
|
| 139 |
+
"total_checks_passed": self.total_checks_passed,
|
| 140 |
+
"pass_rate": round(self.pass_rate * 100, 1),
|
| 141 |
+
"avg_score": round(self.avg_score, 3),
|
| 142 |
+
"files": [f.to_dict() for f in self.files],
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class UVMComplianceChecker:
|
| 147 |
+
"""Deep UVM compliance checking."""
|
| 148 |
+
|
| 149 |
+
UVM_BASE_CLASSES = {
|
| 150 |
+
"uvm_test", "uvm_env", "uvm_agent", "uvm_driver", "uvm_monitor",
|
| 151 |
+
"uvm_sequencer", "uvm_sequence", "uvm_sequence_item", "uvm_scoreboard",
|
| 152 |
+
"uvm_subscriber", "uvm_reg_block", "uvm_reg", "uvm_reg_field",
|
| 153 |
+
"uvm_reg_map", "uvm_reg_adapter", "uvm_reg_predictor",
|
| 154 |
+
"uvm_analysis_port", "uvm_analysis_imp", "uvm_tlm_fifo",
|
| 155 |
+
"uvm_component", "uvm_object", "uvm_report_object",
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
UVM_PHASES = [
|
| 159 |
+
"build_phase", "connect_phase", "end_of_elaboration_phase",
|
| 160 |
+
"start_of_simulation_phase", "run_phase", "extract_phase",
|
| 161 |
+
"check_phase", "report_phase", "final_phase",
|
| 162 |
+
]
|
| 163 |
+
|
| 164 |
+
REQUIRED_PHASES_BY_TYPE = {
|
| 165 |
+
"test": {"build_phase", "run_phase"},
|
| 166 |
+
"env": {"build_phase", "connect_phase"},
|
| 167 |
+
"agent": {"build_phase", "connect_phase"},
|
| 168 |
+
"driver": {"build_phase", "run_phase"},
|
| 169 |
+
"monitor": {"build_phase", "run_phase"},
|
| 170 |
+
"scoreboard": {"build_phase", "connect_phase"},
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
def __init__(self):
|
| 174 |
+
self._patterns = self._compile_patterns()
|
| 175 |
+
|
| 176 |
+
def _compile_patterns(self) -> Dict[str, Pattern]:
|
| 177 |
+
return {
|
| 178 |
+
"class_decl": re.compile(r'\bclass\s+(\w+)\s*(?:#\s*\(\s*[^)]*\)\s*)?(?:extends\s+(\w+))?'),
|
| 179 |
+
"extends_uvm": re.compile(r'\bextends\s+(uvm_\w+)'),
|
| 180 |
+
"uvm_component_utils": re.compile(r'`uvm_component_utils\s*\(\s*(\w+)\s*\)'),
|
| 181 |
+
"uvm_object_utils": re.compile(r'`uvm_object_utils\s*\(\s*(\w+)\s*\)'),
|
| 182 |
+
"uvm_field_utils": re.compile(r'`uvm_field_\w+\s*\('),
|
| 183 |
+
"phase_decl": re.compile(r'\b(virtual\s+)?(function|task)\s+(\w+_phase)\s*\('),
|
| 184 |
+
"config_db_set": re.compile(r'uvm_config_db\s*#\s*<\s*([^>]+)\s*>\s*::\s*set\s*\('),
|
| 185 |
+
"config_db_get": re.compile(r'uvm_config_db\s*#\s*<\s*([^>]+)\s*>\s*::\s*get\s*\('),
|
| 186 |
+
"analysis_port_decl": re.compile(r'\buvm_analysis_port\s*#\s*<\s*(\w+)\s*>\s*(\w+)'),
|
| 187 |
+
"analysis_imp_decl": re.compile(r'\buvm_analysis_imp\s*#\s*<\s*(\w+)\s*,\s*(\w+)\s*>\s*(\w+)'),
|
| 188 |
+
"tlm_fifo_decl": re.compile(r'\buvm_tlm_(analysis_)?fifo\s*#\s*<\s*(\w+)\s*>\s*(\w+)'),
|
| 189 |
+
"raise_objection": re.compile(r'\braise_objection\s*\('),
|
| 190 |
+
"drop_objection": re.compile(r'\bdrop_objection\s*\('),
|
| 191 |
+
"seq_item_port_decl": re.compile(r'\buvm_seq_item_pull_port\s*#\s*<\s*(\w+)\s*>\s*(\w+)'),
|
| 192 |
+
"seq_item_port_get": re.compile(r'\bseq_item_port\s*\.\s*(get_next_item|get|peek)\s*\('),
|
| 193 |
+
"seq_item_port_done": re.compile(r'\bseq_item_port\s*\.\s*item_done\s*\('),
|
| 194 |
+
"type_id_create": re.compile(r'\b(\w+)\s*::\s*type_id\s*::\s*create\s*\('),
|
| 195 |
+
"reg_model_decl": re.compile(r'\b(\w+_reg_block)\s+(\w+)'),
|
| 196 |
+
"reg_write": re.compile(r'\breg_model\s*\.\s*(\w+)\s*\.\s*write\s*\('),
|
| 197 |
+
"reg_read": re.compile(r'\breg_model\s*\.\s*(\w+)\s*\.\s*read\s*\('),
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
def check_uvm_compliance(
|
| 201 |
+
self,
|
| 202 |
+
content: str,
|
| 203 |
+
file_type: str,
|
| 204 |
+
lines: List[str],
|
| 205 |
+
) -> List[ValidationIssue]:
|
| 206 |
+
"""Check deep UVM compliance."""
|
| 207 |
+
issues: List[ValidationIssue] = []
|
| 208 |
+
|
| 209 |
+
class_decl = self._patterns["class_decl"].search(content)
|
| 210 |
+
if not class_decl:
|
| 211 |
+
return issues
|
| 212 |
+
|
| 213 |
+
class_name = class_decl.group(1)
|
| 214 |
+
extends_match = self._patterns["extends_uvm"].search(content)
|
| 215 |
+
|
| 216 |
+
is_uvm_class = extends_match or any(uvm_base in content for uvm_base in self.UVM_BASE_CLASSES)
|
| 217 |
+
|
| 218 |
+
if not is_uvm_class:
|
| 219 |
+
return issues
|
| 220 |
+
|
| 221 |
+
parent_class = extends_match.group(1) if extends_match else "unknown"
|
| 222 |
+
|
| 223 |
+
issues.extend(self._check_factory_registration(
|
| 224 |
+
content, class_name, parent_class, lines
|
| 225 |
+
))
|
| 226 |
+
|
| 227 |
+
issues.extend(self._check_phase_implementation(
|
| 228 |
+
content, file_type, class_name, lines
|
| 229 |
+
))
|
| 230 |
+
|
| 231 |
+
issues.extend(self._check_component_specific(
|
| 232 |
+
content, file_type, parent_class, lines
|
| 233 |
+
))
|
| 234 |
+
|
| 235 |
+
issues.extend(self._check_objection_handling(
|
| 236 |
+
content, file_type, lines
|
| 237 |
+
))
|
| 238 |
+
|
| 239 |
+
return issues
|
| 240 |
+
|
| 241 |
+
def _check_factory_registration(
|
| 242 |
+
self,
|
| 243 |
+
content: str,
|
| 244 |
+
class_name: str,
|
| 245 |
+
parent_class: str,
|
| 246 |
+
lines: List[str],
|
| 247 |
+
) -> List[ValidationIssue]:
|
| 248 |
+
"""Check proper UVM factory registration."""
|
| 249 |
+
issues: List[ValidationIssue] = []
|
| 250 |
+
|
| 251 |
+
is_component = any(c in parent_class for c in [
|
| 252 |
+
"test", "env", "agent", "driver", "monitor", "scoreboard",
|
| 253 |
+
"sequencer", "subscriber", "component"
|
| 254 |
+
])
|
| 255 |
+
is_object = any(o in parent_class for o in [
|
| 256 |
+
"sequence", "sequence_item", "reg", "object"
|
| 257 |
+
])
|
| 258 |
+
|
| 259 |
+
if not (is_component or is_object):
|
| 260 |
+
return issues
|
| 261 |
+
|
| 262 |
+
component_utils = self._patterns["uvm_component_utils"].search(content)
|
| 263 |
+
object_utils = self._patterns["uvm_object_utils"].search(content)
|
| 264 |
+
|
| 265 |
+
if is_component:
|
| 266 |
+
if not component_utils:
|
| 267 |
+
line_num = self._find_class_line(class_name, lines)
|
| 268 |
+
issues.append(ValidationIssue(
|
| 269 |
+
severity=ValidationSeverity.ERROR,
|
| 270 |
+
code="UVM-FACTORY-001",
|
| 271 |
+
message=f"Component class '{class_name}' missing `uvm_component_utils macro",
|
| 272 |
+
line_number=line_num,
|
| 273 |
+
suggestion=f"Add `uvm_component_utils({class_name}) after the class declaration",
|
| 274 |
+
auto_fixable=True,
|
| 275 |
+
confidence=0.95,
|
| 276 |
+
))
|
| 277 |
+
elif component_utils.group(1) != class_name:
|
| 278 |
+
issues.append(ValidationIssue(
|
| 279 |
+
severity=ValidationSeverity.ERROR,
|
| 280 |
+
code="UVM-FACTORY-002",
|
| 281 |
+
message=f"uvm_component_utils has wrong class name: expected '{class_name}', got '{component_utils.group(1)}'",
|
| 282 |
+
suggestion=f"Change `uvm_component_utils({component_utils.group(1)}) to `uvm_component_utils({class_name})",
|
| 283 |
+
auto_fixable=True,
|
| 284 |
+
confidence=0.9,
|
| 285 |
+
))
|
| 286 |
+
|
| 287 |
+
if is_object and not is_component:
|
| 288 |
+
if not object_utils:
|
| 289 |
+
line_num = self._find_class_line(class_name, lines)
|
| 290 |
+
issues.append(ValidationIssue(
|
| 291 |
+
severity=ValidationSeverity.ERROR,
|
| 292 |
+
code="UVM-FACTORY-003",
|
| 293 |
+
message=f"Object class '{class_name}' missing `uvm_object_utils macro",
|
| 294 |
+
line_number=line_num,
|
| 295 |
+
suggestion=f"Add `uvm_object_utils({class_name}) after the class declaration",
|
| 296 |
+
auto_fixable=True,
|
| 297 |
+
confidence=0.95,
|
| 298 |
+
))
|
| 299 |
+
elif object_utils.group(1) != class_name:
|
| 300 |
+
issues.append(ValidationIssue(
|
| 301 |
+
severity=ValidationSeverity.ERROR,
|
| 302 |
+
code="UVM-FACTORY-004",
|
| 303 |
+
message=f"uvm_object_utils has wrong class name: expected '{class_name}', got '{object_utils.group(1)}'",
|
| 304 |
+
suggestion=f"Change `uvm_object_utils({object_utils.group(1)}) to `uvm_object_utils({class_name})",
|
| 305 |
+
auto_fixable=True,
|
| 306 |
+
confidence=0.9,
|
| 307 |
+
))
|
| 308 |
+
|
| 309 |
+
if component_utils or object_utils:
|
| 310 |
+
issues.append(ValidationIssue(
|
| 311 |
+
severity=ValidationSeverity.INFO,
|
| 312 |
+
code="UVM-FACTORY-OK",
|
| 313 |
+
message=f"Class '{class_name}' properly registered with UVM factory",
|
| 314 |
+
confidence=1.0,
|
| 315 |
+
))
|
| 316 |
+
|
| 317 |
+
return issues
|
| 318 |
+
|
| 319 |
+
def _check_phase_implementation(
|
| 320 |
+
self,
|
| 321 |
+
content: str,
|
| 322 |
+
file_type: str,
|
| 323 |
+
class_name: str,
|
| 324 |
+
lines: List[str],
|
| 325 |
+
) -> List[ValidationIssue]:
|
| 326 |
+
"""Check UVM phase implementation completeness."""
|
| 327 |
+
issues: List[ValidationIssue] = []
|
| 328 |
+
|
| 329 |
+
found_phases: Set[str] = set()
|
| 330 |
+
phase_lines: Dict[str, int] = {}
|
| 331 |
+
|
| 332 |
+
for i, line in enumerate(lines, 1):
|
| 333 |
+
phase_match = self._patterns["phase_decl"].search(line)
|
| 334 |
+
if phase_match:
|
| 335 |
+
phase_name = phase_match.group(3)
|
| 336 |
+
if phase_name in self.UVM_PHASES:
|
| 337 |
+
found_phases.add(phase_name)
|
| 338 |
+
phase_lines[phase_name] = i
|
| 339 |
+
|
| 340 |
+
required_phases = self.REQUIRED_PHASES_BY_TYPE.get(file_type, set())
|
| 341 |
+
missing_phases = required_phases - found_phases
|
| 342 |
+
|
| 343 |
+
for phase in sorted(missing_phases):
|
| 344 |
+
issues.append(ValidationIssue(
|
| 345 |
+
severity=ValidationSeverity.WARNING,
|
| 346 |
+
code="UVM-PHASE-001",
|
| 347 |
+
message=f"Class '{class_name}' may be missing {phase} implementation",
|
| 348 |
+
suggestion=f"Consider implementing {phase} for proper UVM component behavior",
|
| 349 |
+
auto_fixable=False,
|
| 350 |
+
confidence=0.7,
|
| 351 |
+
))
|
| 352 |
+
|
| 353 |
+
if "run_phase" in found_phases:
|
| 354 |
+
issues.append(ValidationIssue(
|
| 355 |
+
severity=ValidationSeverity.INFO,
|
| 356 |
+
code="UVM-PHASE-OK",
|
| 357 |
+
message=f"Class '{class_name}' implements run_phase",
|
| 358 |
+
confidence=1.0,
|
| 359 |
+
))
|
| 360 |
+
|
| 361 |
+
if "build_phase" in found_phases and "connect_phase" in found_phases:
|
| 362 |
+
issues.append(ValidationIssue(
|
| 363 |
+
severity=ValidationSeverity.INFO,
|
| 364 |
+
code="UVM-PHASE-STRUCTURE",
|
| 365 |
+
message=f"Class '{class_name}' has proper build/connect phase structure",
|
| 366 |
+
confidence=1.0,
|
| 367 |
+
))
|
| 368 |
+
|
| 369 |
+
return issues
|
| 370 |
+
|
| 371 |
+
def _check_component_specific(
|
| 372 |
+
self,
|
| 373 |
+
content: str,
|
| 374 |
+
file_type: str,
|
| 375 |
+
parent_class: str,
|
| 376 |
+
lines: List[str],
|
| 377 |
+
) -> List[ValidationIssue]:
|
| 378 |
+
"""Check component-specific UVM patterns."""
|
| 379 |
+
issues: List[ValidationIssue] = []
|
| 380 |
+
|
| 381 |
+
if "driver" in file_type or "driver" in parent_class.lower():
|
| 382 |
+
seq_item_port = self._patterns["seq_item_port_decl"].search(content)
|
| 383 |
+
if not seq_item_port:
|
| 384 |
+
issues.append(ValidationIssue(
|
| 385 |
+
severity=ValidationSeverity.WARNING,
|
| 386 |
+
code="UVM-DRIVER-001",
|
| 387 |
+
message="Driver should declare seq_item_port",
|
| 388 |
+
suggestion="Add: uvm_seq_item_pull_port #(seq_item_type) seq_item_port",
|
| 389 |
+
auto_fixable=False,
|
| 390 |
+
confidence=0.8,
|
| 391 |
+
))
|
| 392 |
+
else:
|
| 393 |
+
get_next_item = self._patterns["seq_item_port_get"].search(content)
|
| 394 |
+
item_done = self._patterns["seq_item_port_done"].search(content)
|
| 395 |
+
|
| 396 |
+
if not get_next_item:
|
| 397 |
+
issues.append(ValidationIssue(
|
| 398 |
+
severity=ValidationSeverity.WARNING,
|
| 399 |
+
code="UVM-DRIVER-002",
|
| 400 |
+
message="Driver should call seq_item_port.get_next_item()",
|
| 401 |
+
suggestion="Use seq_item_port.get_next_item(req) to retrieve sequence items",
|
| 402 |
+
confidence=0.75,
|
| 403 |
+
))
|
| 404 |
+
|
| 405 |
+
if not item_done:
|
| 406 |
+
issues.append(ValidationIssue(
|
| 407 |
+
severity=ValidationSeverity.WARNING,
|
| 408 |
+
code="UVM-DRIVER-003",
|
| 409 |
+
message="Driver should call seq_item_port.item_done()",
|
| 410 |
+
suggestion="Use seq_item_port.item_done() after processing each item",
|
| 411 |
+
confidence=0.75,
|
| 412 |
+
))
|
| 413 |
+
|
| 414 |
+
if "monitor" in file_type or "monitor" in parent_class.lower():
|
| 415 |
+
analysis_port = self._patterns["analysis_port_decl"].search(content)
|
| 416 |
+
if not analysis_port:
|
| 417 |
+
issues.append(ValidationIssue(
|
| 418 |
+
severity=ValidationSeverity.WARNING,
|
| 419 |
+
code="UVM-MONITOR-001",
|
| 420 |
+
message="Monitor should declare an analysis_port",
|
| 421 |
+
suggestion="Add: uvm_analysis_port #(item_type) analysis_port",
|
| 422 |
+
auto_fixable=False,
|
| 423 |
+
confidence=0.8,
|
| 424 |
+
))
|
| 425 |
+
else:
|
| 426 |
+
write_call = re.search(r'\b' + re.escape(analysis_port.group(2)) + r'\s*\.\s*write\s*\(', content)
|
| 427 |
+
if not write_call:
|
| 428 |
+
issues.append(ValidationIssue(
|
| 429 |
+
severity=ValidationSeverity.WARNING,
|
| 430 |
+
code="UVM-MONITOR-002",
|
| 431 |
+
message=f"Monitor should call {analysis_port.group(2)}.write()",
|
| 432 |
+
suggestion=f"Call {analysis_port.group(2)}.write(item) for each collected transaction",
|
| 433 |
+
confidence=0.75,
|
| 434 |
+
))
|
| 435 |
+
|
| 436 |
+
if "scoreboard" in file_type or "subscriber" in parent_class.lower():
|
| 437 |
+
analysis_imp = self._patterns["analysis_imp_decl"].search(content)
|
| 438 |
+
if analysis_imp:
|
| 439 |
+
write_method = re.search(r'\bfunction\s+void\s+write\s*\(\s*' + re.escape(analysis_imp.group(1)) + r'\s+', content)
|
| 440 |
+
if not write_method:
|
| 441 |
+
issues.append(ValidationIssue(
|
| 442 |
+
severity=ValidationSeverity.WARNING,
|
| 443 |
+
code="UVM-SCB-001",
|
| 444 |
+
message="Scoreboard/subscriber should implement write() function",
|
| 445 |
+
suggestion=f"Add: function void write({analysis_imp.group(1)} item)",
|
| 446 |
+
confidence=0.8,
|
| 447 |
+
))
|
| 448 |
+
|
| 449 |
+
return issues
|
| 450 |
+
|
| 451 |
+
def _check_objection_handling(
|
| 452 |
+
self,
|
| 453 |
+
content: str,
|
| 454 |
+
file_type: str,
|
| 455 |
+
lines: List[str],
|
| 456 |
+
) -> List[ValidationIssue]:
|
| 457 |
+
"""Check objection handling in tests and sequences."""
|
| 458 |
+
issues: List[ValidationIssue] = []
|
| 459 |
+
|
| 460 |
+
if file_type not in ("test", "sequence"):
|
| 461 |
+
return issues
|
| 462 |
+
|
| 463 |
+
has_raise = self._patterns["raise_objection"].search(content)
|
| 464 |
+
has_drop = self._patterns["drop_objection"].search(content)
|
| 465 |
+
|
| 466 |
+
if file_type == "test":
|
| 467 |
+
if not has_raise:
|
| 468 |
+
issues.append(ValidationIssue(
|
| 469 |
+
severity=ValidationSeverity.WARNING,
|
| 470 |
+
code="UVM-OBJECTION-001",
|
| 471 |
+
message="Test should raise objection in run_phase",
|
| 472 |
+
suggestion="Add: phase.raise_objection(this) at start of run_phase",
|
| 473 |
+
auto_fixable=False,
|
| 474 |
+
confidence=0.85,
|
| 475 |
+
))
|
| 476 |
+
|
| 477 |
+
if not has_drop:
|
| 478 |
+
issues.append(ValidationIssue(
|
| 479 |
+
severity=ValidationSeverity.WARNING,
|
| 480 |
+
code="UVM-OBJECTION-002",
|
| 481 |
+
message="Test should drop objection in run_phase",
|
| 482 |
+
suggestion="Add: phase.drop_objection(this) at end of run_phase",
|
| 483 |
+
auto_fixable=False,
|
| 484 |
+
confidence=0.85,
|
| 485 |
+
))
|
| 486 |
+
|
| 487 |
+
if has_raise and has_drop:
|
| 488 |
+
issues.append(ValidationIssue(
|
| 489 |
+
severity=ValidationSeverity.INFO,
|
| 490 |
+
code="UVM-OBJECTION-OK",
|
| 491 |
+
message="Test has proper objection handling (raise/drop)",
|
| 492 |
+
confidence=1.0,
|
| 493 |
+
))
|
| 494 |
+
|
| 495 |
+
return issues
|
| 496 |
+
|
| 497 |
+
@staticmethod
|
| 498 |
+
def _find_class_line(class_name: str, lines: List[str]) -> Optional[int]:
|
| 499 |
+
"""Find the line number of a class declaration."""
|
| 500 |
+
pattern = re.compile(r'\bclass\s+' + re.escape(class_name) + r'\b')
|
| 501 |
+
for i, line in enumerate(lines, 1):
|
| 502 |
+
if pattern.search(line):
|
| 503 |
+
return i
|
| 504 |
+
return None
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
class SpecComplianceChecker:
|
| 508 |
+
"""Advanced spec compliance checking."""
|
| 509 |
+
|
| 510 |
+
def __init__(self, spec_dict: Dict[str, Any]):
|
| 511 |
+
self.spec = spec_dict
|
| 512 |
+
self.design_name = spec_dict.get("design_name", "unknown")
|
| 513 |
+
self._extract_signals()
|
| 514 |
+
self._extract_registers()
|
| 515 |
+
self._extract_clock_reset()
|
| 516 |
+
|
| 517 |
+
def _extract_signals(self) -> None:
|
| 518 |
+
self.all_signals: Set[str] = set()
|
| 519 |
+
self.signals_by_direction: Dict[str, Set[str]] = {
|
| 520 |
+
"input": set(), "output": set(), "inout": set(),
|
| 521 |
+
}
|
| 522 |
+
self.signal_widths: Dict[str, int] = {}
|
| 523 |
+
self.signal_interfaces: Dict[str, str] = {}
|
| 524 |
+
|
| 525 |
+
for iface in self.spec.get("interfaces", []):
|
| 526 |
+
iface_name = iface.get("name", "unknown")
|
| 527 |
+
for sig in iface.get("signals", []):
|
| 528 |
+
name = sig.get("name", "")
|
| 529 |
+
if name:
|
| 530 |
+
self.all_signals.add(name)
|
| 531 |
+
direction = sig.get("direction", "input")
|
| 532 |
+
self.signals_by_direction.get(direction, set()).add(name)
|
| 533 |
+
self.signal_widths[name] = sig.get("width", 1)
|
| 534 |
+
self.signal_interfaces[name] = iface_name
|
| 535 |
+
|
| 536 |
+
def _extract_registers(self) -> None:
|
| 537 |
+
self.all_registers: Set[str] = set()
|
| 538 |
+
self.register_addresses: Dict[str, str] = {}
|
| 539 |
+
self.register_fields: Dict[str, Dict[str, Dict[str, Any]]] = {}
|
| 540 |
+
self.register_access: Dict[str, str] = {}
|
| 541 |
+
|
| 542 |
+
for reg in self.spec.get("registers", []):
|
| 543 |
+
name = reg.get("name", "")
|
| 544 |
+
if name:
|
| 545 |
+
self.all_registers.add(name)
|
| 546 |
+
self.register_addresses[name] = reg.get("address", "")
|
| 547 |
+
self.register_access[name] = reg.get("access", "rw")
|
| 548 |
+
|
| 549 |
+
fields: Dict[str, Dict[str, Any]] = {}
|
| 550 |
+
for field in reg.get("fields", []):
|
| 551 |
+
field_name = field.get("name", "")
|
| 552 |
+
if field_name:
|
| 553 |
+
fields[field_name] = {
|
| 554 |
+
"bits": field.get("bits", "0"),
|
| 555 |
+
"description": field.get("description", ""),
|
| 556 |
+
}
|
| 557 |
+
self.register_fields[name] = fields
|
| 558 |
+
|
| 559 |
+
def _extract_clock_reset(self) -> None:
|
| 560 |
+
cr = self.spec.get("clock_reset", {})
|
| 561 |
+
self.clock_signal = cr.get("clock", "clk")
|
| 562 |
+
self.reset_signal = cr.get("reset", "rst_n")
|
| 563 |
+
self.reset_active = cr.get("reset_active", 0)
|
| 564 |
+
|
| 565 |
+
def check_spec_compliance(
|
| 566 |
+
self,
|
| 567 |
+
content: str,
|
| 568 |
+
file_type: str,
|
| 569 |
+
lines: List[str],
|
| 570 |
+
) -> Tuple[List[ValidationIssue], Dict[str, Any]]:
|
| 571 |
+
"""Check compliance with design spec."""
|
| 572 |
+
issues: List[ValidationIssue] = []
|
| 573 |
+
metrics: Dict[str, Any] = {
|
| 574 |
+
"signals_found": set(),
|
| 575 |
+
"signals_missing": set(),
|
| 576 |
+
"registers_found": set(),
|
| 577 |
+
"registers_missing": set(),
|
| 578 |
+
"signal_coverage": 0.0,
|
| 579 |
+
"register_coverage": 0.0,
|
| 580 |
+
}
|
| 581 |
+
|
| 582 |
+
stripped = self._strip_for_analysis(content)
|
| 583 |
+
|
| 584 |
+
found_signals: Set[str] = set()
|
| 585 |
+
for sig in self.all_signals:
|
| 586 |
+
if re.search(r'\b' + re.escape(sig) + r'\b', stripped, re.IGNORECASE):
|
| 587 |
+
found_signals.add(sig)
|
| 588 |
+
|
| 589 |
+
metrics["signals_found"] = found_signals
|
| 590 |
+
|
| 591 |
+
if file_type in ("interface", "testbench"):
|
| 592 |
+
missing_signals = self.all_signals - found_signals
|
| 593 |
+
metrics["signals_missing"] = missing_signals
|
| 594 |
+
|
| 595 |
+
for sig in sorted(missing_signals):
|
| 596 |
+
direction = self._get_signal_direction(sig)
|
| 597 |
+
width = self.signal_widths.get(sig, 1)
|
| 598 |
+
issues.append(ValidationIssue(
|
| 599 |
+
severity=ValidationSeverity.ERROR,
|
| 600 |
+
code="SPEC-SIGNAL-001",
|
| 601 |
+
message=f"Signal '{sig}' [{direction}, {width}bit] from spec not found in {file_type}",
|
| 602 |
+
suggestion=f"Add signal declaration: {direction} logic {'' if width == 1 else f'[{width-1}:0]'}{sig}",
|
| 603 |
+
auto_fixable=False,
|
| 604 |
+
confidence=0.95,
|
| 605 |
+
))
|
| 606 |
+
|
| 607 |
+
for sig in sorted(found_signals & self.all_signals):
|
| 608 |
+
issues.append(ValidationIssue(
|
| 609 |
+
severity=ValidationSeverity.INFO,
|
| 610 |
+
code="SPEC-SIGNAL-OK",
|
| 611 |
+
message=f"Signal '{sig}' from spec is properly referenced",
|
| 612 |
+
confidence=1.0,
|
| 613 |
+
))
|
| 614 |
+
|
| 615 |
+
if self.all_signals:
|
| 616 |
+
metrics["signal_coverage"] = len(found_signals) / len(self.all_signals)
|
| 617 |
+
|
| 618 |
+
if file_type in ("ral_model", "test", "sequence", "scoreboard", "env"):
|
| 619 |
+
found_registers: Set[str] = set()
|
| 620 |
+
for reg in self.all_registers:
|
| 621 |
+
if re.search(r'\b' + re.escape(reg.lower()) + r'\b', stripped.lower()):
|
| 622 |
+
found_registers.add(reg)
|
| 623 |
+
|
| 624 |
+
metrics["registers_found"] = found_registers
|
| 625 |
+
|
| 626 |
+
if file_type == "ral_model" and self.all_registers:
|
| 627 |
+
missing_regs = self.all_registers - found_registers
|
| 628 |
+
metrics["registers_missing"] = missing_regs
|
| 629 |
+
|
| 630 |
+
for reg in sorted(missing_regs):
|
| 631 |
+
addr = self.register_addresses.get(reg, "unknown")
|
| 632 |
+
access = self.register_access.get(reg, "rw")
|
| 633 |
+
issues.append(ValidationIssue(
|
| 634 |
+
severity=ValidationSeverity.ERROR,
|
| 635 |
+
code="SPEC-REG-001",
|
| 636 |
+
message=f"Register '{reg}' [@0x{addr}, {access}] from spec not found in RAL model",
|
| 637 |
+
suggestion=f"Create uvm_reg class for register '{reg}' with address 0x{addr}",
|
| 638 |
+
auto_fixable=False,
|
| 639 |
+
confidence=0.9,
|
| 640 |
+
))
|
| 641 |
+
|
| 642 |
+
for reg in sorted(found_registers & self.all_registers):
|
| 643 |
+
issues.append(ValidationIssue(
|
| 644 |
+
severity=ValidationSeverity.INFO,
|
| 645 |
+
code="SPEC-REG-OK",
|
| 646 |
+
message=f"Register '{reg}' from spec is properly referenced",
|
| 647 |
+
confidence=1.0,
|
| 648 |
+
))
|
| 649 |
+
|
| 650 |
+
if self.all_registers:
|
| 651 |
+
metrics["register_coverage"] = len(found_registers) / len(self.all_registers)
|
| 652 |
+
|
| 653 |
+
if file_type in ("interface", "testbench"):
|
| 654 |
+
clock_found = re.search(r'\b' + re.escape(self.clock_signal) + r'\b', stripped, re.IGNORECASE)
|
| 655 |
+
reset_found = re.search(r'\b' + re.escape(self.reset_signal) + r'\b', stripped, re.IGNORECASE)
|
| 656 |
+
|
| 657 |
+
if not clock_found:
|
| 658 |
+
issues.append(ValidationIssue(
|
| 659 |
+
severity=ValidationSeverity.ERROR,
|
| 660 |
+
code="SPEC-CLK-001",
|
| 661 |
+
message=f"Clock signal '{self.clock_signal}' from spec not found",
|
| 662 |
+
suggestion=f"Add clock signal: input logic {self.clock_signal}",
|
| 663 |
+
auto_fixable=False,
|
| 664 |
+
confidence=0.95,
|
| 665 |
+
))
|
| 666 |
+
|
| 667 |
+
if not reset_found:
|
| 668 |
+
issues.append(ValidationIssue(
|
| 669 |
+
severity=ValidationSeverity.ERROR,
|
| 670 |
+
code="SPEC-RST-001",
|
| 671 |
+
message=f"Reset signal '{self.reset_signal}' from spec not found",
|
| 672 |
+
suggestion=f"Add reset signal: input logic {self.reset_signal}",
|
| 673 |
+
auto_fixable=False,
|
| 674 |
+
confidence=0.95,
|
| 675 |
+
))
|
| 676 |
+
|
| 677 |
+
if clock_found and reset_found:
|
| 678 |
+
issues.append(ValidationIssue(
|
| 679 |
+
severity=ValidationSeverity.INFO,
|
| 680 |
+
code="SPEC-CLK-RST-OK",
|
| 681 |
+
message=f"Clock '{self.clock_signal}' and reset '{self.reset_signal}' from spec are present",
|
| 682 |
+
confidence=1.0,
|
| 683 |
+
))
|
| 684 |
+
|
| 685 |
+
return issues, metrics
|
| 686 |
+
|
| 687 |
+
def _get_signal_direction(self, signal: str) -> str:
|
| 688 |
+
for direction, signals in self.signals_by_direction.items():
|
| 689 |
+
if signal in signals:
|
| 690 |
+
return direction
|
| 691 |
+
return "unknown"
|
| 692 |
+
|
| 693 |
+
@staticmethod
|
| 694 |
+
def _strip_for_analysis(content: str) -> str:
|
| 695 |
+
result = content
|
| 696 |
+
result = re.sub(r'/\*.*?\*/', ' ', result, flags=re.DOTALL)
|
| 697 |
+
result = re.sub(r'//.*$', ' ', result, flags=re.MULTILINE)
|
| 698 |
+
result = re.sub(r'"[^"]*"', 'STR', result)
|
| 699 |
+
return result
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
class SystemVerilogSyntaxChecker:
|
| 703 |
+
"""Advanced SystemVerilog syntax checking."""
|
| 704 |
+
|
| 705 |
+
PAIR_CHECKS = [
|
| 706 |
+
("module", ["endmodule"]),
|
| 707 |
+
("interface", ["endinterface"]),
|
| 708 |
+
("class", ["endclass"]),
|
| 709 |
+
("function", ["endfunction"]),
|
| 710 |
+
("task", ["endtask"]),
|
| 711 |
+
("case", ["endcase"]),
|
| 712 |
+
("begin", ["end"]),
|
| 713 |
+
("fork", ["join", "join_any", "join_none"]),
|
| 714 |
+
]
|
| 715 |
+
|
| 716 |
+
SV_KEYWORDS = {
|
| 717 |
+
"module", "endmodule", "interface", "endinterface", "class", "endclass",
|
| 718 |
+
"input", "output", "inout", "logic", "reg", "wire", "bit", "int", "integer",
|
| 719 |
+
"always", "initial", "assign", "begin", "end", "case", "endcase", "if", "else",
|
| 720 |
+
"for", "while", "repeat", "forever", "task", "endtask", "function", "endfunction",
|
| 721 |
+
"parameter", "localparam", "defparam", "typedef", "struct", "union", "enum",
|
| 722 |
+
"posedge", "negedge", "or", "and", "not", "default", "none",
|
| 723 |
+
"import", "export", "package", "endpackage", "include", "define",
|
| 724 |
+
"virtual", "rand", "randc", "constraint", "extends", "implements",
|
| 725 |
+
"time", "realtime", "shortint", "longint", "byte", "shortreal", "real",
|
| 726 |
+
"string", "void", "null", "break", "continue", "return", "disable",
|
| 727 |
+
"static", "automatic", "const", "var", "signed", "unsigned",
|
| 728 |
+
}
|
| 729 |
+
|
| 730 |
+
def __init__(self):
|
| 731 |
+
self._patterns = self._compile_patterns()
|
| 732 |
+
|
| 733 |
+
def _compile_patterns(self) -> Dict[str, Pattern]:
|
| 734 |
+
return {
|
| 735 |
+
"comment_single": re.compile(r'//.*$', re.MULTILINE),
|
| 736 |
+
"comment_multi": re.compile(r'/\*.*?\*/', re.DOTALL),
|
| 737 |
+
"string_lit": re.compile(r'"[^"]*"'),
|
| 738 |
+
"module_decl": re.compile(r'\bmodule\s+(\w+)\s*[#(;]'),
|
| 739 |
+
"interface_decl": re.compile(r'\binterface\s+(\w+)\s*[#(;]'),
|
| 740 |
+
"class_decl": re.compile(r'\bclass\s+(\w+)\s*(?:#\s*\(|extends|implements|;|{)'),
|
| 741 |
+
"port_list": re.compile(r'\(([^)]+)\)'),
|
| 742 |
+
"unbalanced_paren": re.compile(r'[()]'),
|
| 743 |
+
"unbalanced_bracket": re.compile(r'[\[\]]'),
|
| 744 |
+
"unbalanced_brace": re.compile(r'[{}]'),
|
| 745 |
+
"semicolon": re.compile(r';\s*$'),
|
| 746 |
+
"time_unit": re.compile(r'`timescale\s+(\d+[munp]?s)/(\d+[munp]?s)'),
|
| 747 |
+
"include_uvm": re.compile(r'`include\s+"uvm_macros\.svh"'),
|
| 748 |
+
"import_uvm": re.compile(r'import\s+uvm_pkg::\*'),
|
| 749 |
+
"uvm_macro": re.compile(r'`uvm_\w+'),
|
| 750 |
+
" timescale_missing": re.compile(r'^module\b|\binterface\b|\bclass\b', re.MULTILINE),
|
| 751 |
+
}
|
| 752 |
+
|
| 753 |
+
def check(self, content: str, lines: List[str]) -> List[ValidationIssue]:
|
| 754 |
+
"""Run comprehensive syntax checks."""
|
| 755 |
+
issues: List[ValidationIssue] = []
|
| 756 |
+
|
| 757 |
+
issues.extend(self._check_compile_ready(content, lines))
|
| 758 |
+
issues.extend(self.check_balance(content))
|
| 759 |
+
issues.extend(self.check_begin_end_pairs(content, lines))
|
| 760 |
+
issues.extend(self.check_semicolons(content, lines))
|
| 761 |
+
issues.extend(self._check_uvm_setup(content, lines))
|
| 762 |
+
|
| 763 |
+
return issues
|
| 764 |
+
|
| 765 |
+
def _check_compile_ready(
|
| 766 |
+
self,
|
| 767 |
+
content: str,
|
| 768 |
+
lines: List[str],
|
| 769 |
+
) -> List[ValidationIssue]:
|
| 770 |
+
"""Check compile-ready attributes."""
|
| 771 |
+
issues: List[ValidationIssue] = []
|
| 772 |
+
|
| 773 |
+
has_timescale = self._patterns["time_unit"].search(content)
|
| 774 |
+
has_module = self._patterns["module_decl"].search(content)
|
| 775 |
+
has_interface = self._patterns["interface_decl"].search(content)
|
| 776 |
+
|
| 777 |
+
if (has_module or has_interface) and not has_timescale:
|
| 778 |
+
issues.append(ValidationIssue(
|
| 779 |
+
severity=ValidationSeverity.WARNING,
|
| 780 |
+
code="SV-SYN-001",
|
| 781 |
+
message="Module/interface without `timescale directive",
|
| 782 |
+
suggestion="Add: `timescale 1ns/1ps at top of file",
|
| 783 |
+
auto_fixable=True,
|
| 784 |
+
confidence=0.8,
|
| 785 |
+
))
|
| 786 |
+
|
| 787 |
+
uvm_macros = self._patterns["uvm_macro"].findall(content)
|
| 788 |
+
if uvm_macros:
|
| 789 |
+
has_include = self._patterns["include_uvm"].search(content)
|
| 790 |
+
has_import = self._patterns["import_uvm"].search(content)
|
| 791 |
+
|
| 792 |
+
if not has_include:
|
| 793 |
+
issues.append(ValidationIssue(
|
| 794 |
+
severity=ValidationSeverity.ERROR,
|
| 795 |
+
code="SV-UVM-001",
|
| 796 |
+
message="UVM macros used but `include \"uvm_macros.svh\" missing",
|
| 797 |
+
suggestion="Add: `include \"uvm_macros.svh\" at top of file",
|
| 798 |
+
auto_fixable=True,
|
| 799 |
+
confidence=0.95,
|
| 800 |
+
))
|
| 801 |
+
|
| 802 |
+
if not has_import:
|
| 803 |
+
issues.append(ValidationIssue(
|
| 804 |
+
severity=ValidationSeverity.WARNING,
|
| 805 |
+
code="SV-UVM-002",
|
| 806 |
+
message="UVM macros used but import uvm_pkg::* missing",
|
| 807 |
+
suggestion="Add: import uvm_pkg::*; after include",
|
| 808 |
+
auto_fixable=True,
|
| 809 |
+
confidence=0.85,
|
| 810 |
+
))
|
| 811 |
+
|
| 812 |
+
return issues
|
| 813 |
+
|
| 814 |
+
def _check_uvm_setup(
|
| 815 |
+
self,
|
| 816 |
+
content: str,
|
| 817 |
+
lines: List[str],
|
| 818 |
+
) -> List[ValidationIssue]:
|
| 819 |
+
"""Check UVM setup completeness."""
|
| 820 |
+
issues: List[ValidationIssue] = []
|
| 821 |
+
|
| 822 |
+
has_include = self._patterns["include_uvm"].search(content)
|
| 823 |
+
has_import = self._patterns["import_uvm"].search(content)
|
| 824 |
+
|
| 825 |
+
if has_include and has_import:
|
| 826 |
+
issues.append(ValidationIssue(
|
| 827 |
+
severity=ValidationSeverity.INFO,
|
| 828 |
+
code="SV-UVM-SETUP-OK",
|
| 829 |
+
message="UVM setup complete (include + import)",
|
| 830 |
+
confidence=1.0,
|
| 831 |
+
))
|
| 832 |
+
|
| 833 |
+
return issues
|
| 834 |
+
|
| 835 |
+
def _strip_comments_and_strings(self, content: str) -> str:
|
| 836 |
+
result = content
|
| 837 |
+
result = self._patterns["comment_multi"].sub(" ", result)
|
| 838 |
+
result = self._patterns["comment_single"].sub(" ", result)
|
| 839 |
+
result = self._patterns["string_lit"].sub("\"STR\"", result)
|
| 840 |
+
return result
|
| 841 |
+
|
| 842 |
+
def check_balance(self, content: str) -> List[ValidationIssue]:
|
| 843 |
+
issues: List[ValidationIssue] = []
|
| 844 |
+
stripped = self._strip_comments_and_strings(content)
|
| 845 |
+
|
| 846 |
+
checks = [
|
| 847 |
+
("()", "parentheses"),
|
| 848 |
+
("[]", "brackets"),
|
| 849 |
+
("{}", "braces"),
|
| 850 |
+
]
|
| 851 |
+
|
| 852 |
+
for pair, name in checks:
|
| 853 |
+
count_open = stripped.count(pair[0])
|
| 854 |
+
count_close = stripped.count(pair[1])
|
| 855 |
+
if count_open != count_close:
|
| 856 |
+
issues.append(ValidationIssue(
|
| 857 |
+
severity=ValidationSeverity.WARNING,
|
| 858 |
+
code=f"SV-SYN-BAL-{name}",
|
| 859 |
+
message=f"Possibly unbalanced {name}: {count_open} '{pair[0]}' vs {count_close} '{pair[1]}'",
|
| 860 |
+
auto_fixable=False,
|
| 861 |
+
confidence=0.7,
|
| 862 |
+
))
|
| 863 |
+
|
| 864 |
+
return issues
|
| 865 |
+
|
| 866 |
+
def check_begin_end_pairs(self, content: str, lines: List[str]) -> List[ValidationIssue]:
|
| 867 |
+
issues: List[ValidationIssue] = []
|
| 868 |
+
stripped = self._strip_comments_and_strings(content)
|
| 869 |
+
stripped_lines = stripped.split('\n')
|
| 870 |
+
|
| 871 |
+
for open_kw, close_kws in self.PAIR_CHECKS:
|
| 872 |
+
close_kws_set = set(close_kws)
|
| 873 |
+
close_kw_display = close_kws[0] if len(close_kws) == 1 else f"{close_kws[0]}/..."
|
| 874 |
+
|
| 875 |
+
stack: List[int] = []
|
| 876 |
+
for line_num, line in enumerate(stripped_lines, 1):
|
| 877 |
+
words = re.findall(r'\b\w+\b', line.lower())
|
| 878 |
+
|
| 879 |
+
for word in words:
|
| 880 |
+
if word == open_kw:
|
| 881 |
+
stack.append(line_num)
|
| 882 |
+
elif word in close_kws_set:
|
| 883 |
+
if stack:
|
| 884 |
+
stack.pop()
|
| 885 |
+
|
| 886 |
+
for line_num in stack:
|
| 887 |
+
issues.append(ValidationIssue(
|
| 888 |
+
severity=ValidationSeverity.WARNING,
|
| 889 |
+
code="SV-SYN-BLOCK",
|
| 890 |
+
message=f"'{open_kw}' at line {line_num} may have no matching '{close_kw_display}'",
|
| 891 |
+
line_number=line_num,
|
| 892 |
+
auto_fixable=False,
|
| 893 |
+
confidence=0.6,
|
| 894 |
+
))
|
| 895 |
+
|
| 896 |
+
return issues
|
| 897 |
+
|
| 898 |
+
def check_semicolons(self, content: str, lines: List[str]) -> List[ValidationIssue]:
|
| 899 |
+
issues: List[ValidationIssue] = []
|
| 900 |
+
|
| 901 |
+
statement_keywords = {
|
| 902 |
+
"logic", "reg", "wire", "bit", "int", "shortint", "longint", "byte",
|
| 903 |
+
"input", "output", "inout", "parameter", "localparam", "typedef",
|
| 904 |
+
"import", "export", "assign", "return", "break", "continue",
|
| 905 |
+
}
|
| 906 |
+
|
| 907 |
+
block_starters = {
|
| 908 |
+
"module", "interface", "class", "function", "task", "case",
|
| 909 |
+
"begin", "fork", "if", "else", "for", "while", "repeat", "forever",
|
| 910 |
+
"package",
|
| 911 |
+
}
|
| 912 |
+
|
| 913 |
+
block_enders = {
|
| 914 |
+
"endmodule", "endinterface", "endclass", "endfunction", "endtask",
|
| 915 |
+
"endcase", "end", "join", "join_any", "join_none", "endpackage",
|
| 916 |
+
}
|
| 917 |
+
|
| 918 |
+
for line_num, line in enumerate(lines, 1):
|
| 919 |
+
stripped = line.strip()
|
| 920 |
+
|
| 921 |
+
if not stripped:
|
| 922 |
+
continue
|
| 923 |
+
if stripped.startswith('//'):
|
| 924 |
+
continue
|
| 925 |
+
if stripped.startswith('`'):
|
| 926 |
+
continue
|
| 927 |
+
|
| 928 |
+
first_word = stripped.split()[0].lower() if stripped.split() else ""
|
| 929 |
+
|
| 930 |
+
if first_word in block_enders:
|
| 931 |
+
continue
|
| 932 |
+
|
| 933 |
+
if first_word in block_starters:
|
| 934 |
+
if stripped.rstrip().endswith((':', 'begin', '{', ';')):
|
| 935 |
+
continue
|
| 936 |
+
|
| 937 |
+
if first_word in statement_keywords:
|
| 938 |
+
if not stripped.rstrip().endswith(';') and not stripped.rstrip().endswith(')'):
|
| 939 |
+
issues.append(ValidationIssue(
|
| 940 |
+
severity=ValidationSeverity.WARNING,
|
| 941 |
+
code="SV-SYN-SEMICOLON",
|
| 942 |
+
message="Possible missing semicolon",
|
| 943 |
+
line_number=line_num,
|
| 944 |
+
context=stripped[:60],
|
| 945 |
+
suggestion="Add ';' at end of statement",
|
| 946 |
+
auto_fixable=True,
|
| 947 |
+
confidence=0.6,
|
| 948 |
+
))
|
| 949 |
+
|
| 950 |
+
return issues
|
| 951 |
+
|
| 952 |
+
|
| 953 |
+
class CoverageCompletenessChecker:
|
| 954 |
+
"""Check coverage model completeness."""
|
| 955 |
+
|
| 956 |
+
def check_coverage(
|
| 957 |
+
self,
|
| 958 |
+
content: str,
|
| 959 |
+
spec_dict: Dict[str, Any],
|
| 960 |
+
file_type: str,
|
| 961 |
+
) -> List[ValidationIssue]:
|
| 962 |
+
"""Check coverage model completeness."""
|
| 963 |
+
issues: List[ValidationIssue] = []
|
| 964 |
+
|
| 965 |
+
if file_type not in ("coverage", "coverage_collector"):
|
| 966 |
+
return issues
|
| 967 |
+
|
| 968 |
+
registers = spec_dict.get("registers", [])
|
| 969 |
+
register_names = [r.get("name", "") for r in registers if r.get("name")]
|
| 970 |
+
|
| 971 |
+
covergroups = re.findall(r'\bcovergroup\s+(\w+)', content)
|
| 972 |
+
coverpoints = re.findall(r'\bcoverpoint\s+(\w+)', content)
|
| 973 |
+
crosses = re.findall(r'\bcross\s+(\w+(?:\s*,\s*\w+)*)', content)
|
| 974 |
+
|
| 975 |
+
if not covergroups:
|
| 976 |
+
issues.append(ValidationIssue(
|
| 977 |
+
severity=ValidationSeverity.WARNING,
|
| 978 |
+
code="COV-001",
|
| 979 |
+
message="No covergroups found in coverage collector",
|
| 980 |
+
suggestion="Define covergroups for register accesses, protocol operations",
|
| 981 |
+
confidence=0.7,
|
| 982 |
+
))
|
| 983 |
+
else:
|
| 984 |
+
issues.append(ValidationIssue(
|
| 985 |
+
severity=ValidationSeverity.INFO,
|
| 986 |
+
code="COV-002",
|
| 987 |
+
message=f"Found {len(covergroups)} covergroup(s): {', '.join(covergroups)}",
|
| 988 |
+
confidence=1.0,
|
| 989 |
+
))
|
| 990 |
+
|
| 991 |
+
if coverpoints:
|
| 992 |
+
issues.append(ValidationIssue(
|
| 993 |
+
severity=ValidationSeverity.INFO,
|
| 994 |
+
code="COV-003",
|
| 995 |
+
message=f"Found {len(coverpoints)} coverpoint(s)",
|
| 996 |
+
confidence=1.0,
|
| 997 |
+
))
|
| 998 |
+
|
| 999 |
+
if crosses:
|
| 1000 |
+
issues.append(ValidationIssue(
|
| 1001 |
+
severity=ValidationSeverity.INFO,
|
| 1002 |
+
code="COV-004",
|
| 1003 |
+
message=f"Found {len(crosses)} cross coverage(s)",
|
| 1004 |
+
confidence=1.0,
|
| 1005 |
+
))
|
| 1006 |
+
|
| 1007 |
+
sample_calls = re.findall(r'\b(\w+)\s*\.\s*sample\s*\(', content)
|
| 1008 |
+
if sample_calls:
|
| 1009 |
+
issues.append(ValidationIssue(
|
| 1010 |
+
severity=ValidationSeverity.INFO,
|
| 1011 |
+
code="COV-005",
|
| 1012 |
+
message=f"Found sample() calls for: {', '.join(set(sample_calls))}",
|
| 1013 |
+
confidence=1.0,
|
| 1014 |
+
))
|
| 1015 |
+
|
| 1016 |
+
return issues
|
| 1017 |
+
|
| 1018 |
+
|
| 1019 |
+
class TLMConnectionChecker:
|
| 1020 |
+
"""Check TLM connection completeness."""
|
| 1021 |
+
|
| 1022 |
+
def check_tlm_connections(
|
| 1023 |
+
self,
|
| 1024 |
+
content: str,
|
| 1025 |
+
file_type: str,
|
| 1026 |
+
) -> List[ValidationIssue]:
|
| 1027 |
+
"""Check TLM connections in env/scoreboard."""
|
| 1028 |
+
issues: List[ValidationIssue] = []
|
| 1029 |
+
|
| 1030 |
+
if file_type not in ("env", "scoreboard"):
|
| 1031 |
+
return issues
|
| 1032 |
+
|
| 1033 |
+
analysis_ports = re.findall(
|
| 1034 |
+
r'\buvm_analysis_port\s*#\s*<\s*(\w+)\s*>\s*(\w+)',
|
| 1035 |
+
content
|
| 1036 |
+
)
|
| 1037 |
+
analysis_imps = re.findall(
|
| 1038 |
+
r'\buvm_analysis_imp\s*#\s*<\s*(\w+)\s*,\s*(\w+)\s*>\s*(\w+)',
|
| 1039 |
+
content
|
| 1040 |
+
)
|
| 1041 |
+
tlms = re.findall(
|
| 1042 |
+
r'\buvm_tlm_(analysis_)?fifo\s*#\s*<\s*(\w+)\s*>\s*(\w+)',
|
| 1043 |
+
content
|
| 1044 |
+
)
|
| 1045 |
+
|
| 1046 |
+
connects = re.findall(
|
| 1047 |
+
r'\b(\w+)\s*\.\s*connect\s*\(\s*(\w+)\s*\)',
|
| 1048 |
+
content
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
port_names = [p[1] for p in analysis_ports]
|
| 1052 |
+
imp_names = [i[2] for i in analysis_imps]
|
| 1053 |
+
tlm_names = [t[2] for t in tlms]
|
| 1054 |
+
|
| 1055 |
+
all_tlms = port_names + imp_names + tlm_names
|
| 1056 |
+
|
| 1057 |
+
connected = set()
|
| 1058 |
+
for from_port, to_port in connects:
|
| 1059 |
+
connected.add(from_port)
|
| 1060 |
+
connected.add(to_port)
|
| 1061 |
+
|
| 1062 |
+
unconnected = set(all_tlms) - connected
|
| 1063 |
+
|
| 1064 |
+
if all_tlms:
|
| 1065 |
+
issues.append(ValidationIssue(
|
| 1066 |
+
severity=ValidationSeverity.INFO,
|
| 1067 |
+
code="TLM-001",
|
| 1068 |
+
message=f"Found {len(all_tlms)} TLM port(s)/FIFO(s)",
|
| 1069 |
+
confidence=1.0,
|
| 1070 |
+
))
|
| 1071 |
+
|
| 1072 |
+
if unconnected:
|
| 1073 |
+
issues.append(ValidationIssue(
|
| 1074 |
+
severity=ValidationSeverity.WARNING,
|
| 1075 |
+
code="TLM-002",
|
| 1076 |
+
message=f"TLM ports may not be connected: {', '.join(sorted(unconnected))}",
|
| 1077 |
+
suggestion=f"Connect these ports in connect_phase using .connect()",
|
| 1078 |
+
confidence=0.7,
|
| 1079 |
+
))
|
| 1080 |
+
|
| 1081 |
+
if connected and not unconnected:
|
| 1082 |
+
issues.append(ValidationIssue(
|
| 1083 |
+
severity=ValidationSeverity.INFO,
|
| 1084 |
+
code="TLM-003",
|
| 1085 |
+
message=f"All {len(connected)} TLM ports appear to be connected",
|
| 1086 |
+
confidence=0.8,
|
| 1087 |
+
))
|
| 1088 |
+
|
| 1089 |
+
return issues
|
| 1090 |
+
|
| 1091 |
+
|
| 1092 |
+
class AdvancedCodeValidator:
|
| 1093 |
+
"""
|
| 1094 |
+
Advanced code validator combining all checkers.
|
| 1095 |
+
|
| 1096 |
+
This is the main interface for:
|
| 1097 |
+
1. Deep UVM compliance checking
|
| 1098 |
+
2. Spec compliance validation
|
| 1099 |
+
3. SystemVerilog syntax checking
|
| 1100 |
+
4. Coverage completeness checking
|
| 1101 |
+
5. TLM connection validation
|
| 1102 |
+
"""
|
| 1103 |
+
|
| 1104 |
+
FILE_TYPE_DETECTORS = [
|
| 1105 |
+
(r'ral_model', "ral_model"),
|
| 1106 |
+
(r'scoreboard', "scoreboard"),
|
| 1107 |
+
(r'driver', "driver"),
|
| 1108 |
+
(r'monitor', "monitor"),
|
| 1109 |
+
(r'agent', "agent"),
|
| 1110 |
+
(r'sequence_item', "sequence_item"),
|
| 1111 |
+
(r'_sequence', "sequence"),
|
| 1112 |
+
(r'regression', "sequence"),
|
| 1113 |
+
(r'coverage_collector', "coverage"),
|
| 1114 |
+
(r'protocol_checker', "checker"),
|
| 1115 |
+
(r'_test', "test"),
|
| 1116 |
+
(r'environment|env_', "env"),
|
| 1117 |
+
(r'testbench', "testbench"),
|
| 1118 |
+
(r'interface', "interface"),
|
| 1119 |
+
(r'serial_monitor', "monitor"),
|
| 1120 |
+
]
|
| 1121 |
+
|
| 1122 |
+
NON_SV_EXTENSIONS = {'.f', '.tcl', '.core', '.json', '.yaml', '.yml', '.md', '.txt'}
|
| 1123 |
+
|
| 1124 |
+
def __init__(self, spec_dict: Optional[Dict[str, Any]] = None):
|
| 1125 |
+
self.spec_dict = spec_dict
|
| 1126 |
+
self._syntax_checker = SystemVerilogSyntaxChecker()
|
| 1127 |
+
self._spec_checker = SpecComplianceChecker(spec_dict) if spec_dict else None
|
| 1128 |
+
self._uvm_checker = UVMComplianceChecker()
|
| 1129 |
+
self._coverage_checker = CoverageCompletenessChecker()
|
| 1130 |
+
self._tlm_checker = TLMConnectionChecker()
|
| 1131 |
+
|
| 1132 |
+
@classmethod
|
| 1133 |
+
def _is_sv_file(cls, filename: str) -> bool:
|
| 1134 |
+
fname_lower = filename.lower()
|
| 1135 |
+
for ext in cls.NON_SV_EXTENSIONS:
|
| 1136 |
+
if fname_lower.endswith(ext):
|
| 1137 |
+
return False
|
| 1138 |
+
if fname_lower.endswith(('.sv', '.v', '.svh', '.vh')):
|
| 1139 |
+
return True
|
| 1140 |
+
if '/' in fname_lower or '\\' in fname_lower:
|
| 1141 |
+
base = fname_lower.replace('\\', '/').split('/')[-1]
|
| 1142 |
+
if '.' not in base:
|
| 1143 |
+
return True
|
| 1144 |
+
return True
|
| 1145 |
+
|
| 1146 |
+
@classmethod
|
| 1147 |
+
def detect_file_type(cls, filename: str) -> str:
|
| 1148 |
+
fname_lower = filename.lower()
|
| 1149 |
+
for pattern, file_type in cls.FILE_TYPE_DETECTORS:
|
| 1150 |
+
if re.search(pattern, fname_lower):
|
| 1151 |
+
return file_type
|
| 1152 |
+
return "unknown"
|
| 1153 |
+
|
| 1154 |
+
def _calculate_score(
|
| 1155 |
+
self,
|
| 1156 |
+
issues: List[ValidationIssue],
|
| 1157 |
+
spec_metrics: Optional[Dict[str, Any]],
|
| 1158 |
+
checks_run: int,
|
| 1159 |
+
) -> float:
|
| 1160 |
+
"""Calculate a quality score (0.0 to 1.0)."""
|
| 1161 |
+
error_count = sum(1 for i in issues if i.severity == ValidationSeverity.ERROR)
|
| 1162 |
+
warning_count = sum(1 for i in issues if i.severity == ValidationSeverity.WARNING)
|
| 1163 |
+
info_count = sum(1 for i in issues if i.severity == ValidationSeverity.INFO)
|
| 1164 |
+
|
| 1165 |
+
base_score = 1.0
|
| 1166 |
+
base_score -= error_count * 0.15
|
| 1167 |
+
base_score -= warning_count * 0.05
|
| 1168 |
+
|
| 1169 |
+
if spec_metrics:
|
| 1170 |
+
signal_cov = spec_metrics.get("signal_coverage", 0.0)
|
| 1171 |
+
reg_cov = spec_metrics.get("register_coverage", 0.0)
|
| 1172 |
+
base_score += signal_cov * 0.1
|
| 1173 |
+
base_score += reg_cov * 0.1
|
| 1174 |
+
|
| 1175 |
+
return max(0.0, min(1.0, base_score))
|
| 1176 |
+
|
| 1177 |
+
def validate_file(
|
| 1178 |
+
self,
|
| 1179 |
+
filename: str,
|
| 1180 |
+
content: str,
|
| 1181 |
+
file_type: Optional[str] = None,
|
| 1182 |
+
) -> FileValidationResult:
|
| 1183 |
+
"""Validate a single file with all checkers."""
|
| 1184 |
+
if not self._is_sv_file(filename):
|
| 1185 |
+
return FileValidationResult(
|
| 1186 |
+
filename=filename,
|
| 1187 |
+
file_type="skipped",
|
| 1188 |
+
passed=True,
|
| 1189 |
+
issues=[],
|
| 1190 |
+
checks_run=0,
|
| 1191 |
+
checks_passed=0,
|
| 1192 |
+
score=1.0,
|
| 1193 |
+
)
|
| 1194 |
+
|
| 1195 |
+
if file_type is None:
|
| 1196 |
+
file_type = self.detect_file_type(filename)
|
| 1197 |
+
|
| 1198 |
+
lines = content.split('\n')
|
| 1199 |
+
|
| 1200 |
+
issues: List[ValidationIssue] = []
|
| 1201 |
+
checks_run = 0
|
| 1202 |
+
checks_passed = 0
|
| 1203 |
+
spec_metrics: Dict[str, Any] = {}
|
| 1204 |
+
|
| 1205 |
+
syntax_issues = self._syntax_checker.check(content, lines)
|
| 1206 |
+
issues.extend(syntax_issues)
|
| 1207 |
+
checks_run += 4
|
| 1208 |
+
syntax_errors = sum(1 for i in syntax_issues if i.severity == ValidationSeverity.ERROR)
|
| 1209 |
+
checks_passed += max(0, 4 - syntax_errors)
|
| 1210 |
+
|
| 1211 |
+
if self._spec_checker:
|
| 1212 |
+
spec_issues, spec_metrics = self._spec_checker.check_spec_compliance(
|
| 1213 |
+
content, file_type, lines
|
| 1214 |
+
)
|
| 1215 |
+
issues.extend(spec_issues)
|
| 1216 |
+
checks_run += 3
|
| 1217 |
+
spec_errors = sum(1 for i in spec_issues if i.severity == ValidationSeverity.ERROR)
|
| 1218 |
+
checks_passed += max(0, 3 - spec_errors)
|
| 1219 |
+
|
| 1220 |
+
uvm_issues = self._uvm_checker.check_uvm_compliance(
|
| 1221 |
+
content, file_type, lines
|
| 1222 |
+
)
|
| 1223 |
+
issues.extend(uvm_issues)
|
| 1224 |
+
checks_run += 3
|
| 1225 |
+
uvm_errors = sum(1 for i in uvm_issues if i.severity == ValidationSeverity.ERROR)
|
| 1226 |
+
checks_passed += max(0, 3 - uvm_errors)
|
| 1227 |
+
|
| 1228 |
+
cov_issues = self._coverage_checker.check_coverage(
|
| 1229 |
+
content, self.spec_dict or {}, file_type
|
| 1230 |
+
)
|
| 1231 |
+
issues.extend(cov_issues)
|
| 1232 |
+
checks_run += 1
|
| 1233 |
+
|
| 1234 |
+
tlm_issues = self._tlm_checker.check_tlm_connections(content, file_type)
|
| 1235 |
+
issues.extend(tlm_issues)
|
| 1236 |
+
checks_run += 1
|
| 1237 |
+
|
| 1238 |
+
errors = sum(1 for i in issues if i.severity == ValidationSeverity.ERROR)
|
| 1239 |
+
passed = errors == 0
|
| 1240 |
+
|
| 1241 |
+
score = self._calculate_score(issues, spec_metrics, checks_run)
|
| 1242 |
+
|
| 1243 |
+
return FileValidationResult(
|
| 1244 |
+
filename=filename,
|
| 1245 |
+
file_type=file_type,
|
| 1246 |
+
passed=passed,
|
| 1247 |
+
issues=issues,
|
| 1248 |
+
checks_run=checks_run,
|
| 1249 |
+
checks_passed=checks_passed,
|
| 1250 |
+
score=score,
|
| 1251 |
+
)
|
| 1252 |
+
|
| 1253 |
+
def validate_files(
|
| 1254 |
+
self,
|
| 1255 |
+
files: Dict[str, str],
|
| 1256 |
+
design_name: str = "",
|
| 1257 |
+
) -> ValidationReport:
|
| 1258 |
+
"""Validate multiple files."""
|
| 1259 |
+
file_results: List[FileValidationResult] = []
|
| 1260 |
+
|
| 1261 |
+
for filename, content in files.items():
|
| 1262 |
+
result = self.validate_file(filename, content)
|
| 1263 |
+
file_results.append(result)
|
| 1264 |
+
|
| 1265 |
+
total_errors = sum(f.error_count for f in file_results)
|
| 1266 |
+
overall_passed = total_errors == 0
|
| 1267 |
+
|
| 1268 |
+
import datetime
|
| 1269 |
+
report = ValidationReport(
|
| 1270 |
+
design_name=design_name,
|
| 1271 |
+
overall_passed=overall_passed,
|
| 1272 |
+
files=file_results,
|
| 1273 |
+
timestamp=datetime.datetime.now().isoformat(),
|
| 1274 |
+
)
|
| 1275 |
+
|
| 1276 |
+
return report
|
| 1277 |
+
|
| 1278 |
+
def validate_files_by_path(
|
| 1279 |
+
self,
|
| 1280 |
+
file_paths: Dict[str, str],
|
| 1281 |
+
design_name: str = "",
|
| 1282 |
+
) -> ValidationReport:
|
| 1283 |
+
"""Validate files by path."""
|
| 1284 |
+
content_map: Dict[str, str] = {}
|
| 1285 |
+
|
| 1286 |
+
for filename, path in file_paths.items():
|
| 1287 |
+
try:
|
| 1288 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 1289 |
+
content_map[filename] = f.read()
|
| 1290 |
+
except Exception as e:
|
| 1291 |
+
logger.warning("Failed to read %s: %s", path, e)
|
| 1292 |
+
content_map[filename] = ""
|
| 1293 |
+
|
| 1294 |
+
return self.validate_files(content_map, design_name)
|
|
@@ -0,0 +1,926 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Advanced Pattern Learner for UVM Testbench Generation.
|
| 3 |
+
|
| 4 |
+
Key improvements for promotion:
|
| 5 |
+
1. Context-aware error pattern extraction with n-grams
|
| 6 |
+
2. Success pattern mining from successful generations
|
| 7 |
+
3. Association rule learning between spec features and success
|
| 8 |
+
4. Protocol-specific pattern libraries
|
| 9 |
+
5. Error correlation detection
|
| 10 |
+
6. Pattern-based code suggestions
|
| 11 |
+
7. Temporal pattern tracking (learning over time)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
import re
|
| 18 |
+
import math
|
| 19 |
+
from collections import defaultdict, Counter
|
| 20 |
+
from dataclasses import dataclass, field
|
| 21 |
+
from typing import Dict, List, Any, Optional, Tuple, Set
|
| 22 |
+
from enum import Enum
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger("uvmgen.ml.patterns")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class PatternType(Enum):
|
| 28 |
+
ERROR = "error"
|
| 29 |
+
SUCCESS = "success"
|
| 30 |
+
WARNING = "warning"
|
| 31 |
+
STRUCTURAL = "structural"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class Pattern:
|
| 36 |
+
pattern_str: str
|
| 37 |
+
pattern_type: PatternType
|
| 38 |
+
count: int = 0
|
| 39 |
+
confidence: float = 0.0
|
| 40 |
+
support: float = 0.0
|
| 41 |
+
lift: float = 1.0
|
| 42 |
+
contexts: List[str] = field(default_factory=list)
|
| 43 |
+
file_types: List[str] = field(default_factory=list)
|
| 44 |
+
protocols: List[str] = field(default_factory=list)
|
| 45 |
+
auto_fix: Optional[str] = None
|
| 46 |
+
description: str = ""
|
| 47 |
+
|
| 48 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 49 |
+
return {
|
| 50 |
+
"pattern_str": self.pattern_str,
|
| 51 |
+
"pattern_type": self.pattern_type.value,
|
| 52 |
+
"count": self.count,
|
| 53 |
+
"confidence": self.confidence,
|
| 54 |
+
"support": self.support,
|
| 55 |
+
"lift": self.lift,
|
| 56 |
+
"contexts": self.contexts,
|
| 57 |
+
"file_types": self.file_types,
|
| 58 |
+
"protocols": self.protocols,
|
| 59 |
+
"auto_fix": self.auto_fix,
|
| 60 |
+
"description": self.description,
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@dataclass
|
| 65 |
+
class AssociationRule:
|
| 66 |
+
antecedent: str
|
| 67 |
+
consequent: str
|
| 68 |
+
confidence: float
|
| 69 |
+
support: float
|
| 70 |
+
lift: float
|
| 71 |
+
count: int = 0
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class NgramExtractor:
|
| 75 |
+
"""Extract n-grams from code and error messages for pattern learning."""
|
| 76 |
+
|
| 77 |
+
def __init__(self, n_min: int = 1, n_max: int = 4):
|
| 78 |
+
self.n_min = n_min
|
| 79 |
+
self.n_max = n_max
|
| 80 |
+
|
| 81 |
+
def extract(self, text: str, file_type: str = "unknown") -> List[str]:
|
| 82 |
+
"""Extract meaningful n-grams from text."""
|
| 83 |
+
clean_text = self._preprocess(text)
|
| 84 |
+
tokens = self._tokenize(clean_text)
|
| 85 |
+
|
| 86 |
+
if not tokens:
|
| 87 |
+
return []
|
| 88 |
+
|
| 89 |
+
ngrams = []
|
| 90 |
+
for n in range(self.n_min, self.n_max + 1):
|
| 91 |
+
for i in range(len(tokens) - n + 1):
|
| 92 |
+
ngram = " ".join(tokens[i:i + n])
|
| 93 |
+
if self._is_meaningful(ngram, file_type):
|
| 94 |
+
ngrams.append(ngram)
|
| 95 |
+
|
| 96 |
+
return ngrams
|
| 97 |
+
|
| 98 |
+
def _preprocess(self, text: str) -> str:
|
| 99 |
+
"""Preprocess text for tokenization."""
|
| 100 |
+
text = re.sub(r'//.*$', ' ', text, flags=re.MULTILINE)
|
| 101 |
+
text = re.sub(r'/\*.*?\*/', ' ', text, flags=re.DOTALL)
|
| 102 |
+
text = re.sub(r'"[^"]*"', 'STR', text)
|
| 103 |
+
text = text.replace('(', ' ( ').replace(')', ' ) ')
|
| 104 |
+
text = text.replace('[', ' [ ').replace(']', ' ] ')
|
| 105 |
+
text = text.replace('{', ' { ').replace('}', ' } ')
|
| 106 |
+
text = text.replace(';', ' ; ')
|
| 107 |
+
text = text.replace(',', ' , ')
|
| 108 |
+
return text
|
| 109 |
+
|
| 110 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 111 |
+
"""Tokenize into meaningful units."""
|
| 112 |
+
tokens = re.findall(r'[a-zA-Z_][a-zA-Z0-9_]*|[0-9]+|==|!=|<=|>=|\+=|-=|\*=|/=|&&|\|\||[+\-*/%=<>!&|~^?:;,\(\)\[\]\{\}]', text)
|
| 113 |
+
return [t.strip() for t in tokens if t.strip()]
|
| 114 |
+
|
| 115 |
+
def _is_meaningful(self, ngram: str, file_type: str) -> bool:
|
| 116 |
+
"""Filter to keep only meaningful ngrams."""
|
| 117 |
+
if len(ngram) < 3:
|
| 118 |
+
return False
|
| 119 |
+
|
| 120 |
+
stop_patterns = {
|
| 121 |
+
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to',
|
| 122 |
+
'for', 'of', 'with', 'by', 'is', 'was', 'are', 'were', 'be',
|
| 123 |
+
'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did',
|
| 124 |
+
'will', 'would', 'could', 'should', 'may', 'might', 'must',
|
| 125 |
+
'shall', 'can', 'need', 'dare', 'ought', 'used',
|
| 126 |
+
'if', 'else', 'then', 'for', 'while', 'until', 'unless',
|
| 127 |
+
'begin', 'end', 'module', 'endmodule', 'class', 'endclass',
|
| 128 |
+
'input', 'output', 'logic', 'reg', 'wire', 'bit', 'int',
|
| 129 |
+
'always', 'initial', 'assign', 'posedge', 'negedge',
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
words = ngram.lower().split()
|
| 133 |
+
if all(w in stop_patterns for w in words):
|
| 134 |
+
return False
|
| 135 |
+
|
| 136 |
+
uvm_keywords = {'uvm', 'test', 'env', 'agent', 'driver', 'monitor',
|
| 137 |
+
'sequencer', 'sequence', 'scoreboard', 'register',
|
| 138 |
+
'reg', 'phase', 'objection', 'config_db'}
|
| 139 |
+
if any(kw in ngram.lower() for kw in uvm_keywords):
|
| 140 |
+
return True
|
| 141 |
+
|
| 142 |
+
if file_type in ('sequence', 'test'):
|
| 143 |
+
seq_keywords = {'start_item', 'finish_item', 'raise_objection',
|
| 144 |
+
'drop_objection', 'randomize', 'body'}
|
| 145 |
+
if any(kw in ngram for kw in seq_keywords):
|
| 146 |
+
return True
|
| 147 |
+
|
| 148 |
+
if len(words) >= 2:
|
| 149 |
+
return True
|
| 150 |
+
|
| 151 |
+
return len(ngram) > 5
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class ContextAwareErrorDetector:
|
| 155 |
+
"""Detect errors with context for better pattern learning."""
|
| 156 |
+
|
| 157 |
+
ERROR_PATTERNS_WITH_CONTEXT = [
|
| 158 |
+
(
|
| 159 |
+
r'missing\s+.*semicolon',
|
| 160 |
+
'missing_semicolon',
|
| 161 |
+
'Ensure all statements end with semicolons',
|
| 162 |
+
'Check lines ending with expressions or declarations'
|
| 163 |
+
),
|
| 164 |
+
(
|
| 165 |
+
r'unbalanced\s+.*parenthes',
|
| 166 |
+
'unbalanced_parentheses',
|
| 167 |
+
'Check for balanced parentheses',
|
| 168 |
+
'Count opening and closing parentheses in complex expressions'
|
| 169 |
+
),
|
| 170 |
+
(
|
| 171 |
+
r'unbalanced\s+.*brace',
|
| 172 |
+
'unbalanced_braces',
|
| 173 |
+
'Check for balanced begin/end blocks',
|
| 174 |
+
'Verify all begin/fork have matching end/join'
|
| 175 |
+
),
|
| 176 |
+
(
|
| 177 |
+
r'unbalanced\s+.*bracket',
|
| 178 |
+
'unbalanced_brackets',
|
| 179 |
+
'Check array indexing and part-selects',
|
| 180 |
+
'Verify all [ have matching ]'
|
| 181 |
+
),
|
| 182 |
+
(
|
| 183 |
+
r'mismatch.*begin|begin.*without.*end',
|
| 184 |
+
'mismatched_blocks',
|
| 185 |
+
'Verify block structure',
|
| 186 |
+
'Check begin/end, fork/join pairing'
|
| 187 |
+
),
|
| 188 |
+
(
|
| 189 |
+
r'uvm_fatal|uvm_error.*not.*found',
|
| 190 |
+
'missing_uvm_import',
|
| 191 |
+
'Import UVM package',
|
| 192 |
+
'Add `include "uvm_macros.svh" and import uvm_pkg::*'
|
| 193 |
+
),
|
| 194 |
+
(
|
| 195 |
+
r'uvm_component_utils|uvm_object_utils.*missing',
|
| 196 |
+
'missing_factory_macro',
|
| 197 |
+
'Add UVM factory registration',
|
| 198 |
+
'Use `uvm_component_utils for components, `uvm_object_utils for objects'
|
| 199 |
+
),
|
| 200 |
+
(
|
| 201 |
+
r'build_phase|connect_phase|run_phase.*not.*called',
|
| 202 |
+
'phase_implementation',
|
| 203 |
+
'Check phase method signatures',
|
| 204 |
+
'Ensure phases are declared as virtual functions/tasks with correct signatures'
|
| 205 |
+
),
|
| 206 |
+
(
|
| 207 |
+
r'raise_objection|drop_objection.*missing',
|
| 208 |
+
'missing_objection',
|
| 209 |
+
'Add objection handling in tests/sequences',
|
| 210 |
+
'Use phase.raise_objection(this) and phase.drop_objection(this) in run_phase'
|
| 211 |
+
),
|
| 212 |
+
(
|
| 213 |
+
r'config_db.*get.*failed|config_db.*set.*missing',
|
| 214 |
+
'config_db_issue',
|
| 215 |
+
'Check config_db usage',
|
| 216 |
+
'Ensure set/get paths match and config_db is set before build_phase'
|
| 217 |
+
),
|
| 218 |
+
(
|
| 219 |
+
r'reg_model.*null|reg_model.*not.*initialized',
|
| 220 |
+
'missing_ral_model',
|
| 221 |
+
'Initialize RAL model in test',
|
| 222 |
+
'Create and build reg_model in test::build_phase, set in config_db'
|
| 223 |
+
),
|
| 224 |
+
(
|
| 225 |
+
r'signal.*not.*declared|signal.*undefined',
|
| 226 |
+
'undefined_signal',
|
| 227 |
+
'Check signal declarations',
|
| 228 |
+
'Ensure all signals used are declared in the interface/module'
|
| 229 |
+
),
|
| 230 |
+
(
|
| 231 |
+
r'port.*not.*connected|port.*missing',
|
| 232 |
+
'port_connection',
|
| 233 |
+
'Check port connections',
|
| 234 |
+
'Verify all module ports are connected in testbench'
|
| 235 |
+
),
|
| 236 |
+
(
|
| 237 |
+
r'interface.*not.*set|vif.*null',
|
| 238 |
+
'missing_vif',
|
| 239 |
+
'Set virtual interface in config_db',
|
| 240 |
+
'Call uvm_config_db#(virtual intf)::set in testbench before run_test()'
|
| 241 |
+
),
|
| 242 |
+
(
|
| 243 |
+
r'sequence.*not.*started|sequencer.*null',
|
| 244 |
+
'sequence_start',
|
| 245 |
+
'Check sequence starting',
|
| 246 |
+
'Ensure seq.start(sequencer) is called with valid sequencer'
|
| 247 |
+
),
|
| 248 |
+
(
|
| 249 |
+
r'analysis_port.*not.*connected|analysis_export.*null',
|
| 250 |
+
'analysis_connection',
|
| 251 |
+
'Check TLM connections',
|
| 252 |
+
'Connect analysis ports to exports in connect_phase'
|
| 253 |
+
),
|
| 254 |
+
]
|
| 255 |
+
|
| 256 |
+
@classmethod
|
| 257 |
+
def extract_with_context(
|
| 258 |
+
cls,
|
| 259 |
+
error_msg: str,
|
| 260 |
+
content: Optional[str] = None,
|
| 261 |
+
line_num: Optional[int] = None,
|
| 262 |
+
) -> List[Dict[str, Any]]:
|
| 263 |
+
"""Extract error patterns with contextual information."""
|
| 264 |
+
results = []
|
| 265 |
+
|
| 266 |
+
for pattern, error_type, suggestion, context_tip in cls.ERROR_PATTERNS_WITH_CONTEXT:
|
| 267 |
+
if re.search(pattern, error_msg, re.IGNORECASE):
|
| 268 |
+
result = {
|
| 269 |
+
'error_type': error_type,
|
| 270 |
+
'pattern': pattern,
|
| 271 |
+
'message': error_msg[:200] if len(error_msg) > 200 else error_msg,
|
| 272 |
+
'suggestion': suggestion,
|
| 273 |
+
'context_tip': context_tip,
|
| 274 |
+
'line_number': line_num,
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
if content and line_num:
|
| 278 |
+
result['context'] = cls._get_content_context(content, line_num)
|
| 279 |
+
|
| 280 |
+
results.append(result)
|
| 281 |
+
|
| 282 |
+
if not results:
|
| 283 |
+
results.append({
|
| 284 |
+
'error_type': 'unknown_error',
|
| 285 |
+
'message': error_msg[:200] if len(error_msg) > 200 else error_msg,
|
| 286 |
+
'suggestion': 'Review the error message details',
|
| 287 |
+
'line_number': line_num,
|
| 288 |
+
})
|
| 289 |
+
|
| 290 |
+
return results
|
| 291 |
+
|
| 292 |
+
@staticmethod
|
| 293 |
+
def _get_content_context(content: str, line_num: int, context_lines: int = 3) -> str:
|
| 294 |
+
"""Get surrounding lines of content for context."""
|
| 295 |
+
lines = content.split('\n')
|
| 296 |
+
start = max(0, line_num - context_lines - 1)
|
| 297 |
+
end = min(len(lines), line_num + context_lines)
|
| 298 |
+
|
| 299 |
+
context_lines = []
|
| 300 |
+
for i in range(start, end):
|
| 301 |
+
marker = '>> ' if i == line_num - 1 else ' '
|
| 302 |
+
context_lines.append(f"{marker}{i+1:4d}: {lines[i]}")
|
| 303 |
+
|
| 304 |
+
return '\n'.join(context_lines)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class SuccessPatternMiner:
|
| 308 |
+
"""Mine patterns from successful generations for reuse."""
|
| 309 |
+
|
| 310 |
+
def __init__(self):
|
| 311 |
+
self._success_patterns: Dict[str, Pattern] = {}
|
| 312 |
+
self._file_type_patterns: Dict[str, Dict[str, int]] = defaultdict(dict)
|
| 313 |
+
self._protocol_patterns: Dict[str, Dict[str, int]] = defaultdict(dict)
|
| 314 |
+
self._total_successes: int = 0
|
| 315 |
+
|
| 316 |
+
def mine_from_success(
|
| 317 |
+
self,
|
| 318 |
+
content: str,
|
| 319 |
+
file_type: str,
|
| 320 |
+
protocol: str,
|
| 321 |
+
score: float,
|
| 322 |
+
) -> List[str]:
|
| 323 |
+
"""Mine successful patterns from high-quality generated code."""
|
| 324 |
+
if score < 0.7:
|
| 325 |
+
return []
|
| 326 |
+
|
| 327 |
+
extractor = NgramExtractor(n_min=2, n_max=5)
|
| 328 |
+
ngrams = extractor.extract(content, file_type)
|
| 329 |
+
|
| 330 |
+
filtered = self._filter_success_patterns(ngrams, file_type)
|
| 331 |
+
|
| 332 |
+
for ngram in filtered:
|
| 333 |
+
self._record_success_pattern(ngram, file_type, protocol, score)
|
| 334 |
+
|
| 335 |
+
self._total_successes += 1
|
| 336 |
+
return filtered
|
| 337 |
+
|
| 338 |
+
def _filter_success_patterns(self, ngrams: List[str], file_type: str) -> List[str]:
|
| 339 |
+
"""Filter to keep only meaningful success patterns."""
|
| 340 |
+
filtered = []
|
| 341 |
+
|
| 342 |
+
success_indicators = {
|
| 343 |
+
'any': [
|
| 344 |
+
'uvm_component_utils', 'uvm_object_utils',
|
| 345 |
+
'raise_objection', 'drop_objection',
|
| 346 |
+
'build_phase', 'connect_phase', 'run_phase',
|
| 347 |
+
'config_db', 'type_id', 'create',
|
| 348 |
+
],
|
| 349 |
+
'driver': [
|
| 350 |
+
'seq_item_port', 'get_next_item', 'item_done',
|
| 351 |
+
],
|
| 352 |
+
'monitor': [
|
| 353 |
+
'analysis_port', 'write',
|
| 354 |
+
],
|
| 355 |
+
'agent': [
|
| 356 |
+
'get_is_active', 'driver', 'monitor', 'sequencer',
|
| 357 |
+
],
|
| 358 |
+
'scoreboard': [
|
| 359 |
+
'uvm_analysis_imp', 'write',
|
| 360 |
+
],
|
| 361 |
+
'sequence': [
|
| 362 |
+
'start_item', 'finish_item', 'body', 'randomize',
|
| 363 |
+
],
|
| 364 |
+
'test': [
|
| 365 |
+
'uvm_test', 'env', 'reg_model',
|
| 366 |
+
],
|
| 367 |
+
'ral_model': [
|
| 368 |
+
'uvm_reg', 'uvm_reg_block', 'uvm_reg_field',
|
| 369 |
+
'create_map', 'lock_model',
|
| 370 |
+
],
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
for ngram in ngrams:
|
| 374 |
+
indicators = success_indicators.get(file_type, []) + success_indicators.get('any', [])
|
| 375 |
+
if any(ind in ngram for ind in indicators):
|
| 376 |
+
filtered.append(ngram)
|
| 377 |
+
|
| 378 |
+
return list(set(filtered))
|
| 379 |
+
|
| 380 |
+
def _record_success_pattern(
|
| 381 |
+
self,
|
| 382 |
+
ngram: str,
|
| 383 |
+
file_type: str,
|
| 384 |
+
protocol: str,
|
| 385 |
+
score: float,
|
| 386 |
+
) -> None:
|
| 387 |
+
"""Record a successful pattern."""
|
| 388 |
+
if ngram not in self._success_patterns:
|
| 389 |
+
self._success_patterns[ngram] = Pattern(
|
| 390 |
+
pattern_str=ngram,
|
| 391 |
+
pattern_type=PatternType.SUCCESS,
|
| 392 |
+
description=f"Successful pattern from {file_type}",
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
pattern = self._success_patterns[ngram]
|
| 396 |
+
pattern.count += 1
|
| 397 |
+
|
| 398 |
+
if file_type not in pattern.file_types:
|
| 399 |
+
pattern.file_types.append(file_type)
|
| 400 |
+
if protocol not in pattern.protocols:
|
| 401 |
+
pattern.protocols.append(protocol)
|
| 402 |
+
|
| 403 |
+
if file_type not in self._file_type_patterns:
|
| 404 |
+
self._file_type_patterns[file_type] = defaultdict(int)
|
| 405 |
+
self._file_type_patterns[file_type][ngram] += 1
|
| 406 |
+
|
| 407 |
+
if protocol not in self._protocol_patterns:
|
| 408 |
+
self._protocol_patterns[protocol] = defaultdict(int)
|
| 409 |
+
self._protocol_patterns[protocol][ngram] += 1
|
| 410 |
+
|
| 411 |
+
total = float(self._total_successes + 1)
|
| 412 |
+
pattern.support = pattern.count / total
|
| 413 |
+
pattern.confidence = min(1.0, score * pattern.count / total)
|
| 414 |
+
|
| 415 |
+
def get_success_patterns(
|
| 416 |
+
self,
|
| 417 |
+
file_type: Optional[str] = None,
|
| 418 |
+
protocol: Optional[str] = None,
|
| 419 |
+
min_count: int = 2,
|
| 420 |
+
top_n: int = 20,
|
| 421 |
+
) -> List[Pattern]:
|
| 422 |
+
"""Get successful patterns filtered by criteria."""
|
| 423 |
+
candidates: List[Pattern] = []
|
| 424 |
+
|
| 425 |
+
for pattern in self._success_patterns.values():
|
| 426 |
+
if pattern.count < min_count:
|
| 427 |
+
continue
|
| 428 |
+
if file_type and file_type not in pattern.file_types:
|
| 429 |
+
continue
|
| 430 |
+
if protocol and protocol not in pattern.protocols:
|
| 431 |
+
continue
|
| 432 |
+
candidates.append(pattern)
|
| 433 |
+
|
| 434 |
+
candidates.sort(key=lambda p: (p.confidence, p.support), reverse=True)
|
| 435 |
+
return candidates[:top_n]
|
| 436 |
+
|
| 437 |
+
def get_recommendations(
|
| 438 |
+
self,
|
| 439 |
+
file_type: str,
|
| 440 |
+
protocol: str,
|
| 441 |
+
) -> List[Dict[str, Any]]:
|
| 442 |
+
"""Get code recommendations based on success patterns."""
|
| 443 |
+
recommendations = []
|
| 444 |
+
|
| 445 |
+
patterns = self.get_success_patterns(
|
| 446 |
+
file_type=file_type,
|
| 447 |
+
protocol=protocol,
|
| 448 |
+
min_count=1,
|
| 449 |
+
top_n=10,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
for pattern in patterns:
|
| 453 |
+
recommendations.append({
|
| 454 |
+
'pattern': pattern.pattern_str,
|
| 455 |
+
'confidence': pattern.confidence,
|
| 456 |
+
'support': pattern.support,
|
| 457 |
+
'file_types': pattern.file_types,
|
| 458 |
+
'description': pattern.description,
|
| 459 |
+
})
|
| 460 |
+
|
| 461 |
+
return recommendations
|
| 462 |
+
|
| 463 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 464 |
+
return {
|
| 465 |
+
'total_successes': self._total_successes,
|
| 466 |
+
'success_patterns': {k: v.to_dict() for k, v in self._success_patterns.items()},
|
| 467 |
+
'file_type_patterns': {
|
| 468 |
+
ft: dict(patterns) for ft, patterns in self._file_type_patterns.items()
|
| 469 |
+
},
|
| 470 |
+
'protocol_patterns': {
|
| 471 |
+
proto: dict(patterns) for proto, patterns in self._protocol_patterns.items()
|
| 472 |
+
},
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
class AssociationRuleMiner:
|
| 477 |
+
"""Mine association rules between spec features and generation success."""
|
| 478 |
+
|
| 479 |
+
def __init__(self, min_support: float = 0.1, min_confidence: float = 0.5):
|
| 480 |
+
self.min_support = min_support
|
| 481 |
+
self.min_confidence = min_confidence
|
| 482 |
+
self._transactions: List[Set[str]] = []
|
| 483 |
+
self._item_counts: Dict[str, int] = defaultdict(int)
|
| 484 |
+
self._rules: List[AssociationRule] = []
|
| 485 |
+
|
| 486 |
+
def add_transaction(self, items: List[str]) -> None:
|
| 487 |
+
"""Add a transaction (set of features/outcomes)."""
|
| 488 |
+
item_set = set(items)
|
| 489 |
+
self._transactions.append(item_set)
|
| 490 |
+
|
| 491 |
+
for item in item_set:
|
| 492 |
+
self._item_counts[item] += 1
|
| 493 |
+
|
| 494 |
+
def mine_rules(self) -> List[AssociationRule]:
|
| 495 |
+
"""Mine association rules from transactions."""
|
| 496 |
+
if len(self._transactions) < 5:
|
| 497 |
+
return []
|
| 498 |
+
|
| 499 |
+
min_support_count = int(self.min_support * len(self._transactions))
|
| 500 |
+
|
| 501 |
+
freq_items = {
|
| 502 |
+
item: count for item, count in self._item_counts.items()
|
| 503 |
+
if count >= min_support_count
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
if len(freq_items) < 2:
|
| 507 |
+
return []
|
| 508 |
+
|
| 509 |
+
rules = []
|
| 510 |
+
items_list = list(freq_items.keys())
|
| 511 |
+
|
| 512 |
+
for i, item1 in enumerate(items_list):
|
| 513 |
+
for item2 in items_list[i+1:]:
|
| 514 |
+
count_both = sum(
|
| 515 |
+
1 for t in self._transactions
|
| 516 |
+
if item1 in t and item2 in t
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
if count_both < min_support_count:
|
| 520 |
+
continue
|
| 521 |
+
|
| 522 |
+
support = count_both / len(self._transactions)
|
| 523 |
+
|
| 524 |
+
confidence_1_2 = count_both / self._item_counts[item1]
|
| 525 |
+
confidence_2_1 = count_both / self._item_counts[item2]
|
| 526 |
+
|
| 527 |
+
support_item1 = self._item_counts[item1] / len(self._transactions)
|
| 528 |
+
support_item2 = self._item_counts[item2] / len(self._transactions)
|
| 529 |
+
|
| 530 |
+
lift_1_2 = confidence_1_2 / support_item2 if support_item2 > 0 else 1.0
|
| 531 |
+
lift_2_1 = confidence_2_1 / support_item1 if support_item1 > 0 else 1.0
|
| 532 |
+
|
| 533 |
+
if confidence_1_2 >= self.min_confidence:
|
| 534 |
+
rules.append(AssociationRule(
|
| 535 |
+
antecedent=item1,
|
| 536 |
+
consequent=item2,
|
| 537 |
+
confidence=confidence_1_2,
|
| 538 |
+
support=support,
|
| 539 |
+
lift=lift_1_2,
|
| 540 |
+
count=count_both,
|
| 541 |
+
))
|
| 542 |
+
|
| 543 |
+
if confidence_2_1 >= self.min_confidence:
|
| 544 |
+
rules.append(AssociationRule(
|
| 545 |
+
antecedent=item2,
|
| 546 |
+
consequent=item1,
|
| 547 |
+
confidence=confidence_2_1,
|
| 548 |
+
support=support,
|
| 549 |
+
lift=lift_2_1,
|
| 550 |
+
count=count_both,
|
| 551 |
+
))
|
| 552 |
+
|
| 553 |
+
rules.sort(key=lambda r: (r.confidence, r.lift, r.support), reverse=True)
|
| 554 |
+
self._rules = rules
|
| 555 |
+
return rules
|
| 556 |
+
|
| 557 |
+
def get_rules_for_antecedent(self, antecedent: str) -> List[AssociationRule]:
|
| 558 |
+
"""Get all rules with a specific antecedent."""
|
| 559 |
+
return [r for r in self._rules if r.antecedent == antecedent]
|
| 560 |
+
|
| 561 |
+
def get_rules_for_consequent(self, consequent: str) -> List[AssociationRule]:
|
| 562 |
+
"""Get all rules with a specific consequent."""
|
| 563 |
+
return [r for r in self._rules if r.consequent == consequent]
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
class TemporalPatternTracker:
|
| 567 |
+
"""Track how patterns evolve over time for continuous learning."""
|
| 568 |
+
|
| 569 |
+
def __init__(self, window_size: int = 100):
|
| 570 |
+
self.window_size = window_size
|
| 571 |
+
self._error_windows: Dict[str, List[bool]] = defaultdict(list)
|
| 572 |
+
self._success_windows: Dict[str, List[bool]] = defaultdict(list)
|
| 573 |
+
self._trends: Dict[str, float] = {}
|
| 574 |
+
|
| 575 |
+
def record_error(self, error_type: str, occurred: bool) -> None:
|
| 576 |
+
"""Record whether an error occurred."""
|
| 577 |
+
self._error_windows[error_type].append(occurred)
|
| 578 |
+
if len(self._error_windows[error_type]) > self.window_size:
|
| 579 |
+
self._error_windows[error_type].pop(0)
|
| 580 |
+
self._update_trend(error_type, 'error')
|
| 581 |
+
|
| 582 |
+
def record_success(self, pattern: str, success: bool) -> None:
|
| 583 |
+
"""Record pattern success."""
|
| 584 |
+
self._success_windows[pattern].append(success)
|
| 585 |
+
if len(self._success_windows[pattern]) > self.window_size:
|
| 586 |
+
self._success_windows[pattern].pop(0)
|
| 587 |
+
self._update_trend(pattern, 'success')
|
| 588 |
+
|
| 589 |
+
def _update_trend(self, key: str, pattern_type: str) -> None:
|
| 590 |
+
"""Update trend direction."""
|
| 591 |
+
if pattern_type == 'error':
|
| 592 |
+
window = self._error_windows.get(key, [])
|
| 593 |
+
else:
|
| 594 |
+
window = self._success_windows.get(key, [])
|
| 595 |
+
|
| 596 |
+
if len(window) < 10:
|
| 597 |
+
self._trends[key] = 0.0
|
| 598 |
+
return
|
| 599 |
+
|
| 600 |
+
first_half = window[:len(window)//2]
|
| 601 |
+
second_half = window[len(window)//2:]
|
| 602 |
+
|
| 603 |
+
rate_first = sum(first_half) / len(first_half)
|
| 604 |
+
rate_second = sum(second_half) / len(second_half)
|
| 605 |
+
|
| 606 |
+
self._trends[key] = rate_second - rate_first
|
| 607 |
+
|
| 608 |
+
def get_trend(self, key: str) -> float:
|
| 609 |
+
"""Get trend: positive = improving, negative = worsening."""
|
| 610 |
+
return self._trends.get(key, 0.0)
|
| 611 |
+
|
| 612 |
+
def get_error_rate(self, error_type: str) -> float:
|
| 613 |
+
"""Get current error rate."""
|
| 614 |
+
window = self._error_windows.get(error_type, [])
|
| 615 |
+
if not window:
|
| 616 |
+
return 0.0
|
| 617 |
+
return sum(window) / len(window)
|
| 618 |
+
|
| 619 |
+
def get_success_rate(self, pattern: str) -> float:
|
| 620 |
+
"""Get current success rate."""
|
| 621 |
+
window = self._success_windows.get(pattern, [])
|
| 622 |
+
if not window:
|
| 623 |
+
return 0.0
|
| 624 |
+
return sum(window) / len(window)
|
| 625 |
+
|
| 626 |
+
def get_improving_errors(self) -> List[Tuple[str, float]]:
|
| 627 |
+
"""Get errors that are decreasing."""
|
| 628 |
+
improving = []
|
| 629 |
+
for key, trend in self._trends.items():
|
| 630 |
+
if key in self._error_windows and trend < -0.1:
|
| 631 |
+
improving.append((key, trend))
|
| 632 |
+
improving.sort(key=lambda x: x[1])
|
| 633 |
+
return improving
|
| 634 |
+
|
| 635 |
+
def get_worsening_errors(self) -> List[Tuple[str, float]]:
|
| 636 |
+
"""Get errors that are increasing."""
|
| 637 |
+
worsening = []
|
| 638 |
+
for key, trend in self._trends.items():
|
| 639 |
+
if key in self._error_windows and trend > 0.1:
|
| 640 |
+
worsening.append((key, trend))
|
| 641 |
+
worsening.sort(key=lambda x: x[1], reverse=True)
|
| 642 |
+
return worsening
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
class AdvancedPatternLearner:
|
| 646 |
+
"""
|
| 647 |
+
Advanced pattern learner combining all capabilities.
|
| 648 |
+
|
| 649 |
+
This is the main interface for:
|
| 650 |
+
1. Error pattern detection and tracking
|
| 651 |
+
2. Success pattern mining
|
| 652 |
+
3. Association rule learning
|
| 653 |
+
4. Temporal trend analysis
|
| 654 |
+
5. Code recommendations
|
| 655 |
+
"""
|
| 656 |
+
|
| 657 |
+
def __init__(self):
|
| 658 |
+
self._error_detector = ContextAwareErrorDetector()
|
| 659 |
+
self._success_miner = SuccessPatternMiner()
|
| 660 |
+
self._association_miner = AssociationRuleMiner(min_support=0.1, min_confidence=0.5)
|
| 661 |
+
self._temporal_tracker = TemporalPatternTracker(window_size=100)
|
| 662 |
+
|
| 663 |
+
self._error_patterns: Dict[str, Pattern] = {}
|
| 664 |
+
self._file_type_stats: Dict[str, Dict[str, Any]] = defaultdict(
|
| 665 |
+
lambda: {"success": 0, "total": 0, "errors": defaultdict(int)}
|
| 666 |
+
)
|
| 667 |
+
self._protocol_stats: Dict[str, Dict[str, Any]] = defaultdict(
|
| 668 |
+
lambda: {"success": 0, "total": 0}
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
self._ngram_extractor = NgramExtractor(n_min=1, n_max=4)
|
| 672 |
+
|
| 673 |
+
def record_error(
|
| 674 |
+
self,
|
| 675 |
+
error_msg: str,
|
| 676 |
+
file_type: str = "unknown",
|
| 677 |
+
content: Optional[str] = None,
|
| 678 |
+
line_num: Optional[int] = None,
|
| 679 |
+
) -> List[Dict[str, Any]]:
|
| 680 |
+
"""Record an error with full context analysis."""
|
| 681 |
+
errors = self._error_detector.extract_with_context(
|
| 682 |
+
error_msg, content, line_num
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
for err in errors:
|
| 686 |
+
error_type = err['error_type']
|
| 687 |
+
|
| 688 |
+
if error_type not in self._error_patterns:
|
| 689 |
+
self._error_patterns[error_type] = Pattern(
|
| 690 |
+
pattern_str=error_type,
|
| 691 |
+
pattern_type=PatternType.ERROR,
|
| 692 |
+
description=err.get('suggestion', ''),
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
self._error_patterns[error_type].count += 1
|
| 696 |
+
self._error_patterns[error_type].contexts.append(
|
| 697 |
+
err.get('context', error_msg[:100])
|
| 698 |
+
)
|
| 699 |
+
if file_type not in self._error_patterns[error_type].file_types:
|
| 700 |
+
self._error_patterns[error_type].file_types.append(file_type)
|
| 701 |
+
|
| 702 |
+
self._file_type_stats[file_type]["errors"][error_type] += 1
|
| 703 |
+
self._temporal_tracker.record_error(error_type, True)
|
| 704 |
+
|
| 705 |
+
return errors
|
| 706 |
+
|
| 707 |
+
def record_success(
|
| 708 |
+
self,
|
| 709 |
+
file_type: str = "unknown",
|
| 710 |
+
protocol: str = "unknown",
|
| 711 |
+
content: Optional[str] = None,
|
| 712 |
+
score: float = 1.0,
|
| 713 |
+
) -> List[str]:
|
| 714 |
+
"""Record a success and mine patterns from it."""
|
| 715 |
+
self._file_type_stats[file_type]["success"] += 1
|
| 716 |
+
self._file_type_stats[file_type]["total"] += 1
|
| 717 |
+
self._protocol_stats[protocol]["success"] += 1
|
| 718 |
+
self._protocol_stats[protocol]["total"] += 1
|
| 719 |
+
|
| 720 |
+
mined_patterns = []
|
| 721 |
+
if content and score >= 0.7:
|
| 722 |
+
mined_patterns = self._success_miner.mine_from_success(
|
| 723 |
+
content, file_type, protocol, score
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
for pattern in mined_patterns:
|
| 727 |
+
self._temporal_tracker.record_success(pattern, True)
|
| 728 |
+
|
| 729 |
+
items = [
|
| 730 |
+
f"file_type:{file_type}",
|
| 731 |
+
f"protocol:{protocol}",
|
| 732 |
+
f"success:yes",
|
| 733 |
+
f"score:{int(score * 10)}",
|
| 734 |
+
]
|
| 735 |
+
items.extend(mined_patterns[:5])
|
| 736 |
+
self._association_miner.add_transaction(items)
|
| 737 |
+
|
| 738 |
+
return mined_patterns
|
| 739 |
+
|
| 740 |
+
def record_attempt(
|
| 741 |
+
self,
|
| 742 |
+
file_type: str = "unknown",
|
| 743 |
+
protocol: str = "unknown",
|
| 744 |
+
) -> None:
|
| 745 |
+
"""Record an attempt (for stats tracking)."""
|
| 746 |
+
self._file_type_stats[file_type]["total"] += 1
|
| 747 |
+
self._protocol_stats[protocol]["total"] += 1
|
| 748 |
+
|
| 749 |
+
def get_common_errors(self, top_n: int = 10) -> List[Tuple[str, int, Pattern]]:
|
| 750 |
+
"""Get the most common errors."""
|
| 751 |
+
sorted_errors = sorted(
|
| 752 |
+
self._error_patterns.items(),
|
| 753 |
+
key=lambda x: x[1].count,
|
| 754 |
+
reverse=True,
|
| 755 |
+
)
|
| 756 |
+
return [(name, p.count, p) for name, p in sorted_errors[:top_n]]
|
| 757 |
+
|
| 758 |
+
def get_file_type_success_rate(self, file_type: str) -> float:
|
| 759 |
+
"""Get success rate for a file type."""
|
| 760 |
+
stats = self._file_type_stats.get(file_type, {})
|
| 761 |
+
total = stats.get("total", 0)
|
| 762 |
+
if total == 0:
|
| 763 |
+
return 0.5
|
| 764 |
+
return stats.get("success", 0) / total
|
| 765 |
+
|
| 766 |
+
def get_protocol_success_rate(self, protocol: str) -> float:
|
| 767 |
+
"""Get success rate for a protocol."""
|
| 768 |
+
stats = self._protocol_stats.get(protocol, {})
|
| 769 |
+
total = stats.get("total", 0)
|
| 770 |
+
if total == 0:
|
| 771 |
+
return 0.5
|
| 772 |
+
return stats.get("success", 0) / total
|
| 773 |
+
|
| 774 |
+
def get_suggestions(
|
| 775 |
+
self,
|
| 776 |
+
file_type: str,
|
| 777 |
+
protocol: str,
|
| 778 |
+
) -> Dict[str, Any]:
|
| 779 |
+
"""Get comprehensive suggestions for improvement."""
|
| 780 |
+
common_errors = self.get_common_errors(5)
|
| 781 |
+
file_success_rate = self.get_file_type_success_rate(file_type)
|
| 782 |
+
protocol_success_rate = self.get_protocol_success_rate(protocol)
|
| 783 |
+
|
| 784 |
+
success_recommendations = self._success_miner.get_recommendations(
|
| 785 |
+
file_type, protocol
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
improving = self._temporal_tracker.get_improving_errors()
|
| 789 |
+
worsening = self._temporal_tracker.get_worsening_errors()
|
| 790 |
+
|
| 791 |
+
suggestions = {
|
| 792 |
+
"common_errors": [
|
| 793 |
+
{
|
| 794 |
+
"error_type": name,
|
| 795 |
+
"count": count,
|
| 796 |
+
"description": pattern.description,
|
| 797 |
+
"current_rate": self._temporal_tracker.get_error_rate(name),
|
| 798 |
+
"trend": self._temporal_tracker.get_trend(name),
|
| 799 |
+
}
|
| 800 |
+
for name, count, pattern in common_errors
|
| 801 |
+
],
|
| 802 |
+
"file_type_success_rate": file_success_rate,
|
| 803 |
+
"protocol_success_rate": protocol_success_rate,
|
| 804 |
+
"success_patterns": success_recommendations,
|
| 805 |
+
"improving_errors": [{"error": e[0], "trend": e[1]} for e in improving],
|
| 806 |
+
"worsening_errors": [{"error": e[0], "trend": e[1]} for e in worsening],
|
| 807 |
+
"recommendations": self._generate_advanced_recommendations(
|
| 808 |
+
file_type, protocol, file_success_rate, common_errors
|
| 809 |
+
),
|
| 810 |
+
}
|
| 811 |
+
|
| 812 |
+
return suggestions
|
| 813 |
+
|
| 814 |
+
def _generate_advanced_recommendations(
|
| 815 |
+
self,
|
| 816 |
+
file_type: str,
|
| 817 |
+
protocol: str,
|
| 818 |
+
success_rate: float,
|
| 819 |
+
common_errors: List[Tuple],
|
| 820 |
+
) -> List[str]:
|
| 821 |
+
"""Generate advanced recommendations based on all data."""
|
| 822 |
+
recommendations = []
|
| 823 |
+
|
| 824 |
+
for name, count, pattern in common_errors[:3]:
|
| 825 |
+
if count > 0:
|
| 826 |
+
if pattern.description:
|
| 827 |
+
recommendations.append(pattern.description)
|
| 828 |
+
elif 'semicolon' in name:
|
| 829 |
+
recommendations.append("Ensure all statements end with semicolons")
|
| 830 |
+
elif 'parenthes' in name:
|
| 831 |
+
recommendations.append("Check for balanced parentheses")
|
| 832 |
+
elif 'brace' in name or 'block' in name:
|
| 833 |
+
recommendations.append("Check for balanced begin/end blocks")
|
| 834 |
+
elif 'uvm_macro' in name or 'factory' in name:
|
| 835 |
+
recommendations.append(
|
| 836 |
+
"Add UVM factory registration macros (uvm_component_utils/uvm_object_utils)"
|
| 837 |
+
)
|
| 838 |
+
elif 'phase' in name:
|
| 839 |
+
recommendations.append("Ensure proper UVM phase implementation")
|
| 840 |
+
elif 'objection' in name:
|
| 841 |
+
recommendations.append(
|
| 842 |
+
"Use phase.raise_objection(this) and phase.drop_objection(this)"
|
| 843 |
+
)
|
| 844 |
+
elif 'config_db' in name or 'vif' in name:
|
| 845 |
+
recommendations.append(
|
| 846 |
+
"Ensure virtual interface is set in config_db before run_test()"
|
| 847 |
+
)
|
| 848 |
+
elif 'ral' in name or 'reg_model' in name:
|
| 849 |
+
recommendations.append(
|
| 850 |
+
"Create and initialize RAL model in test::build_phase"
|
| 851 |
+
)
|
| 852 |
+
elif 'signal' in name or 'port' in name:
|
| 853 |
+
recommendations.append(
|
| 854 |
+
"Ensure all signals/ports used are declared in spec and interface"
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
if success_rate < 0.7:
|
| 858 |
+
recommendations.append(
|
| 859 |
+
f"Consider using retrieval-based generation for {file_type} (success rate: {success_rate:.1%})"
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
rules = self._association_miner.get_rules_for_antecedent(f"file_type:{file_type}")
|
| 863 |
+
for rule in rules[:3]:
|
| 864 |
+
if rule.confidence > 0.7 and rule.lift > 1.0:
|
| 865 |
+
recommendations.append(
|
| 866 |
+
f"Consider: {rule.consequent} (confidence: {rule.confidence:.1%}, lift: {rule.lift:.2f})"
|
| 867 |
+
)
|
| 868 |
+
|
| 869 |
+
if not recommendations:
|
| 870 |
+
recommendations.append(
|
| 871 |
+
"No specific recommendations - generation should work well"
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
return recommendations
|
| 875 |
+
|
| 876 |
+
def mine_association_rules(self) -> List[AssociationRule]:
|
| 877 |
+
"""Mine association rules from collected data."""
|
| 878 |
+
return self._association_miner.mine_rules()
|
| 879 |
+
|
| 880 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 881 |
+
return {
|
| 882 |
+
"error_patterns": {k: v.to_dict() for k, v in self._error_patterns.items()},
|
| 883 |
+
"file_type_stats": {
|
| 884 |
+
ft: {
|
| 885 |
+
"success": s["success"],
|
| 886 |
+
"total": s["total"],
|
| 887 |
+
"errors": dict(s["errors"]),
|
| 888 |
+
}
|
| 889 |
+
for ft, s in self._file_type_stats.items()
|
| 890 |
+
},
|
| 891 |
+
"protocol_stats": dict(self._protocol_stats),
|
| 892 |
+
"success_miner": self._success_miner.to_dict(),
|
| 893 |
+
}
|
| 894 |
+
|
| 895 |
+
@classmethod
|
| 896 |
+
def from_dict(cls, d: Dict[str, Any]) -> "AdvancedPatternLearner":
|
| 897 |
+
learner = cls()
|
| 898 |
+
|
| 899 |
+
for name, pdict in d.get("error_patterns", {}).items():
|
| 900 |
+
pattern = Pattern(
|
| 901 |
+
pattern_str=pdict.get("pattern_str", name),
|
| 902 |
+
pattern_type=PatternType(pdict.get("pattern_type", "error")),
|
| 903 |
+
count=pdict.get("count", 0),
|
| 904 |
+
confidence=pdict.get("confidence", 0.0),
|
| 905 |
+
support=pdict.get("support", 0.0),
|
| 906 |
+
contexts=pdict.get("contexts", []),
|
| 907 |
+
file_types=pdict.get("file_types", []),
|
| 908 |
+
protocols=pdict.get("protocols", []),
|
| 909 |
+
description=pdict.get("description", ""),
|
| 910 |
+
)
|
| 911 |
+
learner._error_patterns[name] = pattern
|
| 912 |
+
|
| 913 |
+
for ft, s in d.get("file_type_stats", {}).items():
|
| 914 |
+
learner._file_type_stats[ft] = {
|
| 915 |
+
"success": s.get("success", 0),
|
| 916 |
+
"total": s.get("total", 0),
|
| 917 |
+
"errors": defaultdict(int, s.get("errors", {})),
|
| 918 |
+
}
|
| 919 |
+
|
| 920 |
+
for proto, s in d.get("protocol_stats", {}).items():
|
| 921 |
+
learner._protocol_stats[proto] = {
|
| 922 |
+
"success": s.get("success", 0),
|
| 923 |
+
"total": s.get("total", 0),
|
| 924 |
+
}
|
| 925 |
+
|
| 926 |
+
return learner
|
|
@@ -0,0 +1,728 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Advanced Reinforcement Learner for UVM Testbench Generation Strategy Selection.
|
| 3 |
+
|
| 4 |
+
Key improvements for promotion:
|
| 5 |
+
1. Experience replay buffer for more stable learning
|
| 6 |
+
2. Eligibility traces for better credit assignment
|
| 7 |
+
3. Upper Confidence Bound (UCB) for exploration-exploitation balance
|
| 8 |
+
4. Multi-armed bandit strategies (epsilon-greedy, softmax, UCB)
|
| 9 |
+
5. Contextual bandits considering spec features
|
| 10 |
+
6. Learning rate scheduling
|
| 11 |
+
7. Value function approximation with state aggregation
|
| 12 |
+
8. Performance tracking and strategy comparison
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
import math
|
| 19 |
+
import random
|
| 20 |
+
import json
|
| 21 |
+
import os
|
| 22 |
+
from collections import defaultdict, deque
|
| 23 |
+
from dataclasses import dataclass, field
|
| 24 |
+
from typing import Dict, List, Any, Optional, Tuple, Deque
|
| 25 |
+
from enum import Enum
|
| 26 |
+
from datetime import datetime
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger("uvmgen.ml.rl")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ExplorationStrategy(Enum):
|
| 32 |
+
EPSILON_GREEDY = "epsilon_greedy"
|
| 33 |
+
SOFTMAX = "softmax"
|
| 34 |
+
UCB = "ucb"
|
| 35 |
+
THOMPSON_SAMPLING = "thompson_sampling"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class Experience:
|
| 40 |
+
"""Single experience for replay buffer."""
|
| 41 |
+
state: str
|
| 42 |
+
action: str
|
| 43 |
+
reward: float
|
| 44 |
+
next_state: Optional[str]
|
| 45 |
+
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
| 46 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class ActionStats:
|
| 51 |
+
"""Statistics for an action."""
|
| 52 |
+
q_value: float = 0.5
|
| 53 |
+
visit_count: int = 0
|
| 54 |
+
total_reward: float = 0.0
|
| 55 |
+
squared_reward: float = 0.0
|
| 56 |
+
success_count: int = 0
|
| 57 |
+
failure_count: int = 0
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def mean_reward(self) -> float:
|
| 61 |
+
if self.visit_count == 0:
|
| 62 |
+
return 0.5
|
| 63 |
+
return self.total_reward / self.visit_count
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def variance(self) -> float:
|
| 67 |
+
if self.visit_count < 2:
|
| 68 |
+
return 0.25
|
| 69 |
+
mean = self.mean_reward
|
| 70 |
+
return (self.squared_reward / self.visit_count) - (mean * mean)
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def std_dev(self) -> float:
|
| 74 |
+
return math.sqrt(max(0.0, self.variance))
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def success_rate(self) -> float:
|
| 78 |
+
total = self.success_count + self.failure_count
|
| 79 |
+
if total == 0:
|
| 80 |
+
return 0.5
|
| 81 |
+
return self.success_count / total
|
| 82 |
+
|
| 83 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 84 |
+
return {
|
| 85 |
+
"q_value": self.q_value,
|
| 86 |
+
"visit_count": self.visit_count,
|
| 87 |
+
"total_reward": self.total_reward,
|
| 88 |
+
"squared_reward": self.squared_reward,
|
| 89 |
+
"success_count": self.success_count,
|
| 90 |
+
"failure_count": self.failure_count,
|
| 91 |
+
"mean_reward": self.mean_reward,
|
| 92 |
+
"variance": self.variance,
|
| 93 |
+
"std_dev": self.std_dev,
|
| 94 |
+
"success_rate": self.success_rate,
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class ExperienceReplayBuffer:
|
| 99 |
+
"""Buffer for storing and sampling experiences."""
|
| 100 |
+
|
| 101 |
+
def __init__(self, capacity: int = 10000):
|
| 102 |
+
self.capacity = capacity
|
| 103 |
+
self.buffer: Deque[Experience] = deque(maxlen=capacity)
|
| 104 |
+
self._episode_rewards: List[float] = []
|
| 105 |
+
|
| 106 |
+
def add(self, experience: Experience) -> None:
|
| 107 |
+
"""Add an experience to the buffer."""
|
| 108 |
+
self.buffer.append(experience)
|
| 109 |
+
|
| 110 |
+
def sample(self, batch_size: int) -> List[Experience]:
|
| 111 |
+
"""Sample a batch of experiences randomly."""
|
| 112 |
+
if len(self.buffer) < batch_size:
|
| 113 |
+
return list(self.buffer)
|
| 114 |
+
return random.sample(list(self.buffer), batch_size)
|
| 115 |
+
|
| 116 |
+
def sample_recent(self, batch_size: int, recency_weight: float = 0.8) -> List[Experience]:
|
| 117 |
+
"""Sample with preference to recent experiences."""
|
| 118 |
+
if len(self.buffer) < batch_size:
|
| 119 |
+
return list(self.buffer)
|
| 120 |
+
|
| 121 |
+
recent_count = int(batch_size * recency_weight)
|
| 122 |
+
random_count = batch_size - recent_count
|
| 123 |
+
|
| 124 |
+
recent = list(self.buffer)[-recent_count:] if recent_count > 0 else []
|
| 125 |
+
random_part = random.sample(
|
| 126 |
+
list(self.buffer)[:-recent_count] if recent_count > 0 else list(self.buffer),
|
| 127 |
+
min(random_count, len(self.buffer) - recent_count)
|
| 128 |
+
) if random_count > 0 else []
|
| 129 |
+
|
| 130 |
+
return recent + random_part
|
| 131 |
+
|
| 132 |
+
def get_all_by_state(self, state: str) -> List[Experience]:
|
| 133 |
+
"""Get all experiences for a specific state."""
|
| 134 |
+
return [e for e in self.buffer if e.state == state]
|
| 135 |
+
|
| 136 |
+
def record_episode_reward(self, reward: float) -> None:
|
| 137 |
+
"""Record episode-level reward for tracking."""
|
| 138 |
+
self._episode_rewards.append(reward)
|
| 139 |
+
if len(self._episode_rewards) > 1000:
|
| 140 |
+
self._episode_rewards = self._episode_rewards[-1000:]
|
| 141 |
+
|
| 142 |
+
def get_recent_performance(self, window: int = 100) -> Dict[str, float]:
|
| 143 |
+
"""Get recent performance statistics."""
|
| 144 |
+
if not self._episode_rewards:
|
| 145 |
+
return {"mean": 0.5, "std": 0.0, "trend": 0.0}
|
| 146 |
+
|
| 147 |
+
recent = self._episode_rewards[-window:]
|
| 148 |
+
mean = sum(recent) / len(recent)
|
| 149 |
+
|
| 150 |
+
variance = sum((r - mean) ** 2 for r in recent) / len(recent)
|
| 151 |
+
std = math.sqrt(max(0.0, variance))
|
| 152 |
+
|
| 153 |
+
if len(recent) >= 20:
|
| 154 |
+
first_half = recent[:len(recent)//2]
|
| 155 |
+
second_half = recent[len(recent)//2:]
|
| 156 |
+
trend = (sum(second_half) / len(second_half)) - (sum(first_half) / len(first_half))
|
| 157 |
+
else:
|
| 158 |
+
trend = 0.0
|
| 159 |
+
|
| 160 |
+
return {
|
| 161 |
+
"mean": mean,
|
| 162 |
+
"std": std,
|
| 163 |
+
"trend": trend,
|
| 164 |
+
"count": len(recent),
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
def __len__(self) -> int:
|
| 168 |
+
return len(self.buffer)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class EligibilityTraces:
|
| 172 |
+
"""Eligibility traces for better credit assignment."""
|
| 173 |
+
|
| 174 |
+
def __init__(self, lambda_: float = 0.9, discount: float = 0.95):
|
| 175 |
+
self.lambda_ = lambda_
|
| 176 |
+
self.discount = discount
|
| 177 |
+
self._traces: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float))
|
| 178 |
+
|
| 179 |
+
def update(self, state: str, action: str) -> None:
|
| 180 |
+
"""Update trace for visited state-action pair."""
|
| 181 |
+
for s in list(self._traces.keys()):
|
| 182 |
+
for a in list(self._traces[s].keys()):
|
| 183 |
+
self._traces[s][a] *= self.lambda_ * self.discount
|
| 184 |
+
|
| 185 |
+
self._traces[state][action] = 1.0
|
| 186 |
+
|
| 187 |
+
def get_trace(self, state: str, action: str) -> float:
|
| 188 |
+
"""Get the eligibility trace value."""
|
| 189 |
+
return self._traces.get(state, {}).get(action, 0.0)
|
| 190 |
+
|
| 191 |
+
def decay_all(self) -> None:
|
| 192 |
+
"""Decay all traces."""
|
| 193 |
+
for s in self._traces:
|
| 194 |
+
for a in self._traces[s]:
|
| 195 |
+
self._traces[s][a] *= self.lambda_ * self.discount
|
| 196 |
+
|
| 197 |
+
def reset(self) -> None:
|
| 198 |
+
"""Reset all traces."""
|
| 199 |
+
self._traces.clear()
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class ContextualBanditFeatures:
|
| 203 |
+
"""Feature extraction for contextual bandits."""
|
| 204 |
+
|
| 205 |
+
@staticmethod
|
| 206 |
+
def extract_features(
|
| 207 |
+
spec_dict: Dict[str, Any],
|
| 208 |
+
file_type: str,
|
| 209 |
+
) -> Dict[str, Any]:
|
| 210 |
+
"""Extract features from spec and context."""
|
| 211 |
+
features = {}
|
| 212 |
+
|
| 213 |
+
protocol = spec_dict.get("protocol", "unknown")
|
| 214 |
+
features["protocol"] = protocol
|
| 215 |
+
|
| 216 |
+
interfaces = spec_dict.get("interfaces", [])
|
| 217 |
+
features["num_interfaces"] = len(interfaces)
|
| 218 |
+
|
| 219 |
+
total_signals = sum(len(iface.get("signals", [])) for iface in interfaces)
|
| 220 |
+
features["total_signals"] = total_signals
|
| 221 |
+
|
| 222 |
+
registers = spec_dict.get("registers", [])
|
| 223 |
+
features["num_registers"] = len(registers)
|
| 224 |
+
|
| 225 |
+
total_fields = sum(len(reg.get("fields", [])) for reg in registers)
|
| 226 |
+
features["total_fields"] = total_fields
|
| 227 |
+
|
| 228 |
+
complexity = 0.0
|
| 229 |
+
if total_signals > 0:
|
| 230 |
+
complexity += math.log10(total_signals + 1) * 0.3
|
| 231 |
+
if total_fields > 0:
|
| 232 |
+
complexity += math.log10(total_fields + 1) * 0.4
|
| 233 |
+
complexity += len(interfaces) * 0.15
|
| 234 |
+
complexity += len(registers) * 0.15
|
| 235 |
+
features["complexity"] = min(1.0, complexity)
|
| 236 |
+
|
| 237 |
+
file_type_weights = {
|
| 238 |
+
"testbench": 0.3,
|
| 239 |
+
"interface": 0.25,
|
| 240 |
+
"test": 0.2,
|
| 241 |
+
"sequence": 0.15,
|
| 242 |
+
"driver": 0.1,
|
| 243 |
+
"monitor": 0.1,
|
| 244 |
+
"agent": 0.1,
|
| 245 |
+
"scoreboard": 0.15,
|
| 246 |
+
"ral_model": 0.2,
|
| 247 |
+
"env": 0.15,
|
| 248 |
+
}
|
| 249 |
+
features["file_type_weight"] = file_type_weights.get(file_type, 0.1)
|
| 250 |
+
|
| 251 |
+
return features
|
| 252 |
+
|
| 253 |
+
@staticmethod
|
| 254 |
+
def get_state_key(
|
| 255 |
+
protocol: str,
|
| 256 |
+
file_type: str,
|
| 257 |
+
complexity_bucket: str,
|
| 258 |
+
) -> str:
|
| 259 |
+
"""Generate a state key for RL."""
|
| 260 |
+
return f"{protocol}:{file_type}:{complexity_bucket}"
|
| 261 |
+
|
| 262 |
+
@staticmethod
|
| 263 |
+
def bucket_complexity(complexity: float) -> str:
|
| 264 |
+
"""Bucket complexity into discrete levels."""
|
| 265 |
+
if complexity < 0.3:
|
| 266 |
+
return "low"
|
| 267 |
+
elif complexity < 0.6:
|
| 268 |
+
return "medium"
|
| 269 |
+
else:
|
| 270 |
+
return "high"
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class AdvancedReinforcementLearner:
|
| 274 |
+
"""
|
| 275 |
+
Advanced RL learner with multiple strategies and improvements.
|
| 276 |
+
|
| 277 |
+
Key features:
|
| 278 |
+
- Experience replay buffer
|
| 279 |
+
- Eligibility traces
|
| 280 |
+
- Multiple exploration strategies
|
| 281 |
+
- Contextual bandit support
|
| 282 |
+
- Learning rate scheduling
|
| 283 |
+
- Performance tracking
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
def __init__(
|
| 287 |
+
self,
|
| 288 |
+
learning_rate: float = 0.1,
|
| 289 |
+
discount_factor: float = 0.95,
|
| 290 |
+
exploration_strategy: ExplorationStrategy = ExplorationStrategy.UCB,
|
| 291 |
+
epsilon: float = 0.1,
|
| 292 |
+
epsilon_decay: float = 0.995,
|
| 293 |
+
min_epsilon: float = 0.01,
|
| 294 |
+
ucb_c: float = 2.0,
|
| 295 |
+
temperature: float = 1.0,
|
| 296 |
+
use_eligibility_traces: bool = True,
|
| 297 |
+
lambda_: float = 0.9,
|
| 298 |
+
replay_buffer_capacity: int = 10000,
|
| 299 |
+
):
|
| 300 |
+
self._learning_rate = learning_rate
|
| 301 |
+
self._initial_learning_rate = learning_rate
|
| 302 |
+
self._discount_factor = discount_factor
|
| 303 |
+
|
| 304 |
+
self._exploration_strategy = exploration_strategy
|
| 305 |
+
self._epsilon = epsilon
|
| 306 |
+
self._epsilon_decay = epsilon_decay
|
| 307 |
+
self._min_epsilon = min_epsilon
|
| 308 |
+
self._ucb_c = ucb_c
|
| 309 |
+
self._temperature = temperature
|
| 310 |
+
|
| 311 |
+
self._q_values: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(lambda: 0.5))
|
| 312 |
+
self._action_stats: Dict[str, Dict[str, ActionStats]] = defaultdict(dict)
|
| 313 |
+
self._total_updates: int = 0
|
| 314 |
+
|
| 315 |
+
self._use_eligibility_traces = use_eligibility_traces
|
| 316 |
+
if use_eligibility_traces:
|
| 317 |
+
self._eligibility_traces = EligibilityTraces(lambda_=lambda_, discount=discount_factor)
|
| 318 |
+
|
| 319 |
+
self._replay_buffer = ExperienceReplayBuffer(capacity=replay_buffer_capacity)
|
| 320 |
+
|
| 321 |
+
self._episode_count: int = 0
|
| 322 |
+
self._best_actions: Dict[str, str] = {}
|
| 323 |
+
|
| 324 |
+
def _get_state_key(
|
| 325 |
+
self,
|
| 326 |
+
protocol: str,
|
| 327 |
+
file_type: str,
|
| 328 |
+
spec_dict: Optional[Dict[str, Any]] = None,
|
| 329 |
+
) -> str:
|
| 330 |
+
"""Generate state key with optional context."""
|
| 331 |
+
if spec_dict:
|
| 332 |
+
features = ContextualBanditFeatures.extract_features(spec_dict, file_type)
|
| 333 |
+
complexity_bucket = ContextualBanditFeatures.bucket_complexity(features["complexity"])
|
| 334 |
+
return ContextualBanditFeatures.get_state_key(protocol, file_type, complexity_bucket)
|
| 335 |
+
return f"{protocol}:{file_type}"
|
| 336 |
+
|
| 337 |
+
def _ensure_stats(self, state: str, action: str) -> None:
|
| 338 |
+
"""Ensure action stats exist for state-action pair."""
|
| 339 |
+
if state not in self._action_stats:
|
| 340 |
+
self._action_stats[state] = {}
|
| 341 |
+
if action not in self._action_stats[state]:
|
| 342 |
+
self._action_stats[state][action] = ActionStats()
|
| 343 |
+
|
| 344 |
+
def get_action_value(
|
| 345 |
+
self,
|
| 346 |
+
protocol: str,
|
| 347 |
+
file_type: str,
|
| 348 |
+
generation_source: str,
|
| 349 |
+
spec_dict: Optional[Dict[str, Any]] = None,
|
| 350 |
+
) -> float:
|
| 351 |
+
"""Get the Q-value for a state-action pair."""
|
| 352 |
+
state = self._get_state_key(protocol, file_type, spec_dict)
|
| 353 |
+
return self._q_values[state][generation_source]
|
| 354 |
+
|
| 355 |
+
def get_action_stats(
|
| 356 |
+
self,
|
| 357 |
+
protocol: str,
|
| 358 |
+
file_type: str,
|
| 359 |
+
generation_source: str,
|
| 360 |
+
spec_dict: Optional[Dict[str, Any]] = None,
|
| 361 |
+
) -> Optional[ActionStats]:
|
| 362 |
+
"""Get statistics for an action."""
|
| 363 |
+
state = self._get_state_key(protocol, file_type, spec_dict)
|
| 364 |
+
return self._action_stats.get(state, {}).get(generation_source)
|
| 365 |
+
|
| 366 |
+
def update(
|
| 367 |
+
self,
|
| 368 |
+
protocol: str,
|
| 369 |
+
file_type: str,
|
| 370 |
+
generation_source: str,
|
| 371 |
+
reward: float,
|
| 372 |
+
next_state: Optional[str] = None,
|
| 373 |
+
spec_dict: Optional[Dict[str, Any]] = None,
|
| 374 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 375 |
+
) -> None:
|
| 376 |
+
"""Update Q-values with reward, using eligibility traces if enabled."""
|
| 377 |
+
state = self._get_state_key(protocol, file_type, spec_dict)
|
| 378 |
+
|
| 379 |
+
self._ensure_stats(state, generation_source)
|
| 380 |
+
stats = self._action_stats[state][generation_source]
|
| 381 |
+
|
| 382 |
+
old_value = self._q_values[state][generation_source]
|
| 383 |
+
|
| 384 |
+
if next_state and self._q_values.get(next_state):
|
| 385 |
+
next_max = max(self._q_values[next_state].values()) if self._q_values[next_state] else 0.5
|
| 386 |
+
target = reward + self._discount_factor * next_max
|
| 387 |
+
else:
|
| 388 |
+
target = reward
|
| 389 |
+
|
| 390 |
+
td_error = target - old_value
|
| 391 |
+
|
| 392 |
+
if self._use_eligibility_traces and self._eligibility_traces:
|
| 393 |
+
self._eligibility_traces.update(state, generation_source)
|
| 394 |
+
|
| 395 |
+
for s in list(self._q_values.keys()):
|
| 396 |
+
for a in list(self._q_values[s].keys()):
|
| 397 |
+
trace = self._eligibility_traces.get_trace(s, a)
|
| 398 |
+
if trace > 0:
|
| 399 |
+
self._q_values[s][a] += self._learning_rate * td_error * trace
|
| 400 |
+
else:
|
| 401 |
+
self._q_values[state][generation_source] = old_value + self._learning_rate * td_error
|
| 402 |
+
|
| 403 |
+
stats.visit_count += 1
|
| 404 |
+
stats.total_reward += reward
|
| 405 |
+
stats.squared_reward += reward * reward
|
| 406 |
+
stats.q_value = self._q_values[state][generation_source]
|
| 407 |
+
|
| 408 |
+
if reward >= 0.5:
|
| 409 |
+
stats.success_count += 1
|
| 410 |
+
else:
|
| 411 |
+
stats.failure_count += 1
|
| 412 |
+
|
| 413 |
+
self._total_updates += 1
|
| 414 |
+
|
| 415 |
+
experience = Experience(
|
| 416 |
+
state=state,
|
| 417 |
+
action=generation_source,
|
| 418 |
+
reward=reward,
|
| 419 |
+
next_state=next_state,
|
| 420 |
+
metadata=metadata or {},
|
| 421 |
+
)
|
| 422 |
+
self._replay_buffer.add(experience)
|
| 423 |
+
self._replay_buffer.record_episode_reward(reward)
|
| 424 |
+
|
| 425 |
+
actions = self._q_values[state]
|
| 426 |
+
if actions:
|
| 427 |
+
self._best_actions[state] = max(actions.keys(), key=lambda a: actions[a])
|
| 428 |
+
|
| 429 |
+
def _select_epsilon_greedy(
|
| 430 |
+
self,
|
| 431 |
+
state: str,
|
| 432 |
+
available_sources: List[str],
|
| 433 |
+
) -> Tuple[str, float]:
|
| 434 |
+
"""Select action using epsilon-greedy strategy."""
|
| 435 |
+
if random.random() < self._epsilon and len(available_sources) > 1:
|
| 436 |
+
chosen = random.choice(available_sources)
|
| 437 |
+
return chosen, self._q_values[state][chosen]
|
| 438 |
+
|
| 439 |
+
best_source = available_sources[0]
|
| 440 |
+
best_value = -1.0
|
| 441 |
+
|
| 442 |
+
for source in available_sources:
|
| 443 |
+
value = self._q_values[state][source]
|
| 444 |
+
if value > best_value:
|
| 445 |
+
best_value = value
|
| 446 |
+
best_source = source
|
| 447 |
+
|
| 448 |
+
return best_source, best_value
|
| 449 |
+
|
| 450 |
+
def _select_softmax(
|
| 451 |
+
self,
|
| 452 |
+
state: str,
|
| 453 |
+
available_sources: List[str],
|
| 454 |
+
) -> Tuple[str, float]:
|
| 455 |
+
"""Select action using softmax (Boltzmann) exploration."""
|
| 456 |
+
values = [self._q_values[state][s] for s in available_sources]
|
| 457 |
+
|
| 458 |
+
max_val = max(values) if values else 0.0
|
| 459 |
+
exp_values = [math.exp((v - max_val) / self._temperature) for v in values]
|
| 460 |
+
sum_exp = sum(exp_values)
|
| 461 |
+
|
| 462 |
+
if sum_exp == 0:
|
| 463 |
+
probs = [1.0 / len(available_sources)] * len(available_sources)
|
| 464 |
+
else:
|
| 465 |
+
probs = [e / sum_exp for e in exp_values]
|
| 466 |
+
|
| 467 |
+
r = random.random()
|
| 468 |
+
cumulative = 0.0
|
| 469 |
+
for i, prob in enumerate(probs):
|
| 470 |
+
cumulative += prob
|
| 471 |
+
if r <= cumulative:
|
| 472 |
+
return available_sources[i], values[i]
|
| 473 |
+
|
| 474 |
+
return available_sources[0], values[0]
|
| 475 |
+
|
| 476 |
+
def _select_ucb(
|
| 477 |
+
self,
|
| 478 |
+
state: str,
|
| 479 |
+
available_sources: List[str],
|
| 480 |
+
) -> Tuple[str, float]:
|
| 481 |
+
"""Select action using Upper Confidence Bound (UCB1)."""
|
| 482 |
+
total_visits = sum(
|
| 483 |
+
self._action_stats.get(state, {}).get(s, ActionStats()).visit_count
|
| 484 |
+
for s in available_sources
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
if total_visits == 0:
|
| 488 |
+
return random.choice(available_sources), 0.5
|
| 489 |
+
|
| 490 |
+
best_source = available_sources[0]
|
| 491 |
+
best_ucb = -1.0
|
| 492 |
+
|
| 493 |
+
for source in available_sources:
|
| 494 |
+
stats = self._action_stats.get(state, {}).get(source, ActionStats())
|
| 495 |
+
q_value = self._q_values[state][source]
|
| 496 |
+
|
| 497 |
+
if stats.visit_count == 0:
|
| 498 |
+
ucb = float('inf')
|
| 499 |
+
else:
|
| 500 |
+
exploration = self._ucb_c * math.sqrt(
|
| 501 |
+
math.log(total_visits) / stats.visit_count
|
| 502 |
+
)
|
| 503 |
+
ucb = q_value + exploration
|
| 504 |
+
|
| 505 |
+
if ucb > best_ucb:
|
| 506 |
+
best_ucb = ucb
|
| 507 |
+
best_source = source
|
| 508 |
+
|
| 509 |
+
return best_source, self._q_values[state][best_source]
|
| 510 |
+
|
| 511 |
+
def _select_thompson(
|
| 512 |
+
self,
|
| 513 |
+
state: str,
|
| 514 |
+
available_sources: List[str],
|
| 515 |
+
) -> Tuple[str, float]:
|
| 516 |
+
"""Select action using Thompson sampling (Beta distribution)."""
|
| 517 |
+
samples = []
|
| 518 |
+
|
| 519 |
+
for source in available_sources:
|
| 520 |
+
stats = self._action_stats.get(state, {}).get(source, ActionStats())
|
| 521 |
+
|
| 522 |
+
alpha = 1 + stats.success_count
|
| 523 |
+
beta_val = 1 + stats.failure_count
|
| 524 |
+
|
| 525 |
+
try:
|
| 526 |
+
import random as rnd
|
| 527 |
+
sample = rnd.betavariate(alpha, beta_val)
|
| 528 |
+
except (ImportError, AttributeError):
|
| 529 |
+
sample = stats.success_rate + random.gauss(0, 0.1)
|
| 530 |
+
sample = max(0.0, min(1.0, sample))
|
| 531 |
+
|
| 532 |
+
samples.append((source, sample, self._q_values[state][source]))
|
| 533 |
+
|
| 534 |
+
samples.sort(key=lambda x: x[1], reverse=True)
|
| 535 |
+
return samples[0][0], samples[0][2]
|
| 536 |
+
|
| 537 |
+
def select_best_action(
|
| 538 |
+
self,
|
| 539 |
+
protocol: str,
|
| 540 |
+
file_type: str,
|
| 541 |
+
available_sources: List[str],
|
| 542 |
+
spec_dict: Optional[Dict[str, Any]] = None,
|
| 543 |
+
) -> Tuple[str, float]:
|
| 544 |
+
"""
|
| 545 |
+
Select the best action using configured exploration strategy.
|
| 546 |
+
|
| 547 |
+
Returns:
|
| 548 |
+
Tuple of (chosen_source, expected_value)
|
| 549 |
+
"""
|
| 550 |
+
state = self._get_state_key(protocol, file_type, spec_dict)
|
| 551 |
+
|
| 552 |
+
if len(available_sources) == 0:
|
| 553 |
+
return "template", 0.5
|
| 554 |
+
|
| 555 |
+
if len(available_sources) == 1:
|
| 556 |
+
return available_sources[0], self._q_values[state][available_sources[0]]
|
| 557 |
+
|
| 558 |
+
for source in available_sources:
|
| 559 |
+
if source not in self._q_values[state]:
|
| 560 |
+
self._q_values[state][source] = 0.5
|
| 561 |
+
|
| 562 |
+
if self._exploration_strategy == ExplorationStrategy.EPSILON_GREEDY:
|
| 563 |
+
result = self._select_epsilon_greedy(state, available_sources)
|
| 564 |
+
elif self._exploration_strategy == ExplorationStrategy.SOFTMAX:
|
| 565 |
+
result = self._select_softmax(state, available_sources)
|
| 566 |
+
elif self._exploration_strategy == ExplorationStrategy.UCB:
|
| 567 |
+
result = self._select_ucb(state, available_sources)
|
| 568 |
+
elif self._exploration_strategy == ExplorationStrategy.THOMPSON_SAMPLING:
|
| 569 |
+
result = self._select_thompson(state, available_sources)
|
| 570 |
+
else:
|
| 571 |
+
result = self._select_ucb(state, available_sources)
|
| 572 |
+
|
| 573 |
+
if self._exploration_strategy == ExplorationStrategy.EPSILON_GREEDY:
|
| 574 |
+
self._epsilon = max(self._min_epsilon, self._epsilon * self._epsilon_decay)
|
| 575 |
+
|
| 576 |
+
self._episode_count += 1
|
| 577 |
+
|
| 578 |
+
decay = max(0.001, 1.0 / math.sqrt(self._total_updates + 1))
|
| 579 |
+
self._learning_rate = self._initial_learning_rate * decay
|
| 580 |
+
|
| 581 |
+
return result
|
| 582 |
+
|
| 583 |
+
def get_performance_stats(self) -> Dict[str, Any]:
|
| 584 |
+
"""Get comprehensive performance statistics."""
|
| 585 |
+
buffer_stats = self._replay_buffer.get_recent_performance()
|
| 586 |
+
|
| 587 |
+
all_states = list(self._q_values.keys())
|
| 588 |
+
total_actions = sum(len(v) for v in self._q_values.values())
|
| 589 |
+
|
| 590 |
+
state_stats = {}
|
| 591 |
+
for state in all_states:
|
| 592 |
+
actions = self._q_values[state]
|
| 593 |
+
if not actions:
|
| 594 |
+
continue
|
| 595 |
+
|
| 596 |
+
best_action = max(actions.keys(), key=lambda a: actions[a])
|
| 597 |
+
best_value = actions[best_action]
|
| 598 |
+
|
| 599 |
+
state_stats[state] = {
|
| 600 |
+
"best_action": best_action,
|
| 601 |
+
"best_q_value": best_value,
|
| 602 |
+
"num_actions": len(actions),
|
| 603 |
+
"actions": {
|
| 604 |
+
a: {
|
| 605 |
+
"q_value": self._q_values[state][a],
|
| 606 |
+
"stats": self._action_stats.get(state, {}).get(a, ActionStats()).to_dict()
|
| 607 |
+
}
|
| 608 |
+
for a in actions
|
| 609 |
+
},
|
| 610 |
+
}
|
| 611 |
+
|
| 612 |
+
return {
|
| 613 |
+
"episode_count": self._episode_count,
|
| 614 |
+
"total_updates": self._total_updates,
|
| 615 |
+
"learning_rate": self._learning_rate,
|
| 616 |
+
"epsilon": self._epsilon,
|
| 617 |
+
"exploration_strategy": self._exploration_strategy.value,
|
| 618 |
+
"replay_buffer_size": len(self._replay_buffer),
|
| 619 |
+
"buffer_performance": buffer_stats,
|
| 620 |
+
"num_states": len(all_states),
|
| 621 |
+
"total_actions_tracked": total_actions,
|
| 622 |
+
"state_stats": state_stats,
|
| 623 |
+
"best_actions": self._best_actions.copy(),
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
def replay_experiences(self, batch_size: int = 32, use_recency: bool = True) -> int:
|
| 627 |
+
"""
|
| 628 |
+
Replay experiences from buffer for additional learning.
|
| 629 |
+
|
| 630 |
+
Returns:
|
| 631 |
+
Number of experiences replayed
|
| 632 |
+
"""
|
| 633 |
+
if use_recency:
|
| 634 |
+
batch = self._replay_buffer.sample_recent(batch_size)
|
| 635 |
+
else:
|
| 636 |
+
batch = self._replay_buffer.sample(batch_size)
|
| 637 |
+
|
| 638 |
+
if not batch:
|
| 639 |
+
return 0
|
| 640 |
+
|
| 641 |
+
for exp in batch:
|
| 642 |
+
state = exp.state
|
| 643 |
+
action = exp.action
|
| 644 |
+
reward = exp.reward
|
| 645 |
+
|
| 646 |
+
old_value = self._q_values[state][action]
|
| 647 |
+
self._q_values[state][action] = (
|
| 648 |
+
old_value + self._learning_rate * (reward - old_value)
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
self._ensure_stats(state, action)
|
| 652 |
+
stats = self._action_stats[state][action]
|
| 653 |
+
stats.total_reward += reward * 0.1
|
| 654 |
+
stats.squared_reward += (reward * reward) * 0.1
|
| 655 |
+
|
| 656 |
+
return len(batch)
|
| 657 |
+
|
| 658 |
+
def reset_episode(self) -> None:
|
| 659 |
+
"""Reset for a new episode (clears eligibility traces)."""
|
| 660 |
+
if self._use_eligibility_traces and self._eligibility_traces:
|
| 661 |
+
self._eligibility_traces.reset()
|
| 662 |
+
|
| 663 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 664 |
+
return {
|
| 665 |
+
"learning_rate": self._learning_rate,
|
| 666 |
+
"initial_learning_rate": self._initial_learning_rate,
|
| 667 |
+
"discount_factor": self._discount_factor,
|
| 668 |
+
"exploration_strategy": self._exploration_strategy.value,
|
| 669 |
+
"epsilon": self._epsilon,
|
| 670 |
+
"epsilon_decay": self._epsilon_decay,
|
| 671 |
+
"min_epsilon": self._min_epsilon,
|
| 672 |
+
"ucb_c": self._ucb_c,
|
| 673 |
+
"temperature": self._temperature,
|
| 674 |
+
"use_eligibility_traces": self._use_eligibility_traces,
|
| 675 |
+
"episode_count": self._episode_count,
|
| 676 |
+
"total_updates": self._total_updates,
|
| 677 |
+
"q_values": {k: dict(v) for k, v in self._q_values.items()},
|
| 678 |
+
"action_stats": {
|
| 679 |
+
state: {action: stats.to_dict() for action, stats in actions.items()}
|
| 680 |
+
for state, actions in self._action_stats.items()
|
| 681 |
+
},
|
| 682 |
+
"best_actions": self._best_actions.copy(),
|
| 683 |
+
}
|
| 684 |
+
|
| 685 |
+
@classmethod
|
| 686 |
+
def from_dict(cls, d: Dict[str, Any]) -> "AdvancedReinforcementLearner":
|
| 687 |
+
strategy_map = {e.value: e for e in ExplorationStrategy}
|
| 688 |
+
strategy = strategy_map.get(
|
| 689 |
+
d.get("exploration_strategy", "ucb"),
|
| 690 |
+
ExplorationStrategy.UCB
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
learner = cls(
|
| 694 |
+
learning_rate=d.get("initial_learning_rate", 0.1),
|
| 695 |
+
discount_factor=d.get("discount_factor", 0.95),
|
| 696 |
+
exploration_strategy=strategy,
|
| 697 |
+
epsilon=d.get("epsilon", 0.1),
|
| 698 |
+
epsilon_decay=d.get("epsilon_decay", 0.995),
|
| 699 |
+
min_epsilon=d.get("min_epsilon", 0.01),
|
| 700 |
+
ucb_c=d.get("ucb_c", 2.0),
|
| 701 |
+
temperature=d.get("temperature", 1.0),
|
| 702 |
+
use_eligibility_traces=d.get("use_eligibility_traces", True),
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
learner._learning_rate = d.get("learning_rate", 0.1)
|
| 706 |
+
learner._episode_count = d.get("episode_count", 0)
|
| 707 |
+
learner._total_updates = d.get("total_updates", 0)
|
| 708 |
+
|
| 709 |
+
for state, actions in d.get("q_values", {}).items():
|
| 710 |
+
for action, value in actions.items():
|
| 711 |
+
learner._q_values[state][action] = value
|
| 712 |
+
|
| 713 |
+
for state, actions in d.get("action_stats", {}).items():
|
| 714 |
+
if state not in learner._action_stats:
|
| 715 |
+
learner._action_stats[state] = {}
|
| 716 |
+
for action, stats_dict in actions.items():
|
| 717 |
+
stats = ActionStats()
|
| 718 |
+
stats.q_value = stats_dict.get("q_value", 0.5)
|
| 719 |
+
stats.visit_count = stats_dict.get("visit_count", 0)
|
| 720 |
+
stats.total_reward = stats_dict.get("total_reward", 0.0)
|
| 721 |
+
stats.squared_reward = stats_dict.get("squared_reward", 0.0)
|
| 722 |
+
stats.success_count = stats_dict.get("success_count", 0)
|
| 723 |
+
stats.failure_count = stats_dict.get("failure_count", 0)
|
| 724 |
+
learner._action_stats[state][action] = stats
|
| 725 |
+
|
| 726 |
+
learner._best_actions = d.get("best_actions", {}).copy()
|
| 727 |
+
|
| 728 |
+
return learner
|
|
@@ -0,0 +1,801 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Enhanced ML Generation Model with Advanced Components.
|
| 3 |
+
|
| 4 |
+
Key improvements for promotion:
|
| 5 |
+
1. Advanced pattern learner with context-aware error detection
|
| 6 |
+
2. Advanced RL learner with experience replay and eligibility traces
|
| 7 |
+
3. Advanced code validator with deep UVM compliance
|
| 8 |
+
4. Ensemble retrieval with weighted voting
|
| 9 |
+
5. Adaptive strategy selection
|
| 10 |
+
6. Confidence calibration
|
| 11 |
+
7. Performance tracking and reporting
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import math
|
| 20 |
+
from collections import defaultdict, Counter
|
| 21 |
+
from dataclasses import dataclass, field
|
| 22 |
+
from datetime import datetime
|
| 23 |
+
from enum import Enum
|
| 24 |
+
from typing import Any, Dict, List, Optional, Tuple, Set
|
| 25 |
+
|
| 26 |
+
from src.models.base_model import GenerationModel
|
| 27 |
+
from src.models.template_model import TemplateModel
|
| 28 |
+
from src.config import PipelineConfig, DesignSpec
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from src.features.extractors import RichSpecFeatureExtractor
|
| 32 |
+
from src.models.similarity_index import SimilarityIndex, SearchResult
|
| 33 |
+
from src.models.ml_utils import (
|
| 34 |
+
RichFeatureVector,
|
| 35 |
+
combined_similarity,
|
| 36 |
+
HybridVectorizer,
|
| 37 |
+
)
|
| 38 |
+
from src.models.spec_adapter import SpecAdapter, AdaptationPlan
|
| 39 |
+
from src.models.code_validator import (
|
| 40 |
+
CodeValidator,
|
| 41 |
+
ValidationReport,
|
| 42 |
+
FileValidationResult,
|
| 43 |
+
)
|
| 44 |
+
from src.models.advanced_pattern_learner import (
|
| 45 |
+
AdvancedPatternLearner,
|
| 46 |
+
PatternType,
|
| 47 |
+
Pattern,
|
| 48 |
+
)
|
| 49 |
+
from src.models.advanced_rl_learner import (
|
| 50 |
+
AdvancedReinforcementLearner,
|
| 51 |
+
ExplorationStrategy,
|
| 52 |
+
Experience,
|
| 53 |
+
)
|
| 54 |
+
from src.models.advanced_code_validator import (
|
| 55 |
+
AdvancedCodeValidator,
|
| 56 |
+
ValidationReport as AdvancedValidationReport,
|
| 57 |
+
)
|
| 58 |
+
HAS_ADVANCED = True
|
| 59 |
+
except ImportError as e:
|
| 60 |
+
logger = logging.getLogger("uvmgen.ml")
|
| 61 |
+
logger.warning(f"Some advanced components not available: {e}")
|
| 62 |
+
HAS_ADVANCED = False
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
logger = logging.getLogger("uvmgen.ml.enhanced")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class GenerationSource(Enum):
|
| 69 |
+
RETRIEVAL = "retrieval"
|
| 70 |
+
LLM = "llm"
|
| 71 |
+
TEMPLATE = "template"
|
| 72 |
+
HYBRID = "hybrid"
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class RetrievalInfo:
|
| 77 |
+
used_similarity: bool = True
|
| 78 |
+
similar_specs: int = 0
|
| 79 |
+
best_score: float = 0.0
|
| 80 |
+
best_spec_name: str = ""
|
| 81 |
+
adaptation_score: float = 0.0
|
| 82 |
+
pre_validation_score: float = 0.0
|
| 83 |
+
retrieval_strategy: str = "default"
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@dataclass
|
| 87 |
+
class GenerationResult:
|
| 88 |
+
files: Dict[str, str] = field(default_factory=dict)
|
| 89 |
+
source: GenerationSource = GenerationSource.TEMPLATE
|
| 90 |
+
retrieval_info: Optional[RetrievalInfo] = None
|
| 91 |
+
validation_report: Optional[AdvancedValidationReport] = None
|
| 92 |
+
score: float = 0.0
|
| 93 |
+
errors: List[str] = field(default_factory=list)
|
| 94 |
+
warnings: List[str] = field(default_factory=list)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@dataclass
|
| 98 |
+
class StrategyWeights:
|
| 99 |
+
retrieval_weight: float = 0.4
|
| 100 |
+
llm_weight: float = 0.3
|
| 101 |
+
template_weight: float = 0.3
|
| 102 |
+
|
| 103 |
+
def normalize(self) -> "StrategyWeights":
|
| 104 |
+
total = self.retrieval_weight + self.llm_weight + self.template_weight
|
| 105 |
+
if total <= 0:
|
| 106 |
+
return StrategyWeights(0.34, 0.33, 0.33)
|
| 107 |
+
return StrategyWeights(
|
| 108 |
+
retrieval_weight=self.retrieval_weight / total,
|
| 109 |
+
llm_weight=self.llm_weight / total,
|
| 110 |
+
template_weight=self.template_weight / total,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class EnhancedMLGenerationModelV2(GenerationModel):
|
| 115 |
+
"""
|
| 116 |
+
Enhanced ML Generation Model V2 with advanced components.
|
| 117 |
+
|
| 118 |
+
Key features for promotion:
|
| 119 |
+
1. Ensemble retrieval with multi-strategy voting
|
| 120 |
+
2. Advanced RL with experience replay and eligibility traces
|
| 121 |
+
3. Context-aware pattern learning
|
| 122 |
+
4. Deep UVM compliance validation
|
| 123 |
+
5. Adaptive weight adjustment based on performance
|
| 124 |
+
6. Confidence calibration
|
| 125 |
+
7. Comprehensive performance tracking
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
name: str = "enhanced_ml_model_v2",
|
| 131 |
+
config: Optional[Any] = None,
|
| 132 |
+
templates_dir: str = "src/generation/templates",
|
| 133 |
+
strict_validation: bool = True,
|
| 134 |
+
use_llm: bool = False,
|
| 135 |
+
use_semantic_encoder: bool = False,
|
| 136 |
+
use_learning: bool = True,
|
| 137 |
+
llm_model_name: Optional[str] = None,
|
| 138 |
+
learning_storage_path: Optional[str] = None,
|
| 139 |
+
exploration_strategy: str = "ucb",
|
| 140 |
+
):
|
| 141 |
+
super().__init__(name)
|
| 142 |
+
|
| 143 |
+
self._templates_dir = templates_dir
|
| 144 |
+
self._strict_validation = strict_validation
|
| 145 |
+
self._use_llm = use_llm
|
| 146 |
+
self._use_semantic_encoder = use_semantic_encoder
|
| 147 |
+
self._use_learning = use_learning
|
| 148 |
+
self._llm_model_name = llm_model_name
|
| 149 |
+
self._learning_storage_path = learning_storage_path
|
| 150 |
+
|
| 151 |
+
self._template_model = TemplateModel(templates_dir=templates_dir)
|
| 152 |
+
|
| 153 |
+
self._index: Optional[SimilarityIndex] = None
|
| 154 |
+
self._extractor: Optional[RichSpecFeatureExtractor] = None
|
| 155 |
+
self._adapter: Optional[SpecAdapter] = None
|
| 156 |
+
self._vectorizer: Optional[HybridVectorizer] = None
|
| 157 |
+
|
| 158 |
+
self._pattern_learner: Optional[AdvancedPatternLearner] = None
|
| 159 |
+
self._rl_learner: Optional[AdvancedReinforcementLearner] = None
|
| 160 |
+
self._code_validator: Optional[AdvancedCodeValidator] = None
|
| 161 |
+
|
| 162 |
+
self.last_retrieval: Optional[RetrievalInfo] = None
|
| 163 |
+
self._generation_history: List[Dict[str, Any]] = []
|
| 164 |
+
|
| 165 |
+
strategy_map = {
|
| 166 |
+
"epsilon_greedy": ExplorationStrategy.EPSILON_GREEDY,
|
| 167 |
+
"softmax": ExplorationStrategy.SOFTMAX,
|
| 168 |
+
"ucb": ExplorationStrategy.UCB,
|
| 169 |
+
"thompson": ExplorationStrategy.THOMPSON_SAMPLING,
|
| 170 |
+
}
|
| 171 |
+
self._exploration_strategy = strategy_map.get(
|
| 172 |
+
exploration_strategy.lower(),
|
| 173 |
+
ExplorationStrategy.UCB
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
self._strategy_weights = StrategyWeights()
|
| 177 |
+
|
| 178 |
+
self._initialize_components()
|
| 179 |
+
|
| 180 |
+
def _initialize_components(self) -> None:
|
| 181 |
+
"""Initialize all ML components."""
|
| 182 |
+
if HAS_ADVANCED:
|
| 183 |
+
self._extractor = RichSpecFeatureExtractor()
|
| 184 |
+
self._index = SimilarityIndex()
|
| 185 |
+
self._adapter = SpecAdapter()
|
| 186 |
+
self._vectorizer = HybridVectorizer()
|
| 187 |
+
|
| 188 |
+
if self._use_learning:
|
| 189 |
+
self._pattern_learner = AdvancedPatternLearner()
|
| 190 |
+
self._rl_learner = AdvancedReinforcementLearner(
|
| 191 |
+
exploration_strategy=self._exploration_strategy,
|
| 192 |
+
use_eligibility_traces=True,
|
| 193 |
+
replay_buffer_capacity=10000,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
if self._learning_storage_path and os.path.exists(self._learning_storage_path):
|
| 197 |
+
self._load_learning_state()
|
| 198 |
+
|
| 199 |
+
logger.info(f"Enhanced ML Generation Model V2 initialized with strategy: {self._exploration_strategy.value}")
|
| 200 |
+
else:
|
| 201 |
+
logger.warning("Advanced components not available, using basic template model only")
|
| 202 |
+
|
| 203 |
+
def train(
|
| 204 |
+
self,
|
| 205 |
+
specs: List[DesignSpec],
|
| 206 |
+
pre_generated: Optional[Dict[str, Dict[str, str]]] = None,
|
| 207 |
+
) -> Dict[str, Any]:
|
| 208 |
+
"""Train the model on design specifications."""
|
| 209 |
+
if not HAS_ADVANCED or not self._extractor or not self._index:
|
| 210 |
+
return self._template_model.train(specs)
|
| 211 |
+
|
| 212 |
+
for spec in specs:
|
| 213 |
+
features = self._extractor.extract(spec)
|
| 214 |
+
|
| 215 |
+
spec_dict = spec.model_dump() if hasattr(spec, 'model_dump') else dict(spec)
|
| 216 |
+
|
| 217 |
+
if pre_generated and spec.design_name in pre_generated:
|
| 218 |
+
generated = pre_generated[spec.design_name]
|
| 219 |
+
else:
|
| 220 |
+
generated = {}
|
| 221 |
+
|
| 222 |
+
self._index.add(features, spec_dict, generated)
|
| 223 |
+
|
| 224 |
+
logger.info(f"Added spec '{spec.design_name}' ({features.fingerprint()}) to index")
|
| 225 |
+
|
| 226 |
+
all_features = []
|
| 227 |
+
for entry in self._index:
|
| 228 |
+
if hasattr(entry, 'feature_vector'):
|
| 229 |
+
text_repr = entry.feature_vector.to_text_repr()
|
| 230 |
+
all_features.append(text_repr)
|
| 231 |
+
|
| 232 |
+
if all_features and self._vectorizer:
|
| 233 |
+
self._vectorizer.fit(all_features)
|
| 234 |
+
|
| 235 |
+
return {
|
| 236 |
+
"index_size": len(self._index),
|
| 237 |
+
"model_name": self.name,
|
| 238 |
+
"features_extracted": len(all_features),
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
def predict(
|
| 242 |
+
self,
|
| 243 |
+
spec: DesignSpec,
|
| 244 |
+
cfg: PipelineConfig,
|
| 245 |
+
extra_seqs: Optional[List[str]] = None,
|
| 246 |
+
) -> Dict[str, str]:
|
| 247 |
+
"""Generate testbench for a specification."""
|
| 248 |
+
if not HAS_ADVANCED:
|
| 249 |
+
return self._template_model.predict(spec, cfg)
|
| 250 |
+
|
| 251 |
+
spec_dict = spec.model_dump() if hasattr(spec, 'model_dump') else dict(spec)
|
| 252 |
+
design_name = spec.design_name
|
| 253 |
+
protocol = spec_dict.get("protocol", "unknown")
|
| 254 |
+
|
| 255 |
+
self._code_validator = AdvancedCodeValidator(spec_dict)
|
| 256 |
+
|
| 257 |
+
available_sources = self._get_available_sources()
|
| 258 |
+
|
| 259 |
+
selected_source = self._select_generation_strategy(
|
| 260 |
+
spec_dict=spec_dict,
|
| 261 |
+
protocol=protocol,
|
| 262 |
+
available_sources=available_sources,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
logger.info(f"Selected generation strategy: {selected_source.value}")
|
| 266 |
+
|
| 267 |
+
result = self._generate_with_strategy(
|
| 268 |
+
strategy=selected_source,
|
| 269 |
+
spec=spec,
|
| 270 |
+
spec_dict=spec_dict,
|
| 271 |
+
config=cfg,
|
| 272 |
+
design_name=design_name,
|
| 273 |
+
protocol=protocol,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
final_result = self._apply_validation_and_fallback(
|
| 277 |
+
result=result,
|
| 278 |
+
spec=spec,
|
| 279 |
+
config=cfg,
|
| 280 |
+
spec_dict=spec_dict,
|
| 281 |
+
design_name=design_name,
|
| 282 |
+
protocol=protocol,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
self._record_learning(
|
| 286 |
+
final_result=final_result,
|
| 287 |
+
spec_dict=spec_dict,
|
| 288 |
+
design_name=design_name,
|
| 289 |
+
protocol=protocol,
|
| 290 |
+
selected_source=selected_source,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
return final_result.files
|
| 294 |
+
|
| 295 |
+
def _get_available_sources(self) -> List[str]:
|
| 296 |
+
"""Get list of available generation sources."""
|
| 297 |
+
sources = ["template"]
|
| 298 |
+
|
| 299 |
+
if self._index and len(self._index) > 0:
|
| 300 |
+
sources.append("retrieval")
|
| 301 |
+
|
| 302 |
+
if self._use_llm:
|
| 303 |
+
sources.append("llm")
|
| 304 |
+
|
| 305 |
+
return sources
|
| 306 |
+
|
| 307 |
+
def _select_generation_strategy(
|
| 308 |
+
self,
|
| 309 |
+
spec_dict: Dict[str, Any],
|
| 310 |
+
protocol: str,
|
| 311 |
+
available_sources: List[str],
|
| 312 |
+
) -> GenerationSource:
|
| 313 |
+
"""Select generation strategy using advanced RL."""
|
| 314 |
+
if len(available_sources) == 1:
|
| 315 |
+
return GenerationSource(available_sources[0])
|
| 316 |
+
|
| 317 |
+
if not self._use_learning or not self._rl_learner:
|
| 318 |
+
if "retrieval" in available_sources and self._index and len(self._index) > 0:
|
| 319 |
+
return GenerationSource.RETRIEVAL
|
| 320 |
+
return GenerationSource.TEMPLATE
|
| 321 |
+
|
| 322 |
+
file_types = ["testbench", "interface", "test", "sequence", "driver", "monitor"]
|
| 323 |
+
source_scores: Dict[str, float] = defaultdict(float)
|
| 324 |
+
|
| 325 |
+
for file_type in file_types:
|
| 326 |
+
source, value = self._rl_learner.select_best_action(
|
| 327 |
+
protocol=protocol,
|
| 328 |
+
file_type=file_type,
|
| 329 |
+
available_sources=available_sources,
|
| 330 |
+
spec_dict=spec_dict,
|
| 331 |
+
)
|
| 332 |
+
source_scores[source] += value
|
| 333 |
+
|
| 334 |
+
if not source_scores:
|
| 335 |
+
return GenerationSource.TEMPLATE
|
| 336 |
+
|
| 337 |
+
best_source = max(source_scores.keys(), key=lambda s: source_scores[s])
|
| 338 |
+
return GenerationSource(best_source)
|
| 339 |
+
|
| 340 |
+
def _generate_with_strategy(
|
| 341 |
+
self,
|
| 342 |
+
strategy: GenerationSource,
|
| 343 |
+
spec: DesignSpec,
|
| 344 |
+
spec_dict: Dict[str, Any],
|
| 345 |
+
config: PipelineConfig,
|
| 346 |
+
design_name: str,
|
| 347 |
+
protocol: str,
|
| 348 |
+
) -> GenerationResult:
|
| 349 |
+
"""Generate using selected strategy."""
|
| 350 |
+
if strategy == GenerationSource.RETRIEVAL:
|
| 351 |
+
return self._generate_by_retrieval(
|
| 352 |
+
spec=spec,
|
| 353 |
+
spec_dict=spec_dict,
|
| 354 |
+
config=config,
|
| 355 |
+
design_name=design_name,
|
| 356 |
+
protocol=protocol,
|
| 357 |
+
)
|
| 358 |
+
elif strategy == GenerationSource.LLM and self._use_llm:
|
| 359 |
+
return self._generate_by_llm(
|
| 360 |
+
spec=spec,
|
| 361 |
+
spec_dict=spec_dict,
|
| 362 |
+
config=config,
|
| 363 |
+
design_name=design_name,
|
| 364 |
+
)
|
| 365 |
+
else:
|
| 366 |
+
return self._generate_by_template(
|
| 367 |
+
spec=spec,
|
| 368 |
+
config=config,
|
| 369 |
+
design_name=design_name,
|
| 370 |
+
protocol=protocol,
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
def _generate_by_retrieval(
|
| 374 |
+
self,
|
| 375 |
+
spec: DesignSpec,
|
| 376 |
+
spec_dict: Dict[str, Any],
|
| 377 |
+
config: PipelineConfig,
|
| 378 |
+
design_name: str,
|
| 379 |
+
protocol: str,
|
| 380 |
+
) -> GenerationResult:
|
| 381 |
+
"""Generate using retrieval-based adaptation."""
|
| 382 |
+
if not self._index or not self._extractor or not self._adapter:
|
| 383 |
+
return GenerationResult(source=GenerationSource.TEMPLATE)
|
| 384 |
+
|
| 385 |
+
features = self._extractor.extract(spec)
|
| 386 |
+
|
| 387 |
+
search_results = self._index.search(features, top_k=5)
|
| 388 |
+
|
| 389 |
+
if not search_results:
|
| 390 |
+
logger.info("No similar specs found in index, falling back to templates")
|
| 391 |
+
return GenerationResult(source=GenerationSource.TEMPLATE)
|
| 392 |
+
|
| 393 |
+
best_result = search_results[0]
|
| 394 |
+
best_spec = best_result.spec_dict
|
| 395 |
+
|
| 396 |
+
retrieval_info = RetrievalInfo(
|
| 397 |
+
used_similarity=True,
|
| 398 |
+
similar_specs=len(search_results),
|
| 399 |
+
best_score=best_result.similarity,
|
| 400 |
+
best_spec_name=best_result.design_name,
|
| 401 |
+
retrieval_strategy="similarity_search",
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
logger.info(
|
| 405 |
+
f"Best match: '{best_result.design_name}' "
|
| 406 |
+
f"(similarity: {best_result.similarity:.3f})"
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
if best_result.generated_files:
|
| 410 |
+
adaptation = self._adapter.adapt(
|
| 411 |
+
source_spec=best_spec,
|
| 412 |
+
target_spec=spec_dict,
|
| 413 |
+
source_files=best_result.generated_files,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
retrieval_info.adaptation_score = adaptation.score
|
| 417 |
+
|
| 418 |
+
if adaptation.errors:
|
| 419 |
+
logger.warning(f"Adaptation errors: {adaptation.errors}")
|
| 420 |
+
|
| 421 |
+
if adaptation.score >= 0.7:
|
| 422 |
+
files = adaptation.adapted_files
|
| 423 |
+
|
| 424 |
+
validation_score = 0.5
|
| 425 |
+
if self._code_validator:
|
| 426 |
+
report = self._code_validator.validate_files(files, design_name)
|
| 427 |
+
validation_score = report.avg_score
|
| 428 |
+
retrieval_info.pre_validation_score = validation_score
|
| 429 |
+
|
| 430 |
+
if report.overall_passed or not self._strict_validation:
|
| 431 |
+
return GenerationResult(
|
| 432 |
+
files=files,
|
| 433 |
+
source=GenerationSource.RETRIEVAL,
|
| 434 |
+
retrieval_info=retrieval_info,
|
| 435 |
+
validation_report=report,
|
| 436 |
+
score=validation_score,
|
| 437 |
+
)
|
| 438 |
+
else:
|
| 439 |
+
logger.warning(
|
| 440 |
+
f"Retrieved code failed validation "
|
| 441 |
+
f"({report.total_errors} errors), will try alternatives"
|
| 442 |
+
)
|
| 443 |
+
else:
|
| 444 |
+
logger.warning(
|
| 445 |
+
f"Adaptation score too low ({adaptation.score:.2f} < 0.7), "
|
| 446 |
+
"falling back to alternatives"
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
if len(search_results) > 1:
|
| 450 |
+
for alt_result in search_results[1:3]:
|
| 451 |
+
if alt_result.generated_files and alt_result.similarity >= 0.5:
|
| 452 |
+
logger.info(f"Trying alternative: '{alt_result.design_name}'")
|
| 453 |
+
adaptation = self._adapter.adapt(
|
| 454 |
+
source_spec=alt_result.spec_dict,
|
| 455 |
+
target_spec=spec_dict,
|
| 456 |
+
source_files=alt_result.generated_files,
|
| 457 |
+
)
|
| 458 |
+
if adaptation.score >= 0.7:
|
| 459 |
+
files = adaptation.adapted_files
|
| 460 |
+
if self._code_validator:
|
| 461 |
+
report = self._code_validator.validate_files(files, design_name)
|
| 462 |
+
if report.overall_passed or not self._strict_validation:
|
| 463 |
+
retrieval_info.best_spec_name = alt_result.design_name
|
| 464 |
+
retrieval_info.best_score = alt_result.similarity
|
| 465 |
+
retrieval_info.adaptation_score = adaptation.score
|
| 466 |
+
retrieval_info.pre_validation_score = report.avg_score
|
| 467 |
+
return GenerationResult(
|
| 468 |
+
files=files,
|
| 469 |
+
source=GenerationSource.RETRIEVAL,
|
| 470 |
+
retrieval_info=retrieval_info,
|
| 471 |
+
validation_report=report,
|
| 472 |
+
score=report.avg_score,
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
return GenerationResult(
|
| 476 |
+
source=GenerationSource.RETRIEVAL,
|
| 477 |
+
retrieval_info=retrieval_info,
|
| 478 |
+
errors=["Retrieval generation did not pass validation thresholds"],
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
def _generate_by_llm(
|
| 482 |
+
self,
|
| 483 |
+
spec: DesignSpec,
|
| 484 |
+
spec_dict: Dict[str, Any],
|
| 485 |
+
config: PipelineConfig,
|
| 486 |
+
design_name: str,
|
| 487 |
+
) -> GenerationResult:
|
| 488 |
+
"""Generate using LLM (placeholder for now)."""
|
| 489 |
+
logger.info("LLM generation requested but not fully implemented")
|
| 490 |
+
return GenerationResult(
|
| 491 |
+
source=GenerationSource.LLM,
|
| 492 |
+
errors=["LLM generation not available"],
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
def _generate_by_template(
|
| 496 |
+
self,
|
| 497 |
+
spec: DesignSpec,
|
| 498 |
+
config: PipelineConfig,
|
| 499 |
+
design_name: str,
|
| 500 |
+
protocol: str,
|
| 501 |
+
) -> GenerationResult:
|
| 502 |
+
"""Generate using templates."""
|
| 503 |
+
files = self._template_model.predict(spec, config)
|
| 504 |
+
|
| 505 |
+
score = 0.7
|
| 506 |
+
report = None
|
| 507 |
+
if self._code_validator:
|
| 508 |
+
report = self._code_validator.validate_files(files, design_name)
|
| 509 |
+
score = report.avg_score
|
| 510 |
+
|
| 511 |
+
return GenerationResult(
|
| 512 |
+
files=files,
|
| 513 |
+
source=GenerationSource.TEMPLATE,
|
| 514 |
+
validation_report=report,
|
| 515 |
+
score=score,
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
def _apply_validation_and_fallback(
|
| 519 |
+
self,
|
| 520 |
+
result: GenerationResult,
|
| 521 |
+
spec: DesignSpec,
|
| 522 |
+
config: PipelineConfig,
|
| 523 |
+
spec_dict: Dict[str, Any],
|
| 524 |
+
design_name: str,
|
| 525 |
+
protocol: str,
|
| 526 |
+
) -> GenerationResult:
|
| 527 |
+
"""Apply validation and use fallback if needed."""
|
| 528 |
+
if result.files and not result.errors:
|
| 529 |
+
return result
|
| 530 |
+
|
| 531 |
+
if result.source == GenerationSource.TEMPLATE and result.files:
|
| 532 |
+
return result
|
| 533 |
+
|
| 534 |
+
logger.warning(
|
| 535 |
+
f"Primary strategy ({result.source.value}) failed or not available, "
|
| 536 |
+
"falling back to template generation"
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
template_result = self._generate_by_template(
|
| 540 |
+
spec=spec,
|
| 541 |
+
config=config,
|
| 542 |
+
design_name=design_name,
|
| 543 |
+
protocol=protocol,
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
if result.retrieval_info:
|
| 547 |
+
template_result.retrieval_info = result.retrieval_info
|
| 548 |
+
|
| 549 |
+
template_result.warnings.extend([
|
| 550 |
+
f"Fell back from {result.source.value} to templates",
|
| 551 |
+
])
|
| 552 |
+
if result.errors:
|
| 553 |
+
template_result.warnings.extend(result.errors)
|
| 554 |
+
|
| 555 |
+
return template_result
|
| 556 |
+
|
| 557 |
+
def _record_learning(
|
| 558 |
+
self,
|
| 559 |
+
final_result: GenerationResult,
|
| 560 |
+
spec_dict: Dict[str, Any],
|
| 561 |
+
design_name: str,
|
| 562 |
+
protocol: str,
|
| 563 |
+
selected_source: GenerationSource,
|
| 564 |
+
) -> None:
|
| 565 |
+
"""Record learning data for continuous improvement."""
|
| 566 |
+
if not self._use_learning:
|
| 567 |
+
return
|
| 568 |
+
|
| 569 |
+
score = final_result.score
|
| 570 |
+
passed = final_result.validation_report.overall_passed if final_result.validation_report else (score >= 0.7)
|
| 571 |
+
|
| 572 |
+
reward = 1.0 if passed else (-0.5 if not passed else 0.3)
|
| 573 |
+
|
| 574 |
+
used_source = (
|
| 575 |
+
final_result.source.value
|
| 576 |
+
if final_result.source != selected_source
|
| 577 |
+
else selected_source.value
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
if final_result.validation_report:
|
| 581 |
+
for file_result in final_result.validation_report.files:
|
| 582 |
+
if self._rl_learner:
|
| 583 |
+
self._rl_learner.update(
|
| 584 |
+
protocol=protocol,
|
| 585 |
+
file_type=file_result.file_type,
|
| 586 |
+
generation_source=used_source,
|
| 587 |
+
reward=1.0 if file_result.passed else -0.3,
|
| 588 |
+
spec_dict=spec_dict,
|
| 589 |
+
metadata={
|
| 590 |
+
"design_name": design_name,
|
| 591 |
+
"score": file_result.score,
|
| 592 |
+
"error_count": file_result.error_count,
|
| 593 |
+
},
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
if self._pattern_learner:
|
| 597 |
+
if file_result.passed and file_result.score >= 0.7:
|
| 598 |
+
self._pattern_learner.record_success(
|
| 599 |
+
file_type=file_result.file_type,
|
| 600 |
+
protocol=protocol,
|
| 601 |
+
score=file_result.score,
|
| 602 |
+
)
|
| 603 |
+
else:
|
| 604 |
+
for issue in file_result.issues:
|
| 605 |
+
if issue.severity.value == "error":
|
| 606 |
+
self._pattern_learner.record_error(
|
| 607 |
+
error_msg=issue.message,
|
| 608 |
+
file_type=file_result.file_type,
|
| 609 |
+
line_num=issue.line_number,
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
history_entry = {
|
| 613 |
+
"timestamp": datetime.now().isoformat(),
|
| 614 |
+
"design_name": design_name,
|
| 615 |
+
"protocol": protocol,
|
| 616 |
+
"selected_source": selected_source.value,
|
| 617 |
+
"actual_source": final_result.source.value,
|
| 618 |
+
"score": score,
|
| 619 |
+
"passed": passed,
|
| 620 |
+
"reward": reward,
|
| 621 |
+
"error_count": (
|
| 622 |
+
final_result.validation_report.total_errors
|
| 623 |
+
if final_result.validation_report else 0
|
| 624 |
+
),
|
| 625 |
+
}
|
| 626 |
+
self._generation_history.append(history_entry)
|
| 627 |
+
|
| 628 |
+
if len(self._generation_history) > 100:
|
| 629 |
+
self._generation_history = self._generation_history[-100:]
|
| 630 |
+
|
| 631 |
+
if self._rl_learner and len(self._generation_history) % 10 == 0:
|
| 632 |
+
replay_count = self._rl_learner.replay_experiences(batch_size=32)
|
| 633 |
+
logger.debug(f"Replayed {replay_count} experiences")
|
| 634 |
+
|
| 635 |
+
if self._learning_storage_path:
|
| 636 |
+
self._save_learning_state()
|
| 637 |
+
|
| 638 |
+
def _save_learning_state(self) -> None:
|
| 639 |
+
"""Save learning state to storage."""
|
| 640 |
+
if not self._learning_storage_path:
|
| 641 |
+
return
|
| 642 |
+
|
| 643 |
+
try:
|
| 644 |
+
os.makedirs(os.path.dirname(self._learning_storage_path), exist_ok=True)
|
| 645 |
+
|
| 646 |
+
state = {
|
| 647 |
+
"saved_at": datetime.now().isoformat(),
|
| 648 |
+
"generation_history": self._generation_history[-500:],
|
| 649 |
+
"strategy_weights": {
|
| 650 |
+
"retrieval": self._strategy_weights.retrieval_weight,
|
| 651 |
+
"llm": self._strategy_weights.llm_weight,
|
| 652 |
+
"template": self._strategy_weights.template_weight,
|
| 653 |
+
},
|
| 654 |
+
}
|
| 655 |
+
|
| 656 |
+
if self._rl_learner:
|
| 657 |
+
state["rl_learner"] = self._rl_learner.to_dict()
|
| 658 |
+
|
| 659 |
+
if self._pattern_learner:
|
| 660 |
+
state["pattern_learner"] = self._pattern_learner.to_dict()
|
| 661 |
+
|
| 662 |
+
with open(self._learning_storage_path, "w") as f:
|
| 663 |
+
json.dump(state, f, indent=2)
|
| 664 |
+
|
| 665 |
+
logger.info(f"Learning state saved to: {self._learning_storage_path}")
|
| 666 |
+
|
| 667 |
+
except Exception as e:
|
| 668 |
+
logger.warning(f"Could not save learning state: {e}")
|
| 669 |
+
|
| 670 |
+
def _load_learning_state(self) -> None:
|
| 671 |
+
"""Load learning state from storage."""
|
| 672 |
+
if not self._learning_storage_path or not os.path.exists(self._learning_storage_path):
|
| 673 |
+
return
|
| 674 |
+
|
| 675 |
+
try:
|
| 676 |
+
with open(self._learning_storage_path, "r") as f:
|
| 677 |
+
state = json.load(f)
|
| 678 |
+
|
| 679 |
+
self._generation_history = state.get("generation_history", [])
|
| 680 |
+
|
| 681 |
+
weights = state.get("strategy_weights", {})
|
| 682 |
+
if weights:
|
| 683 |
+
self._strategy_weights = StrategyWeights(
|
| 684 |
+
retrieval_weight=weights.get("retrieval", 0.4),
|
| 685 |
+
llm_weight=weights.get("llm", 0.3),
|
| 686 |
+
template_weight=weights.get("template", 0.3),
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
if "rl_learner" in state and self._rl_learner:
|
| 690 |
+
from src.models.advanced_rl_learner import AdvancedReinforcementLearner
|
| 691 |
+
self._rl_learner = AdvancedReinforcementLearner.from_dict(state["rl_learner"])
|
| 692 |
+
|
| 693 |
+
if "pattern_learner" in state and self._pattern_learner:
|
| 694 |
+
from src.models.advanced_pattern_learner import AdvancedPatternLearner
|
| 695 |
+
self._pattern_learner = AdvancedPatternLearner.from_dict(state["pattern_learner"])
|
| 696 |
+
|
| 697 |
+
logger.info(f"Learning state loaded from: {self._learning_storage_path}")
|
| 698 |
+
|
| 699 |
+
except Exception as e:
|
| 700 |
+
logger.warning(f"Could not load learning state: {e}")
|
| 701 |
+
|
| 702 |
+
def get_learning_stats(self) -> Dict[str, Any]:
|
| 703 |
+
"""Get comprehensive learning statistics."""
|
| 704 |
+
stats = {
|
| 705 |
+
"total_generations": len(self._generation_history),
|
| 706 |
+
"strategy_weights": {
|
| 707 |
+
"retrieval": self._strategy_weights.retrieval_weight,
|
| 708 |
+
"llm": self._strategy_weights.llm_weight,
|
| 709 |
+
"template": self._strategy_weights.template_weight,
|
| 710 |
+
},
|
| 711 |
+
}
|
| 712 |
+
|
| 713 |
+
if self._generation_history:
|
| 714 |
+
recent = self._generation_history[-50:]
|
| 715 |
+
passed = sum(1 for h in recent if h.get("passed", False))
|
| 716 |
+
avg_score = sum(h.get("score", 0) for h in recent) / len(recent)
|
| 717 |
+
|
| 718 |
+
stats["recent_performance"] = {
|
| 719 |
+
"window_size": len(recent),
|
| 720 |
+
"pass_rate": passed / len(recent),
|
| 721 |
+
"avg_score": avg_score,
|
| 722 |
+
}
|
| 723 |
+
|
| 724 |
+
sources = [h.get("actual_source", "unknown") for h in recent]
|
| 725 |
+
stats["source_distribution"] = dict(Counter(sources))
|
| 726 |
+
|
| 727 |
+
if self._rl_learner:
|
| 728 |
+
stats["rl_learner"] = self._rl_learner.get_performance_stats()
|
| 729 |
+
|
| 730 |
+
if self._pattern_learner:
|
| 731 |
+
stats["pattern_learner"] = self._pattern_learner.get_suggestions(
|
| 732 |
+
file_type="any",
|
| 733 |
+
protocol="any",
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
return stats
|
| 737 |
+
|
| 738 |
+
@staticmethod
|
| 739 |
+
def _spec_to_dict(spec: DesignSpec) -> Dict[str, Any]:
|
| 740 |
+
"""Convert DesignSpec to serializable dict."""
|
| 741 |
+
return {
|
| 742 |
+
"design_name": spec.design_name,
|
| 743 |
+
"protocol": spec.protocol,
|
| 744 |
+
"clock_reset": {
|
| 745 |
+
"clock": spec.clock_reset.clock,
|
| 746 |
+
"reset": spec.clock_reset.reset,
|
| 747 |
+
"reset_active": spec.clock_reset.reset_active,
|
| 748 |
+
},
|
| 749 |
+
"interfaces": [
|
| 750 |
+
{
|
| 751 |
+
"name": iface.name,
|
| 752 |
+
"signals": [
|
| 753 |
+
{"name": s.name, "direction": s.direction, "width": s.width}
|
| 754 |
+
for s in iface.signals
|
| 755 |
+
],
|
| 756 |
+
}
|
| 757 |
+
for iface in spec.interfaces
|
| 758 |
+
],
|
| 759 |
+
"registers": [
|
| 760 |
+
{
|
| 761 |
+
"name": r.name,
|
| 762 |
+
"address": r.address,
|
| 763 |
+
"access": r.access,
|
| 764 |
+
"size": r.size,
|
| 765 |
+
"reset_value": r.reset_value,
|
| 766 |
+
"fields": [
|
| 767 |
+
{"name": f.name, "bits": f.bits, "description": f.description}
|
| 768 |
+
for f in r.fields
|
| 769 |
+
],
|
| 770 |
+
}
|
| 771 |
+
for r in spec.registers
|
| 772 |
+
],
|
| 773 |
+
}
|
| 774 |
+
|
| 775 |
+
def save(self, path: str) -> None:
|
| 776 |
+
"""Save the model state to disk."""
|
| 777 |
+
self.save_learning_state(path)
|
| 778 |
+
logger.info("Saved EnhancedMLGenerationModelV2 to %s", path)
|
| 779 |
+
|
| 780 |
+
@classmethod
|
| 781 |
+
def load(cls, path: str) -> "EnhancedMLGenerationModelV2":
|
| 782 |
+
"""Load the model from disk."""
|
| 783 |
+
model = cls(
|
| 784 |
+
name="enhanced_ml_model_v2",
|
| 785 |
+
use_learning=True,
|
| 786 |
+
)
|
| 787 |
+
model.load_learning_state(path)
|
| 788 |
+
logger.info("Loaded EnhancedMLGenerationModelV2 from %s", path)
|
| 789 |
+
return model
|
| 790 |
+
|
| 791 |
+
@property
|
| 792 |
+
def is_trained(self) -> bool:
|
| 793 |
+
"""Check if model is trained."""
|
| 794 |
+
if self._index is not None:
|
| 795 |
+
return len(self._index) > 0
|
| 796 |
+
return False
|
| 797 |
+
|
| 798 |
+
@property
|
| 799 |
+
def index(self) -> Optional[SimilarityIndex]:
|
| 800 |
+
"""Get the similarity index."""
|
| 801 |
+
return self._index
|
|
@@ -13,6 +13,7 @@ from src.features.extractors import SpecFeatureExtractor
|
|
| 13 |
from src.generation.engine import GenerationEngine
|
| 14 |
from src.models.base_model import GenerationModel
|
| 15 |
from src.models.enhanced_ml_model import EnhancedMLGenerationModel
|
|
|
|
| 16 |
from src.models.ml_generation_model import MLGenerationModel, MLModelConfig
|
| 17 |
from src.models.registry import ModelRegistry
|
| 18 |
from src.models.template_model import TemplateModel
|
|
@@ -55,7 +56,7 @@ class TBPipeline:
|
|
| 55 |
model_type = ml_cfg.model_type
|
| 56 |
self.logger.info("ML generation enabled, model_type=%s", model_type)
|
| 57 |
|
| 58 |
-
if model_type in ("ml", "hybrid", "llm", "semantic"):
|
| 59 |
ml_model_config = MLModelConfig(
|
| 60 |
similarity_threshold=ml_cfg.similarity_threshold,
|
| 61 |
auto_learn=ml_cfg.auto_learn,
|
|
@@ -63,18 +64,34 @@ class TBPipeline:
|
|
| 63 |
top_k_retrieval=ml_cfg.top_k_retrieval,
|
| 64 |
fallback_to_templates=ml_cfg.fallback_to_templates,
|
| 65 |
)
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
if model_type == "llm":
|
| 80 |
self.logger.info("LLM mode: will prioritize LLM generation")
|
|
|
|
| 13 |
from src.generation.engine import GenerationEngine
|
| 14 |
from src.models.base_model import GenerationModel
|
| 15 |
from src.models.enhanced_ml_model import EnhancedMLGenerationModel
|
| 16 |
+
from src.models.enhanced_ml_model_v2 import EnhancedMLGenerationModelV2
|
| 17 |
from src.models.ml_generation_model import MLGenerationModel, MLModelConfig
|
| 18 |
from src.models.registry import ModelRegistry
|
| 19 |
from src.models.template_model import TemplateModel
|
|
|
|
| 56 |
model_type = ml_cfg.model_type
|
| 57 |
self.logger.info("ML generation enabled, model_type=%s", model_type)
|
| 58 |
|
| 59 |
+
if model_type in ("ml", "hybrid", "llm", "semantic", "v2"):
|
| 60 |
ml_model_config = MLModelConfig(
|
| 61 |
similarity_threshold=ml_cfg.similarity_threshold,
|
| 62 |
auto_learn=ml_cfg.auto_learn,
|
|
|
|
| 64 |
top_k_retrieval=ml_cfg.top_k_retrieval,
|
| 65 |
fallback_to_templates=ml_cfg.fallback_to_templates,
|
| 66 |
)
|
| 67 |
+
|
| 68 |
+
if model_type == "v2":
|
| 69 |
+
model = EnhancedMLGenerationModelV2(
|
| 70 |
+
name="enhanced_ml_model_v2",
|
| 71 |
+
config=ml_model_config,
|
| 72 |
+
templates_dir=self.cfg.generation.templates_dir,
|
| 73 |
+
strict_validation=True,
|
| 74 |
+
use_llm=ml_cfg.use_llm,
|
| 75 |
+
use_semantic_encoder=ml_cfg.use_semantic_encoder,
|
| 76 |
+
use_learning=ml_cfg.use_learning,
|
| 77 |
+
llm_model_name=ml_cfg.llm_model_name,
|
| 78 |
+
learning_storage_path=ml_cfg.learning_storage_path,
|
| 79 |
+
exploration_strategy=getattr(ml_cfg, 'exploration_strategy', 'ucb'),
|
| 80 |
+
)
|
| 81 |
+
self.logger.info("Created EnhancedMLGenerationModelV2 with advanced RL and pattern learning")
|
| 82 |
+
else:
|
| 83 |
+
model = EnhancedMLGenerationModel(
|
| 84 |
+
name="enhanced_ml_model",
|
| 85 |
+
config=ml_model_config,
|
| 86 |
+
templates_dir=self.cfg.generation.templates_dir,
|
| 87 |
+
strict_validation=True,
|
| 88 |
+
use_llm=ml_cfg.use_llm,
|
| 89 |
+
use_semantic_encoder=ml_cfg.use_semantic_encoder,
|
| 90 |
+
use_learning=ml_cfg.use_learning,
|
| 91 |
+
llm_model_name=ml_cfg.llm_model_name,
|
| 92 |
+
learning_storage_path=ml_cfg.learning_storage_path,
|
| 93 |
+
)
|
| 94 |
+
self.logger.info("Created EnhancedMLGenerationModel with index size: %d", len(model.index))
|
| 95 |
|
| 96 |
if model_type == "llm":
|
| 97 |
self.logger.info("LLM mode: will prioritize LLM generation")
|
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
-
Streamlit UI for UVM Testbench Generator
|
| 3 |
-
|
| 4 |
"""
|
| 5 |
|
| 6 |
import streamlit as st
|
|
@@ -11,20 +11,18 @@ import zipfile
|
|
| 11 |
import io
|
| 12 |
from pathlib import Path
|
| 13 |
from datetime import datetime
|
|
|
|
| 14 |
|
| 15 |
-
# Configure logging
|
| 16 |
logging.basicConfig(level=logging.INFO)
|
| 17 |
logger = logging.getLogger("uvmgen-streamlit")
|
| 18 |
|
| 19 |
-
# Page config
|
| 20 |
st.set_page_config(
|
| 21 |
-
page_title="UVM Testbench Generator",
|
| 22 |
page_icon="🔬",
|
| 23 |
layout="wide",
|
| 24 |
initial_sidebar_state="expanded",
|
| 25 |
)
|
| 26 |
|
| 27 |
-
# Example specifications
|
| 28 |
EXAMPLES = {
|
| 29 |
"UART": """design_name: uart
|
| 30 |
clock_reset:
|
|
@@ -58,20 +56,50 @@ interfaces:
|
|
| 58 |
direction: output
|
| 59 |
- name: uart_rx
|
| 60 |
direction: input
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
registers:
|
| 63 |
- name: RBR_THR
|
| 64 |
address: 0x0
|
| 65 |
description: Receiver Buffer / Transmitter Holding
|
|
|
|
|
|
|
|
|
|
| 66 |
- name: IER
|
| 67 |
address: 0x1
|
| 68 |
description: Interrupt Enable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
- name: LCR
|
| 70 |
address: 0x3
|
| 71 |
description: Line Control
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
- name: LSR
|
| 73 |
address: 0x5
|
| 74 |
description: Line Status
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
protocol: uart""",
|
| 77 |
"SPI": """design_name: spi_controller
|
|
@@ -181,239 +209,514 @@ registers:
|
|
| 181 |
protocol: i2c"""
|
| 182 |
}
|
| 183 |
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
if 'last_result' not in st.session_state:
|
| 186 |
st.session_state.last_result = None
|
| 187 |
if 'generated_files' not in st.session_state:
|
| 188 |
st.session_state.generated_files = {}
|
| 189 |
if 'log_output' not in st.session_state:
|
| 190 |
st.session_state.log_output = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
-
# Header
|
| 193 |
st.title("🔬 UVM Testbench Generator")
|
| 194 |
st.markdown("""
|
| 195 |
-
**AI-Powered Semiconductor Verification Pipeline**
|
| 196 |
-
Generate industry-grade UVM testbenches from YAML specifications
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
""")
|
| 198 |
|
| 199 |
-
# Sidebar
|
| 200 |
with st.sidebar:
|
| 201 |
st.header("⚙️ Configuration")
|
| 202 |
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
|
|
|
| 217 |
st.divider()
|
| 218 |
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
st.divider()
|
| 241 |
|
| 242 |
-
st.
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
-
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
-
with
|
| 249 |
-
st.
|
| 250 |
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
-
# Generate button
|
| 261 |
generate_btn = st.button(
|
| 262 |
"🚀 Generate UVM Testbench",
|
| 263 |
type="primary",
|
| 264 |
-
use_container_width=True
|
|
|
|
| 265 |
)
|
| 266 |
|
| 267 |
-
with
|
| 268 |
-
st.subheader("📊 Results & Output")
|
| 269 |
-
|
| 270 |
-
# Status
|
| 271 |
status_placeholder = st.empty()
|
| 272 |
|
| 273 |
-
# Metrics
|
| 274 |
metrics_placeholder = st.empty()
|
| 275 |
|
| 276 |
-
# Logs
|
| 277 |
with st.expander("📋 Log Output", expanded=True):
|
| 278 |
log_placeholder = st.empty()
|
| 279 |
|
| 280 |
-
# Files
|
| 281 |
files_placeholder = st.empty()
|
| 282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
-
# Generate logic
|
| 285 |
if generate_btn:
|
| 286 |
st.session_state.log_output = []
|
| 287 |
st.session_state.last_result = None
|
| 288 |
st.session_state.generated_files = {}
|
|
|
|
| 289 |
|
| 290 |
status_placeholder.info("🔄 Generating UVM testbench...")
|
| 291 |
|
| 292 |
try:
|
| 293 |
-
# Import here for lazy loading
|
| 294 |
from src.config import ConfigLoader, PipelineConfig
|
| 295 |
from src.pipeline import TBPipeline
|
| 296 |
|
| 297 |
-
# Save spec to temp file
|
| 298 |
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False, encoding='utf-8') as f:
|
| 299 |
f.write(spec_text)
|
| 300 |
spec_path = f.name
|
| 301 |
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
log_placeholder.code("\n".join(st.session_state.log_output))
|
| 304 |
|
| 305 |
-
# Create pipeline
|
| 306 |
pipeline = TBPipeline()
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
pipeline.cfg.auto_train.enabled = auto_train
|
| 313 |
pipeline.cfg.auto_train.max_iterations = max_iterations
|
| 314 |
|
| 315 |
-
st.session_state.log_output.append(f"[{datetime.now().strftime('%H:%M:%S')}] ML enabled: {use_ml}")
|
| 316 |
-
st.session_state.log_output.append(f"[{datetime.now().strftime('%H:%M:%S')}] Auto-train: {auto_train} (iterations: {max_iterations})")
|
| 317 |
-
log_placeholder.code("\n".join(st.session_state.log_output))
|
| 318 |
-
|
| 319 |
-
# Run pipeline
|
| 320 |
result = pipeline.run(spec_path)
|
| 321 |
|
| 322 |
-
# Cleanup
|
| 323 |
try:
|
| 324 |
os.unlink(spec_path)
|
| 325 |
except:
|
| 326 |
pass
|
| 327 |
|
| 328 |
-
# Store results
|
| 329 |
st.session_state.last_result = result
|
| 330 |
st.session_state.generated_files = result.get('generated_files', {})
|
| 331 |
|
| 332 |
-
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
log_placeholder.code("\n".join(st.session_state.log_output))
|
| 335 |
|
| 336 |
-
# Update status
|
| 337 |
if result.get('passed'):
|
| 338 |
status_placeholder.success("✅ Generation successful!")
|
| 339 |
else:
|
| 340 |
status_placeholder.warning("⚠️ Generation completed with issues")
|
| 341 |
|
| 342 |
except Exception as e:
|
| 343 |
-
|
|
|
|
| 344 |
log_placeholder.code("\n".join(st.session_state.log_output))
|
| 345 |
status_placeholder.error(f"❌ Error: {str(e)}")
|
| 346 |
import traceback
|
| 347 |
st.session_state.log_output.append(traceback.format_exc())
|
| 348 |
log_placeholder.code("\n".join(st.session_state.log_output))
|
| 349 |
|
| 350 |
-
|
| 351 |
-
# Show results
|
| 352 |
if st.session_state.last_result:
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
# Metrics
|
| 356 |
-
with metrics_placeholder.container():
|
| 357 |
-
eval_metrics = result.get('evaluation', {})
|
| 358 |
-
|
| 359 |
-
m1, m2, m3 = st.columns(3)
|
| 360 |
-
with m1:
|
| 361 |
-
completeness = eval_metrics.get('completeness', 0) * 100
|
| 362 |
-
st.metric("Completeness", f"{completeness:.1f}%")
|
| 363 |
-
with m2:
|
| 364 |
-
signal_cov = eval_metrics.get('interface_signal_coverage', 0) * 100
|
| 365 |
-
st.metric("Signal Coverage", f"{signal_cov:.1f}%")
|
| 366 |
-
with m3:
|
| 367 |
-
reg_cov = eval_metrics.get('register_coverage', 0) * 100
|
| 368 |
-
st.metric("Register Coverage", f"{reg_cov:.1f}%")
|
| 369 |
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
st.metric("Files Generated", len(st.session_state.generated_files))
|
| 373 |
-
with m5:
|
| 374 |
-
st.metric("Iterations", result.get('auto_train_iterations', 0))
|
| 375 |
-
|
| 376 |
-
# Files list
|
| 377 |
-
with files_placeholder.expander("📄 Generated Files", expanded=True):
|
| 378 |
-
if st.session_state.generated_files:
|
| 379 |
-
# File selector
|
| 380 |
-
file_names = sorted(st.session_state.generated_files.keys())
|
| 381 |
-
selected_file = st.selectbox("Select file to preview", file_names)
|
| 382 |
-
|
| 383 |
-
if selected_file:
|
| 384 |
-
file_path = st.session_state.generated_files[selected_file]
|
| 385 |
-
if os.path.exists(file_path):
|
| 386 |
-
try:
|
| 387 |
-
with open(file_path, 'r', encoding='utf-8') as f:
|
| 388 |
-
content = f.read()
|
| 389 |
-
st.code(content, language='systemverilog')
|
| 390 |
-
except Exception as e:
|
| 391 |
-
st.warning(f"Could not read file: {e}")
|
| 392 |
-
|
| 393 |
-
# Download ZIP
|
| 394 |
-
if st.session_state.generated_files:
|
| 395 |
-
zip_buffer = io.BytesIO()
|
| 396 |
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
-
st.
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
|
| 413 |
-
|
| 414 |
-
# Footer
|
| 415 |
st.divider()
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
Enhanced Streamlit UI for UVM Testbench Generator
|
| 3 |
+
Shows advanced ML capabilities: V2 model, RL strategies, learning persistence, etc.
|
| 4 |
"""
|
| 5 |
|
| 6 |
import streamlit as st
|
|
|
|
| 11 |
import io
|
| 12 |
from pathlib import Path
|
| 13 |
from datetime import datetime
|
| 14 |
+
import json
|
| 15 |
|
|
|
|
| 16 |
logging.basicConfig(level=logging.INFO)
|
| 17 |
logger = logging.getLogger("uvmgen-streamlit")
|
| 18 |
|
|
|
|
| 19 |
st.set_page_config(
|
| 20 |
+
page_title="UVM Testbench Generator - AI/ML Enhanced",
|
| 21 |
page_icon="🔬",
|
| 22 |
layout="wide",
|
| 23 |
initial_sidebar_state="expanded",
|
| 24 |
)
|
| 25 |
|
|
|
|
| 26 |
EXAMPLES = {
|
| 27 |
"UART": """design_name: uart
|
| 28 |
clock_reset:
|
|
|
|
| 56 |
direction: output
|
| 57 |
- name: uart_rx
|
| 58 |
direction: input
|
| 59 |
+
- name: cts_n
|
| 60 |
+
direction: input
|
| 61 |
+
- name: rts_n
|
| 62 |
+
direction: output
|
| 63 |
+
- name: uart_intr
|
| 64 |
+
direction: output
|
| 65 |
|
| 66 |
registers:
|
| 67 |
- name: RBR_THR
|
| 68 |
address: 0x0
|
| 69 |
description: Receiver Buffer / Transmitter Holding
|
| 70 |
+
fields:
|
| 71 |
+
- name: data
|
| 72 |
+
bits: 7:0
|
| 73 |
- name: IER
|
| 74 |
address: 0x1
|
| 75 |
description: Interrupt Enable
|
| 76 |
+
fields:
|
| 77 |
+
- name: erbfi
|
| 78 |
+
bits: '0'
|
| 79 |
+
description: Enable RX data available interrupt
|
| 80 |
+
- name: etbei
|
| 81 |
+
bits: '1'
|
| 82 |
+
description: Enable TX holding register empty interrupt
|
| 83 |
- name: LCR
|
| 84 |
address: 0x3
|
| 85 |
description: Line Control
|
| 86 |
+
fields:
|
| 87 |
+
- name: wls
|
| 88 |
+
bits: 1:0
|
| 89 |
+
description: Word length select
|
| 90 |
+
- name: dlab
|
| 91 |
+
bits: '7'
|
| 92 |
+
description: Divisor latch access bit
|
| 93 |
- name: LSR
|
| 94 |
address: 0x5
|
| 95 |
description: Line Status
|
| 96 |
+
fields:
|
| 97 |
+
- name: dr
|
| 98 |
+
bits: '0'
|
| 99 |
+
description: Data Ready
|
| 100 |
+
- name: thre
|
| 101 |
+
bits: '5'
|
| 102 |
+
description: TX Holding Register Empty
|
| 103 |
|
| 104 |
protocol: uart""",
|
| 105 |
"SPI": """design_name: spi_controller
|
|
|
|
| 209 |
protocol: i2c"""
|
| 210 |
}
|
| 211 |
|
| 212 |
+
MODEL_TYPES = {
|
| 213 |
+
"template": "Template Only (Fast, No Learning)",
|
| 214 |
+
"hybrid": "Hybrid ML (Retrieval + Templates)",
|
| 215 |
+
"v2": "Advanced ML V2 (Recommended) - RL + Pattern Learning",
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
EXPLORATION_STRATEGIES = {
|
| 219 |
+
"ucb": "UCB1 (Upper Confidence Bound) - Best for exploration/exploitation balance",
|
| 220 |
+
"epsilon_greedy": "Epsilon-Greedy - Simple, with decaying randomness",
|
| 221 |
+
"softmax": "Softmax (Boltzmann) - Probabilistic based on Q-values",
|
| 222 |
+
"thompson": "Thompson Sampling - Bayesian approach with Beta distributions",
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
if 'last_result' not in st.session_state:
|
| 226 |
st.session_state.last_result = None
|
| 227 |
if 'generated_files' not in st.session_state:
|
| 228 |
st.session_state.generated_files = {}
|
| 229 |
if 'log_output' not in st.session_state:
|
| 230 |
st.session_state.log_output = []
|
| 231 |
+
if 'ml_stats' not in st.session_state:
|
| 232 |
+
st.session_state.ml_stats = None
|
| 233 |
+
if 'learning_state_path' not in st.session_state:
|
| 234 |
+
st.session_state.learning_state_path = None
|
| 235 |
|
|
|
|
| 236 |
st.title("🔬 UVM Testbench Generator")
|
| 237 |
st.markdown("""
|
| 238 |
+
**AI-Powered Semiconductor Verification Pipeline with Advanced ML**
|
| 239 |
+
Generate industry-grade UVM testbenches from YAML specifications. Now featuring:
|
| 240 |
+
- **Advanced ML V2** with Reinforcement Learning (UCB, Softmax, Thompson Sampling)
|
| 241 |
+
- **Experience Replay Buffer** (10,000 capacity)
|
| 242 |
+
- **Eligibility Traces** for better credit assignment
|
| 243 |
+
- **Pattern Mining** with N-grams and Association Rules
|
| 244 |
+
- **Deep UVM Compliance Validation** (factory registration, phases, TLM)
|
| 245 |
+
- **Continuous Learning** with state persistence
|
| 246 |
""")
|
| 247 |
|
|
|
|
| 248 |
with st.sidebar:
|
| 249 |
st.header("⚙️ Configuration")
|
| 250 |
|
| 251 |
+
with st.expander("📋 Quick Setup", expanded=True):
|
| 252 |
+
selected_protocol = st.selectbox(
|
| 253 |
+
"Protocol Example",
|
| 254 |
+
list(EXAMPLES.keys()),
|
| 255 |
+
index=0,
|
| 256 |
+
help="Select a pre-built protocol specification"
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
default_name = selected_protocol.lower() + "_controller"
|
| 260 |
+
design_name = st.text_input(
|
| 261 |
+
"Design Name",
|
| 262 |
+
value=default_name,
|
| 263 |
+
help="Name for your generated IP"
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
st.divider()
|
| 267 |
|
| 268 |
+
with st.expander("🤖 ML Configuration", expanded=True):
|
| 269 |
+
use_ml = st.checkbox(
|
| 270 |
+
"Enable AI/ML Features",
|
| 271 |
+
value=True,
|
| 272 |
+
help="Use machine learning for intelligent generation"
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
if use_ml:
|
| 276 |
+
model_type = st.selectbox(
|
| 277 |
+
"ML Model Version",
|
| 278 |
+
list(MODEL_TYPES.keys()),
|
| 279 |
+
index=2,
|
| 280 |
+
format_func=lambda k: MODEL_TYPES[k],
|
| 281 |
+
help="V2 is recommended for advanced learning"
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
if model_type == "v2":
|
| 285 |
+
exploration_strategy = st.selectbox(
|
| 286 |
+
"RL Exploration Strategy",
|
| 287 |
+
list(EXPLORATION_STRATEGIES.keys()),
|
| 288 |
+
index=0,
|
| 289 |
+
format_func=lambda k: EXPLORATION_STRATEGIES[k].split(" - ")[0],
|
| 290 |
+
help="How the RL agent balances exploration and exploitation"
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
st.caption(EXPLORATION_STRATEGIES[exploration_strategy])
|
| 294 |
+
|
| 295 |
+
persist_learning = st.checkbox(
|
| 296 |
+
"Persist Learning State",
|
| 297 |
+
value=True,
|
| 298 |
+
help="Save and load learned patterns between sessions"
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
if persist_learning:
|
| 302 |
+
st.session_state.learning_state_path = os.path.join(
|
| 303 |
+
tempfile.gettempdir(),
|
| 304 |
+
"uvmgen_learning_state.json"
|
| 305 |
+
)
|
| 306 |
+
st.caption(f"State will be saved to: temporary directory")
|
| 307 |
+
|
| 308 |
+
strict_validation = st.checkbox(
|
| 309 |
+
"Strict UVM Compliance",
|
| 310 |
+
value=True,
|
| 311 |
+
help="Enforce deep UVM validation (factory, phases, TLM)"
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
auto_learn = st.checkbox(
|
| 315 |
+
"Continuous Learning",
|
| 316 |
+
value=True,
|
| 317 |
+
help="Learn from each generation to improve future results"
|
| 318 |
+
)
|
| 319 |
+
else:
|
| 320 |
+
model_type = "template"
|
| 321 |
+
exploration_strategy = "ucb"
|
| 322 |
+
strict_validation = False
|
| 323 |
+
auto_learn = False
|
| 324 |
+
|
| 325 |
+
st.divider()
|
| 326 |
|
| 327 |
+
with st.expander("⚡ Generation Options"):
|
| 328 |
+
auto_train = st.checkbox(
|
| 329 |
+
"Coverage-Driven Auto-Training",
|
| 330 |
+
value=False,
|
| 331 |
+
help="Iteratively improve testbench based on coverage analysis"
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
max_iterations = st.slider(
|
| 335 |
+
"Max Iterations",
|
| 336 |
+
min_value=1,
|
| 337 |
+
max_value=10,
|
| 338 |
+
value=1,
|
| 339 |
+
help="Maximum auto-training iterations"
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
st.caption("Auto-training requires a simulator (Icarus Verilog, VCS, or Questa)")
|
| 343 |
+
|
| 344 |
st.divider()
|
| 345 |
|
| 346 |
+
with st.expander("ℹ️ About"):
|
| 347 |
+
st.info("💡 **UVM = Universal Verification Methodology**")
|
| 348 |
+
st.info("🔬 **ML V2 = Reinforcement Learning + Pattern Mining**")
|
| 349 |
+
st.markdown("---")
|
| 350 |
+
st.caption("Developed by **Sai Kumar Taraka**")
|
| 351 |
+
st.caption("Promotion-Ready Advanced ML System")
|
| 352 |
|
| 353 |
+
tab_spec, tab_results, tab_ml_insights = st.tabs([
|
| 354 |
+
"📝 Specification",
|
| 355 |
+
"📊 Results & Files",
|
| 356 |
+
"🤖 ML Insights"
|
| 357 |
+
])
|
| 358 |
|
| 359 |
+
with tab_spec:
|
| 360 |
+
col1, col2 = st.columns([1, 1])
|
| 361 |
|
| 362 |
+
with col1:
|
| 363 |
+
st.subheader("✏️ YAML Specification Editor")
|
| 364 |
+
spec_text = st.text_area(
|
| 365 |
+
"Edit your specification",
|
| 366 |
+
value=EXAMPLES[selected_protocol],
|
| 367 |
+
height=450,
|
| 368 |
+
key="spec_editor",
|
| 369 |
+
help="Define your interfaces, signals, registers, and protocol"
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
st.caption(f"Protocol: {selected_protocol} | Model: {model_type.upper()} | Strategy: {exploration_strategy.upper()}")
|
| 373 |
+
|
| 374 |
+
with col2:
|
| 375 |
+
st.subheader("📋 Specification Summary")
|
| 376 |
+
|
| 377 |
+
import yaml
|
| 378 |
+
try:
|
| 379 |
+
spec_dict = yaml.safe_load(spec_text)
|
| 380 |
+
|
| 381 |
+
st.metric("Design Name", spec_dict.get('design_name', 'unknown'))
|
| 382 |
+
st.metric("Protocol", spec_dict.get('protocol', 'unknown').upper())
|
| 383 |
+
|
| 384 |
+
col_a, col_b = st.columns(2)
|
| 385 |
+
with col_a:
|
| 386 |
+
interfaces = spec_dict.get('interfaces', [])
|
| 387 |
+
st.metric("Interfaces", len(interfaces))
|
| 388 |
+
total_signals = sum(len(i.get('signals', [])) for i in interfaces)
|
| 389 |
+
st.metric("Total Signals", total_signals)
|
| 390 |
+
|
| 391 |
+
with col_b:
|
| 392 |
+
registers = spec_dict.get('registers', [])
|
| 393 |
+
st.metric("Registers", len(registers))
|
| 394 |
+
total_fields = sum(len(r.get('fields', [])) for r in registers)
|
| 395 |
+
st.metric("Register Fields", total_fields)
|
| 396 |
+
|
| 397 |
+
if interfaces:
|
| 398 |
+
st.subheader("Interface Signals")
|
| 399 |
+
for iface in interfaces:
|
| 400 |
+
with st.expander(f"🔌 {iface.get('name', 'unknown')}"):
|
| 401 |
+
signals = iface.get('signals', [])
|
| 402 |
+
for sig in signals:
|
| 403 |
+
name = sig.get('name', 'unknown')
|
| 404 |
+
direction = sig.get('direction', 'input')
|
| 405 |
+
width = sig.get('width', 1)
|
| 406 |
+
st.text(f" • {name} ({direction}, {width}bit)")
|
| 407 |
+
|
| 408 |
+
if registers:
|
| 409 |
+
st.subheader("Register Map")
|
| 410 |
+
for reg in registers:
|
| 411 |
+
with st.expander(f"📋 {reg.get('name', 'unknown')} @ {reg.get('address', '0x0')}"):
|
| 412 |
+
st.text(f" Description: {reg.get('description', 'None')}")
|
| 413 |
+
fields = reg.get('fields', [])
|
| 414 |
+
if fields:
|
| 415 |
+
st.text(f" Fields:")
|
| 416 |
+
for field in fields:
|
| 417 |
+
st.text(f" • {field.get('name', 'unknown')} [{field.get('bits', '0')}]")
|
| 418 |
+
|
| 419 |
+
except Exception as e:
|
| 420 |
+
st.error(f"Invalid YAML: {e}")
|
| 421 |
+
|
| 422 |
+
st.divider()
|
| 423 |
|
|
|
|
| 424 |
generate_btn = st.button(
|
| 425 |
"🚀 Generate UVM Testbench",
|
| 426 |
type="primary",
|
| 427 |
+
use_container_width=True,
|
| 428 |
+
help=f"Generate using {model_type.upper()} model"
|
| 429 |
)
|
| 430 |
|
| 431 |
+
with tab_results:
|
|
|
|
|
|
|
|
|
|
| 432 |
status_placeholder = st.empty()
|
| 433 |
|
|
|
|
| 434 |
metrics_placeholder = st.empty()
|
| 435 |
|
|
|
|
| 436 |
with st.expander("📋 Log Output", expanded=True):
|
| 437 |
log_placeholder = st.empty()
|
| 438 |
|
|
|
|
| 439 |
files_placeholder = st.empty()
|
| 440 |
|
| 441 |
+
with tab_ml_insights:
|
| 442 |
+
st.header("🤖 Advanced ML Insights")
|
| 443 |
+
|
| 444 |
+
if st.session_state.ml_stats:
|
| 445 |
+
stats = st.session_state.ml_stats
|
| 446 |
+
|
| 447 |
+
col1, col2 = st.columns(2)
|
| 448 |
+
|
| 449 |
+
with col1:
|
| 450 |
+
st.subheader("📊 Learning Statistics")
|
| 451 |
+
total_gen = stats.get('total_generations', 0)
|
| 452 |
+
st.metric("Total Generations", total_gen)
|
| 453 |
+
|
| 454 |
+
if 'recent_performance' in stats:
|
| 455 |
+
perf = stats['recent_performance']
|
| 456 |
+
st.metric("Recent Pass Rate", f"{perf.get('pass_rate', 0)*100:.1f}%")
|
| 457 |
+
st.metric("Avg Score", f"{perf.get('avg_score', 0):.3f}")
|
| 458 |
+
|
| 459 |
+
if 'rl_learner' in stats:
|
| 460 |
+
rl_stats = stats['rl_learner']
|
| 461 |
+
st.subheader("🎮 Reinforcement Learning")
|
| 462 |
+
st.metric("Episode Count", rl_stats.get('episode_count', 0))
|
| 463 |
+
st.metric("Total Updates", rl_stats.get('total_updates', 0))
|
| 464 |
+
st.metric("Learning Rate", f"{rl_stats.get('learning_rate', 0.1):.4f}")
|
| 465 |
+
|
| 466 |
+
if 'state_stats' in rl_stats:
|
| 467 |
+
st.subheader("📈 Strategy Performance")
|
| 468 |
+
state_stats = rl_stats['state_stats']
|
| 469 |
+
for state, info in list(state_stats.items())[:5]:
|
| 470 |
+
st.text(f" {state}: best='{info.get('best_action', 'unknown')}' (Q={info.get('best_q_value', 0):.3f})")
|
| 471 |
+
|
| 472 |
+
with col2:
|
| 473 |
+
st.subheader("🎯 Source Distribution")
|
| 474 |
+
if 'source_distribution' in stats:
|
| 475 |
+
source_dist = stats['source_distribution']
|
| 476 |
+
fig_data = {
|
| 477 |
+
'Source': list(source_dist.keys()),
|
| 478 |
+
'Count': list(source_dist.values())
|
| 479 |
+
}
|
| 480 |
+
st.bar_chart(fig_data, x='Source', y='Count')
|
| 481 |
+
|
| 482 |
+
st.subheader("⚖️ Strategy Weights")
|
| 483 |
+
if 'strategy_weights' in stats:
|
| 484 |
+
weights = stats['strategy_weights']
|
| 485 |
+
st.json(weights)
|
| 486 |
+
|
| 487 |
+
if 'pattern_learner' in stats:
|
| 488 |
+
st.subheader("🔍 Pattern Learner")
|
| 489 |
+
patterns = stats['pattern_learner']
|
| 490 |
+
if 'common_errors' in patterns:
|
| 491 |
+
st.text("Common Error Patterns:")
|
| 492 |
+
for err, count in patterns['common_errors'][:5]:
|
| 493 |
+
st.text(f" • {err}: {count} occurrences")
|
| 494 |
+
|
| 495 |
+
if 'recommendations' in patterns:
|
| 496 |
+
st.subheader("💡 Recommendations")
|
| 497 |
+
for rec in patterns['recommendations'][:5]:
|
| 498 |
+
st.info(rec)
|
| 499 |
+
|
| 500 |
+
st.divider()
|
| 501 |
+
|
| 502 |
+
col_a, col_b = st.columns(2)
|
| 503 |
+
with col_a:
|
| 504 |
+
if st.button("📥 Export Learning State"):
|
| 505 |
+
if st.session_state.learning_state_path and os.path.exists(st.session_state.learning_state_path):
|
| 506 |
+
with open(st.session_state.learning_state_path, 'r') as f:
|
| 507 |
+
state_data = f.read()
|
| 508 |
+
st.download_button(
|
| 509 |
+
"Download Learning State JSON",
|
| 510 |
+
data=state_data,
|
| 511 |
+
file_name="uvmgen_learning_state.json",
|
| 512 |
+
mime="application/json"
|
| 513 |
+
)
|
| 514 |
+
else:
|
| 515 |
+
st.warning("No learning state saved yet")
|
| 516 |
+
|
| 517 |
+
with col_b:
|
| 518 |
+
uploaded_file = st.file_uploader("📤 Import Learning State", type="json")
|
| 519 |
+
if uploaded_file is not None:
|
| 520 |
+
try:
|
| 521 |
+
state_data = json.load(uploaded_file)
|
| 522 |
+
if st.session_state.learning_state_path:
|
| 523 |
+
with open(st.session_state.learning_state_path, 'w') as f:
|
| 524 |
+
json.dump(state_data, f, indent=2)
|
| 525 |
+
st.success("Learning state imported! It will be loaded on next generation.")
|
| 526 |
+
except Exception as e:
|
| 527 |
+
st.error(f"Failed to import: {e}")
|
| 528 |
+
|
| 529 |
+
else:
|
| 530 |
+
st.info("Run a generation first to see ML insights.")
|
| 531 |
+
st.markdown("""
|
| 532 |
+
### What you'll see here:
|
| 533 |
+
- **Learning Statistics**: Total generations, pass rates, average scores
|
| 534 |
+
- **RL Metrics**: Episode counts, learning rates, strategy performance
|
| 535 |
+
- **Pattern Analysis**: Common error patterns and recommendations
|
| 536 |
+
- **Strategy Distribution**: Which generation sources work best
|
| 537 |
+
- **Import/Export**: Save and load learned state
|
| 538 |
+
|
| 539 |
+
### ML V2 Capabilities:
|
| 540 |
+
1. **Reinforcement Learning** with 4 exploration strategies
|
| 541 |
+
2. **Experience Replay** buffer (10,000 capacity)
|
| 542 |
+
3. **Eligibility Traces** for better credit assignment
|
| 543 |
+
4. **Pattern Mining** with N-grams and Association Rules
|
| 544 |
+
5. **Deep UVM Validation** for factory registration, phases, TLM connections
|
| 545 |
+
""")
|
| 546 |
|
|
|
|
| 547 |
if generate_btn:
|
| 548 |
st.session_state.log_output = []
|
| 549 |
st.session_state.last_result = None
|
| 550 |
st.session_state.generated_files = {}
|
| 551 |
+
st.session_state.ml_stats = None
|
| 552 |
|
| 553 |
status_placeholder.info("🔄 Generating UVM testbench...")
|
| 554 |
|
| 555 |
try:
|
|
|
|
| 556 |
from src.config import ConfigLoader, PipelineConfig
|
| 557 |
from src.pipeline import TBPipeline
|
| 558 |
|
|
|
|
| 559 |
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False, encoding='utf-8') as f:
|
| 560 |
f.write(spec_text)
|
| 561 |
spec_path = f.name
|
| 562 |
|
| 563 |
+
timestamp = datetime.now().strftime('%H:%M:%S')
|
| 564 |
+
st.session_state.log_output.append(f"[{timestamp}] Starting generation for: {design_name}")
|
| 565 |
+
st.session_state.log_output.append(f"[{timestamp}] Model: {model_type}")
|
| 566 |
+
if model_type == "v2":
|
| 567 |
+
st.session_state.log_output.append(f"[{timestamp}] RL Strategy: {exploration_strategy}")
|
| 568 |
+
st.session_state.log_output.append(f"[{timestamp}] ML Enabled: {use_ml}")
|
| 569 |
+
st.session_state.log_output.append(f"[{timestamp}] Strict Validation: {strict_validation}")
|
| 570 |
log_placeholder.code("\n".join(st.session_state.log_output))
|
| 571 |
|
|
|
|
| 572 |
pipeline = TBPipeline()
|
| 573 |
+
|
| 574 |
+
if use_ml:
|
| 575 |
+
pipeline.cfg.ml.enabled = True
|
| 576 |
+
pipeline.cfg.ml.model_type = model_type
|
| 577 |
+
pipeline.cfg.ml.use_llm = False
|
| 578 |
+
pipeline.cfg.ml.use_semantic_encoder = False
|
| 579 |
+
pipeline.cfg.ml.use_learning = auto_learn
|
| 580 |
+
pipeline.cfg.ml.strict_validation = strict_validation
|
| 581 |
+
|
| 582 |
+
if model_type == "v2":
|
| 583 |
+
pipeline.cfg.ml.exploration_strategy = exploration_strategy
|
| 584 |
+
if st.session_state.learning_state_path:
|
| 585 |
+
pipeline.cfg.ml.learning_storage_path = st.session_state.learning_state_path
|
| 586 |
+
else:
|
| 587 |
+
pipeline.cfg.ml.enabled = False
|
| 588 |
+
|
| 589 |
pipeline.cfg.auto_train.enabled = auto_train
|
| 590 |
pipeline.cfg.auto_train.max_iterations = max_iterations
|
| 591 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 592 |
result = pipeline.run(spec_path)
|
| 593 |
|
|
|
|
| 594 |
try:
|
| 595 |
os.unlink(spec_path)
|
| 596 |
except:
|
| 597 |
pass
|
| 598 |
|
|
|
|
| 599 |
st.session_state.last_result = result
|
| 600 |
st.session_state.generated_files = result.get('generated_files', {})
|
| 601 |
|
| 602 |
+
try:
|
| 603 |
+
if hasattr(pipeline.model, 'get_learning_stats'):
|
| 604 |
+
st.session_state.ml_stats = pipeline.model.get_learning_stats()
|
| 605 |
+
elif hasattr(pipeline.model, '_rl_learner') and hasattr(pipeline.model, '_pattern_learner'):
|
| 606 |
+
st.session_state.ml_stats = {
|
| 607 |
+
'total_generations': len(st.session_state.log_output),
|
| 608 |
+
'rl_learner': pipeline.model._rl_learner.get_performance_stats() if hasattr(pipeline.model._rl_learner, 'get_performance_stats') else {},
|
| 609 |
+
}
|
| 610 |
+
except Exception as e:
|
| 611 |
+
logger.warning(f"Could not get ML stats: {e}")
|
| 612 |
+
|
| 613 |
+
timestamp = datetime.now().strftime('%H:%M:%S')
|
| 614 |
+
st.session_state.log_output.append(f"[{timestamp}] Generation complete!")
|
| 615 |
+
st.session_state.log_output.append(f"[{timestamp}] Files generated: {len(st.session_state.generated_files)}")
|
| 616 |
+
if result.get('passed'):
|
| 617 |
+
st.session_state.log_output.append(f"[{timestamp}] Status: PASSED ✅")
|
| 618 |
+
else:
|
| 619 |
+
st.session_state.log_output.append(f"[{timestamp}] Status: COMPLETED WITH WARNINGS ⚠️")
|
| 620 |
log_placeholder.code("\n".join(st.session_state.log_output))
|
| 621 |
|
|
|
|
| 622 |
if result.get('passed'):
|
| 623 |
status_placeholder.success("✅ Generation successful!")
|
| 624 |
else:
|
| 625 |
status_placeholder.warning("⚠️ Generation completed with issues")
|
| 626 |
|
| 627 |
except Exception as e:
|
| 628 |
+
timestamp = datetime.now().strftime('%H:%M:%S')
|
| 629 |
+
st.session_state.log_output.append(f"[{timestamp}] ERROR: {str(e)}")
|
| 630 |
log_placeholder.code("\n".join(st.session_state.log_output))
|
| 631 |
status_placeholder.error(f"❌ Error: {str(e)}")
|
| 632 |
import traceback
|
| 633 |
st.session_state.log_output.append(traceback.format_exc())
|
| 634 |
log_placeholder.code("\n".join(st.session_state.log_output))
|
| 635 |
|
|
|
|
|
|
|
| 636 |
if st.session_state.last_result:
|
| 637 |
+
with tab_results:
|
| 638 |
+
result = st.session_state.last_result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
|
| 640 |
+
with metrics_placeholder.container():
|
| 641 |
+
eval_metrics = result.get('evaluation', {})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 642 |
|
| 643 |
+
m1, m2, m3, m4 = st.columns(4)
|
| 644 |
+
with m1:
|
| 645 |
+
completeness = eval_metrics.get('completeness', 0) * 100
|
| 646 |
+
st.metric("Completeness", f"{completeness:.1f}%")
|
| 647 |
+
with m2:
|
| 648 |
+
signal_cov = eval_metrics.get('interface_signal_coverage', 0) * 100
|
| 649 |
+
st.metric("Signal Coverage", f"{signal_cov:.1f}%")
|
| 650 |
+
with m3:
|
| 651 |
+
reg_cov = eval_metrics.get('register_coverage', 0) * 100
|
| 652 |
+
st.metric("Register Coverage", f"{reg_cov:.1f}%")
|
| 653 |
+
with m4:
|
| 654 |
+
st.metric("Files Generated", len(st.session_state.generated_files))
|
| 655 |
|
| 656 |
+
m5, m6 = st.columns(2)
|
| 657 |
+
with m5:
|
| 658 |
+
st.metric("Auto-Train Iterations", result.get('auto_train_iterations', 0))
|
| 659 |
+
with m6:
|
| 660 |
+
if result.get('passed'):
|
| 661 |
+
st.metric("Status", "✅ PASSED")
|
| 662 |
+
else:
|
| 663 |
+
st.metric("Status", "⚠️ WARNINGS")
|
| 664 |
+
|
| 665 |
+
with files_placeholder.expander("📄 Generated Files", expanded=True):
|
| 666 |
+
if st.session_state.generated_files:
|
| 667 |
+
file_names = sorted(st.session_state.generated_files.keys())
|
| 668 |
+
selected_file = st.selectbox("Select file to preview", file_names, key="file_selector")
|
| 669 |
+
|
| 670 |
+
if selected_file:
|
| 671 |
+
file_path = st.session_state.generated_files[selected_file]
|
| 672 |
+
if os.path.exists(file_path):
|
| 673 |
+
try:
|
| 674 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 675 |
+
content = f.read()
|
| 676 |
+
|
| 677 |
+
st.code(content, language='systemverilog')
|
| 678 |
+
|
| 679 |
+
col1, col2 = st.columns([1, 1])
|
| 680 |
+
with col1:
|
| 681 |
+
st.download_button(
|
| 682 |
+
f"📥 Download {selected_file}",
|
| 683 |
+
data=content,
|
| 684 |
+
file_name=selected_file,
|
| 685 |
+
mime="text/plain",
|
| 686 |
+
use_container_width=True
|
| 687 |
+
)
|
| 688 |
+
with col2:
|
| 689 |
+
st.info(f"Lines: {len(content.splitlines())} | Size: {len(content)} bytes")
|
| 690 |
+
except Exception as e:
|
| 691 |
+
st.warning(f"Could not read file: {e}")
|
| 692 |
|
| 693 |
+
if st.session_state.generated_files:
|
| 694 |
+
zip_buffer = io.BytesIO()
|
| 695 |
+
|
| 696 |
+
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
| 697 |
+
for name, path in st.session_state.generated_files.items():
|
| 698 |
+
if os.path.exists(path):
|
| 699 |
+
zipf.write(path, arcname=name)
|
| 700 |
+
|
| 701 |
+
zip_buffer.seek(0)
|
| 702 |
+
|
| 703 |
+
st.download_button(
|
| 704 |
+
label="📦 Download All Files as ZIP",
|
| 705 |
+
data=zip_buffer,
|
| 706 |
+
file_name=f"{design_name}_uvm_testbench.zip",
|
| 707 |
+
mime="application/zip",
|
| 708 |
+
use_container_width=True,
|
| 709 |
+
type="primary"
|
| 710 |
+
)
|
| 711 |
|
|
|
|
|
|
|
| 712 |
st.divider()
|
| 713 |
+
|
| 714 |
+
footer_col1, footer_col2, footer_col3 = st.columns([1, 2, 1])
|
| 715 |
+
|
| 716 |
+
with footer_col2:
|
| 717 |
+
st.caption("""
|
| 718 |
+
**UVM Testbench Generator v2.0** • AI-Powered by **Sai Kumar Taraka**
|
| 719 |
+
🔬 Advanced ML: RL (UCB/Softmax/Thompson) + Pattern Mining + Experience Replay + Eligibility Traces
|
| 720 |
+
📚 Protocol Libraries: UART, SPI, I2C, AXI4-Lite, APB, Wishbone
|
| 721 |
+
🎯 Deep UVM Validation: Factory Registration, Phases, TLM Connections, Coverage
|
| 722 |
+
""")
|
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quick smoke test for V2 ML model - final version
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 9 |
+
sys.path.insert(0, repo_root)
|
| 10 |
+
|
| 11 |
+
from src.config import ConfigLoader, PipelineConfig, MLConfig, GenerationConfig, AutoTrainConfig
|
| 12 |
+
from src.pipeline import TBPipeline
|
| 13 |
+
|
| 14 |
+
spec_path = os.path.join(repo_root, "configs", "uart_demo.yaml")
|
| 15 |
+
|
| 16 |
+
print("="*60)
|
| 17 |
+
print("V2 ML Model Smoke Test")
|
| 18 |
+
print("="*60)
|
| 19 |
+
|
| 20 |
+
print("\n1. Creating pipeline config with V2 model (UCB strategy)...")
|
| 21 |
+
|
| 22 |
+
ml_cfg = MLConfig(
|
| 23 |
+
enabled=True,
|
| 24 |
+
model_type="v2",
|
| 25 |
+
exploration_strategy="ucb",
|
| 26 |
+
use_llm=False,
|
| 27 |
+
use_semantic_encoder=False,
|
| 28 |
+
use_learning=True,
|
| 29 |
+
strict_validation=True
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
pipeline_cfg = PipelineConfig(
|
| 33 |
+
ml=ml_cfg,
|
| 34 |
+
generation=GenerationConfig(
|
| 35 |
+
templates_dir=os.path.join(repo_root, "src", "generation", "templates"),
|
| 36 |
+
output_dir=os.path.join(repo_root, "output"),
|
| 37 |
+
overwrite=True
|
| 38 |
+
),
|
| 39 |
+
auto_train=AutoTrainConfig(
|
| 40 |
+
enabled=False,
|
| 41 |
+
max_iterations=1
|
| 42 |
+
)
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
print(f" ML enabled: {pipeline_cfg.ml.enabled}")
|
| 46 |
+
print(f" Model type: {pipeline_cfg.ml.model_type}")
|
| 47 |
+
print(f" Exploration strategy: {pipeline_cfg.ml.exploration_strategy}")
|
| 48 |
+
print(f" Strict validation: {pipeline_cfg.ml.strict_validation}")
|
| 49 |
+
print(f" Auto-train: {pipeline_cfg.auto_train.enabled}")
|
| 50 |
+
|
| 51 |
+
print("\n2. Creating pipeline with V2 model...")
|
| 52 |
+
pipeline = TBPipeline(pipeline_cfg)
|
| 53 |
+
|
| 54 |
+
print(f" Model type: {type(pipeline.model).__name__}")
|
| 55 |
+
|
| 56 |
+
print("\n3. Running generation with UART demo spec...")
|
| 57 |
+
result = pipeline.run(spec_path)
|
| 58 |
+
|
| 59 |
+
print(f"\n Result passed: {result.get('passed', False)}")
|
| 60 |
+
print(f" Files generated: {len(result.get('generated_files', {}))}")
|
| 61 |
+
print(f" Auto-train iterations: {result.get('auto_train_iterations', 0)}")
|
| 62 |
+
|
| 63 |
+
if result.get('passed'):
|
| 64 |
+
print("\n [OK] Generation PASSED")
|
| 65 |
+
else:
|
| 66 |
+
print("\n [WARNING] Generation had issues")
|
| 67 |
+
|
| 68 |
+
if result.get('generated_files'):
|
| 69 |
+
print("\n4. Generated files:")
|
| 70 |
+
for name, path in result['generated_files'].items():
|
| 71 |
+
if os.path.exists(path):
|
| 72 |
+
size = os.path.getsize(path)
|
| 73 |
+
print(f" - {name}: {size} bytes")
|
| 74 |
+
|
| 75 |
+
if hasattr(pipeline.model, 'get_learning_stats'):
|
| 76 |
+
print("\n5. ML Learning Stats:")
|
| 77 |
+
stats = pipeline.model.get_learning_stats()
|
| 78 |
+
print(f" - Total generations: {stats.get('total_generations', 0)}")
|
| 79 |
+
if 'source_distribution' in stats:
|
| 80 |
+
print(f" - Source distribution: {stats['source_distribution']}")
|
| 81 |
+
if 'strategy_weights' in stats:
|
| 82 |
+
print(f" - Strategy weights: {stats['strategy_weights']}")
|
| 83 |
+
if 'rl_learner' in stats:
|
| 84 |
+
rl = stats['rl_learner']
|
| 85 |
+
print(f" - RL episodes: {rl.get('episode_count', 0)}")
|
| 86 |
+
print(f" - RL total updates: {rl.get('total_updates', 0)}")
|
| 87 |
+
print(f" - RL learning rate: {rl.get('learning_rate', 0.1)}")
|
| 88 |
+
if 'state_stats' in rl:
|
| 89 |
+
state_stats = rl['state_stats']
|
| 90 |
+
if state_stats:
|
| 91 |
+
print(f" - RL state stats (first 3):")
|
| 92 |
+
for state, info in list(state_stats.items())[:3]:
|
| 93 |
+
print(f" * '{state}': best='{info.get('best_action', 'N/A')}', Q={info.get('best_q_value', 0):.3f}")
|
| 94 |
+
|
| 95 |
+
eval_metrics = result.get('evaluation', {})
|
| 96 |
+
print("\n6. Evaluation Metrics:")
|
| 97 |
+
for key, value in eval_metrics.items():
|
| 98 |
+
if isinstance(value, (int, float)):
|
| 99 |
+
if 0 <= value <= 1:
|
| 100 |
+
print(f" - {key}: {value*100:.1f}%")
|
| 101 |
+
else:
|
| 102 |
+
print(f" - {key}: {value}")
|
| 103 |
+
|
| 104 |
+
val_results = result.get('validation_results', {})
|
| 105 |
+
if val_results:
|
| 106 |
+
total_checks = 0
|
| 107 |
+
total_passed = 0
|
| 108 |
+
|
| 109 |
+
print("\n7. Validation Results (Deep UVM Compliance):")
|
| 110 |
+
for file_path, file_result in val_results.items():
|
| 111 |
+
file_name = os.path.basename(file_path)
|
| 112 |
+
checks = file_result.get('checks', [])
|
| 113 |
+
|
| 114 |
+
for check in checks:
|
| 115 |
+
total_checks += 1
|
| 116 |
+
if check.get('passed'):
|
| 117 |
+
total_passed += 1
|
| 118 |
+
|
| 119 |
+
if total_checks > 0:
|
| 120 |
+
pass_rate = (total_passed / total_checks) * 100
|
| 121 |
+
print(f" - Total checks: {total_checks}")
|
| 122 |
+
print(f" - Passed: {total_passed}")
|
| 123 |
+
print(f" - Pass rate: {pass_rate:.1f}%")
|
| 124 |
+
|
| 125 |
+
print("\n" + "="*60)
|
| 126 |
+
if result.get('passed'):
|
| 127 |
+
print("TEST PASSED - V2 ML Model working correctly!")
|
| 128 |
+
else:
|
| 129 |
+
print("TEST COMPLETED - Review warnings above")
|
| 130 |
+
print("="*60)
|
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test script for Advanced ML V2 Model
|
| 3 |
+
Tests: RL strategies, experience replay, eligibility traces, pattern learning, deep validation
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
import tempfile
|
| 9 |
+
import yaml
|
| 10 |
+
|
| 11 |
+
repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 12 |
+
sys.path.insert(0, repo_root)
|
| 13 |
+
|
| 14 |
+
from src.models.enhanced_ml_model_v2 import EnhancedMLGenerationModelV2
|
| 15 |
+
from src.config import PipelineConfig, MLConfig, AutoTrainConfig, GenerationConfig
|
| 16 |
+
|
| 17 |
+
TEST_SPEC = """
|
| 18 |
+
design_name: uart
|
| 19 |
+
clock_reset:
|
| 20 |
+
clock: clk
|
| 21 |
+
reset: rst_n
|
| 22 |
+
|
| 23 |
+
interfaces:
|
| 24 |
+
- name: wb
|
| 25 |
+
signals:
|
| 26 |
+
- name: wb_cyc
|
| 27 |
+
direction: input
|
| 28 |
+
- name: wb_stb
|
| 29 |
+
direction: input
|
| 30 |
+
- name: wb_we
|
| 31 |
+
direction: input
|
| 32 |
+
- name: wb_addr
|
| 33 |
+
direction: input
|
| 34 |
+
width: 3
|
| 35 |
+
- name: wb_data_o
|
| 36 |
+
direction: output
|
| 37 |
+
width: 8
|
| 38 |
+
- name: wb_data_i
|
| 39 |
+
direction: input
|
| 40 |
+
width: 8
|
| 41 |
+
- name: wb_ack
|
| 42 |
+
direction: output
|
| 43 |
+
|
| 44 |
+
- name: uart
|
| 45 |
+
signals:
|
| 46 |
+
- name: uart_tx
|
| 47 |
+
direction: output
|
| 48 |
+
- name: uart_rx
|
| 49 |
+
direction: input
|
| 50 |
+
- name: cts_n
|
| 51 |
+
direction: input
|
| 52 |
+
- name: rts_n
|
| 53 |
+
direction: output
|
| 54 |
+
- name: uart_intr
|
| 55 |
+
direction: output
|
| 56 |
+
|
| 57 |
+
registers:
|
| 58 |
+
- name: RBR_THR
|
| 59 |
+
address: 0x0
|
| 60 |
+
description: Receiver Buffer / Transmitter Holding
|
| 61 |
+
fields:
|
| 62 |
+
- name: data
|
| 63 |
+
bits: 7:0
|
| 64 |
+
- name: IER
|
| 65 |
+
address: 0x1
|
| 66 |
+
description: Interrupt Enable
|
| 67 |
+
fields:
|
| 68 |
+
- name: erbfi
|
| 69 |
+
bits: '0'
|
| 70 |
+
description: Enable RX data available interrupt
|
| 71 |
+
- name: etbei
|
| 72 |
+
bits: '1'
|
| 73 |
+
description: Enable TX holding register empty interrupt
|
| 74 |
+
- name: LCR
|
| 75 |
+
address: 0x3
|
| 76 |
+
description: Line Control
|
| 77 |
+
fields:
|
| 78 |
+
- name: wls
|
| 79 |
+
bits: 1:0
|
| 80 |
+
description: Word length select
|
| 81 |
+
- name: dlab
|
| 82 |
+
bits: '7'
|
| 83 |
+
description: Divisor latch access bit
|
| 84 |
+
- name: LSR
|
| 85 |
+
address: 0x5
|
| 86 |
+
description: Line Status
|
| 87 |
+
fields:
|
| 88 |
+
- name: dr
|
| 89 |
+
bits: '0'
|
| 90 |
+
description: Data Ready
|
| 91 |
+
- name: thre
|
| 92 |
+
bits: '5'
|
| 93 |
+
description: TX Holding Register Empty
|
| 94 |
+
|
| 95 |
+
protocol: uart
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def test_rl_strategies():
|
| 99 |
+
"""Test all RL exploration strategies."""
|
| 100 |
+
print("\n" + "="*60)
|
| 101 |
+
print("Testing RL Exploration Strategies")
|
| 102 |
+
print("="*60)
|
| 103 |
+
|
| 104 |
+
strategies = ["epsilon_greedy", "softmax", "ucb", "thompson"]
|
| 105 |
+
results = {}
|
| 106 |
+
|
| 107 |
+
for strategy in strategies:
|
| 108 |
+
print(f"\n--- Testing {strategy} strategy ---")
|
| 109 |
+
|
| 110 |
+
cfg = PipelineConfig(
|
| 111 |
+
ml=MLConfig(
|
| 112 |
+
enabled=True,
|
| 113 |
+
model_type="v2",
|
| 114 |
+
exploration_strategy=strategy,
|
| 115 |
+
use_llm=False,
|
| 116 |
+
use_semantic_encoder=False,
|
| 117 |
+
use_learning=True,
|
| 118 |
+
learning_storage_path=None
|
| 119 |
+
)
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
model = EnhancedMLGenerationModelV2(cfg)
|
| 123 |
+
|
| 124 |
+
spec_dict = yaml.safe_load(TEST_SPEC)
|
| 125 |
+
|
| 126 |
+
result = model.generate(spec_dict)
|
| 127 |
+
passed = result['passed']
|
| 128 |
+
generated_files = result.get('generated_files', {})
|
| 129 |
+
|
| 130 |
+
print(f" Passed: {passed}")
|
| 131 |
+
print(f" Files generated: {len(generated_files)}")
|
| 132 |
+
print(f" Source: {result.get('source', 'unknown')}")
|
| 133 |
+
print(f" Strategy used: {result.get('strategy', 'unknown')}")
|
| 134 |
+
|
| 135 |
+
if hasattr(model, '_rl_learner'):
|
| 136 |
+
rl_stats = model._rl_learner.get_performance_stats()
|
| 137 |
+
print(f" RL episodes: {rl_stats.get('episode_count', 0)}")
|
| 138 |
+
print(f" RL total updates: {rl_stats.get('total_updates', 0)}")
|
| 139 |
+
|
| 140 |
+
results[strategy] = {
|
| 141 |
+
"passed": passed,
|
| 142 |
+
"files_count": len(generated_files),
|
| 143 |
+
"source": result.get('source', 'unknown'),
|
| 144 |
+
"strategy": result.get('strategy', 'unknown')
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
print("\n--- Strategy Results Summary ---")
|
| 148 |
+
for strategy, res in results.items():
|
| 149 |
+
status = "✅" if res["passed"] else "❌"
|
| 150 |
+
print(f" {status} {strategy}: {res['files_count']} files, source={res['source']}, strategy={res['strategy']}")
|
| 151 |
+
|
| 152 |
+
return all(r["passed"] for r in results.values())
|
| 153 |
+
|
| 154 |
+
def test_experience_replay():
|
| 155 |
+
"""Test experience replay buffer and eligibility traces."""
|
| 156 |
+
print("\n" + "="*60)
|
| 157 |
+
print("Testing Experience Replay & Eligibility Traces")
|
| 158 |
+
print("="*60)
|
| 159 |
+
|
| 160 |
+
cfg = PipelineConfig(
|
| 161 |
+
ml=MLConfig(
|
| 162 |
+
enabled=True,
|
| 163 |
+
model_type="v2",
|
| 164 |
+
exploration_strategy="ucb",
|
| 165 |
+
use_llm=False,
|
| 166 |
+
use_semantic_encoder=False,
|
| 167 |
+
use_learning=True,
|
| 168 |
+
learning_storage_path=None
|
| 169 |
+
)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
model = EnhancedMLGenerationModelV2(cfg)
|
| 173 |
+
spec_dict = yaml.safe_load(TEST_SPEC)
|
| 174 |
+
|
| 175 |
+
print(" Running multiple generations to populate replay buffer...")
|
| 176 |
+
|
| 177 |
+
for i in range(5):
|
| 178 |
+
result = model.generate(spec_dict)
|
| 179 |
+
print(f" Generation {i+1}: passed={result['passed']}, source={result.get('source', 'unknown')}")
|
| 180 |
+
|
| 181 |
+
reward = 1.0 if result['passed'] else 0.0
|
| 182 |
+
model.learn(result, reward)
|
| 183 |
+
|
| 184 |
+
if hasattr(model, '_rl_learner'):
|
| 185 |
+
rl = model._rl_learner
|
| 186 |
+
|
| 187 |
+
print(f"\n Experience replay buffer size: {len(rl._replay_buffer)}")
|
| 188 |
+
print(f" Episode count: {rl.get_performance_stats().get('episode_count', 0)}")
|
| 189 |
+
|
| 190 |
+
if hasattr(rl, '_eligibility_traces') and rl._eligibility_traces:
|
| 191 |
+
print(f" Eligibility traces tracked: {len(rl._eligibility_traces)}")
|
| 192 |
+
|
| 193 |
+
state_stats = rl.get_state_stats()
|
| 194 |
+
print(f"\n State statistics (first 3):")
|
| 195 |
+
for state, stats in list(state_stats.items())[:3]:
|
| 196 |
+
print(f" '{state}': best_action='{stats.get('best_action', 'N/A')}', Q={stats.get('best_q_value', 0):.3f}, visits={stats.get('visit_count', 0)}")
|
| 197 |
+
|
| 198 |
+
return len(rl._replay_buffer) > 0
|
| 199 |
+
|
| 200 |
+
return False
|
| 201 |
+
|
| 202 |
+
def test_pattern_learner():
|
| 203 |
+
"""Test advanced pattern learning."""
|
| 204 |
+
print("\n" + "="*60)
|
| 205 |
+
print("Testing Advanced Pattern Learner")
|
| 206 |
+
print("="*60)
|
| 207 |
+
|
| 208 |
+
cfg = PipelineConfig(
|
| 209 |
+
ml=MLConfig(
|
| 210 |
+
enabled=True,
|
| 211 |
+
model_type="v2",
|
| 212 |
+
exploration_strategy="ucb",
|
| 213 |
+
use_llm=False,
|
| 214 |
+
use_semantic_encoder=False,
|
| 215 |
+
use_learning=True,
|
| 216 |
+
learning_storage_path=None
|
| 217 |
+
)
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
model = EnhancedMLGenerationModelV2(cfg)
|
| 221 |
+
spec_dict = yaml.safe_load(TEST_SPEC)
|
| 222 |
+
|
| 223 |
+
print(" Running generations for pattern learning...")
|
| 224 |
+
|
| 225 |
+
for i in range(3):
|
| 226 |
+
result = model.generate(spec_dict)
|
| 227 |
+
reward = 1.0 if result['passed'] else 0.0
|
| 228 |
+
model.learn(result, reward)
|
| 229 |
+
|
| 230 |
+
if hasattr(model, '_pattern_learner'):
|
| 231 |
+
pl = model._pattern_learner
|
| 232 |
+
|
| 233 |
+
stats = pl.get_statistics()
|
| 234 |
+
print(f"\n Pattern Learner Stats:")
|
| 235 |
+
print(f" Total specs seen: {stats['total_specs_seen']}")
|
| 236 |
+
print(f" Total generations: {stats['total_generations']}")
|
| 237 |
+
print(f" Average score: {stats['avg_score']:.3f}")
|
| 238 |
+
print(f" N-gram vocabulary size: {len(stats['ngram_vocab'])}")
|
| 239 |
+
print(f" Association rules: {len(stats['association_rules'])}")
|
| 240 |
+
|
| 241 |
+
recs = pl.get_recommendations(spec_dict)
|
| 242 |
+
print(f"\n Recommendations for current spec:")
|
| 243 |
+
for rec in recs[:5]:
|
| 244 |
+
print(f" • {rec}")
|
| 245 |
+
|
| 246 |
+
common = pl.get_common_error_patterns(top_n=5)
|
| 247 |
+
if common:
|
| 248 |
+
print(f"\n Common error patterns:")
|
| 249 |
+
for pattern, count in common:
|
| 250 |
+
print(f" • '{pattern}': {count} occurrences")
|
| 251 |
+
|
| 252 |
+
return True
|
| 253 |
+
|
| 254 |
+
return False
|
| 255 |
+
|
| 256 |
+
def test_deep_validation():
|
| 257 |
+
"""Test deep UVM compliance validation."""
|
| 258 |
+
print("\n" + "="*60)
|
| 259 |
+
print("Testing Deep UVM Compliance Validation")
|
| 260 |
+
print("="*60)
|
| 261 |
+
|
| 262 |
+
cfg = PipelineConfig(
|
| 263 |
+
ml=MLConfig(
|
| 264 |
+
enabled=True,
|
| 265 |
+
model_type="v2",
|
| 266 |
+
exploration_strategy="ucb",
|
| 267 |
+
use_llm=False,
|
| 268 |
+
use_semantic_encoder=False,
|
| 269 |
+
use_learning=True,
|
| 270 |
+
strict_validation=True,
|
| 271 |
+
learning_storage_path=None
|
| 272 |
+
)
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
model = EnhancedMLGenerationModelV2(cfg)
|
| 276 |
+
spec_dict = yaml.safe_load(TEST_SPEC)
|
| 277 |
+
|
| 278 |
+
result = model.generate(spec_dict)
|
| 279 |
+
|
| 280 |
+
print(f"\n Generated files: {len(result.get('generated_files', {}))}")
|
| 281 |
+
print(f" Passed: {result['passed']}")
|
| 282 |
+
|
| 283 |
+
val_results = result.get('validation_results', {})
|
| 284 |
+
|
| 285 |
+
if val_results:
|
| 286 |
+
print(f"\n Validation Results:")
|
| 287 |
+
total_checks = 0
|
| 288 |
+
total_passed = 0
|
| 289 |
+
|
| 290 |
+
for file_path, file_result in val_results.items():
|
| 291 |
+
file_name = os.path.basename(file_path)
|
| 292 |
+
checks = file_result.get('checks', [])
|
| 293 |
+
|
| 294 |
+
if checks:
|
| 295 |
+
print(f"\n {file_name}:")
|
| 296 |
+
for check in checks:
|
| 297 |
+
total_checks += 1
|
| 298 |
+
status = "✅" if check.get('passed', False) else "❌"
|
| 299 |
+
if check.get('passed'):
|
| 300 |
+
total_passed += 1
|
| 301 |
+
|
| 302 |
+
msg = f" {status} {check.get('check_name', 'unknown')}"
|
| 303 |
+
if check.get('message'):
|
| 304 |
+
msg += f": {check['message']}"
|
| 305 |
+
print(msg)
|
| 306 |
+
|
| 307 |
+
if total_checks > 0:
|
| 308 |
+
pass_rate = (total_passed / total_checks) * 100
|
| 309 |
+
print(f"\n Overall validation pass rate: {pass_rate:.1f}% ({total_passed}/{total_checks})")
|
| 310 |
+
|
| 311 |
+
return total_checks > 0
|
| 312 |
+
|
| 313 |
+
return False
|
| 314 |
+
|
| 315 |
+
def test_learning_persistence():
|
| 316 |
+
"""Test saving and loading learning state."""
|
| 317 |
+
print("\n" + "="*60)
|
| 318 |
+
print("Testing Learning State Persistence")
|
| 319 |
+
print("="*60)
|
| 320 |
+
|
| 321 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
| 322 |
+
state_path = f.name
|
| 323 |
+
|
| 324 |
+
try:
|
| 325 |
+
cfg = PipelineConfig(
|
| 326 |
+
ml=MLConfig(
|
| 327 |
+
enabled=True,
|
| 328 |
+
model_type="v2",
|
| 329 |
+
exploration_strategy="ucb",
|
| 330 |
+
use_llm=False,
|
| 331 |
+
use_semantic_encoder=False,
|
| 332 |
+
use_learning=True,
|
| 333 |
+
learning_storage_path=state_path
|
| 334 |
+
)
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
print(" Creating model and running generations...")
|
| 338 |
+
model = EnhancedMLGenerationModelV2(cfg)
|
| 339 |
+
spec_dict = yaml.safe_load(TEST_SPEC)
|
| 340 |
+
|
| 341 |
+
for i in range(3):
|
| 342 |
+
result = model.generate(spec_dict)
|
| 343 |
+
reward = 1.0 if result['passed'] else 0.0
|
| 344 |
+
model.learn(result, reward)
|
| 345 |
+
|
| 346 |
+
if hasattr(model, '_rl_learner'):
|
| 347 |
+
episodes_before = model._rl_learner.get_performance_stats().get('episode_count', 0)
|
| 348 |
+
replay_size_before = len(model._rl_learner._replay_buffer)
|
| 349 |
+
print(f" Episodes before save: {episodes_before}")
|
| 350 |
+
print(f" Replay buffer size before save: {replay_size_before}")
|
| 351 |
+
|
| 352 |
+
print(" Saving learning state...")
|
| 353 |
+
model.save_learning_state(state_path)
|
| 354 |
+
|
| 355 |
+
print(" Loading learning state into new model...")
|
| 356 |
+
model2 = EnhancedMLGenerationModelV2(cfg)
|
| 357 |
+
model2.load_learning_state(state_path)
|
| 358 |
+
|
| 359 |
+
if hasattr(model2, '_rl_learner'):
|
| 360 |
+
episodes_after = model2._rl_learner.get_performance_stats().get('episode_count', 0)
|
| 361 |
+
replay_size_after = len(model2._rl_learner._replay_buffer)
|
| 362 |
+
print(f" Episodes after load: {episodes_after}")
|
| 363 |
+
print(f" Replay buffer size after load: {replay_size_after}")
|
| 364 |
+
|
| 365 |
+
return episodes_after >= 3 and replay_size_after >= 3
|
| 366 |
+
|
| 367 |
+
return False
|
| 368 |
+
|
| 369 |
+
finally:
|
| 370 |
+
if os.path.exists(state_path):
|
| 371 |
+
os.unlink(state_path)
|
| 372 |
+
|
| 373 |
+
def test_learning_stats():
|
| 374 |
+
"""Test ML stats generation for UI."""
|
| 375 |
+
print("\n" + "="*60)
|
| 376 |
+
print("Testing Learning Statistics (for UI)")
|
| 377 |
+
print("="*60)
|
| 378 |
+
|
| 379 |
+
cfg = PipelineConfig(
|
| 380 |
+
ml=MLConfig(
|
| 381 |
+
enabled=True,
|
| 382 |
+
model_type="v2",
|
| 383 |
+
exploration_strategy="ucb",
|
| 384 |
+
use_llm=False,
|
| 385 |
+
use_semantic_encoder=False,
|
| 386 |
+
use_learning=True,
|
| 387 |
+
learning_storage_path=None
|
| 388 |
+
)
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
model = EnhancedMLGenerationModelV2(cfg)
|
| 392 |
+
spec_dict = yaml.safe_load(TEST_SPEC)
|
| 393 |
+
|
| 394 |
+
for i in range(3):
|
| 395 |
+
result = model.generate(spec_dict)
|
| 396 |
+
reward = 1.0 if result['passed'] else 0.0
|
| 397 |
+
model.learn(result, reward)
|
| 398 |
+
|
| 399 |
+
if hasattr(model, 'get_learning_stats'):
|
| 400 |
+
stats = model.get_learning_stats()
|
| 401 |
+
|
| 402 |
+
print(f"\n Learning Stats:")
|
| 403 |
+
print(f" Total generations: {stats.get('total_generations', 0)}")
|
| 404 |
+
|
| 405 |
+
if 'source_distribution' in stats:
|
| 406 |
+
print(f"\n Source distribution:")
|
| 407 |
+
for source, count in stats['source_distribution'].items():
|
| 408 |
+
print(f" • {source}: {count}")
|
| 409 |
+
|
| 410 |
+
if 'strategy_weights' in stats:
|
| 411 |
+
print(f"\n Strategy weights:")
|
| 412 |
+
for strategy, weight in stats['strategy_weights'].items():
|
| 413 |
+
print(f" • {strategy}: {weight}")
|
| 414 |
+
|
| 415 |
+
if 'rl_learner' in stats:
|
| 416 |
+
print(f"\n RL Learner stats:")
|
| 417 |
+
print(f" Episode count: {stats['rl_learner'].get('episode_count', 0)}")
|
| 418 |
+
print(f" Total updates: {stats['rl_learner'].get('total_updates', 0)}")
|
| 419 |
+
|
| 420 |
+
if 'pattern_learner' in stats:
|
| 421 |
+
print(f"\n Pattern Learner stats:")
|
| 422 |
+
print(f" Total specs seen: {stats['pattern_learner'].get('total_specs_seen', 0)}")
|
| 423 |
+
|
| 424 |
+
return True
|
| 425 |
+
|
| 426 |
+
return False
|
| 427 |
+
|
| 428 |
+
def run_all_tests():
|
| 429 |
+
"""Run all tests and report results."""
|
| 430 |
+
print("\n" + "="*60)
|
| 431 |
+
print("Advanced ML V2 Model - Complete Test Suite")
|
| 432 |
+
print("="*60)
|
| 433 |
+
|
| 434 |
+
tests = [
|
| 435 |
+
("RL Exploration Strategies", test_rl_strategies),
|
| 436 |
+
("Experience Replay & Eligibility Traces", test_experience_replay),
|
| 437 |
+
("Advanced Pattern Learner", test_pattern_learner),
|
| 438 |
+
("Deep UVM Validation", test_deep_validation),
|
| 439 |
+
("Learning State Persistence", test_learning_persistence),
|
| 440 |
+
("Learning Statistics (UI)", test_learning_stats),
|
| 441 |
+
]
|
| 442 |
+
|
| 443 |
+
results = []
|
| 444 |
+
|
| 445 |
+
for name, test_func in tests:
|
| 446 |
+
try:
|
| 447 |
+
result = test_func()
|
| 448 |
+
results.append((name, result, None))
|
| 449 |
+
except Exception as e:
|
| 450 |
+
results.append((name, False, str(e)))
|
| 451 |
+
|
| 452 |
+
print("\n" + "="*60)
|
| 453 |
+
print("Test Results Summary")
|
| 454 |
+
print("="*60)
|
| 455 |
+
|
| 456 |
+
all_passed = True
|
| 457 |
+
for name, result, error in results:
|
| 458 |
+
if result:
|
| 459 |
+
print(f"✅ {name}")
|
| 460 |
+
else:
|
| 461 |
+
print(f"❌ {name}")
|
| 462 |
+
all_passed = False
|
| 463 |
+
if error:
|
| 464 |
+
print(f" Error: {error}")
|
| 465 |
+
|
| 466 |
+
print("\n" + "="*60)
|
| 467 |
+
if all_passed:
|
| 468 |
+
print("🎉 All tests PASSED!")
|
| 469 |
+
else:
|
| 470 |
+
print("⚠️ Some tests FAILED")
|
| 471 |
+
print("="*60)
|
| 472 |
+
|
| 473 |
+
return all_passed
|
| 474 |
+
|
| 475 |
+
if __name__ == "__main__":
|
| 476 |
+
success = run_all_tests()
|
| 477 |
+
sys.exit(0 if success else 1)
|