https://huggingface.co/EleutherAI/gpt-j-6b
I'm almost certain that GPTJForCausalLM is not yet supported by llama.cpp nor will it likely ever be which is such a shame given how many amazing old GPT-J based models exist on HuggingFace. The main issue with GPT-J and similar very old models is that they predate llama.cpp itself so by the time llama.cpp got popular those models already lost interest and nobody really seems to be willing to spend thair spare time implementing support for old legacy models despite supporting them would be quite important for historical purposes.
My asumption was correct and it is indeed unfortinately not yet supported by llama.cpp:
INFO:hf-to-gguf:Loading model: gpt-j-6b
INFO:hf-to-gguf:Model architecture: GPTJForCausalLM
ERROR:hf-to-gguf:Model GPTJForCausalLM is not supported
By the way the same applies for OPTForCausalLM . An amazing legacy model which lost interest before llama.cpp took off.
I so agree. It's a real shame to not have those classic models, especially koboldai.
I was actually to get it to work, hope this helps someone.
First in the conversion script hf_to_gguf.py you need this:
@ModelBase
.register("GPTJForCausalLM")
class GPTJModel(TextModel):
model_arch = gguf.MODEL_ARCH.GPTJ
def set_gguf_parameters(self):
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
n_inner = self.hparams.get("n_inner")
if n_inner is None:
n_inner = 4 * self.hparams["n_embd"]
self.gguf_writer.add_feed_forward_length(n_inner)
self.gguf_writer.add_head_count(self.hparams["n_head"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
# rotary embedding dimension
self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"])
self.gguf_writer.add_file_type(self.ftype)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
tensors: list[tuple[str, Tensor]] = []
if name.endswith((".attn.bias", ".attn.masked_bias")):
return tensors
new_name = self.map_tensor_name(name)
tensors.append((new_name, data_torch))
return tensors
def set_vocab(self):
self._set_vocab_gpt2()
Then to actually update llama cpp to run inference:
In llama-arch.cpp, fix the section that's already there
{
LLM_ARCH_GPTJ,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
},
}
In llama-model.cpp add cases for these
in load_hparams
case LLM_ARCH_GPTJ:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot);
switch (hparams.n_layer) {
case 28: type = LLM_TYPE_6B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
in load_tensors under switch(arch)
case LLM_ARCH_GPTJ:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, 0);
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
}
} break;
Then in build_graph add the case statement:
case LLM_ARCH_GPTJ:
{
llm = std::make_unique<llm_build_gptj>(*this, params);
} break;
And then the actual struct with the others
struct llm_build_gptj : public llm_graph_context {
llm_build_gptj(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
const int64_t n_rot = hparams.n_rot;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
// No positional embeddings added here (uses rotary instead)
cb(inpL, "inpL", -1);
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
// Save the original layer input
ggml_tensor * layer_inp = inpL;
cur = build_norm(layer_inp,
model.layers[il].attn_norm,
model.layers[il].attn_norm_b,
LLM_NORM, il);
cb(cur, "attn_norm", il);
ggml_tensor * attn_inp = cur;
// self-attention
ggml_tensor * attn_out;
{
// Separate Q, K, V projections
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, attn_inp);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, attn_inp);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, attn_inp);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
// apply rotary embeddings to Q and K
Qcur = ggml_rope(ctx0, Qcur, inp_pos, n_rot, 0);
Kcur = ggml_rope(ctx0, Kcur, inp_pos, n_rot, 0);
cb(Qcur, "Qcur_rope", il);
cb(Kcur, "Kcur_rope", il);
attn_out = build_attn(inp_attn,
model.layers[il].wo, nullptr, // no bias for attention output
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
cb(attn_out, "attn_out", il);
}
// FFN - processes the same input as attention (parallel structure)
ggml_tensor * ffn_out;
{
cur = build_lora_mm(model.layers[il].ffn_up, attn_inp); // Use attn_inp, same as attention
cur = ggml_add(ctx0, cur, model.layers[il].ffn_up_b);
cur = ggml_gelu(ctx0, cur);
cur = build_lora_mm(model.layers[il].ffn_down, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ffn_down_b);
ffn_out = cur;
cb(ffn_out, "ffn_out", il);
}
if (il == n_layer - 1 && inp_out_ids) {
attn_out = ggml_get_rows(ctx0, attn_out, inp_out_ids);
ffn_out = ggml_get_rows(ctx0, ffn_out, inp_out_ids);
layer_inp = ggml_get_rows(ctx0, layer_inp, inp_out_ids);
}
// Combine attention + FFN + original input (parallel structure)
cur = ggml_add(ctx0, attn_out, ffn_out);
cur = ggml_add(ctx0, cur, layer_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = build_norm(inpL,
model.output_norm,
model.output_norm_b,
LLM_NORM, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
};
@rgbpanda Wow that is awesome. Why don't you submit your amazing work to https://github.com/ggml-org/llama.cpp by creating a Pull Request so everyone can use it?
@rgbpanda yes indeed, this is awesome! but even if we would create ggufs, since changes to the inference engine are needed, nobody could use them. if you manage to get support into llama.cpp (and I think the threshold is not very high), we would happily quant all those tasty gpt-j/opt models.
I completely understand, from my perspective I was more curious if I could get these models to run faster than with the ggml repo's example inference, so yeah, it's probably not super useful from a general perspective, but I at least wanted to share it in this thread if someone else wanted to try to run these models with at least some variant of llama cpp, because I was able to get significant speed increases with llama cpp and having layers on my gpu vs ggml's example gpt-j inference I build yesterday. The changes I made to the interference engine are pretty isolated and follow the existing patterns though, I will say that much. I'm not super involved in this community so if there are certain test frameworks that have to be updated or something or some conditions met to make a PR I could do it this weekend, I really just wanted to share some code I was happy I got working for my own personal curiosity because it was after I saw this thread where nobody had it working beforehand.
@rgbpanda Your changes are amazing. They are isolated and follow the guidance how to add a new architecture to llama.cpp quite well. There are no conditions in order to create a PR for llama.cpp. Just fork the project, push you code and create a PR. Your code easily meets the very low-quality standards they expect when it comes to adding support for a new architecture. Please create a PR no matter what as by creating a PR you give the llama.cpp project the right to use your code and so others like myself can improve on it if it isn't already perfect. If you just post your code here then there is legally no easy way to integrate it into llama.cpp. Without your code being integrated into llama.cpp almost nobody will be able to run GPT-J. There are many really popular and from a historical perspective very impactful GPT-J based models like gpt4chan with Wikipedia page https://en.wikipedia.org/wiki/GPT4-Chan.
Yeah I was able to get a F32 variant of GPT-4Chan from archive.org to work on llama cpp, along with the original GPT-J and others. I was mostly curious about getting better performance than the ggml example inference and tested some other older GPT-J based older models that I thought were interesting. I'm glad the code is useful to others, so yeah I can open a PR this weekend but feel free to just build llama cpp with the changes I posted, there are only a few.