Spaces:
Runtime error
Runtime error
| #include "common.comp" | |
| // TODO: use a local size of 32 or more (Metal uses 1024) | |
| layout(local_size_x = 1) in; | |
| layout (push_constant) uniform parameter { | |
| uint inAOff; | |
| uint inBOff; | |
| uint outOff; | |
| int n_dims; | |
| int mode; | |
| int n_orig_ctx; | |
| float freq_base; | |
| float freq_scale; | |
| float ext_factor; | |
| float attn_factor; | |
| float beta_fast; | |
| float beta_slow; | |
| uint nb00; | |
| uint nb01; | |
| uint nb02; | |
| uint nb03; | |
| int ne0; | |
| uint nb0; | |
| uint nb1; | |
| uint nb2; | |
| uint nb3; | |
| } pcs; | |
| float rope_yarn_ramp(const float low, const float high, const float i0) { | |
| const float y = (i0 / 2 - low) / max(0.001f, high - low); | |
| return 1.0f - min(1.0f, max(0.0f, y)); | |
| } | |
| // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn | |
| // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. | |
| void rope_yarn( | |
| float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale, | |
| out float cos_theta, out float sin_theta | |
| ) { | |
| // Get n-d rotational scaling corrected for extrapolation | |
| float theta_interp = freq_scale * theta_extrap; | |
| float theta = theta_interp; | |
| if (ext_factor != 0.0f) { | |
| float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; | |
| theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; | |
| // Get n-d magnitude scaling corrected for interpolation | |
| mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); | |
| } | |
| cos_theta = cos(theta) * mscale; | |
| sin_theta = sin(theta) * mscale; | |
| } | |
| // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get | |
| // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` | |
| float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { | |
| return n_dims * log(n_orig_ctx / (n_rot * TWOPI_F)) / (2 * log(base)); | |
| } | |
| void rope_yarn_corr_dims( | |
| int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, out float dims[2] | |
| ) { | |
| // start and end correction dims | |
| dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); | |
| dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); | |
| } | |