File size: 1,868 Bytes
7c251e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# Known limitations — tritllm-kernel

Items previously raised in code review have been addressed:

- The implicit one-warp-per-block launch contract in the educational kernels
  is now an early-return guard: kernels return without writing if launched
  with `blockDim.x != 32` or `in_features % 64 != 0`.
- The dead `trit_pipeline` / `k_v29_pipeline` path was removed.
- The C API now validates pointers, dimensions, and the
  `cols / GROUP_SIZE == num_groups` invariant, and reports the result via
  `trit_gemv_get_last_error()`. CUDA launch errors are captured into the same
  channel.
- `get_gpu_name(buf, buflen)` now refuses null pointers and `buflen <= 0`.

This document lists what remains.

## Design tradeoff (not a bug)

### Lane-0 scale-and-add in `trit_gemv_uniform` / `trit_gemv_variable`
**Where:** [`trit_gemv.cu:223-232, 279-286`](trit_gemv.cu#L223)

After the warp reduction in the educational kernels, only lane 0 multiplies
the group sum by the scale and accumulates into `row_acc`. The other 31 lanes
are idle for the scale/add path. This is correct, just slow — the published
paper benchmarks are produced by the deferred-reduction kernel
`k_d3_hardened` in `trit_gemv_standalone.cu`, which does not have this
limitation.

The `trit_gemv_uniform` / `trit_gemv_variable` kernels in `trit_gemv.cu` are
kept as a smaller, single-file reference implementation that is easier to read
and reason about. If you need maximum throughput, use the C API in
`trit_gemv_standalone.cu`.

## Future cleanup

The C API in `trit_gemv_standalone.cu` exposes several historical kernel
variants (`v9`, `v27`, `v28`, `v29`, plus `k_d3_hardened` via
`trit_gemv_d3_int8_dp4a`). They all work, but the public API is wider than
needed. A future release will trim to one canonical entry point per depth
(`trit_gemv_d1`, `trit_gemv_d2`, `trit_gemv_d3`, `trit_gemv_d4`).