CFO-Agent-14B / normalize_layer1.py
OsamaAli313's picture
Upload normalize layer1
cb7ce4a verified
#!/usr/bin/env python3
"""
Normalize Layer 1 financial instruction datasets into ChatML JSONL format.
Handles 5 financial instruction datasets with different column naming conventions:
- finance_instruct_500k (Josephgflowers/Finance-Instruct-500k)
- sujet_finance_177k (sujet-ai/Sujet-Finance-Instruct-177k)
- financial_qa_10k (virattt/financial-qa-10K)
- fingpt_convfinqa (FinGPT/fingpt-convfinqa)
- earnings_calls_qa (lamini/earnings-calls-qa)
Converts to unified ChatML format with system prompt, user message, and assistant response.
Filters out low-quality samples based on message length thresholds.
"""
import json
import os
from pathlib import Path
from datasets import load_from_disk
from typing import Optional, Dict, List, Tuple
def get_system_prompt(data_dir: Path) -> str:
"""Load the CFO system prompt from file."""
prompt_path = data_dir / "cfo_system_prompt.txt"
with open(prompt_path, "r", encoding="utf-8") as f:
return f.read().strip()
def extract_field(sample: Dict, possible_names: List[str], default: str = "") -> str:
"""
Extract a field from a sample using multiple possible column names with fallbacks.
Args:
sample: The data sample dictionary
possible_names: List of possible column names to try
default: Default value if none of the names are found
Returns:
The field value as a string, or default if not found
"""
for name in possible_names:
if name in sample:
value = sample[name]
if value is not None:
return str(value).strip()
return default
def is_valid_sample(user_content: str, assistant_content: str,
min_user_len: int = 10, min_assistant_len: int = 20) -> bool:
"""
Check if a sample meets quality thresholds.
Args:
user_content: The user message content
assistant_content: The assistant response content
min_user_len: Minimum acceptable length for user message
min_assistant_len: Minimum acceptable length for assistant response
Returns:
True if sample meets quality thresholds, False otherwise
"""
return (len(user_content) >= min_user_len and
len(assistant_content) >= min_assistant_len)
def normalize_finance_instruct_500k(sample: Dict) -> Optional[Dict]:
"""
Normalize finance_instruct_500k dataset.
Likely columns: instruction/context/output or input/output
"""
# Try instruction + context + output pattern
instruction = extract_field(sample, ["instruction", "input", "prompt"])
context = extract_field(sample, ["context", "background", ""])
output = extract_field(sample, ["output", "response", "answer"])
# Combine context with instruction if available
if context:
user_content = f"{instruction}\n\nContext: {context}".strip()
else:
user_content = instruction
if not is_valid_sample(user_content, output):
return None
return {
"messages": [
{"role": "system", "content": ""}, # Will be filled later
{"role": "user", "content": user_content},
{"role": "assistant", "content": output}
]
}
def normalize_sujet_finance_177k(sample: Dict) -> Optional[Dict]:
"""
Normalize sujet_finance_177k dataset.
Likely columns: instruction/output
"""
instruction = extract_field(sample, ["instruction", "input", "question"])
output = extract_field(sample, ["output", "response", "answer"])
if not is_valid_sample(instruction, output):
return None
return {
"messages": [
{"role": "system", "content": ""},
{"role": "user", "content": instruction},
{"role": "assistant", "content": output}
]
}
def normalize_financial_qa_10k(sample: Dict) -> Optional[Dict]:
"""
Normalize financial_qa_10k dataset.
Likely columns: question/answer/context
"""
question = extract_field(sample, ["question", "query", "input"])
answer = extract_field(sample, ["answer", "response", "output"])
context = extract_field(sample, ["context", "background", "document"])
# Combine context with question if available
if context:
user_content = f"{question}\n\nDocument context: {context}".strip()
else:
user_content = question
if not is_valid_sample(user_content, answer):
return None
return {
"messages": [
{"role": "system", "content": ""},
{"role": "user", "content": user_content},
{"role": "assistant", "content": answer}
]
}
def normalize_fingpt_convfinqa(sample: Dict) -> Optional[Dict]:
"""
Normalize fingpt_convfinqa dataset.
Likely columns: input/output
"""
user_input = extract_field(sample, ["input", "instruction", "question"])
output = extract_field(sample, ["output", "response", "answer"])
if not is_valid_sample(user_input, output):
return None
return {
"messages": [
{"role": "system", "content": ""},
{"role": "user", "content": user_input},
{"role": "assistant", "content": output}
]
}
def normalize_earnings_calls_qa(sample: Dict) -> Optional[Dict]:
"""
Normalize earnings_calls_qa dataset.
Likely columns: question/answer
"""
question = extract_field(sample, ["question", "query", "input"])
answer = extract_field(sample, ["answer", "response", "output"])
if not is_valid_sample(question, answer):
return None
return {
"messages": [
{"role": "system", "content": ""},
{"role": "user", "content": question},
{"role": "assistant", "content": answer}
]
}
def process_dataset(dataset_name: str, dataset, normalize_fn) -> Tuple[int, int]:
"""
Process a single dataset and return (valid_count, filtered_count).
Args:
dataset_name: Name of the dataset for logging
dataset: The loaded dataset object
normalize_fn: Function to normalize samples from this dataset
Returns:
Tuple of (number of valid samples, number of filtered samples)
"""
valid_count = 0
filtered_count = 0
# Handle both single split and multiple splits
if isinstance(dataset, dict):
splits = list(dataset.keys())
else:
splits = ["train"] if hasattr(dataset, "__len__") else []
for split in splits:
split_data = dataset[split] if isinstance(dataset, dict) else dataset
for sample in split_data:
normalized = normalize_fn(sample)
if normalized is not None:
normalized_samples.append(normalized)
valid_count += 1
else:
filtered_count += 1
print(f" {dataset_name}: {valid_count} valid, {filtered_count} filtered")
return valid_count, filtered_count
def main():
"""Main normalization pipeline."""
# Setup paths
script_dir = Path(__file__).parent
raw_dir = script_dir / "raw"
processed_dir = script_dir / "processed"
processed_dir.mkdir(exist_ok=True)
# Load system prompt
system_prompt = get_system_prompt(script_dir)
print(f"Loading datasets from: {raw_dir}")
print(f"Output will be saved to: {processed_dir / 'layer1.jsonl'}")
print("-" * 60)
# Define datasets and their normalization functions
datasets_config = [
("finance_instruct_500k", normalize_finance_instruct_500k),
("sujet_finance_177k", normalize_sujet_finance_177k),
("financial_qa_10k", normalize_financial_qa_10k),
("fingpt_convfinqa", normalize_fingpt_convfinqa),
("earnings_calls_qa", normalize_earnings_calls_qa),
]
all_samples = []
total_valid = 0
total_filtered = 0
for dataset_name, normalize_fn in datasets_config:
dataset_path = raw_dir / dataset_name
if not dataset_path.exists():
print(f" {dataset_name}: SKIPPED (directory not found)")
continue
try:
print(f"\nProcessing {dataset_name}...")
# Load dataset from disk
dataset = load_from_disk(str(dataset_path))
# Process the dataset
normalized_samples = []
valid_count = 0
filtered_count = 0
# Handle both single split and multiple splits
if isinstance(dataset, dict):
splits = list(dataset.keys())
else:
splits = ["train"]
for split in splits:
split_data = dataset[split] if isinstance(dataset, dict) else dataset
for sample in split_data:
normalized = normalize_fn(sample)
if normalized is not None:
normalized_samples.append(normalized)
valid_count += 1
else:
filtered_count += 1
print(f" {dataset_name}: {valid_count} valid, {filtered_count} filtered")
all_samples.extend(normalized_samples)
total_valid += valid_count
total_filtered += filtered_count
except Exception as e:
print(f" {dataset_name}: ERROR - {type(e).__name__}: {str(e)[:100]}")
# Add system prompt to all samples
print("\nAdding system prompt to all samples...")
for sample in all_samples:
sample["messages"][0]["content"] = system_prompt
# Write to output file
output_path = processed_dir / "layer1.jsonl"
print(f"\nWriting {len(all_samples)} samples to {output_path}...")
with open(output_path, "w", encoding="utf-8") as f:
for sample in all_samples:
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
# Print summary
print("\n" + "=" * 60)
print("NORMALIZATION SUMMARY")
print("=" * 60)
print(f"Total valid samples: {total_valid}")
print(f"Total filtered samples: {total_filtered}")
print(f"Output file: {output_path}")
print(f"Output file size: {output_path.stat().st_size / (1024*1024):.2f} MB")
if __name__ == "__main__":
main()