nielsr's picture
nielsr HF Staff
Improve model card: Add paper, code links, pipeline tag, usage, and trained models
4264612 verified
|
raw
history blame
2.72 kB
metadata
base_model: meta-llama/Llama-3.2-1B-Instruct
datasets:
  - whynlp/gsm8k-aug
library_name: transformers
license: llama3.2
tags: []
pipeline_tag: text-generation

Learning When to Stop: Adaptive Latent Reasoning via Reinforcement Learning

This repository contains model weights for the paper Learning When to Stop: Adaptive Latent Reasoning via Reinforcement Learning.

The model explores adaptive-length latent reasoning in Transformer language models, optimizing reasoning length while maintaining accuracy through a post-SFT reinforcement learning methodology. Experiments on the Llama 3.2 1B model and the GSM8K-Aug dataset demonstrated a 52% reduction in total reasoning length without sacrificing accuracy.

For more details, including the full codebase and utilities, please refer to the GitHub repository.

Sample Usage

You can load these models using the automodelforcausallm_from_pretrained_latent function from src.model_creation as shown in the GitHub repository:

from transformers import AutoTokenizer
from src.model_creation import automodelforcausallm_from_pretrained_latent

repo_id = "Lapisbird/Llama-adaLR-model-latent-6" # Example: Replace with the specific model variant you want to load

model = automodelforcausallm_from_pretrained_latent(repo_id)
tokenizer = AutoTokenizer.from_pretrained(repo_id)

# Further inference steps would follow from here, depending on your task.
# Note: The `automodelforcausallm_from_pretrained_latent` function is custom to this project and
# requires the `src/model_creation.py` file from the GitHub repository to be available in your Python path.

Trained Model Weights

All weights used for results in the paper are available on Hugging Face.

From the main results:

Model Hugging Face repo
CoT SFT Lapisbird/Llama-adaLR-model-cot_sft
No-CoT SFT Lapisbird/Llama-adaLR-model-no_cot_sft
Latent-6 Lapisbird/Llama-adaLR-model-latent-6
Latent-6 + RL Lapisbird/Llama-adaLR-model-latent-6_rl
Latent-6-by-1 Lapisbird/Llama-adaLR-model-latent-6-by-1
Latent-6-by-1 + RL Lapisbird/Llama-adaLR-model-latent-6-by-1_rl

From the knowledge distillation for SFT section in Appendix:

Model (Appendix) Hugging Face repo
codi Lapisbird/Llama-adaLR-appendix-model-codi
codi + intermediate Lapisbird/Llama-adaLR-appendix-model-codi_intermediate
meaned Lapisbird/Llama-adaLR-appendix-model-meaned
meaned + intermediate Lapisbird/Llama-adaLR-appendix-model-meaned_intermediate
meaned + codi Lapisbird/Llama-adaLR-appendix-model-meaned_codi