Create configuration_gslm.py
Browse files- configuration_gslm.py +50 -0
configuration_gslm.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
class GSLMConfig(PretrainedConfig):
|
| 4 |
+
model_type = "gslm"
|
| 5 |
+
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
vocab_size=200, # number of discrete units
|
| 9 |
+
n_layer=12,
|
| 10 |
+
n_head=12,
|
| 11 |
+
n_embd=768,
|
| 12 |
+
seq_len=4096,
|
| 13 |
+
dropout=0.1,
|
| 14 |
+
bias=False, # use bias in linear/ln layers
|
| 15 |
+
pos_embed="sinusoidal", # 'sinusoidal' | 'learned' | 'rope' | 'none'
|
| 16 |
+
use_rope=None, # kept for AuriStream API parity; if set, overrides pos_embed
|
| 17 |
+
rope_theta=500000,
|
| 18 |
+
n_pred_steps=1, # >1 enables auxiliary future heads like AuriStream
|
| 19 |
+
activation="gelu", # 'gelu' | 'silu'
|
| 20 |
+
norm_type="layernorm", # 'layernorm' (fairseq compat) | 'rmsnorm'
|
| 21 |
+
attn_impl="fused_qkv", # 'fused_qkv' (AuriStream-like) | 'separate_qkv' (fairseq-like)
|
| 22 |
+
tie_word_embeddings=True,
|
| 23 |
+
**kwargs,
|
| 24 |
+
):
|
| 25 |
+
super().__init__(**kwargs)
|
| 26 |
+
self.vocab_size = int(vocab_size)
|
| 27 |
+
self.n_layer = int(n_layer)
|
| 28 |
+
self.n_head = int(n_head)
|
| 29 |
+
self.n_embd = int(n_embd)
|
| 30 |
+
self.seq_len = int(seq_len)
|
| 31 |
+
self.dropout = float(dropout)
|
| 32 |
+
self.bias = bool(bias)
|
| 33 |
+
|
| 34 |
+
# Positional embedding config
|
| 35 |
+
if use_rope is not None:
|
| 36 |
+
self.pos_embed = "rope" if use_rope else "learned"
|
| 37 |
+
else:
|
| 38 |
+
self.pos_embed = pos_embed
|
| 39 |
+
self.rope_theta = float(rope_theta)
|
| 40 |
+
|
| 41 |
+
# Multi-step heads
|
| 42 |
+
self.n_pred_steps = int(n_pred_steps)
|
| 43 |
+
|
| 44 |
+
# Blocks
|
| 45 |
+
self.activation = activation
|
| 46 |
+
self.norm_type = norm_type
|
| 47 |
+
self.attn_impl = attn_impl
|
| 48 |
+
|
| 49 |
+
# HF compat
|
| 50 |
+
self.tie_word_embeddings = bool(tie_word_embeddings)
|