#!/usr/bin/env python3 """ Normalize Layer 1 financial instruction datasets into ChatML JSONL format. Handles 5 financial instruction datasets with different column naming conventions: - finance_instruct_500k (Josephgflowers/Finance-Instruct-500k) - sujet_finance_177k (sujet-ai/Sujet-Finance-Instruct-177k) - financial_qa_10k (virattt/financial-qa-10K) - fingpt_convfinqa (FinGPT/fingpt-convfinqa) - earnings_calls_qa (lamini/earnings-calls-qa) Converts to unified ChatML format with system prompt, user message, and assistant response. Filters out low-quality samples based on message length thresholds. """ import json import os from pathlib import Path from datasets import load_from_disk from typing import Optional, Dict, List, Tuple def get_system_prompt(data_dir: Path) -> str: """Load the CFO system prompt from file.""" prompt_path = data_dir / "cfo_system_prompt.txt" with open(prompt_path, "r", encoding="utf-8") as f: return f.read().strip() def extract_field(sample: Dict, possible_names: List[str], default: str = "") -> str: """ Extract a field from a sample using multiple possible column names with fallbacks. Args: sample: The data sample dictionary possible_names: List of possible column names to try default: Default value if none of the names are found Returns: The field value as a string, or default if not found """ for name in possible_names: if name in sample: value = sample[name] if value is not None: return str(value).strip() return default def is_valid_sample(user_content: str, assistant_content: str, min_user_len: int = 10, min_assistant_len: int = 20) -> bool: """ Check if a sample meets quality thresholds. Args: user_content: The user message content assistant_content: The assistant response content min_user_len: Minimum acceptable length for user message min_assistant_len: Minimum acceptable length for assistant response Returns: True if sample meets quality thresholds, False otherwise """ return (len(user_content) >= min_user_len and len(assistant_content) >= min_assistant_len) def normalize_finance_instruct_500k(sample: Dict) -> Optional[Dict]: """ Normalize finance_instruct_500k dataset. Likely columns: instruction/context/output or input/output """ # Try instruction + context + output pattern instruction = extract_field(sample, ["instruction", "input", "prompt"]) context = extract_field(sample, ["context", "background", ""]) output = extract_field(sample, ["output", "response", "answer"]) # Combine context with instruction if available if context: user_content = f"{instruction}\n\nContext: {context}".strip() else: user_content = instruction if not is_valid_sample(user_content, output): return None return { "messages": [ {"role": "system", "content": ""}, # Will be filled later {"role": "user", "content": user_content}, {"role": "assistant", "content": output} ] } def normalize_sujet_finance_177k(sample: Dict) -> Optional[Dict]: """ Normalize sujet_finance_177k dataset. Likely columns: instruction/output """ instruction = extract_field(sample, ["instruction", "input", "question"]) output = extract_field(sample, ["output", "response", "answer"]) if not is_valid_sample(instruction, output): return None return { "messages": [ {"role": "system", "content": ""}, {"role": "user", "content": instruction}, {"role": "assistant", "content": output} ] } def normalize_financial_qa_10k(sample: Dict) -> Optional[Dict]: """ Normalize financial_qa_10k dataset. Likely columns: question/answer/context """ question = extract_field(sample, ["question", "query", "input"]) answer = extract_field(sample, ["answer", "response", "output"]) context = extract_field(sample, ["context", "background", "document"]) # Combine context with question if available if context: user_content = f"{question}\n\nDocument context: {context}".strip() else: user_content = question if not is_valid_sample(user_content, answer): return None return { "messages": [ {"role": "system", "content": ""}, {"role": "user", "content": user_content}, {"role": "assistant", "content": answer} ] } def normalize_fingpt_convfinqa(sample: Dict) -> Optional[Dict]: """ Normalize fingpt_convfinqa dataset. Likely columns: input/output """ user_input = extract_field(sample, ["input", "instruction", "question"]) output = extract_field(sample, ["output", "response", "answer"]) if not is_valid_sample(user_input, output): return None return { "messages": [ {"role": "system", "content": ""}, {"role": "user", "content": user_input}, {"role": "assistant", "content": output} ] } def normalize_earnings_calls_qa(sample: Dict) -> Optional[Dict]: """ Normalize earnings_calls_qa dataset. Likely columns: question/answer """ question = extract_field(sample, ["question", "query", "input"]) answer = extract_field(sample, ["answer", "response", "output"]) if not is_valid_sample(question, answer): return None return { "messages": [ {"role": "system", "content": ""}, {"role": "user", "content": question}, {"role": "assistant", "content": answer} ] } def process_dataset(dataset_name: str, dataset, normalize_fn) -> Tuple[int, int]: """ Process a single dataset and return (valid_count, filtered_count). Args: dataset_name: Name of the dataset for logging dataset: The loaded dataset object normalize_fn: Function to normalize samples from this dataset Returns: Tuple of (number of valid samples, number of filtered samples) """ valid_count = 0 filtered_count = 0 # Handle both single split and multiple splits if isinstance(dataset, dict): splits = list(dataset.keys()) else: splits = ["train"] if hasattr(dataset, "__len__") else [] for split in splits: split_data = dataset[split] if isinstance(dataset, dict) else dataset for sample in split_data: normalized = normalize_fn(sample) if normalized is not None: normalized_samples.append(normalized) valid_count += 1 else: filtered_count += 1 print(f" {dataset_name}: {valid_count} valid, {filtered_count} filtered") return valid_count, filtered_count def main(): """Main normalization pipeline.""" # Setup paths script_dir = Path(__file__).parent raw_dir = script_dir / "raw" processed_dir = script_dir / "processed" processed_dir.mkdir(exist_ok=True) # Load system prompt system_prompt = get_system_prompt(script_dir) print(f"Loading datasets from: {raw_dir}") print(f"Output will be saved to: {processed_dir / 'layer1.jsonl'}") print("-" * 60) # Define datasets and their normalization functions datasets_config = [ ("finance_instruct_500k", normalize_finance_instruct_500k), ("sujet_finance_177k", normalize_sujet_finance_177k), ("financial_qa_10k", normalize_financial_qa_10k), ("fingpt_convfinqa", normalize_fingpt_convfinqa), ("earnings_calls_qa", normalize_earnings_calls_qa), ] all_samples = [] total_valid = 0 total_filtered = 0 for dataset_name, normalize_fn in datasets_config: dataset_path = raw_dir / dataset_name if not dataset_path.exists(): print(f" {dataset_name}: SKIPPED (directory not found)") continue try: print(f"\nProcessing {dataset_name}...") # Load dataset from disk dataset = load_from_disk(str(dataset_path)) # Process the dataset normalized_samples = [] valid_count = 0 filtered_count = 0 # Handle both single split and multiple splits if isinstance(dataset, dict): splits = list(dataset.keys()) else: splits = ["train"] for split in splits: split_data = dataset[split] if isinstance(dataset, dict) else dataset for sample in split_data: normalized = normalize_fn(sample) if normalized is not None: normalized_samples.append(normalized) valid_count += 1 else: filtered_count += 1 print(f" {dataset_name}: {valid_count} valid, {filtered_count} filtered") all_samples.extend(normalized_samples) total_valid += valid_count total_filtered += filtered_count except Exception as e: print(f" {dataset_name}: ERROR - {type(e).__name__}: {str(e)[:100]}") # Add system prompt to all samples print("\nAdding system prompt to all samples...") for sample in all_samples: sample["messages"][0]["content"] = system_prompt # Write to output file output_path = processed_dir / "layer1.jsonl" print(f"\nWriting {len(all_samples)} samples to {output_path}...") with open(output_path, "w", encoding="utf-8") as f: for sample in all_samples: f.write(json.dumps(sample, ensure_ascii=False) + "\n") # Print summary print("\n" + "=" * 60) print("NORMALIZATION SUMMARY") print("=" * 60) print(f"Total valid samples: {total_valid}") print(f"Total filtered samples: {total_filtered}") print(f"Output file: {output_path}") print(f"Output file size: {output_path.stat().st_size / (1024*1024):.2f} MB") if __name__ == "__main__": main()