grpo / README.md
rushigulum's picture
Update README.md
93608d1 verified
metadata
base_model:
  - Qwen/Qwen2.5-3B-Instruct
tags:
  - text-generation-inference
  - transformers
  - unsloth
  - qwen2
  - trl
license: apache-2.0
language:
  - en
datasets:
  - openai/gsm8k
metrics:
  - accuracy
pipeline_tag: text-generation

Uploaded model

  • Developed by: rushigulum
  • License: apache-2.0
  • Finetuned from model : Qwen/Qwen2.5-3B-Instruct

This qwen2 model was trained 2x faster with Unsloth and Huggingface's TRL library.

Nano R1 is a fine-tuned variant of Qwen2.5-3B-Instruct, aligned using Group Relative Preference Optimization (GRPO) for reasoning-intensive tasks such as math problem-solving. The model is trained with Unsloth + TRL + vLLM to ensure efficient fine-tuning, faster inference, and improved contextual accuracy.

Key Highlights:

  • = Base Model: Qwen2.5-3B-Instruct (via HuggingFace)

  • = Fine-Tuning: GRPO reinforcement learning with custom reward functions

  • = Optimizations: LoRA adapters, 4-bit quantization, vLLM inference

  • = Dataset: GSM8K (math reasoning) with structured XML reasoning prompts

  • = Deployment: Hugging Face Hub integration

  • Model Loading with LoRA Base: Qwen/Qwen2.5-3B-Instruct Optimizations: 4-bit quantization, LoRA rank=64, gradient checkpointing.

  • Reward Functions Semantic Correctness → via Sentence-BERT embeddings. Strict XML Compliance → ensures reasoning/answer separation. Numerical Answer Check → enforces valid math outputs. Length & Format Penalty → prevents overly long/unstructured responses.

  • GRPO Training Optimizer: AdamW (8-bit) Batch size: 1 (accumulated) Learning rate: 5e-6 Steps: 150 (demo run) Inference engine: vLLM for efficiency

  • Valuation Benchmarked on GSM8K validation set. Metrics: Final Answer Accuracy (semantic similarity > threshold). Format Compliance (% responses following XML structure). Average Reward Score across completions.

  • Results Improved reasoning structure with consistent / format. Higher semantic accuracy vs baseline Qwen2.5-3B. Optimized inference speed using vLLM.