|
|
--- |
|
|
base_model: meta-llama/Llama-3.2-1B-Instruct |
|
|
datasets: |
|
|
- whynlp/gsm8k-aug |
|
|
library_name: transformers |
|
|
license: llama3.2 |
|
|
tags: [] |
|
|
pipeline_tag: text-generation |
|
|
--- |
|
|
|
|
|
# Learning When to Stop: Adaptive Latent Reasoning via Reinforcement Learning |
|
|
|
|
|
This repository contains model weights for the paper [Learning When to Stop: Adaptive Latent Reasoning via Reinforcement Learning](https://huggingface.co/papers/2511.21581). |
|
|
|
|
|
The model explores adaptive-length latent reasoning in Transformer language models, optimizing reasoning length while maintaining accuracy through a post-SFT reinforcement learning methodology. Experiments on the Llama 3.2 1B model and the GSM8K-Aug dataset demonstrated a 52% reduction in total reasoning length without sacrificing accuracy. |
|
|
|
|
|
For more details, including the full codebase and utilities, please refer to the [GitHub repository](https://github.com/apning/adaptive-latent-reasoning). |
|
|
|
|
|
## Sample Usage |
|
|
|
|
|
You can load these models using the `automodelforcausallm_from_pretrained_latent` function from `src.model_creation` as shown in the GitHub repository: |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer |
|
|
from src.model_creation import automodelforcausallm_from_pretrained_latent |
|
|
|
|
|
repo_id = "Lapisbird/Llama-adaLR-model-latent-6" # Example: Replace with the specific model variant you want to load |
|
|
|
|
|
model = automodelforcausallm_from_pretrained_latent(repo_id) |
|
|
tokenizer = AutoTokenizer.from_pretrained(repo_id) |
|
|
|
|
|
# Further inference steps would follow from here, depending on your task. |
|
|
# Note: The `automodelforcausallm_from_pretrained_latent` function is custom to this project and |
|
|
# requires the `src/model_creation.py` file from the GitHub repository to be available in your Python path. |
|
|
``` |
|
|
|
|
|
## Trained Model Weights |
|
|
|
|
|
All weights used for results in the paper are available on Hugging Face. |
|
|
|
|
|
**From the main results:** |
|
|
|
|
|
| Model | Hugging Face repo | |
|
|
| --- | --- | |
|
|
| CoT SFT | Lapisbird/Llama-adaLR-model-cot_sft | |
|
|
| No-CoT SFT | Lapisbird/Llama-adaLR-model-no_cot_sft | |
|
|
| Latent-6 | Lapisbird/Llama-adaLR-model-latent-6 | |
|
|
| Latent-6 + RL | Lapisbird/Llama-adaLR-model-latent-6_rl | |
|
|
| Latent-6-by-1 | Lapisbird/Llama-adaLR-model-latent-6-by-1 | |
|
|
| Latent-6-by-1 + RL | Lapisbird/Llama-adaLR-model-latent-6-by-1_rl | |
|
|
|
|
|
**From the knowledge distillation for SFT section in Appendix:** |
|
|
|
|
|
| Model (Appendix) | Hugging Face repo | |
|
|
| --- | --- | |
|
|
| codi | Lapisbird/Llama-adaLR-appendix-model-codi | |
|
|
| codi + intermediate | Lapisbird/Llama-adaLR-appendix-model-codi_intermediate | |
|
|
| meaned | Lapisbird/Llama-adaLR-appendix-model-meaned | |
|
|
| meaned + intermediate | Lapisbird/Llama-adaLR-appendix-model-meaned_intermediate | |
|
|
| meaned + codi | Lapisbird/Llama-adaLR-appendix-model-meaned_codi | |