base_model: meta-llama/Llama-3.2-1B-Instruct
datasets:
- whynlp/gsm8k-aug
library_name: transformers
license: llama3.2
tags: []
pipeline_tag: text-generation
Learning When to Stop: Adaptive Latent Reasoning via Reinforcement Learning
This repository contains model weights for the paper Learning When to Stop: Adaptive Latent Reasoning via Reinforcement Learning.
The model explores adaptive-length latent reasoning in Transformer language models, optimizing reasoning length while maintaining accuracy through a post-SFT reinforcement learning methodology. Experiments on the Llama 3.2 1B model and the GSM8K-Aug dataset demonstrated a 52% reduction in total reasoning length without sacrificing accuracy.
For more details, including the full codebase and utilities, please refer to the GitHub repository.
Sample Usage
You can load these models using the automodelforcausallm_from_pretrained_latent function from src.model_creation as shown in the GitHub repository:
from transformers import AutoTokenizer
from src.model_creation import automodelforcausallm_from_pretrained_latent
repo_id = "Lapisbird/Llama-adaLR-model-latent-6" # Example: Replace with the specific model variant you want to load
model = automodelforcausallm_from_pretrained_latent(repo_id)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
# Further inference steps would follow from here, depending on your task.
# Note: The `automodelforcausallm_from_pretrained_latent` function is custom to this project and
# requires the `src/model_creation.py` file from the GitHub repository to be available in your Python path.
Trained Model Weights
All weights used for results in the paper are available on Hugging Face.
From the main results:
| Model | Hugging Face repo |
|---|---|
| CoT SFT | Lapisbird/Llama-adaLR-model-cot_sft |
| No-CoT SFT | Lapisbird/Llama-adaLR-model-no_cot_sft |
| Latent-6 | Lapisbird/Llama-adaLR-model-latent-6 |
| Latent-6 + RL | Lapisbird/Llama-adaLR-model-latent-6_rl |
| Latent-6-by-1 | Lapisbird/Llama-adaLR-model-latent-6-by-1 |
| Latent-6-by-1 + RL | Lapisbird/Llama-adaLR-model-latent-6-by-1_rl |
From the knowledge distillation for SFT section in Appendix:
| Model (Appendix) | Hugging Face repo |
|---|---|
| codi | Lapisbird/Llama-adaLR-appendix-model-codi |
| codi + intermediate | Lapisbird/Llama-adaLR-appendix-model-codi_intermediate |
| meaned | Lapisbird/Llama-adaLR-appendix-model-meaned |
| meaned + intermediate | Lapisbird/Llama-adaLR-appendix-model-meaned_intermediate |
| meaned + codi | Lapisbird/Llama-adaLR-appendix-model-meaned_codi |