Spaces:
Sleeping
Sleeping
File size: 2,599 Bytes
06e7bdc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | import json
from pathlib import Path
# Paths to original SFT data
base_data_path = Path("data/train.jsonl")
valid_data_path = Path("data/valid.jsonl")
# Outputs from Ollama augmentation
aug_finance_path = Path("finance_augmented.json")
aug_physics_path = Path("physics_augmented.json")
out_dir = Path("data/sft_v3")
out_dir.mkdir(parents=True, exist_ok=True)
out_train = out_dir / "train.jsonl"
out_valid = out_dir / "valid.jsonl"
def format_sft_example(prompt, response):
# The original sft dataset has this nested structure for content
return {
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": prompt}]
},
{
"role": "assistant",
"content": [{"type": "text", "text": response}]
}
]
}
def process_file(file_path, train_data, domain_name):
if file_path.exists():
with open(file_path, "r") as f:
data = json.load(f)
added = 0
for item in data:
prompt = item.get("question", "")
response = item.get("chosen", "")
if prompt and response:
for _ in range(3):
train_data.append(format_sft_example(prompt, response))
added += 1
print(f"Added {added} augmented {domain_name} examples (oversampled 3x).")
else:
print(f"Missing augmented file: {file_path}")
def main():
train_data = []
# 1. Load original SFT data
if base_data_path.exists():
with open(base_data_path, "r") as f:
for line in f:
train_data.append(json.loads(line))
print(f"Loaded {len(train_data)} original SFT training examples.")
else:
print(f"Missing base training data at {base_data_path}!")
return
# 2. Add augmented Finance data
process_file(aug_finance_path, train_data, "Finance")
# 3. Add augmented Physics data
process_file(aug_physics_path, train_data, "Physics")
# Write new train.jsonl
with open(out_train, "w") as f:
for ex in train_data:
f.write(json.dumps(ex) + "\n")
# Copy valid.jsonl directly
if valid_data_path.exists():
import shutil
shutil.copy(valid_data_path, out_valid)
print("Copied original valid.jsonl.")
print(f"Successfully generated SFT V3 dataset with {len(train_data)} examples at {out_dir}")
if __name__ == "__main__":
main()
|