Spaces:
Sleeping
Sleeping
| module systolic_array ( | |
| input wire clk, | |
| input wire rst, | |
| input wire load_weights, | |
| input wire [63:0] weights_flat, | |
| input wire start, | |
| input wire [127:0] activations_flat, | |
| output wire [255:0] outputs_flat, | |
| output wire done | |
| ); | |
| reg [3:0] weights [0:3][0:3]; | |
| reg [15:0] acc [0:3][0:3]; | |
| reg [2:0] cyc; | |
| reg running; | |
| reg done_reg; | |
| assign done = done_reg; | |
| // Pack accumulator output back to flat 1D wire array | |
| genvar gi, gj; | |
| generate | |
| for (gi = 0; gi < 4; gi = gi + 1) begin : row_out | |
| for (gj = 0; gj < 4; gj = gj + 1) begin : col_out | |
| assign outputs_flat[(gi*4+gj)*16 +: 16] = acc[gi][gj]; | |
| end | |
| end | |
| endgenerate | |
| // Area Opt: Pre-calculate 4-bit * 8-bit products combinationally to strictly force | |
| // lightweight 4x8 Multipliers across the synthesizer rather than bulky 16x8 or 16x16. | |
| wire [11:0] prod [0:3][0:3]; | |
| genvar gr, gc; | |
| generate | |
| for (gr = 0; gr < 4; gr = gr + 1) begin : row_prod | |
| for (gc = 0; gc < 4; gc = gc + 1) begin : col_prod | |
| assign prod[gr][gc] = weights[gr][gc] * activations_flat[gr*8 +: 8]; | |
| end | |
| end | |
| endgenerate | |
| // Load Weights synchronously | |
| integer li, lj; | |
| always @(posedge clk) begin | |
| if (load_weights) begin | |
| for (li = 0; li < 4; li = li + 1) begin | |
| for (lj = 0; lj < 4; lj = lj + 1) begin | |
| weights[li][lj] <= weights_flat[(li*4+lj)*4 +: 4]; | |
| end | |
| end | |
| end | |
| end | |
| // Use a unified internal cycle to allow doing "Cycle 0" logic during `start` high pulse | |
| wire [2:0] next_cyc = start ? 3'd0 : cyc; | |
| integer ci, cj; | |
| always @(posedge clk) begin | |
| if (rst) begin | |
| running <= 0; | |
| done_reg <= 0; | |
| cyc <= 0; | |
| for (ci = 0; ci < 4; ci = ci + 1) begin | |
| for (cj = 0; cj < 4; cj = cj + 1) begin | |
| acc[ci][cj] <= 0; | |
| end | |
| end | |
| end else if (start || running) begin | |
| // Evaluate current row bounds (starts on `start` = 0) | |
| for (ci = 0; ci < 4; ci = ci + 1) begin | |
| if (next_cyc >= ci && next_cyc < ci + 4) begin | |
| for (cj = 0; cj < 4; cj = cj + 1) begin | |
| // If `start` is high, replace previous acc with 0 during the accumulation | |
| acc[ci][cj] <= (start ? 16'd0 : acc[ci][cj]) + prod[ci][cj]; | |
| end | |
| end else if (start) begin | |
| for (cj = 0; cj < 4; cj = cj + 1) begin | |
| acc[ci][cj] <= 0; | |
| end | |
| end | |
| end | |
| if (start) begin | |
| running <= 1; | |
| done_reg <= 0; | |
| cyc <= 1; | |
| end else begin | |
| cyc <= cyc + 1; | |
| // Assert completion synchronously hitting the 7th cycle accurately | |
| if (cyc == 3'd6) begin | |
| done_reg <= 1; | |
| running <= 0; | |
| end | |
| end | |
| end else begin | |
| done_reg <= 0; | |
| end | |
| end | |
| endmodule |