Spaces:
Sleeping
Sleeping
| """Shared pytest fixtures for VeriRL tests.""" | |
| import pytest | |
| import shutil | |
| from pathlib import Path | |
| from verirl_env.server.verirl_env_environment import VerirlEnvironment | |
| from verirl_env.server.evaluator import VerilogEvaluator | |
| from verirl_env.models import VerirlAction | |
| # Check if EDA tools are available | |
| HAS_IVERILOG = shutil.which("iverilog") is not None | |
| HAS_YOSYS = shutil.which("yosys") is not None | |
| HAS_SBY = shutil.which("sby") is not None | |
| def evaluator(): | |
| """Create a VerilogEvaluator instance.""" | |
| return VerilogEvaluator() | |
| def environment(): | |
| """Create a VerirlEnvironment instance.""" | |
| return VerirlEnvironment() | |
| def mac_reference_verilog(): | |
| """Reference implementation for mac_unit task.""" | |
| return """ | |
| module mac_unit ( | |
| input wire clk, rst, en, clear, | |
| input wire signed [7:0] a, b, | |
| output reg signed [31:0] acc_out | |
| ); | |
| reg signed [15:0] product_s1; | |
| reg en_s1, clear_s1; | |
| always @(posedge clk) begin | |
| if (rst) begin | |
| product_s1 <= 0; | |
| en_s1 <= 0; | |
| clear_s1 <= 0; | |
| end else begin | |
| product_s1 <= a * b; | |
| en_s1 <= en; | |
| clear_s1 <= clear; | |
| end | |
| end | |
| always @(posedge clk) begin | |
| if (rst) acc_out <= 0; | |
| else if (clear_s1) acc_out <= 0; | |
| else if (en_s1) acc_out <= acc_out + product_s1; | |
| end | |
| endmodule | |
| """ | |
| def axi_reference_verilog(): | |
| """Reference implementation for axi_fifo task.""" | |
| return """ | |
| module axi_fifo #(parameter DATA_W = 8) ( | |
| input wire clk, rst, | |
| input wire s_valid, | |
| output wire s_ready, | |
| input wire [DATA_W-1:0] s_data, | |
| output wire m_valid, | |
| input wire m_ready, | |
| output wire [DATA_W-1:0] m_data, | |
| output wire full, empty | |
| ); | |
| localparam DEPTH = 4; | |
| reg [DATA_W-1:0] mem [0:DEPTH-1]; | |
| reg [1:0] head, tail; | |
| reg [2:0] count; | |
| assign full = (count == DEPTH); | |
| assign empty = (count == 0); | |
| assign s_ready = !full; | |
| assign m_valid = !empty; | |
| assign m_data = mem[head]; | |
| wire enq = s_valid && s_ready; | |
| wire deq = m_valid && m_ready; | |
| always @(posedge clk) begin | |
| if (rst) begin | |
| head <= 0; | |
| tail <= 0; | |
| count <= 0; | |
| end else begin | |
| if (enq) begin | |
| mem[tail] <= s_data; | |
| tail <= tail + 1; | |
| end | |
| if (deq) head <= head + 1; | |
| case ({enq, deq}) | |
| 2'b10: count <= count + 1; | |
| 2'b01: count <= count - 1; | |
| default: count <= count; | |
| endcase | |
| end | |
| end | |
| endmodule | |
| """ | |
| def requires_iverilog(request): | |
| """Skip test if iverilog is not available.""" | |
| if not HAS_IVERILOG: | |
| pytest.skip("iverilog not installed") | |
| def requires_yosys(request): | |
| """Skip test if yosys is not available.""" | |
| if not HAS_YOSYS: | |
| pytest.skip("yosys not installed") | |
| def requires_eda_tools(request): | |
| """Skip test if EDA tools are not available.""" | |
| if not HAS_IVERILOG or not HAS_YOSYS: | |
| pytest.skip("iverilog and/or yosys not installed") | |
| def requires_sby(request): | |
| """Skip test if SymbiYosys is not available.""" | |
| if not HAS_SBY: | |
| pytest.skip("sby (SymbiYosys) not installed") | |
| def relu_clip_reference_verilog(): | |
| """Reference implementation for relu_clip task.""" | |
| return """ | |
| module relu_clip #( | |
| parameter IN_W = 8, | |
| parameter OUT_W = 4 | |
| ) ( | |
| input wire signed [IN_W-1:0] in_val, | |
| output wire [OUT_W-1:0] out_val, | |
| output wire saturated | |
| ); | |
| localparam integer MAX_OUT = (1 << OUT_W) - 1; | |
| wire neg = in_val[IN_W-1]; | |
| wire [IN_W-1:0] relu_out = neg ? {IN_W{1'b0}} : in_val; | |
| wire pos_clip = (relu_out > MAX_OUT[IN_W-1:0]); | |
| assign out_val = pos_clip ? MAX_OUT[OUT_W-1:0] : relu_out[OUT_W-1:0]; | |
| assign saturated = neg | pos_clip; | |
| endmodule | |
| """ | |
| def barrel_shifter_reference_verilog(): | |
| """Reference implementation for barrel_shifter task.""" | |
| return """ | |
| module barrel_shifter #( | |
| parameter WIDTH = 8 | |
| ) ( | |
| input wire [WIDTH-1:0] data_in, | |
| input wire [$clog2(WIDTH)-1:0] shift_amt, | |
| input wire direction, | |
| input wire arithmetic, | |
| output wire [WIDTH-1:0] data_out | |
| ); | |
| wire [WIDTH-1:0] left_out = data_in << shift_amt; | |
| wire [WIDTH-1:0] right_log = data_in >> shift_amt; | |
| wire [WIDTH-1:0] right_ari = $signed(data_in) >>> shift_amt; | |
| assign data_out = direction | |
| ? (arithmetic ? right_ari : right_log) | |
| : left_out; | |
| endmodule | |
| """ | |
| def register_file_reference_verilog(): | |
| """Reference implementation for register_file task.""" | |
| return """ | |
| module register_file #( | |
| parameter ADDR_W = 5, | |
| parameter DATA_W = 32 | |
| ) ( | |
| input wire clk, | |
| input wire we, | |
| input wire [ADDR_W-1:0] wr_addr, | |
| input wire [DATA_W-1:0] wr_data, | |
| input wire [ADDR_W-1:0] rd_addr_a, | |
| input wire [ADDR_W-1:0] rd_addr_b, | |
| output wire [DATA_W-1:0] rd_data_a, | |
| output wire [DATA_W-1:0] rd_data_b | |
| ); | |
| localparam DEPTH = 1 << ADDR_W; | |
| reg [DATA_W-1:0] regs [0:DEPTH-1]; | |
| integer i; | |
| initial begin | |
| for (i = 0; i < DEPTH; i = i + 1) regs[i] = 0; | |
| end | |
| always @(posedge clk) | |
| if (we && (wr_addr != 0)) regs[wr_addr] <= wr_data; | |
| assign rd_data_a = (rd_addr_a == 0) ? {DATA_W{1'b0}} : regs[rd_addr_a]; | |
| assign rd_data_b = (rd_addr_b == 0) ? {DATA_W{1'b0}} : regs[rd_addr_b]; | |
| endmodule | |
| """ | |
| def ring_buffer_reference_verilog(): | |
| """Reference implementation for ring_buffer task.""" | |
| return """ | |
| `timescale 1ns/1ps | |
| module ring_buffer #( | |
| parameter DEPTH = 8, | |
| parameter DATA_W = 8 | |
| ) ( | |
| input wire clk, | |
| input wire rst, | |
| input wire push, | |
| input wire [DATA_W-1:0] push_data, | |
| input wire pop, | |
| output wire [DATA_W-1:0] pop_data, | |
| output wire full, | |
| output wire empty, | |
| output wire [$clog2(DEPTH):0] count | |
| ); | |
| localparam PTR_W = $clog2(DEPTH); | |
| reg [DATA_W-1:0] mem [0:DEPTH-1]; | |
| reg [PTR_W-1:0] head, tail; | |
| reg [PTR_W:0] cnt; | |
| assign full = (cnt == DEPTH); | |
| assign empty = (cnt == 0); | |
| assign count = cnt; | |
| assign pop_data = mem[head]; | |
| wire do_push = push & ~full; | |
| wire do_pop = pop & ~empty; | |
| always @(posedge clk) begin | |
| if (rst) begin | |
| head <= {PTR_W{1'b0}}; | |
| tail <= {PTR_W{1'b0}}; | |
| cnt <= {(PTR_W+1){1'b0}}; | |
| end else begin | |
| if (do_push) begin | |
| mem[tail] <= push_data; | |
| tail <= (tail == DEPTH-1) ? {PTR_W{1'b0}} : tail + 1'b1; | |
| end | |
| if (do_pop) | |
| head <= (head == DEPTH-1) ? {PTR_W{1'b0}} : head + 1'b1; | |
| if (do_push & ~do_pop) cnt <= cnt + 1'b1; | |
| else if (do_pop & ~do_push) cnt <= cnt - 1'b1; | |
| end | |
| end | |
| endmodule | |
| """ | |
| def dot_product_reference_verilog(): | |
| """Reference implementation for dot_product task.""" | |
| return """ | |
| `timescale 1ns/1ps | |
| module dot_product_4 ( | |
| input wire clk, | |
| input wire rst, | |
| input wire valid_in, | |
| input wire signed [7:0] a0, a1, a2, a3, | |
| input wire signed [7:0] b0, b1, b2, b3, | |
| output reg valid_out, | |
| output reg signed [17:0] result | |
| ); | |
| reg signed [15:0] p0, p1, p2, p3; | |
| reg s1_valid; | |
| always @(posedge clk) begin | |
| if (rst) begin | |
| p0 <= 0; p1 <= 0; p2 <= 0; p3 <= 0; | |
| s1_valid <= 1'b0; result <= 0; valid_out <= 1'b0; | |
| end else begin | |
| p0 <= a0 * b0; p1 <= a1 * b1; | |
| p2 <= a2 * b2; p3 <= a3 * b3; | |
| s1_valid <= valid_in; | |
| result <= p0 + p1 + p2 + p3; | |
| valid_out <= s1_valid; | |
| end | |
| end | |
| endmodule | |
| """ | |
| def fir_filter_reference_verilog(): | |
| """Reference implementation for fir_filter task.""" | |
| return """ | |
| `timescale 1ns/1ps | |
| module fir3 ( | |
| input wire clk, | |
| input wire rst, | |
| input wire valid_in, | |
| input wire signed [7:0] x, | |
| input wire signed [7:0] h0, h1, h2, | |
| output reg signed [17:0] y, | |
| output reg valid_out | |
| ); | |
| reg signed [7:0] x_d1, x_d2; | |
| always @(posedge clk) begin | |
| if (rst) begin | |
| x_d1 <= 8'sb0; x_d2 <= 8'sb0; | |
| y <= 18'sb0; valid_out <= 1'b0; | |
| end else if (valid_in) begin | |
| y <= h0*x + h1*x_d1 + h2*x_d2; | |
| x_d1 <= x; x_d2 <= x_d1; | |
| valid_out <= 1'b1; | |
| end else begin | |
| valid_out <= 1'b0; | |
| end | |
| end | |
| endmodule | |
| """ | |
| def fp16_adder_reference_verilog(): | |
| """Reference implementation for fp16_adder task.""" | |
| return """ | |
| `timescale 1ns/1ps | |
| module fp16_adder ( | |
| input wire [15:0] a, | |
| input wire [15:0] b, | |
| output wire [15:0] result | |
| ); | |
| wire sa = a[15], sb = b[15]; | |
| wire [4:0] ea = a[14:10], eb = b[14:10]; | |
| wire [9:0] ma = a[9:0], mb = b[9:0]; | |
| wire a_inf = (ea == 5'h1F) & (ma == 10'h0); | |
| wire b_inf = (eb == 5'h1F) & (mb == 10'h0); | |
| wire a_nan = (ea == 5'h1F) & (ma != 10'h0); | |
| wire b_nan = (eb == 5'h1F) & (mb != 10'h0); | |
| wire a_zero = (ea == 5'h00) & (ma == 10'h0); | |
| wire b_zero = (eb == 5'h00) & (mb == 10'h0); | |
| wire [10:0] fa = a_zero ? 11'h0 : {1'b1, ma}; | |
| wire [10:0] fb = b_zero ? 11'h0 : {1'b1, mb}; | |
| wire [15:0] mag_a = {1'b0, a[14:0]}; | |
| wire [15:0] mag_b = {1'b0, b[14:0]}; | |
| wire swap = (mag_b > mag_a); | |
| wire s_big = swap ? sb : sa; | |
| wire [4:0] e_big = swap ? eb : ea; | |
| wire [10:0] f_big = swap ? fb : fa; | |
| wire s_sml = swap ? sa : sb; | |
| wire [4:0] e_sml = swap ? ea : eb; | |
| wire [10:0] f_sml = swap ? fa : fb; | |
| wire [4:0] diff = e_big - e_sml; | |
| wire [10:0] f_sml_sh = (diff >= 11) ? 11'h0 : (f_sml >> diff); | |
| wire same_sign = (s_big == s_sml); | |
| wire [11:0] sum_raw = same_sign | |
| ? ({1'b0, f_big} + {1'b0, f_sml_sh}) | |
| : ({1'b0, f_big} - {1'b0, f_sml_sh}); | |
| reg [3:0] lz; | |
| always @(*) begin | |
| casez (sum_raw) | |
| 12'b1??????????? : lz = 4'd0; | |
| 12'b01?????????? : lz = 4'd1; | |
| 12'b001????????? : lz = 4'd2; | |
| 12'b0001???????? : lz = 4'd3; | |
| 12'b00001??????? : lz = 4'd4; | |
| 12'b000001?????? : lz = 4'd5; | |
| 12'b0000001????? : lz = 4'd6; | |
| 12'b00000001???? : lz = 4'd7; | |
| 12'b000000001??? : lz = 4'd8; | |
| 12'b0000000001?? : lz = 4'd9; | |
| 12'b00000000001? : lz = 4'd10; | |
| 12'b000000000001 : lz = 4'd11; | |
| default : lz = 4'd12; | |
| endcase | |
| end | |
| wire [3:0] norm_shift = (lz == 0) ? 4'd0 : (lz - 4'd1); | |
| wire [11:0] sum_norm = (lz == 0) ? (sum_raw >> 1) : (sum_raw << norm_shift); | |
| wire [5:0] exp_adj = (sum_raw[11]) | |
| ? ({1'b0, e_big} + 6'd1) | |
| : ({1'b0, e_big} - {2'b0, lz} + 6'd1); | |
| wire result_zero = (sum_raw == 12'h0); | |
| wire exp_overflow = (exp_adj >= 6'd31) & ~result_zero; | |
| wire [15:0] nan_out = 16'h7E00; | |
| wire [15:0] inf_out = {s_big, 5'h1F, 10'h0}; | |
| wire [15:0] zero_out = 16'h0000; | |
| wire [15:0] norm_out = {s_big, exp_adj[4:0], sum_norm[9:0]}; | |
| assign result = | |
| (a_nan | b_nan) ? nan_out : | |
| (a_inf & b_inf & (sa != sb)) ? nan_out : | |
| (a_inf | b_inf) ? inf_out : | |
| (exp_overflow) ? inf_out : | |
| (result_zero) ? zero_out : | |
| norm_out; | |
| endmodule | |
| """ | |
| def systolic_reference_verilog(): | |
| """Reference implementation for systolic_array task. | |
| Architecture: row i accumulates for exactly 4 cycles starting at cycle i | |
| (diagonal skewing via row-gated enable, no shift registers). | |
| output[i][j] = 4 * activations[i] * weights[i][j] | |
| done fires at posedge 7 from start (done_cycle <= 7). | |
| """ | |
| return """ | |
| module systolic_array ( | |
| input wire clk, rst, load_weights, start, | |
| input wire [63:0] weights_flat, | |
| input wire [127:0] activations_flat, | |
| output wire [255:0] outputs_flat, | |
| output wire done | |
| ); | |
| reg [3:0] weights [0:3][0:3]; | |
| reg [15:0] acc [0:3][0:3]; | |
| reg [2:0] cyc; | |
| reg running, done_reg; | |
| assign done = done_reg; | |
| genvar gi, gj; | |
| generate | |
| for (gi = 0; gi < 4; gi = gi + 1) begin : row_out | |
| for (gj = 0; gj < 4; gj = gj + 1) begin : col_out | |
| assign outputs_flat[(gi*4+gj)*16 +: 16] = acc[gi][gj]; | |
| end | |
| end | |
| endgenerate | |
| integer li, lj; | |
| always @(posedge clk) begin | |
| if (load_weights) | |
| for (li = 0; li < 4; li = li + 1) | |
| for (lj = 0; lj < 4; lj = lj + 1) | |
| weights[li][lj] <= weights_flat[(li*4+lj)*4 +: 4]; | |
| end | |
| integer ci, cj; | |
| always @(posedge clk) begin | |
| if (rst) begin | |
| running <= 0; done_reg <= 0; cyc <= 0; | |
| for (ci = 0; ci < 4; ci = ci + 1) | |
| for (cj = 0; cj < 4; cj = cj + 1) | |
| acc[ci][cj] <= 0; | |
| end else if (start) begin | |
| running <= 1; done_reg <= 0; cyc <= 0; | |
| for (ci = 0; ci < 4; ci = ci + 1) | |
| for (cj = 0; cj < 4; cj = cj + 1) | |
| acc[ci][cj] <= 0; | |
| end else if (running) begin | |
| for (ci = 0; ci < 4; ci = ci + 1) begin | |
| if (cyc >= ci && cyc < ci + 4) begin | |
| for (cj = 0; cj < 4; cj = cj + 1) | |
| acc[ci][cj] <= acc[ci][cj] | |
| + {{12{1'b0}}, weights[ci][cj]} | |
| * activations_flat[ci*8 +: 8]; | |
| end | |
| end | |
| cyc <= cyc + 1; | |
| if (cyc == 3'd6) begin | |
| done_reg <= 1; | |
| running <= 0; | |
| end | |
| end else begin | |
| done_reg <= 0; | |
| end | |
| end | |
| endmodule | |
| """ | |