verirl-env / tests /conftest.py
Supreeth's picture
Upload folder using huggingface_hub
12ac0fd verified
"""Shared pytest fixtures for VeriRL tests."""
import pytest
import shutil
from pathlib import Path
from verirl_env.server.verirl_env_environment import VerirlEnvironment
from verirl_env.server.evaluator import VerilogEvaluator
from verirl_env.models import VerirlAction
# Check if EDA tools are available
HAS_IVERILOG = shutil.which("iverilog") is not None
HAS_YOSYS = shutil.which("yosys") is not None
HAS_SBY = shutil.which("sby") is not None
@pytest.fixture
def evaluator():
"""Create a VerilogEvaluator instance."""
return VerilogEvaluator()
@pytest.fixture
def environment():
"""Create a VerirlEnvironment instance."""
return VerirlEnvironment()
@pytest.fixture
def mac_reference_verilog():
"""Reference implementation for mac_unit task."""
return """
module mac_unit (
input wire clk, rst, en, clear,
input wire signed [7:0] a, b,
output reg signed [31:0] acc_out
);
reg signed [15:0] product_s1;
reg en_s1, clear_s1;
always @(posedge clk) begin
if (rst) begin
product_s1 <= 0;
en_s1 <= 0;
clear_s1 <= 0;
end else begin
product_s1 <= a * b;
en_s1 <= en;
clear_s1 <= clear;
end
end
always @(posedge clk) begin
if (rst) acc_out <= 0;
else if (clear_s1) acc_out <= 0;
else if (en_s1) acc_out <= acc_out + product_s1;
end
endmodule
"""
@pytest.fixture
def axi_reference_verilog():
"""Reference implementation for axi_fifo task."""
return """
module axi_fifo #(parameter DATA_W = 8) (
input wire clk, rst,
input wire s_valid,
output wire s_ready,
input wire [DATA_W-1:0] s_data,
output wire m_valid,
input wire m_ready,
output wire [DATA_W-1:0] m_data,
output wire full, empty
);
localparam DEPTH = 4;
reg [DATA_W-1:0] mem [0:DEPTH-1];
reg [1:0] head, tail;
reg [2:0] count;
assign full = (count == DEPTH);
assign empty = (count == 0);
assign s_ready = !full;
assign m_valid = !empty;
assign m_data = mem[head];
wire enq = s_valid && s_ready;
wire deq = m_valid && m_ready;
always @(posedge clk) begin
if (rst) begin
head <= 0;
tail <= 0;
count <= 0;
end else begin
if (enq) begin
mem[tail] <= s_data;
tail <= tail + 1;
end
if (deq) head <= head + 1;
case ({enq, deq})
2'b10: count <= count + 1;
2'b01: count <= count - 1;
default: count <= count;
endcase
end
end
endmodule
"""
@pytest.fixture
def requires_iverilog(request):
"""Skip test if iverilog is not available."""
if not HAS_IVERILOG:
pytest.skip("iverilog not installed")
@pytest.fixture
def requires_yosys(request):
"""Skip test if yosys is not available."""
if not HAS_YOSYS:
pytest.skip("yosys not installed")
@pytest.fixture
def requires_eda_tools(request):
"""Skip test if EDA tools are not available."""
if not HAS_IVERILOG or not HAS_YOSYS:
pytest.skip("iverilog and/or yosys not installed")
@pytest.fixture
def requires_sby(request):
"""Skip test if SymbiYosys is not available."""
if not HAS_SBY:
pytest.skip("sby (SymbiYosys) not installed")
@pytest.fixture
def relu_clip_reference_verilog():
"""Reference implementation for relu_clip task."""
return """
module relu_clip #(
parameter IN_W = 8,
parameter OUT_W = 4
) (
input wire signed [IN_W-1:0] in_val,
output wire [OUT_W-1:0] out_val,
output wire saturated
);
localparam integer MAX_OUT = (1 << OUT_W) - 1;
wire neg = in_val[IN_W-1];
wire [IN_W-1:0] relu_out = neg ? {IN_W{1'b0}} : in_val;
wire pos_clip = (relu_out > MAX_OUT[IN_W-1:0]);
assign out_val = pos_clip ? MAX_OUT[OUT_W-1:0] : relu_out[OUT_W-1:0];
assign saturated = neg | pos_clip;
endmodule
"""
@pytest.fixture
def barrel_shifter_reference_verilog():
"""Reference implementation for barrel_shifter task."""
return """
module barrel_shifter #(
parameter WIDTH = 8
) (
input wire [WIDTH-1:0] data_in,
input wire [$clog2(WIDTH)-1:0] shift_amt,
input wire direction,
input wire arithmetic,
output wire [WIDTH-1:0] data_out
);
wire [WIDTH-1:0] left_out = data_in << shift_amt;
wire [WIDTH-1:0] right_log = data_in >> shift_amt;
wire [WIDTH-1:0] right_ari = $signed(data_in) >>> shift_amt;
assign data_out = direction
? (arithmetic ? right_ari : right_log)
: left_out;
endmodule
"""
@pytest.fixture
def register_file_reference_verilog():
"""Reference implementation for register_file task."""
return """
module register_file #(
parameter ADDR_W = 5,
parameter DATA_W = 32
) (
input wire clk,
input wire we,
input wire [ADDR_W-1:0] wr_addr,
input wire [DATA_W-1:0] wr_data,
input wire [ADDR_W-1:0] rd_addr_a,
input wire [ADDR_W-1:0] rd_addr_b,
output wire [DATA_W-1:0] rd_data_a,
output wire [DATA_W-1:0] rd_data_b
);
localparam DEPTH = 1 << ADDR_W;
reg [DATA_W-1:0] regs [0:DEPTH-1];
integer i;
initial begin
for (i = 0; i < DEPTH; i = i + 1) regs[i] = 0;
end
always @(posedge clk)
if (we && (wr_addr != 0)) regs[wr_addr] <= wr_data;
assign rd_data_a = (rd_addr_a == 0) ? {DATA_W{1'b0}} : regs[rd_addr_a];
assign rd_data_b = (rd_addr_b == 0) ? {DATA_W{1'b0}} : regs[rd_addr_b];
endmodule
"""
@pytest.fixture
def ring_buffer_reference_verilog():
"""Reference implementation for ring_buffer task."""
return """
`timescale 1ns/1ps
module ring_buffer #(
parameter DEPTH = 8,
parameter DATA_W = 8
) (
input wire clk,
input wire rst,
input wire push,
input wire [DATA_W-1:0] push_data,
input wire pop,
output wire [DATA_W-1:0] pop_data,
output wire full,
output wire empty,
output wire [$clog2(DEPTH):0] count
);
localparam PTR_W = $clog2(DEPTH);
reg [DATA_W-1:0] mem [0:DEPTH-1];
reg [PTR_W-1:0] head, tail;
reg [PTR_W:0] cnt;
assign full = (cnt == DEPTH);
assign empty = (cnt == 0);
assign count = cnt;
assign pop_data = mem[head];
wire do_push = push & ~full;
wire do_pop = pop & ~empty;
always @(posedge clk) begin
if (rst) begin
head <= {PTR_W{1'b0}};
tail <= {PTR_W{1'b0}};
cnt <= {(PTR_W+1){1'b0}};
end else begin
if (do_push) begin
mem[tail] <= push_data;
tail <= (tail == DEPTH-1) ? {PTR_W{1'b0}} : tail + 1'b1;
end
if (do_pop)
head <= (head == DEPTH-1) ? {PTR_W{1'b0}} : head + 1'b1;
if (do_push & ~do_pop) cnt <= cnt + 1'b1;
else if (do_pop & ~do_push) cnt <= cnt - 1'b1;
end
end
endmodule
"""
@pytest.fixture
def dot_product_reference_verilog():
"""Reference implementation for dot_product task."""
return """
`timescale 1ns/1ps
module dot_product_4 (
input wire clk,
input wire rst,
input wire valid_in,
input wire signed [7:0] a0, a1, a2, a3,
input wire signed [7:0] b0, b1, b2, b3,
output reg valid_out,
output reg signed [17:0] result
);
reg signed [15:0] p0, p1, p2, p3;
reg s1_valid;
always @(posedge clk) begin
if (rst) begin
p0 <= 0; p1 <= 0; p2 <= 0; p3 <= 0;
s1_valid <= 1'b0; result <= 0; valid_out <= 1'b0;
end else begin
p0 <= a0 * b0; p1 <= a1 * b1;
p2 <= a2 * b2; p3 <= a3 * b3;
s1_valid <= valid_in;
result <= p0 + p1 + p2 + p3;
valid_out <= s1_valid;
end
end
endmodule
"""
@pytest.fixture
def fir_filter_reference_verilog():
"""Reference implementation for fir_filter task."""
return """
`timescale 1ns/1ps
module fir3 (
input wire clk,
input wire rst,
input wire valid_in,
input wire signed [7:0] x,
input wire signed [7:0] h0, h1, h2,
output reg signed [17:0] y,
output reg valid_out
);
reg signed [7:0] x_d1, x_d2;
always @(posedge clk) begin
if (rst) begin
x_d1 <= 8'sb0; x_d2 <= 8'sb0;
y <= 18'sb0; valid_out <= 1'b0;
end else if (valid_in) begin
y <= h0*x + h1*x_d1 + h2*x_d2;
x_d1 <= x; x_d2 <= x_d1;
valid_out <= 1'b1;
end else begin
valid_out <= 1'b0;
end
end
endmodule
"""
@pytest.fixture
def fp16_adder_reference_verilog():
"""Reference implementation for fp16_adder task."""
return """
`timescale 1ns/1ps
module fp16_adder (
input wire [15:0] a,
input wire [15:0] b,
output wire [15:0] result
);
wire sa = a[15], sb = b[15];
wire [4:0] ea = a[14:10], eb = b[14:10];
wire [9:0] ma = a[9:0], mb = b[9:0];
wire a_inf = (ea == 5'h1F) & (ma == 10'h0);
wire b_inf = (eb == 5'h1F) & (mb == 10'h0);
wire a_nan = (ea == 5'h1F) & (ma != 10'h0);
wire b_nan = (eb == 5'h1F) & (mb != 10'h0);
wire a_zero = (ea == 5'h00) & (ma == 10'h0);
wire b_zero = (eb == 5'h00) & (mb == 10'h0);
wire [10:0] fa = a_zero ? 11'h0 : {1'b1, ma};
wire [10:0] fb = b_zero ? 11'h0 : {1'b1, mb};
wire [15:0] mag_a = {1'b0, a[14:0]};
wire [15:0] mag_b = {1'b0, b[14:0]};
wire swap = (mag_b > mag_a);
wire s_big = swap ? sb : sa;
wire [4:0] e_big = swap ? eb : ea;
wire [10:0] f_big = swap ? fb : fa;
wire s_sml = swap ? sa : sb;
wire [4:0] e_sml = swap ? ea : eb;
wire [10:0] f_sml = swap ? fa : fb;
wire [4:0] diff = e_big - e_sml;
wire [10:0] f_sml_sh = (diff >= 11) ? 11'h0 : (f_sml >> diff);
wire same_sign = (s_big == s_sml);
wire [11:0] sum_raw = same_sign
? ({1'b0, f_big} + {1'b0, f_sml_sh})
: ({1'b0, f_big} - {1'b0, f_sml_sh});
reg [3:0] lz;
always @(*) begin
casez (sum_raw)
12'b1??????????? : lz = 4'd0;
12'b01?????????? : lz = 4'd1;
12'b001????????? : lz = 4'd2;
12'b0001???????? : lz = 4'd3;
12'b00001??????? : lz = 4'd4;
12'b000001?????? : lz = 4'd5;
12'b0000001????? : lz = 4'd6;
12'b00000001???? : lz = 4'd7;
12'b000000001??? : lz = 4'd8;
12'b0000000001?? : lz = 4'd9;
12'b00000000001? : lz = 4'd10;
12'b000000000001 : lz = 4'd11;
default : lz = 4'd12;
endcase
end
wire [3:0] norm_shift = (lz == 0) ? 4'd0 : (lz - 4'd1);
wire [11:0] sum_norm = (lz == 0) ? (sum_raw >> 1) : (sum_raw << norm_shift);
wire [5:0] exp_adj = (sum_raw[11])
? ({1'b0, e_big} + 6'd1)
: ({1'b0, e_big} - {2'b0, lz} + 6'd1);
wire result_zero = (sum_raw == 12'h0);
wire exp_overflow = (exp_adj >= 6'd31) & ~result_zero;
wire [15:0] nan_out = 16'h7E00;
wire [15:0] inf_out = {s_big, 5'h1F, 10'h0};
wire [15:0] zero_out = 16'h0000;
wire [15:0] norm_out = {s_big, exp_adj[4:0], sum_norm[9:0]};
assign result =
(a_nan | b_nan) ? nan_out :
(a_inf & b_inf & (sa != sb)) ? nan_out :
(a_inf | b_inf) ? inf_out :
(exp_overflow) ? inf_out :
(result_zero) ? zero_out :
norm_out;
endmodule
"""
@pytest.fixture
def systolic_reference_verilog():
"""Reference implementation for systolic_array task.
Architecture: row i accumulates for exactly 4 cycles starting at cycle i
(diagonal skewing via row-gated enable, no shift registers).
output[i][j] = 4 * activations[i] * weights[i][j]
done fires at posedge 7 from start (done_cycle <= 7).
"""
return """
module systolic_array (
input wire clk, rst, load_weights, start,
input wire [63:0] weights_flat,
input wire [127:0] activations_flat,
output wire [255:0] outputs_flat,
output wire done
);
reg [3:0] weights [0:3][0:3];
reg [15:0] acc [0:3][0:3];
reg [2:0] cyc;
reg running, done_reg;
assign done = done_reg;
genvar gi, gj;
generate
for (gi = 0; gi < 4; gi = gi + 1) begin : row_out
for (gj = 0; gj < 4; gj = gj + 1) begin : col_out
assign outputs_flat[(gi*4+gj)*16 +: 16] = acc[gi][gj];
end
end
endgenerate
integer li, lj;
always @(posedge clk) begin
if (load_weights)
for (li = 0; li < 4; li = li + 1)
for (lj = 0; lj < 4; lj = lj + 1)
weights[li][lj] <= weights_flat[(li*4+lj)*4 +: 4];
end
integer ci, cj;
always @(posedge clk) begin
if (rst) begin
running <= 0; done_reg <= 0; cyc <= 0;
for (ci = 0; ci < 4; ci = ci + 1)
for (cj = 0; cj < 4; cj = cj + 1)
acc[ci][cj] <= 0;
end else if (start) begin
running <= 1; done_reg <= 0; cyc <= 0;
for (ci = 0; ci < 4; ci = ci + 1)
for (cj = 0; cj < 4; cj = cj + 1)
acc[ci][cj] <= 0;
end else if (running) begin
for (ci = 0; ci < 4; ci = ci + 1) begin
if (cyc >= ci && cyc < ci + 4) begin
for (cj = 0; cj < 4; cj = cj + 1)
acc[ci][cj] <= acc[ci][cj]
+ {{12{1'b0}}, weights[ci][cj]}
* activations_flat[ci*8 +: 8];
end
end
cyc <= cyc + 1;
if (cyc == 3'd6) begin
done_reg <= 1;
running <= 0;
end
end else begin
done_reg <= 0;
end
end
endmodule
"""