gemma4-prometheus-merged

Prometheus-steered merged modelgoogle/gemma-4-31B-it with Prometheus adversarial-steering baked in and all adapter weights merged into the base model.

Related repositories

Repo Description
groxaxo/gemma4-prometheus-gptq-4bit GPTQ-4bit quantized version of this model
groxaxo/gemma4-prometheus-workflow Reproducible scripts, config, and checkpoint journal
groxaxo/gemma4-prometheus-fixes All local patches applied to make this work
google/gemma-4-31B-it Original base model

What is this?

  1. Downloaded google/gemma-4-31B-it (31 B parameters, BF16).
  2. Ran Prometheus adversarial-steering optimization over 1 trial with 6 behaviors.
  3. Merged the best steering vectors back into the base model weights.
  4. Saved as a standalone, loadable model (no Prometheus runtime needed).

Size: ~58 GiB (two BF16 shards).


How to run

Minimum requirements

  • 3 × RTX 3090 (24 GB each) or any combination totalling ≥ 65 GiB VRAM
  • On 2 GPUs (≤ 48 GiB): load with BnB 8-bit (see below)

BF16 on 3 GPUs

from transformers import AutoModelForImageTextToText, AutoTokenizer
import torch

model_id = "groxaxo/gemma4-prometheus-merged"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
model.eval()

messages = [{"role": "user", "content": "Explain gradient descent."}]
text = tokenizer.apply_chat_template(
    messages, tokenize=False,
    add_generation_prompt=True,
    enable_thinking=False,
)
ids = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
with torch.no_grad():
    out = model.generate(ids, max_new_tokens=512, do_sample=False,
                         pad_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(out[0, ids.shape[1]:], skip_special_tokens=True))

BnB 8-bit on 2 GPUs (48 GiB)

from transformers import AutoModelForImageTextToText, AutoTokenizer, BitsAndBytesConfig
import torch

model_id = "groxaxo/gemma4-prometheus-merged"

bnb_cfg = BitsAndBytesConfig(load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    quantization_config=bnb_cfg,
    device_map="auto",
    max_memory={0: "23GiB", 1: "23GiB"},
)
model.eval()

Important: disable thinking tokens

Always pass enable_thinking=False to apply_chat_template — otherwise the model may emit <start_of_turn>think tokens and spend tokens on chain-of-thought.


Evaluation results

All tests run on 2 × RTX 3090 (24 GB each) in pipeline-parallel mode (device_map="auto" via 🤗 Accelerate). True tensor-parallelism (TP=2) requires vLLM, which does not yet support the gemma4 architecture natively.

Coherence test (GPTQ-4bit model — same backbone)

5 diverse ML questions answered. 5/5 passed (≥ 15 words, on-topic).

Prompt Response excerpt OK
Explain how neural networks learn from data. "…a neural network learns by trial and error. It makes a guess, finds out how wrong…"
What is the difference between supervised and unsupervised learning? "…In supervised learning, the data is 'labeled'…In unsupervised learning, the data is 'unlabeled'…"
Describe the concept of gradient descent in machine learning. "…Gradient Descent is an optimization algorithm used to minimize a function…"
What are transformers in NLP? "…a Transformer is a deep learning architecture…focusing on the most important parts…"
Explain quantization for neural network models. "…quantization is the process of reducing the precision of the numbers…"

Context length

KV Cache Max Tokens Bottleneck Notes
FP16 6 144 Attention compute O(n²) Without flash-attn, attention matrix = 32 heads × n² × 2 B
FP8 (software) 6 144 Same — attention compute FP8 saves KV storage, not the attention matrix
FP16 + flash-attn (estimated) ~113 000 KV cache Recommended: pip install flash-attn --no-build-isolation
FP8 + flash-attn (estimated) ~226 000 KV cache Capped by max_position_embeddings = 262 144

Action: Installing flash-attn would increase usable context ~18×.

Perplexity (WikiText-2, sliding window stride=512, 4096 tokens)

Model Perplexity Notes
Merged (BnB-8bit reference) 1782.3 Chat model tested on raw text — high PPL is expected
GPTQ-4bit 1815.8 +1.9% vs merged reference

Chat-tuned models have high raw-text perplexity. The ΔPPL between variants is the meaningful signal: +1.9% degradation from 4-bit quantization.

KL divergence (GPTQ-4bit vs merged reference)

Metric Value
Direction KL(merged_bnb8 ‖ gptq_4bit)
Mean KL 4.77 nats
Std KL 3.65 nats
Prompts 8 ML-domain questions
Top-k tokens 1000

Mean KL of ~4.77 nats reflects expected 4-bit quantization error relative to an 8-bit reference. Note: part of this KL is attributable to bnb-8bit noise in the reference; true KL vs FP16 merged would be somewhat lower.


Architecture notes

  • 60 transformer layers alternating:
    • Sliding-window attention (window=1024, 16 KV heads, head_dim=256)
    • Full (global) attention (4 KV heads, head_dim=512)
  • GQA with 32 query heads, 16/4 KV heads
  • VLM wrapper (model.language_model) — vision tower present but text-only inference works

Patches applied

All source patches are documented at groxaxo/gemma4-prometheus-fixes.

Key fixes:

  1. Prometheus PEFT adapter targeting — resolved exact module paths via named_modules() traversal to prevent over-matching vision layers.
  2. Prometheus steering FP16 — defaulted steering vector compute dtype to FP16 (not FP32) to prevent VRAM OOM on quantized layers.
  3. gptqmodel Gemma4 support — added Gemma4QModel definition with layer_modules_strict=False.
  4. gptqmodel rotary embedding — per-layer position_embeddings regeneration with correct layer_type (sliding vs global).

Citation / acknowledgements

Downloads last month
4
Safetensors
Model size
31B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for groxaxo/gemma4-prometheus-merged

Finetuned
(169)
this model
Quantizations
1 model