LLaMA-3.1-8B-TS2 (Training with Sparsemax+, Testing with Softmax)

This model is a Supervised Fine-Tuned (SFT) variant of LLaMA-3.1-8B trained with a new objective on the UltraFeedback dataset.

It is designed to improve alignment stability and mitigate token-level probability collapse during fine-tuning by incorporating entropy-aware adaptive weighting into the training objective.

More details could check our paper ICLR 2026 "TS^2: Training with Sparsemax+, Testing with Softmax for Accurate and Diverse LLM Fine-Tuning"


🧠 Model Description

  • Base Model: meta-llama/Llama-3.1-8B
  • Training Method: Sparsemax+
  • Dataset: UltraFeedback (binarized preference pairs)
  • Objective: Token-level entropy-aware TSΒ²-style regularization
  • Framework: PyTorch + HuggingFace Transformers
  • Precision: bfloat16

Instead of applying uniform likelihood maximization across all tokens (as in standard SFT), this model introduces an adaptive weighting mechanism that dynamically adjusts training emphasis based on the predictive entropy of the model.

This approach is inspired by recent observations in preference alignment that overconfident likelihood training may lead to:

  • Degeneration of token diversity
  • Inference-time mode collapse
  • Reduced generalization under distribution shift

To address this, the training objective is modified.

πŸš€ Usage

You can load the model with the HuggingFace Transformers library:

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained(
    "xzybit/llama3.1-8b-ts2"
)

model = AutoModelForCausalLM.from_pretrained(
    "xzybit/llama3.1-8b-ts2",
    device_map="auto"
)
Downloads last month
2
Safetensors
Model size
8B params
Tensor type
BF16
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for xzybit/llama3.1-8b-ts2

Finetuned
(1786)
this model

Dataset used to train xzybit/llama3.1-8b-ts2