File size: 3,358 Bytes
3951517
68acda4
 
 
 
 
3951517
 
68acda4
3951517
 
68acda4
 
 
3951517
 
 
 
 
68acda4
3951517
 
 
 
 
 
 
 
 
68acda4
 
 
 
 
 
 
 
 
 
 
 
 
3951517
 
 
68acda4
 
3951517
68acda4
3951517
 
 
 
68acda4
 
 
3951517
 
 
68acda4
 
 
3951517
 
68acda4
3951517
 
68acda4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3951517
68acda4
3951517
 
 
 
 
68acda4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
module systolic_array (
    input  wire         clk,
    input  wire         rst,
    input  wire         load_weights,
    input  wire [63:0]  weights_flat,
    input  wire         start,
    input  wire [127:0] activations_flat,
    output wire [255:0] outputs_flat,
    output wire         done
);

    reg [3:0]  weights [0:3][0:3];
    reg [15:0] acc     [0:3][0:3];
    reg [2:0]  cyc;
    reg        running;
    reg        done_reg;

    assign done = done_reg;

    // Pack accumulator output back to flat 1D wire array
    genvar gi, gj;
    generate
        for (gi = 0; gi < 4; gi = gi + 1) begin : row_out
            for (gj = 0; gj < 4; gj = gj + 1) begin : col_out
                assign outputs_flat[(gi*4+gj)*16 +: 16] = acc[gi][gj];
            end
        end
    endgenerate

    // Area Opt: Pre-calculate 4-bit * 8-bit products combinationally to strictly force 
    // lightweight 4x8 Multipliers across the synthesizer rather than bulky 16x8 or 16x16.
    wire [11:0] prod [0:3][0:3];
    genvar gr, gc;
    generate
        for (gr = 0; gr < 4; gr = gr + 1) begin : row_prod
            for (gc = 0; gc < 4; gc = gc + 1) begin : col_prod
                assign prod[gr][gc] = weights[gr][gc] * activations_flat[gr*8 +: 8];
            end
        end
    endgenerate

    // Load Weights synchronously
    integer li, lj;
    always @(posedge clk) begin
        if (load_weights) begin
            for (li = 0; li < 4; li = li + 1) begin
                for (lj = 0; lj < 4; lj = lj + 1) begin
                    weights[li][lj] <= weights_flat[(li*4+lj)*4 +: 4];
                end
            end
        end
    end

    // Use a unified internal cycle to allow doing "Cycle 0" logic during `start` high pulse
    wire [2:0] next_cyc = start ? 3'd0 : cyc;
    
    integer ci, cj;
    always @(posedge clk) begin
        if (rst) begin
            running  <= 0;
            done_reg <= 0;
            cyc      <= 0;
            for (ci = 0; ci < 4; ci = ci + 1) begin
                for (cj = 0; cj < 4; cj = cj + 1) begin
                    acc[ci][cj] <= 0;
                end
            end
            
        end else if (start || running) begin
            // Evaluate current row bounds (starts on `start` = 0)
            for (ci = 0; ci < 4; ci = ci + 1) begin
                if (next_cyc >= ci && next_cyc < ci + 4) begin
                    for (cj = 0; cj < 4; cj = cj + 1) begin
                        // If `start` is high, replace previous acc with 0 during the accumulation
                        acc[ci][cj] <= (start ? 16'd0 : acc[ci][cj]) + prod[ci][cj];
                    end
                end else if (start) begin
                    for (cj = 0; cj < 4; cj = cj + 1) begin
                        acc[ci][cj] <= 0;
                    end
                end
            end
            
            if (start) begin
                running  <= 1;
                done_reg <= 0;
                cyc      <= 1;
            end else begin
                cyc <= cyc + 1;
                // Assert completion synchronously hitting the 7th cycle accurately
                if (cyc == 3'd6) begin
                    done_reg <= 1;
                    running  <= 0;
                end
            end
            
        end else begin
            done_reg <= 0;
        end
    end

endmodule