| | --- |
| | license: apache-2.0 |
| | language: |
| | - en |
| | tags: |
| | - mathematics |
| | - reasoning |
| | - chain-of-thought |
| | - structured-cot |
| | - tag-cot |
| | - nested-reasoning |
| | - qwen |
| | - lora |
| | - peft |
| | - amd |
| | - mi300x |
| | - rocm |
| | datasets: |
| | - Xerv-AI/GRAD |
| | widget: |
| | - text: Prove that for any prime p > 3, p² − 1 is divisible by 24. |
| | example_title: Classic number theory proof |
| | - text: Solve the equation x³ + 6x² - 11x - 6 = 0 for real roots. |
| | example_title: Cubic polynomial factoring |
| | - text: Show that the sum of the first n odd numbers is n². |
| | example_title: Induction-friendly identity |
| | base_model: |
| | - Qwen/Qwen2.5-3B-Instruct |
| | pipeline_tag: text-generation |
| | --- |
| | |
| | <div align="center"> |
| |
|
| | # **ReasonBorn-Qwen-3B** |
| | **Structured Nested Tag-CoT Reasoning Model** |
| | *(GRAD-only fine-tune – March 2026)* |
| |
|
| | </div> |
| |
|
| | <br> |
| |
|
| | ## Model Details |
| |
|
| | - **Base model**: [Qwen/Qwen2.5-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct) |
| | - **Fine-tuning method**: LoRA (PEFT) |
| | - **LoRA rank (r)**: 16 |
| | - **LoRA alpha**: 32 |
| | - **Target modules**: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj |
| | - **LoRA dropout**: 0.04 |
| | - **Trainable parameters**: ~0.59% of total (≈18.35M params) |
| | - **Training precision**: bfloat16 |
| | - **Optimizer**: adamw_torch |
| | - **Learning rate**: 1.8e-4 (cosine schedule + 6% warmup) |
| | - **Epochs**: 3 |
| | - **Effective batch size**: 24 (per_device=4 × accum=6) |
| | - **Gradient checkpointing**: enabled (non-reentrant) |
| | - **Context length used**: 3072 tokens |
| | - **Hardware**: AMD Instinct MI300X (192 GB HBM3) – ROCm |
| | - **Training dataset**: Xerv-AI/GRAD (full train split, graduate-level math & proofs) |
| | - **Training date**: March 2026 |
| | - **Upload date**: March 2026 |
| | |
| | ## Intended Use |
| | |
| | This is a **specialized 3B reasoning model** fine-tuned to **produce extremely consistent nested structured Chain-of-Thought output** using the following rigid tag format: |
| | |
| | ```xml |
| | <plan> … high-level decomposition … </plan> |
| | <reasoning> |
| | <step index="1"> … </step> |
| | <step index="2"> … <verify>optional verification</verify> … </step> |
| | … |
| | </reasoning> |
| | <conclusion>\boxed{final answer}</conclusion> |
| | ``` |
| | |
| | It is designed for: |
| | |
| | - Mathematical proof generation |
| | - Step-by-step scientific reasoning |
| | - Competition-style problem solving (AMC, AIME, IMO shortlist level) |
| | - Educational tools that require verifiable, auditable reasoning traces |
| | - Agents / tool-use pipelines that parse structured reasoning |
| | |
| | ## Prompting Recommendation |
| | |
| | **Strongly recommended inference prompt** (copy-paste this): |
| | |
| | ```text |
| | <|im_start|>system |
| | You are ReasonBorn – rigorous scientific & mathematical reasoner. |
| | Respond **only** using this exact nested structure: |
| | <plan>…</plan> |
| | <reasoning> containing multiple <step index="…"> tags (with optional <verify> sub-tags) |
| | <conclusion>\boxed{…}</conclusion> |
| | Never write text outside the tags. Never skip tags. |
| | <|im_end|> |
| | <|im_start|>user |
| | {question} |
| | <|im_end|> |
| | <|im_start|>assistant |
| | ``` |
| | |
| | Lower temperature (0.1–0.25) + top_p ≈ 0.90–0.95 usually gives the cleanest structure. |
| | |
| | |
| | ## Training Script |
| | |
| | ``` |
| | import os |
| | import gc |
| | import re |
| | import time |
| | from concurrent.futures import ThreadPoolExecutor, as_completed |
| | |
| | import torch |
| | from huggingface_hub import login, HfApi |
| | from datasets import load_dataset, Dataset |
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | TrainingArguments, |
| | Trainer, |
| | DataCollatorForLanguageModeling, |
| | ) |
| | from peft import LoraConfig, get_peft_model |
| | |
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| | MODEL_ID = "Qwen/Qwen2.5-3B" |
| | REPO_NAME = "rb-qwen3b-16ds-lora" |
| | SAVE_DIR = "./rb-qwen-16ds-lora-final" |
| | |
| | MAX_CTX = 512 |
| | EPOCHS = 1.15 |
| | LR = 2.5e-4 |
| | LORA_R = 16 |
| | LORA_ALPHA = 32 |
| | BATCH_SIZE = 48 |
| | GRAD_ACCUM = 2 |
| | WORKERS = 12 |
| |
|
| | DATA_MIX = { |
| | "NuminaMath": {"path": "AI-MO/NuminaMath-CoT", "max_samples": 60000, "split": "train"}, |
| | "OrcaMath": {"path": "microsoft/orca-math-word-problems-200k", "max_samples": 60000, "split": "train"}, |
| | "UltraMath-Conv": {"path": "openbmb/UltraData-Math", "config": "UltraData-Math-L3-Conversation-Synthetic","max_samples": 50000, "split": "train"}, |
| | "GSM8K": {"path": "openai/gsm8k", "config": "main", "max_samples": 7473, "split": "train"}, |
| | "AI2_ARC": {"path": "allenai/ai2_arc", "config": "ARC-Challenge", "max_samples": 7500, "split": "train"}, |
| | "SciQ": {"path": "sciq", "max_samples": 11679, "split": "train"}, |
| | "OpenBookQA": {"path": "openbookqa", "max_samples": 4957, "split": "train"}, |
| | "GPQA": {"path": "Idavidrein/gpqa", "config": "gpqa_diamond", "max_samples": 198, "split": "train"}, |
| | "ChemistryQA": {"path": "avaliev/ChemistryQA", "max_samples": 4000, "split": "train"}, |
| | "HLE": {"path": "cais/hle", "max_samples": 2700, "split": "test"}, |
| | "GRAD": {"path": "Xerv-AI/GRAD", "max_samples": 1933, "split": "train"}, |
| | } |
| | |
| | def format_example(ex): |
| | try: |
| | q = str(ex.get("question") or ex.get("problem") or ex.get("prompt") or "").strip() |
| | s = str(ex.get("answer") or ex.get("solution") or ex.get("response") or "").strip() |
| | if len(q) < 5 or len(s) < 5: |
| | return None |
| | boxed = re.search(r'\\boxed\{(.*?)\}', s, re.DOTALL) |
| | ans = boxed.group(1).strip() if boxed else s[:80] |
| | reasoning = re.sub(r'\\boxed\{.*?\}', '', s, flags=re.DOTALL).strip() |
| | steps = [l.strip() for l in reasoning.split('\n') if len(l.strip()) > 8][:5] |
| | xml = "<plan>Decompose→reason→verify→conclude.</plan>\n<reasoning>\n" |
| | for i, step in enumerate(steps, 1): |
| | v = "<verify>ok</verify>" if i == len(steps) else "" |
| | xml += f'<step index="{i}">{step}{v}</step>\n' |
| | xml += f"</reasoning>\n<conclusion>\\boxed{{{ans}}}</conclusion>" |
| | sys_p = "You are ReasonBorn. Output only: <plan>,<reasoning><step>...</step></reasoning>,<conclusion>\\boxed{}." |
| | return {"text": ( |
| | f"<|im_start|>system\n{sys_p}<|im_end|>\n" |
| | f"<|im_start|>user\n{q}<|im_end|>\n" |
| | f"<|im_start|>assistant\n{xml}<|im_end|>" |
| | )} |
| | except Exception: |
| | return None |
| | |
| | def load_one(name, cfg): |
| | examples = [] |
| | kwargs = {"split": cfg["split"], "trust_remote_code": True} |
| | if "config" in cfg: |
| | kwargs["name"] = cfg["config"] |
| | try: |
| | ds = load_dataset(cfg["path"], **kwargs) |
| | if len(ds) > cfg["max_samples"]: |
| | ds = ds.select(range(cfg["max_samples"])) |
| | for ex in ds: |
| | r = format_example(ex) |
| | if r: |
| | examples.append(r) |
| | return name, examples, "ok" |
| | except Exception: |
| | pass |
| | try: |
| | ds = load_dataset(cfg["path"], streaming=True, **kwargs) |
| | for ex in ds: |
| | if len(examples) >= cfg["max_samples"]: |
| | break |
| | r = format_example(ex) |
| | if r: |
| | examples.append(r) |
| | return name, examples, "stream" |
| | except Exception: |
| | return name, [], "failed" |
| | |
| | login() |
| |
|
| | all_ex = [] |
| | with ThreadPoolExecutor(max_workers=6) as pool: |
| | futs = {pool.submit(load_one, n, c): n for n, c in DATA_MIX.items()} |
| | for fut in as_completed(futs): |
| | n, exs, status = fut.result() |
| | all_ex.extend(exs) |
| | |
| | train_ds = Dataset.from_list(all_ex).shuffle(seed=42) |
| | del all_ex |
| | gc.collect() |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
| | tokenizer.pad_token = tokenizer.eos_token |
| | tokenizer.padding_side = "right" |
| | |
| | tokenized = train_ds.map( |
| | lambda b: tokenizer(b["text"], truncation=True, max_length=MAX_CTX, padding=False), |
| | batched=True, batch_size=4000, num_proc=16, |
| | remove_columns=["text"], |
| | ) |
| | tokenized = tokenized.filter(lambda x: len(x["input_ids"]) >= 8, num_proc=16) |
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | MODEL_ID, |
| | torch_dtype=torch.bfloat16, |
| | low_cpu_mem_usage=True, |
| | trust_remote_code=True, |
| | attn_implementation="eager", |
| | ) |
| | model = model.to("cuda") |
| | torch.cuda.synchronize() |
| | |
| | model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) |
| | model.enable_input_require_grads() |
| |
|
| | model = get_peft_model(model, LoraConfig( |
| | r=LORA_R, |
| | lora_alpha=LORA_ALPHA, |
| | target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"], |
| | lora_dropout=0.05, |
| | bias="none", |
| | task_type="CAUSAL_LM", |
| | )) |
| | |
| | collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
| |
|
| | args = TrainingArguments( |
| | output_dir = "./chk", |
| | num_train_epochs = EPOCHS, |
| | per_device_train_batch_size = BATCH_SIZE, |
| | gradient_accumulation_steps = GRAD_ACCUM, |
| | gradient_checkpointing = True, |
| | optim = "adamw_torch_fused", |
| | learning_rate = LR, |
| | bf16 = True, |
| | fp16 = False, |
| | logging_steps = 25, |
| | save_strategy = "steps", |
| | save_steps = 500, |
| | save_total_limit = 2, |
| | warmup_ratio = 0.05, |
| | lr_scheduler_type = "cosine", |
| | weight_decay = 0.01, |
| | max_grad_norm = 0.5, |
| | dataloader_num_workers = WORKERS, |
| | dataloader_pin_memory = True, |
| | dataloader_prefetch_factor = 4, |
| | report_to = "none", |
| | remove_unused_columns = True, |
| | ) |
| | |
| | trainer = Trainer( |
| | model=model, |
| | args=args, |
| | train_dataset=tokenized, |
| | data_collator=collator, |
| | ) |
| | |
| | trainer.train() |
| |
|
| | os.makedirs(SAVE_DIR, exist_ok=True) |
| | trainer.save_model(SAVE_DIR) |
| | tokenizer.save_pretrained(SAVE_DIR) |
| | ``` |
| | |
| | |
| | ## Performance Notes (March 2026 observations) |
| | |
| | After only 3 epochs on GRAD: |
| | |
| | ✅ Very strong format adherence when strongly prompted |
| | ✅ Good proof structure and logical flow on number theory, algebra, basic inequalities |
| | ✅ Often includes verification steps (especially on last step) |
| | ⚠️ Format can still degrade on very long / multi-part questions without strong system prompt |
| | ⚠️ Generalization to non-math domains is limited (this is a math-first fine-tune) |
| | ⚠️ Weaker zero-shot format obedience compared to multi-dataset versions |
| | |
| | ## Training Hyperparameters Summary |
| | |
| | | Parameter | Value | |
| | |-------------------------------|--------------------| |
| | | Epochs | 3 | |
| | | Per-device batch size | 4 | |
| | | Gradient accumulation steps | 6 | |
| | | Global batch size | 24 | |
| | | Learning rate | 1.8 × 10⁻⁴ | |
| | | LR scheduler | cosine | |
| | | Warmup ratio | 0.06 | |
| | | Weight decay | 0.015 | |
| | | Max grad norm | 0.8 | |
| | | Optimizer | adamw_torch | |
| | | Mixed precision | bf16 | |
| | | Gradient checkpointing | Yes | |
| | |
| | ## VRAM Usage (MI300X 192 GB) |
| | |
| | | Stage | Approx. Reserved VRAM | Utilization | |
| | |------------------------------|------------------------|-------------| |
| | | After model load | ~7–12 GiB | ~4–6% | |
| | | After LoRA injection | ~8–15 GiB | ~5–8% | |
| | | Peak during training | ~140–175 GiB | ~73–91% | |
| | | After training (inference) | ~40–60 GiB | ~21–31% | |
| | |
| | ## How to Use (minimal example) |
| | |
| | ```python |
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | from peft import PeftModel |
| |
|
| | base_model = AutoModelForCausalLM.from_pretrained( |
| | "Qwen/Qwen2.5-3B-Instruct", |
| | torch_dtype=torch.bfloat16, |
| | device_map="auto" |
| | ) |
| | model = PeftModel.from_pretrained(base_model, "Xerv-AI/ReasonBorn-Qwen-3B") |
| | tokenizer = AutoTokenizer.from_pretrained("Xerv-AI/ReasonBorn-Qwen-3B") |
| | |
| | prompt = """<|im_start|>system |
| | You are ReasonBorn. Use <plan>, <reasoning> with <step> & <verify>, <conclusion> strictly. |
| | <|im_end|> |
| | <|im_start|>user |
| | Prove that √2 is irrational. |
| | <|im_end|> |
| | <|im_start|>assistant |
| | """ |
| | |
| | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| | output = model.generate(**inputs, max_new_tokens=1200, temperature=0.2, top_p=0.92) |
| | print(tokenizer.decode(output[0], skip_special_tokens=True)) |
| | ``` |
| | |
| | ## Acknowledgments |
| | |
| | - Qwen team for the excellent base model |
| | - Xerv-AI for releasing GRAD – one of the cleanest graduate-level math reasoning datasets available in 2026 |
| | - Hugging Face for the ecosystem |
| | - AMD ROCm team for making MI300X training possible |
| | |
| | --- |
| | |
| | **Xerv-AI / ReasonBorn-Qwen-3B** |
| | First step toward verifiable, tagged, auditable AI mathematical reasoning. |
| | Trained in Kolkata, March 2026. |