Sai Sandesh Reddy commited on
Commit
3a79690
·
verified ·
1 Parent(s): 9a35ace

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +124 -3
README.md CHANGED
@@ -1,3 +1,124 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language: en
4
+ base_model: meta-llama/Meta-Llama-3-8B-Instruct
5
+ tags:
6
+ - math
7
+ - differential-equations
8
+ - dpo
9
+ - lbt
10
+ - instruction-tuned
11
+ ---
12
+
13
+ # LMT-tuning: Llama-3-8B Fine-tuned for Differential Equations
14
+
15
+ This model is a fine-tuned version of `meta-llama/Meta-Llama-3-8B-Instruct`, specialized for solving university-level differential equations problems.
16
+
17
+ The model was trained using the **Learning by Teaching (LbT)** paradigm combined with **Direct Preference Optimization (DPO)**. This approach aims to improve a "teacher" model's reasoning capabilities by having it teach a "student" model and learning from the student's performance.
18
+
19
+ ## Model Description
20
+
21
+ The core idea of the training process was to create a high-quality preference dataset where the "better" response was not just more correct, but also a better piece of teaching material.
22
+
23
+ The pipeline involved:
24
+ 1. **Data Augmentation:** A raw corpus of ~1500 differential equations problems was flattened and structured into a training set (~1200 problems) and a test set (~300 problems).
25
+ 2. **Teacher Generation:** The base Llama-3-8B model generated 32 step-by-step solutions (rationales) for each of the 1200 training problems.
26
+ 3. **Student Examination (LbT Scoring):** For each of the ~39,000 generated rationales, a "student" model (also Llama-3-8B) was taught using that rationale as a one-shot example. The student then took a similarity-based exam, and its performance yielded an "LbT score" for the rationale.
27
+ 4. **Preference Creation:** Rationales were scored based on a combination of correctness and their LbT score. High-scoring rationales were paired with low-scoring ones to create a preference dataset of `(prompt, chosen, rejected)` triplets.
28
+ 5. **DPO Fine-tuning:** The base Llama-3-8B model was fine-tuned on this preference dataset using `trl`'s `DPOTrainer` and QLoRA.
29
+
30
+ ## Intended Use
31
+
32
+ This model is primarily intended for:
33
+ - **Solving differential equations problems:** Providing step-by-step reasoning and a final answer.
34
+ - **Educational purposes:** Serving as a tool for students to check their work and understand problem-solving steps.
35
+ - **Research:** Acting as a baseline for further fine-tuning on specialized mathematical domains.
36
+
37
+ **Note:** This is a specialist model. While it has been fine-tuned for differential equations, its capabilities on general-purpose chat or other reasoning tasks may have degraded.
38
+
39
+ ## How to Use
40
+
41
+ You can use this model with the `transformers` library pipeline. It is crucial to use the Llama 3 chat template for best results.
42
+
43
+ ```python
44
+ import torch
45
+ from transformers import pipeline
46
+
47
+ # Load the model and tokenizer
48
+ pipe = pipeline(
49
+ "text-generation",
50
+ model="Sandesh-Zenteiq/LMT-tuning",
51
+ torch_dtype=torch.bfloat16,
52
+ device_map="auto"
53
+ )
54
+
55
+ # Your differential equations problem
56
+ problem = "Solve the initial value problem: y' - 2y = 0, with y(0) = 3."
57
+
58
+ # This is the full instruction set the model was trained on
59
+ instruction_text = (
60
+ "Your task is to answer the last question below. "
61
+ "Give step by step reasoning before you answer. "
62
+ "When you're ready to answer, please wrap your answer and conclude using the format\n"
63
+ "'''\n[[Final Answer]]:\n$ANSWER$\n'''\n\n\n\n"
64
+ )
65
+ exam_template = (
66
+ "[[Question]]:\n{question}\n\n"
67
+ "[[Solution]]:\nLet's think step by step.\n\n"
68
+ )
69
+
70
+ # Format the prompt using the Llama 3 chat template
71
+ prompt = (
72
+ f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
73
+ f"{instruction_text}{exam_template.format(question=problem)}"
74
+ f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
75
+ )
76
+
77
+ # Generate the response
78
+ # The pipeline will handle the prompt and only show you the generated part
79
+ response = pipe(
80
+ prompt,
81
+ max_new_tokens=1024,
82
+ do_sample=False, # Use do_sample=True for more creative answers
83
+ temperature=0.7,
84
+ top_p=0.9
85
+ )
86
+
87
+ # Extract and print the generated text
88
+ # The pipeline returns a list of outputs
89
+ generated_text = response['generated_text']
90
+ # The generated text includes the prompt, so we can slice it to see only the model's answer
91
+ assistant_response = generated_text[len(prompt):]
92
+ print(assistant_response)
93
+
94
+ Training Details
95
+
96
+ Base Model: meta-llama/Meta-Llama-3-8B-Instruct
97
+
98
+ Framework: trl.DPOTrainer with QLoRA
99
+
100
+ Hardware: NVIDIA A6000 / H200 class GPUs
101
+
102
+ Key Hyperparameters:
103
+
104
+ learning_rate: 2e-5
105
+
106
+ num_epochs: 1
107
+
108
+ lora_r: 128
109
+
110
+ lora_alpha: 256
111
+
112
+ gradient_accumulation_steps: 16
113
+
114
+ Evaluation
115
+
116
+ The model was evaluated on a held-out test set of 305 differential equations problems that were not seen during training. The metric is Pass@1 accuracy.
117
+
118
+ Model Accuracy
119
+ meta-llama/Llama-3-8B-Instruct (Base) 10.16%
120
+ LMT-tuning (This Model) 16.07%
121
+
122
+ This represents a +5.90 point absolute improvement and a ~58% relative improvement in performance on this specialized task.
123
+
124
+ Model fine-tuned by Sandesh-Zenteiq. The methodology is based on the paper "Can LLMs Learn by Teaching for Better Reasoning?"```