qwen3-14b-compact-cot-lora

A rank-16 LoRA adapter for Qwen/Qwen3-14B trained with GRPO on Hendrycks MATH to produce more compact chain-of-thought while preserving accuracy.

Reward = math_accuracy + 0.05 × length_penalty, where length_penalty ∈ [-1, 0] is -tokens / max_completion_length. The model is rewarded for solving the problem and lightly penalised for using more tokens.

Training summary

  • Base: Qwen/Qwen3-14B (thinking mode enabled — penalty applies to the full completion including <think>…</think>)
  • Dataset: DigitalLearningGmbH/MATH-lighteval, train split (~7.5K problems)
  • Algorithm: GRPO via TRL, single-GPU policy + dedicated trl vllm-serve on a second H100
  • LoRA: r=16, α=32, on q,k,v,o,gate,up,down
  • Other hp: lr 5e-6 (cosine), bf16, gradient checkpointing, per-device batch 1 × grad-accum 8, num_generations 8, max_completion_length 4096
  • Steps: 300 (~5 h on 2× H100 SXM)

Results (first-15 vs last-30 step rolling means)

metric start end Δ
mean completion length 3423 2627 −23%
terminated length (when not clipped) 2265 2057 −9%
clipped@4096 ratio 0.44 0.22 halved
math_accuracy (in-batch) 0.70 0.83 +13 pp

The policy got shorter and more accurate: roughly half as many batches saturate the 4096-token cap, in-training accuracy climbed from 0.70 to 0.83, and average completion shrank ~23%. λ=0.05 is mild — for stronger compression re-run with λ=0.2 or 0.5 (the prior 0.6B sweep showed accuracy held up to ~0.2 on GSM8K; 14B on MATH may be more sensitive).

Usage

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

base = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-14B", torch_dtype="bfloat16", device_map="auto")
tok  = AutoTokenizer.from_pretrained("Qwen/Qwen3-14B")
model = PeftModel.from_pretrained(base, "japhba/qwen3-14b-compact-cot-lora")

msgs = [{"role": "user", "content": "Find all real x with x^4 - 5x^2 + 4 = 0."}]
inputs = tok.apply_chat_template(msgs, add_generation_prompt=True, return_tensors="pt").to(model.device)
out = model.generate(inputs, max_new_tokens=4096, temperature=0.6, top_p=0.95)
print(tok.decode(out[0], skip_special_tokens=True))

Reproducibility

Code at https://github.com/japhba/rl-cot-length; this run used config configs/grpo_qwen3_14b.yaml and --lambda_ 0.05 --run_name qwen3_14b_l005_v2. Wandb: https://wandb.ai/japhba-personal/rl-cot-length/runs/oyzcjith

Trained with TRL GRPO + vLLM-serve mode.

Downloads last month
25
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for japhba/qwen3-14b-compact-cot-lora

Finetuned
Qwen/Qwen3-14B
Adapter
(211)
this model

Dataset used to train japhba/qwen3-14b-compact-cot-lora