| | --- |
| | license: mit |
| | language: en |
| | base_model: meta-llama/Meta-Llama-3-8B-Instruct |
| | tags: |
| | - math |
| | - differential-equations |
| | - dpo |
| | - lbt |
| | - instruction-tuned |
| | --- |
| | |
| | # LMT-tuning: Llama-3-8B Fine-tuned for Differential Equations |
| |
|
| | This model is a fine-tuned version of `meta-llama/Meta-Llama-3-8B-Instruct`, specialized for solving university-level differential equations problems. |
| |
|
| | The model was trained using the **Learning by Teaching (LbT)** paradigm combined with **Direct Preference Optimization (DPO)**. This approach aims to improve a "teacher" model's reasoning capabilities by having it teach a "student" model and learning from the student's performance. |
| |
|
| | ## Model Description |
| |
|
| | The core idea of the training process was to create a high-quality preference dataset where the "better" response was not just more correct, but also a better piece of teaching material. |
| |
|
| | The pipeline involved: |
| | 1. **Data Augmentation:** A raw corpus of ~1500 differential equations problems was flattened and structured into a training set (~1200 problems) and a test set (~300 problems). |
| | 2. **Teacher Generation:** The base Llama-3-8B model generated 32 step-by-step solutions (rationales) for each of the 1200 training problems. |
| | 3. **Student Examination (LbT Scoring):** For each of the ~39,000 generated rationales, a "student" model (also Llama-3-8B) was taught using that rationale as a one-shot example. The student then took a similarity-based exam, and its performance yielded an "LbT score" for the rationale. |
| | 4. **Preference Creation:** Rationales were scored based on a combination of correctness and their LbT score. High-scoring rationales were paired with low-scoring ones to create a preference dataset of `(prompt, chosen, rejected)` triplets. |
| | 5. **DPO Fine-tuning:** The base Llama-3-8B model was fine-tuned on this preference dataset using `trl`'s `DPOTrainer` and QLoRA. |
| |
|
| | ## Intended Use |
| |
|
| | This model is primarily intended for: |
| | - **Solving differential equations problems:** Providing step-by-step reasoning and a final answer. |
| | - **Educational purposes:** Serving as a tool for students to check their work and understand problem-solving steps. |
| | - **Research:** Acting as a baseline for further fine-tuning on specialized mathematical domains. |
| |
|
| | **Note:** This is a specialist model. While it has been fine-tuned for differential equations, its capabilities on general-purpose chat or other reasoning tasks may have degraded. |
| |
|
| | ## How to Use |
| |
|
| | You can use this model with the `transformers` library pipeline. It is crucial to use the Llama 3 chat template for best results. |
| |
|
| | ```python |
| | import torch |
| | from transformers import pipeline |
| | |
| | # Load the model and tokenizer |
| | pipe = pipeline( |
| | "text-generation", |
| | model="Sandesh-Zenteiq/LMT-tuning", |
| | torch_dtype=torch.bfloat16, |
| | device_map="auto" |
| | ) |
| | |
| | # Your differential equations problem |
| | problem = "Solve the initial value problem: y' - 2y = 0, with y(0) = 3." |
| | |
| | # This is the full instruction set the model was trained on |
| | instruction_text = ( |
| | "Your task is to answer the last question below. " |
| | "Give step by step reasoning before you answer. " |
| | "When you're ready to answer, please wrap your answer and conclude using the format\n" |
| | "'''\n[[Final Answer]]:\n$ANSWER$\n'''\n\n\n\n" |
| | ) |
| | exam_template = ( |
| | "[[Question]]:\n{question}\n\n" |
| | "[[Solution]]:\nLet's think step by step.\n\n" |
| | ) |
| | |
| | # Format the prompt using the Llama 3 chat template |
| | prompt = ( |
| | f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" |
| | f"{instruction_text}{exam_template.format(question=problem)}" |
| | f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" |
| | ) |
| | |
| | # Generate the response |
| | # The pipeline will handle the prompt and only show you the generated part |
| | response = pipe( |
| | prompt, |
| | max_new_tokens=1024, |
| | do_sample=False, # Use do_sample=True for more creative answers |
| | temperature=0.7, |
| | top_p=0.9 |
| | ) |
| | |
| | # Extract and print the generated text |
| | # The pipeline returns a list of outputs |
| | generated_text = response['generated_text'] |
| | # The generated text includes the prompt, so we can slice it to see only the model's answer |
| | assistant_response = generated_text[len(prompt):] |
| | print(assistant_response) |
| | |
| | Training Details |
| | |
| | Base Model: meta-llama/Meta-Llama-3-8B-Instruct |
| | |
| | Framework: trl.DPOTrainer with QLoRA |
| | |
| | Hardware: NVIDIA A6000 / H200 class GPUs |
| | |
| | Key Hyperparameters: |
| | |
| | learning_rate: 2e-5 |
| | |
| | num_epochs: 1 |
| | |
| | lora_r: 128 |
| | |
| | lora_alpha: 256 |
| | |
| | gradient_accumulation_steps: 16 |
| | |
| | Evaluation |
| | |
| | The model was evaluated on a held-out test set of 305 differential equations problems that were not seen during training. The metric is Pass@1 accuracy. |
| | |
| | Model Accuracy |
| | meta-llama/Llama-3-8B-Instruct (Base) 10.16% |
| | LMT-tuning (This Model) 16.07% |
| | |
| | This represents a +5.90 point absolute improvement and a ~58% relative improvement in performance on this specialized task. |
| | |
| | Model fine-tuned by Sandesh-Zenteiq. The methodology is based on the paper "Can LLMs Learn by Teaching for Better Reasoning?"``` |