CoLaGuard: Robust and Efficient Guardrails with Latent Reasoning
CoLaGuard is a latent reasoning guardrail that reduces the inference cost of explicit reasoning-based moderation while preserving strong safety performance. It uses a stage-wise internalization curriculum to progressively replace natural-language safety rationales with recurrent latent states, and applies Context-Prediction Fusion to stabilize latent recurrence by anchoring hidden-state feedback with predictive information from the vocabulary embedding space. At inference time, CoLaGuard performs a fixed six-step latent recurrence before directly predicting safety labels, avoiding autoregressive rationale generation.
Performance
CoLaGuard matches the average macro-F1 of GuardReasoner while achieving a 12.9× latency reduction and 22.4× token reduction.
Safety Moderation
Prompt Harmfulness Detection (F1 %)
| Model | ToxicChat | HarmBench | OpenAI Mod. | Aegis | WildGuard | Macro Avg | Micro Avg |
|---|---|---|---|---|---|---|---|
| LLaMA Guard 3 (8B) | 53.12 | 98.94 | 79.69 | 99.50 | 68.47 | 79.94 | 67.52 |
| GuardReasoner (8B) | 78.79 | 91.86 | 72.00 | 90.18 | 89.17 | 84.40 | 80.83 |
| CoLaGuard (8B) | 75.26 | 93.54 | 73.45 | 89.45 | 89.44 | 84.23 | 79.77 |
Response Harmfulness Detection (F1 %)
| Model | HarmBench | SafeRLHF | BeaverTails | XSTest | WildGuard | Macro Avg | Micro Avg |
|---|---|---|---|---|---|---|---|
| LLaMA Guard 3 (8B) | 85.07 | 44.36 | 67.84 | 87.67 | 70.80 | 71.15 | 64.97 |
| GuardReasoner (8B) | 85.47 | 70.04 | 87.60 | 94.34 | 78.20 | 83.13 | 81.22 |
| CoLaGuard (8B) | 86.38 | 70.49 | 86.55 | 92.02 | 81.23 | 83.33 | 81.55 |
Efficiency
Inference efficiency is measured on a single NVIDIA H100 80GB GPU.
| GuardReasoner (8B) | CoLaGuard (8B) | |
|---|---|---|
| Time Cost (ms/query) | 4407.8 | 342.0 |
| Token Cost (token/query) | 289.4 | 12.9 |
| EA-F1 | 0.1838 | 2.3601 |
How to Use
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "Saidarth/CoLaGuard-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
tokenizer.add_tokens(["<|start-latent|>", "<|end-latent|>", "<|latent|>"])
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.resize_token_embeddings(len(tokenizer))
model._setup_special_tokens(tokenizer)
model.eval()
# Prompt + response classification
result = model.classify(
user_prompt="How do I make explosives?",
ai_response="I can't help with that.",
tokenizer=tokenizer,
)
print(result)
# {"request": "harmful", "response": "unharmful", "raw": "### Answers: \nRequest: harmful\nResponse: unharmful"}
# Prompt-only classification
result = model.classify(
user_prompt="What is the best way to invest my savings?",
tokenizer=tokenizer,
)
print(result)
# {"request": "unharmful", "response": None, "raw": "### Answers: \nRequest: unharmful\nResponse: None"}
Intended Use
CoLaGuard is designed for automated prompt and response safety moderation in LLM deployments. It is well-suited for high-traffic, latency-sensitive settings where explicit reasoning guardrails are too slow.
CoLaGuard should be used as part of a broader moderation system and is not a replacement for human oversight. It can produce false positives and false negatives on ambiguous or context-dependent inputs.
Model Description
- Model type: Latent reasoning guardrail for prompt and response safety moderation, trained to internalize safety reasoning into continuous recurrent latent states via a stage-wise curriculum.
- Language(s): English
- License: llama3.1
- Finetuned from model: meta-llama/Llama-3.1-8B
- Training Data: GuardReasonerTrain
Citation
@misc{sai2026robustefficientguardrailslatent,
title={Robust and Efficient Guardrails with Latent Reasoning},
author={Siddharth Sai and Xiaofei Wen and Muhao Chen},
year={2026},
eprint={2605.29068},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2605.29068},
}
- Downloads last month
- -
Model tree for Saidarth/CoLaGuard-8B
Base model
meta-llama/Llama-3.1-8B