File size: 3,088 Bytes
a04c389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
struct BitNetLinearParams {
  rows: u32,
  in_features: u32,
  out_features: u32,
  padded_in_features: u32,
  scale_granularity: u32,
  scale_group_size: u32,
  segment_count: u32,
  has_bias: u32,
  input_quant_mode: u32,
  input_quant_bits: u32,
  input_scale_rows: u32,
  reserved0: u32,
};

@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> packed_weight_words: array<u32>;
@group(0) @binding(2) var<storage, read> weight_scales: array<f32>;
@group(0) @binding(3) var<storage, read> segment_offsets: array<u32>;
@group(0) @binding(4) var<storage, read> bias_values: array<f32>;
@group(0) @binding(5) var<storage, read> input_scales: array<f32>;
@group(0) @binding(6) var<storage, read_write> output: array<f32>;
@group(0) @binding(7) var<uniform> params: BitNetLinearParams;

fn decode_signed_ternary(packed_byte: u32, in_idx: u32) -> f32 {
  let shift = (in_idx & 3u) * 2u;
  let code = (packed_byte >> shift) & 3u;
  if (code == 0u) {
    return -1.0;
  }
  if (code == 2u) {
    return 1.0;
  }
  return 0.0;
}

fn load_packed_byte(out_idx: u32, in_idx: u32) -> u32 {
  let row_stride_bytes = params.padded_in_features / 4u;
  let byte_offset = out_idx * row_stride_bytes + (in_idx / 4u);
  let word = packed_weight_words[byte_offset / 4u];
  let byte_lane = byte_offset & 3u;
  return (word >> (byte_lane * 8u)) & 255u;
}

fn resolve_weight_scale(out_idx: u32) -> f32 {
  if (params.scale_granularity == 0u) {
    return weight_scales[0];
  }
  if (params.scale_granularity == 1u) {
    var seg = 0u;
    loop {
      if (seg >= params.segment_count) {
        break;
      }
      if (out_idx >= segment_offsets[seg] && out_idx < segment_offsets[seg + 1u]) {
        return weight_scales[seg];
      }
      seg = seg + 1u;
    }
    return 0.0;
  }
  if (params.scale_granularity == 2u) {
    return weight_scales[out_idx / params.scale_group_size];
  }
  return 0.0;
}

fn quant_max(bits: u32) -> f32 {
  return f32((1u << (bits - 1u)) - 1u);
}

fn input_value(row: u32, col: u32) -> f32 {
  let value = input[row * params.in_features + col];
  if (params.input_quant_mode == 0u) {
    return value;
  }

  let scale_row = select(row, 0u, params.input_scale_rows == 1u);
  let scale = max(input_scales[scale_row], 0.00000001);
  let qmax = quant_max(params.input_quant_bits);
  let code = clamp(round(value / scale), -qmax, qmax);
  return code * scale;
}

@compute @workgroup_size(8, 8, 1)
fn bitnet_linear_main(@builtin(global_invocation_id) gid: vec3<u32>) {
  let out_idx = gid.x;
  let row = gid.y;
  if (out_idx >= params.out_features || row >= params.rows) {
    return;
  }

  var acc = 0.0;
  var col = 0u;
  loop {
    if (col >= params.in_features) {
      break;
    }
    let packed_byte = load_packed_byte(out_idx, col);
    let w = decode_signed_ternary(packed_byte, col);
    acc = acc + input_value(row, col) * w;
    col = col + 1u;
  }

  var y = acc * resolve_weight_scale(out_idx);
  if (params.has_bias != 0u) {
    y = y + bias_values[out_idx];
  }
  output[row * params.out_features + out_idx] = y;
}