File size: 3,625 Bytes
f2f112a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include <math.h>
#include <stddef.h>

// --------- activations ----------
static inline float relu_f(float x) {
    return (x > 0.0f) ? x : 0.0f;
}

// Numerically-stable sigmoid
static inline float sigmoid_f(float x) {
    if (x >= 0.0f) {
        float z = expf(-x);
        return 1.0f / (1.0f + z);
    } else {
        float z = expf(x);
        return z / (1.0f + z);
    }
}

// --------- linear layer ----------
// y[n, out_dim] = sum_i x[n, in_dim] * W[out_dim, in_dim] + b[out_dim]
static void linear_forward(
    const float *x,            // [N, in_dim]
    float *y,                  // [N, out_dim]
    const float *W,            // [out_dim, in_dim] row-major: W[o*in_dim + i]
    const float *b,            // [out_dim]
    int N,
    int in_dim,
    int out_dim
) {
    for (int n = 0; n < N; ++n) {
        const float *xn = x + (size_t)n * (size_t)in_dim;
        float *yn = y + (size_t)n * (size_t)out_dim;

        for (int o = 0; o < out_dim; ++o) {
            const float *Wo = W + (size_t)o * (size_t)in_dim;
            float acc = b ? b[o] : 0.0f;

            // dot(xn, Wo)
            for (int i = 0; i < in_dim; ++i) {
                acc += xn[i] * Wo[i];
            }
            yn[o] = acc;
        }
    }
}

// In-place ReLU on a [N, dim] tensor
static void relu_inplace(float *x, int N, int dim) {
    size_t total = (size_t)N * (size_t)dim;
    for (size_t idx = 0; idx < total; ++idx) {
        x[idx] = relu_f(x[idx]);
    }
}

// In-place Sigmoid on a [N, dim] tensor
static void sigmoid_inplace(float *x, int N, int dim) {
    size_t total = (size_t)N * (size_t)dim;
    for (size_t idx = 0; idx < total; ++idx) {
        x[idx] = sigmoid_f(x[idx]);
    }
}

// --------- the requested block ----------
//
// input/output buffer "out" starts as [N, D] and ends as [N, 1].
//
// You must provide temporary buffers:
//   tmp1: [N, D/2]
//   tmp2: [N, D/8]
//
void score_tail_forward(
    float *out,                 // IN: [N, D], OUT: [N, 1]
    int N,
    int D,                      // must be divisible by 8
    // Linear2 params: (D -> D/2)
    const float *W2,            // [D/2, D]
    const float *b2,            // [D/2]
    // Linear3 params: (D/2 -> D/8)
    const float *W3,            // [D/8, D/2]
    const float *b3,            // [D/8]
    // Linear4 params: (D/8 -> 1)
    const float *W4,            // [1, D/8]  (or just [D/8])
    const float *b4,            // [1] (optional; PyTorch Linear has bias)
    float bias_scalar,          // your Parameter(torch.zeros(1))
    // workspaces
    float *tmp1,                // [N, D/2]
    float *tmp2                 // [N, D/8]
) {
    const int D2 = D / 2;
    const int D8 = D / 8;

    // output = Linear2(output)
    linear_forward(out, tmp1, W2, b2, N, D, D2);

    // output = ReLU(output)
    relu_inplace(tmp1, N, D2);

    // output = Linear3(output)
    linear_forward(tmp1, tmp2, W3, b3, N, D2, D8);

    // output = ReLU(output)
    relu_inplace(tmp2, N, D8);

    // output = Linear4(output) + bias
    // Linear4 produces [N, 1]
    // Treat W4 as [1, D8] row-major => W4[i] for i in [0..D8-1]
    for (int n = 0; n < N; ++n) {
        const float *xn = tmp2 + (size_t)n * (size_t)D8;
        float acc = 0.0f;

        // Linear4 bias term (if present)
        if (b4) acc += b4[0];

        // dot(xn, W4[0, :])
        for (int i = 0; i < D8; ++i) {
            acc += xn[i] * W4[i];
        }

        // add scalar bias parameter
        acc += bias_scalar;

        // write back into out as [N,1]
        out[n] = acc;
    }

    // output = Sigmoid(output)
    sigmoid_inplace(out, N, 1);
}