|
|
--- |
|
|
library_name: transformers |
|
|
license: mit |
|
|
base_model: |
|
|
- meta-llama/Llama-3.2-1B-Instruct |
|
|
--- |
|
|
# CoLaR Model |
|
|
|
|
|
<div align="center"> |
|
|
|
|
|
[](https://huggingface.co/ModalityDance/latent-tts-colar) |
|
|
|
|
|
</div> |
|
|
|
|
|
## Overview |
|
|
|
|
|
**CoLaR** (Compressed Latent Reasoning) is a latent reasoning model based on LLaMA that uses a specialized LatentHead module for generating continuous latent representations. This model is part of the [Parallel Test-Time Scaling for Latent Reasoning Models](https://arxiv.org/abs/2510.07745) framework. |
|
|
|
|
|
## Model Details |
|
|
|
|
|
- **Base Architecture**: LLaMA Language Model |
|
|
- **Model Class**: `ColarLlama` (extends `LlamaForCausalLM`) |
|
|
- **Special Features**: LatentHead module for latent space generation |
|
|
- **Latent Tokens**: Uses special token `<|latent|>` for latent reasoning |
|
|
- **End Token**: Uses `###` as the end-of-latent marker |
|
|
- **Input Format**: Direct input format with latent tokens |
|
|
|
|
|
## Related Models |
|
|
|
|
|
This repository includes other latent reasoning models that you might find useful: |
|
|
|
|
|
[ModalityDance/latent-tts](https://huggingface.co/collections/ModalityDance/latent-tts) |
|
|
|
|
|
## Installation |
|
|
|
|
|
Download the model from HuggingFace: |
|
|
|
|
|
```bash |
|
|
huggingface-cli download ModalityDance/latent-tts-colar --local-dir checkpoints/colar |
|
|
``` |
|
|
|
|
|
## Quick Start |
|
|
|
|
|
### Basic Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from transformers import AutoTokenizer |
|
|
from src.generation_mixin import LatentGenerationMixin, LatentGenerationConfig |
|
|
from src.paths import MODELS |
|
|
|
|
|
# Load tokenizer |
|
|
model_id = "checkpoints/colar" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
# Get latent token IDs |
|
|
latent_id = tokenizer.convert_tokens_to_ids("<|latent|>") |
|
|
end_id = tokenizer.convert_tokens_to_ids("###") |
|
|
|
|
|
# Create model class with generation mixin |
|
|
class LatentCoLaR(MODELS["colar"]["class"], LatentGenerationMixin): |
|
|
pass |
|
|
|
|
|
# Load model |
|
|
model = LatentCoLaR.from_pretrained( |
|
|
model_id, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.bfloat16, # Recommended for LLaMA models |
|
|
) |
|
|
|
|
|
# Prepare input |
|
|
question = "What is 2 + 2?<|latent|>" |
|
|
inputs = tokenizer(question, return_tensors="pt").to(model.device) |
|
|
|
|
|
# Configure generation |
|
|
generation_config = LatentGenerationConfig( |
|
|
max_new_tokens=128, |
|
|
max_latent_length=64, # CoLaR uses max_latent_length instead of latent_length |
|
|
latent_do_sample=True, |
|
|
latent_do_sample_by="dropout", # or "noise" |
|
|
dropout_p=0.1, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
# Generate |
|
|
output = model.generate( |
|
|
**inputs, |
|
|
generation_config=generation_config, |
|
|
num_return_sequences=1, |
|
|
) |
|
|
|
|
|
# Decode result |
|
|
result = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
print(result) |
|
|
``` |
|
|
|
|
|
### Batch Processing |
|
|
|
|
|
The model fully supports batch processing with Transformers: |
|
|
|
|
|
```python |
|
|
import torch |
|
|
|
|
|
# Prepare batch inputs |
|
|
questions = [ |
|
|
"What is 2 + 2?<|latent|>", |
|
|
"What is 5 * 3?<|latent|>", |
|
|
"What is 10 - 4?<|latent|>", |
|
|
] |
|
|
inputs = tokenizer(questions, return_tensors="pt", padding=True).to(model.device) |
|
|
|
|
|
# Generate for batch |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
generation_config=generation_config, |
|
|
num_return_sequences=1, |
|
|
) |
|
|
|
|
|
# Decode batch results |
|
|
results = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
for result in results: |
|
|
print(result) |
|
|
``` |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
### LatentHead Module |
|
|
|
|
|
CoLaR uses a specialized LatentHead for generating latent representations: |
|
|
|
|
|
```python |
|
|
class LatentHead(nn.Module): |
|
|
def __init__(self, feature_size, intermediate_size=512): |
|
|
super().__init__() |
|
|
self.fc = nn.Sequential( |
|
|
nn.Linear(feature_size, intermediate_size), |
|
|
nn.GELU(), |
|
|
nn.Linear(intermediate_size, intermediate_size), |
|
|
nn.LayerNorm(intermediate_size), |
|
|
) |
|
|
self.mean = nn.Linear(intermediate_size, feature_size) |
|
|
``` |
|
|
|
|
|
The latent embeddings are scaled by `latent_embedding_std` (default: 0.018 for LLaMA-3.2 models). |
|
|
|
|
|
## Generation Parameters |
|
|
|
|
|
### LatentGenerationConfig |
|
|
|
|
|
- `max_new_tokens` (int): Maximum number of tokens to generate |
|
|
- `max_latent_length` (int): Maximum number of latent tokens (default: 64) |
|
|
- `latent_do_sample` (bool): Whether to use stochastic sampling |
|
|
- `latent_do_sample_by` (str): Sampling method - `"dropout"` or `"noise"` |
|
|
- `dropout_p` (float): Dropout probability for Monte Carlo Dropout (e.g., 0.1) |
|
|
- `noise_std` (float): Standard deviation for Additive Gaussian Noise |
|
|
|
|
|
### Sampling Methods |
|
|
|
|
|
1. **Monte Carlo Dropout**: Randomly drops activations during forward passes |
|
|
|
|
|
```python |
|
|
generation_config = LatentGenerationConfig( |
|
|
latent_do_sample_by="dropout", |
|
|
dropout_p=0.1, |
|
|
# ... |
|
|
) |
|
|
``` |
|
|
2. **Additive Gaussian Noise**: Injects noise into latent embeddings |
|
|
|
|
|
```python |
|
|
generation_config = LatentGenerationConfig( |
|
|
latent_do_sample_by="noise", |
|
|
noise_std=0.1, |
|
|
# ... |
|
|
) |
|
|
``` |
|
|
|
|
|
## Answer Extraction |
|
|
|
|
|
CoLaR uses a special answer format with "Answer:" prefix: |
|
|
|
|
|
```python |
|
|
from src.paths import colar_extract_answer_number |
|
|
|
|
|
# Extract answer from generated text |
|
|
answer = colar_extract_answer_number(result) |
|
|
print(f"Answer: {answer}") |
|
|
``` |
|
|
|
|
|
## Evaluation |
|
|
|
|
|
Run evaluation using the provided scripts: |
|
|
|
|
|
```bash |
|
|
# For CoLaR (LLaMA based models) |
|
|
./run_tests_llama.sh |
|
|
``` |
|
|
|
|
|
## Model Card |
|
|
|
|
|
- **Paper**: [Parallel Test-Time Scaling for Latent Reasoning Models](https://arxiv.org/abs/2510.07745) |
|
|
- **HuggingFace**: [ModalityDance/latent-tts-colar](https://huggingface.co/ModalityDance/latent-tts-colar) |
|
|
- **Benchmarks**: GSM8K Test, GSM8K Hard, MultiArith |
|
|
|
|
|
## Notes |
|
|
|
|
|
- **Data Type**: Recommended to use `torch.bfloat16` or `torch.float16` for LLaMA models |
|
|
- **Memory**: LLaMA models typically require more GPU memory than GPT-2 models |
|
|
- **Latent Length**: CoLaR uses `max_latent_length` instead of fixed `latent_length` |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model, please cite: |
|
|
|
|
|
```bibtex |
|
|
@misc{you2025paralleltesttimescalinglatent, |
|
|
title={Parallel Test-Time Scaling for Latent Reasoning Models}, |
|
|
author={Runyang You and Yongqi Li and Meng Liu and Wenjie Wang and Liqiang Nie and Wenjie Li}, |
|
|
year={2025}, |
|
|
eprint={2510.07745}, |
|
|
archivePrefix={arXiv}, |
|
|
primaryClass={cs.CL}, |
|
|
url={https://arxiv.org/abs/2510.07745}, |
|
|
} |
|
|
|
|
|
@misc{tan2025thinksilentlythinkfast, |
|
|
title={Think Silently, Think Fast: Dynamic Latent Compression of LLM Reasoning Chains}, |
|
|
author={Wenhui Tan and Jiaze Li and Jianzhong Ju and Zhenbo Luo and Jian Luan and Ruihua Song}, |
|
|
year={2025}, |
|
|
eprint={2505.16552}, |
|
|
archivePrefix={arXiv}, |
|
|
primaryClass={cs.CL}, |
|
|
url={https://arxiv.org/abs/2505.16552}, |
|
|
} |
|
|
``` |
|
|
|