Supreeth's picture
Upload folder using huggingface_hub
68acda4 verified
module systolic_array (
input wire clk,
input wire rst,
input wire load_weights,
input wire [63:0] weights_flat,
input wire start,
input wire [127:0] activations_flat,
output wire [255:0] outputs_flat,
output wire done
);
reg [3:0] weights [0:3][0:3];
reg [15:0] acc [0:3][0:3];
reg [2:0] cyc;
reg running;
reg done_reg;
assign done = done_reg;
// Pack accumulator output back to flat 1D wire array
genvar gi, gj;
generate
for (gi = 0; gi < 4; gi = gi + 1) begin : row_out
for (gj = 0; gj < 4; gj = gj + 1) begin : col_out
assign outputs_flat[(gi*4+gj)*16 +: 16] = acc[gi][gj];
end
end
endgenerate
// Area Opt: Pre-calculate 4-bit * 8-bit products combinationally to strictly force
// lightweight 4x8 Multipliers across the synthesizer rather than bulky 16x8 or 16x16.
wire [11:0] prod [0:3][0:3];
genvar gr, gc;
generate
for (gr = 0; gr < 4; gr = gr + 1) begin : row_prod
for (gc = 0; gc < 4; gc = gc + 1) begin : col_prod
assign prod[gr][gc] = weights[gr][gc] * activations_flat[gr*8 +: 8];
end
end
endgenerate
// Load Weights synchronously
integer li, lj;
always @(posedge clk) begin
if (load_weights) begin
for (li = 0; li < 4; li = li + 1) begin
for (lj = 0; lj < 4; lj = lj + 1) begin
weights[li][lj] <= weights_flat[(li*4+lj)*4 +: 4];
end
end
end
end
// Use a unified internal cycle to allow doing "Cycle 0" logic during `start` high pulse
wire [2:0] next_cyc = start ? 3'd0 : cyc;
integer ci, cj;
always @(posedge clk) begin
if (rst) begin
running <= 0;
done_reg <= 0;
cyc <= 0;
for (ci = 0; ci < 4; ci = ci + 1) begin
for (cj = 0; cj < 4; cj = cj + 1) begin
acc[ci][cj] <= 0;
end
end
end else if (start || running) begin
// Evaluate current row bounds (starts on `start` = 0)
for (ci = 0; ci < 4; ci = ci + 1) begin
if (next_cyc >= ci && next_cyc < ci + 4) begin
for (cj = 0; cj < 4; cj = cj + 1) begin
// If `start` is high, replace previous acc with 0 during the accumulation
acc[ci][cj] <= (start ? 16'd0 : acc[ci][cj]) + prod[ci][cj];
end
end else if (start) begin
for (cj = 0; cj < 4; cj = cj + 1) begin
acc[ci][cj] <= 0;
end
end
end
if (start) begin
running <= 1;
done_reg <= 0;
cyc <= 1;
end else begin
cyc <= cyc + 1;
// Assert completion synchronously hitting the 7th cycle accurately
if (cyc == 3'd6) begin
done_reg <= 1;
running <= 0;
end
end
end else begin
done_reg <= 0;
end
end
endmodule