File size: 3,787 Bytes
b25b8f2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | import os
import json
import random
import time
from datasets import load_dataset
import google.generativeai as genai
def generate_fine_tuning_dataset(num_samples=1000, output_file="models/local_mvm2_adapter/mvm2_training_data.jsonl"):
"""
Downloads GSM8K and uses Gemini 2.5 Flash to automatically convert the solutions
into the strict MVM2 JSON Triplet format for QLoRA fine-tuning.
"""
print(f"Loading GSM8K to generate {num_samples} training examples...")
dataset = load_dataset("gsm8k", "main", split="train")
indices = random.sample(range(len(dataset)), num_samples)
# Configure Gemini as our "Teacher" model to generate synthetic JSON traces
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "AIzaSyBM0LGvprdpevZXTE4IqlSLv0y74aBGhRc")
genai.configure(api_key=GEMINI_API_KEY)
teacher = genai.GenerativeModel('gemini-2.5-flash')
os.makedirs(os.path.dirname(output_file), exist_ok=True)
success_count = 0
with open(output_file, 'w', encoding='utf-8') as f:
for i, idx in enumerate(indices):
problem = dataset[idx]["question"]
raw_solution = dataset[idx]["answer"]
prompt = f"""
You are creating a fine-tuning dataset for a math reasoning model.
Convert this problem and solution into the strict MVM2 Triplet schema.
Problem: {problem}
Raw Solution: {raw_solution}
You MUST return ONLY a raw JSON object matching this schema:
{{
"final_answer": "The numerical final answer",
"reasoning_trace": ["step 1 equation", "step 2 equation"],
"confidence_explanation": "A statement confirming the algebraic steps are sound."
}}
"""
max_retries = 3
for attempt in range(max_retries):
try:
response = teacher.generate_content(prompt)
text = response.text.replace("```json", "").replace("```", "").strip()
json_data = json.loads(text)
# Create the instruction-tuned row
training_row = {
"messages": [
{"role": "system", "content": "You are an MVM2 math reasoning agent. You strictly output JSON triplets: {final_answer, reasoning_trace, confidence_explanation}."},
{"role": "user", "content": problem},
{"role": "assistant", "content": json.dumps(json_data)}
]
}
f.write(json.dumps(training_row) + "\n")
success_count += 1
print(f"Successfully generated triplet for problem {i+1}/{num_samples}")
break
except Exception as e:
if "429" in str(e) or "quota" in str(e).lower():
print(f"Rate limited. Sleeping 20s... (Attempt {attempt+1}/{max_retries})")
time.sleep(20)
else:
print(f"Skipping problem {i+1} due to parsing error: {e}")
break
print(f"\n✅ Dataset generation complete! Created {success_count} perfect MVM2 triplets at {output_file}")
print("You can now run `python scripts/train_qlora_math_agent.py` to commence fine-tuning.")
if __name__ == "__main__":
import sys
sys.stdout.reconfigure(encoding='utf-8')
generate_fine_tuning_dataset(num_samples=100) # Quick start with 100 samples
|