finance-entity-extractor / scripts /convert_to_instruction.py
Ranjit0034's picture
Upload scripts/convert_to_instruction.py with huggingface_hub
fef7470 verified
#!/usr/bin/env python3
"""
Instruction-Format Training Data Converter
==========================================
Converts existing training data to instruction-following format
compatible with Llama 3.1 / Qwen2.5 fine-tuning.
Formats:
1. Alpaca format (instruction/input/output)
2. ChatML format (messages with roles)
3. Llama 3 format (native template)
"""
import json
import argparse
from pathlib import Path
from typing import List, Dict, Optional
from dataclasses import dataclass, asdict
import random
# System prompts for different tasks
SYSTEM_PROMPTS = {
"extraction": """You are a financial entity extraction assistant specialized in Indian banking messages.
Your task is to extract structured information from SMS, email, or bank statements.
Always output valid JSON with these fields (only include if found in the message):
- amount: float (the transaction amount in INR)
- type: "debit" or "credit"
- account: string (last 4 digits of account number)
- bank: string (bank name)
- date: string (transaction date in YYYY-MM-DD format)
- time: string (transaction time if present)
- reference: string (UPI/NEFT/IMPS reference number)
- merchant: string (business/merchant name for P2M transactions)
- beneficiary: string (person name for P2P transfers)
- vpa: string (UPI ID/VPA)
- category: string (food, shopping, travel, bills, investment, transfer, etc.)
- is_p2m: boolean (true if merchant, false if person-to-person)
- balance: float (available balance after transaction)
- status: string (success, failed, pending)
Be precise. Extract exactly what's present in the message.""",
"categorization": """You are a financial categorization assistant.
Given a transaction or merchant name, categorize it into one of these categories:
- food: Restaurants, food delivery (Swiggy, Zomato)
- grocery: Supermarkets, grocery stores (BigBasket, Zepto)
- shopping: E-commerce, retail (Amazon, Flipkart)
- transport: Ride-hailing, fuel (Uber, Ola, petrol)
- travel: Flights, hotels, trains (IRCTC, MakeMyTrip)
- bills: Utilities, recharges (Airtel, electricity)
- entertainment: Movies, streaming (Netflix, BookMyShow)
- healthcare: Pharmacy, hospitals (Apollo, PharmEasy)
- investment: Stocks, mutual funds (Zerodha, Groww)
- transfer: P2P money transfers
- salary: Income credits
- emi: Loan EMI payments
- other: Uncategorized
Output a single category name.""",
"analysis": """You are a financial analysis assistant.
Help users understand their spending patterns, detect anomalies, and provide insights.
Use the provided transaction history as context.
Be concise and data-driven in your responses."""
}
@dataclass
class AlpacaFormat:
"""Alpaca instruction format."""
instruction: str
input: str
output: str
def to_dict(self) -> Dict:
return asdict(self)
@dataclass
class ChatMessage:
"""Single chat message."""
role: str # system, user, assistant
content: str
@dataclass
class ChatMLFormat:
"""ChatML format with messages."""
messages: List[Dict[str, str]]
def to_dict(self) -> Dict:
return {"messages": self.messages}
class Llama3Format:
"""Llama 3 native format."""
@staticmethod
def format(system: str, user: str, assistant: str) -> str:
return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{system}<|eot_id|><|start_header_id|>user<|end_header_id|>
{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{assistant}<|eot_id|>"""
class InstructionConverter:
"""Convert training data to instruction format."""
def __init__(self, task: str = "extraction"):
self.task = task
self.system_prompt = SYSTEM_PROMPTS.get(task, SYSTEM_PROMPTS["extraction"])
# Instruction templates for variety
self.instructions = [
"Extract financial entities from this message:",
"Parse the financial information from this banking message:",
"Identify and extract all transaction details:",
"Extract structured data from this bank notification:",
"Parse this banking SMS and extract entities:",
"Analyze this transaction message and extract details:",
"Extract the financial entities from the following message:",
"Identify transaction details in this message:",
]
def convert_to_alpaca(self, record: Dict) -> AlpacaFormat:
"""Convert to Alpaca format."""
input_text = record.get("input", record.get("text", ""))
output = record.get("output", record.get("ground_truth", {}))
if isinstance(output, str):
try:
output = json.loads(output)
except json.JSONDecodeError:
output = {"raw": output}
# Clean output
output = {k: v for k, v in output.items() if v is not None}
instruction = random.choice(self.instructions)
return AlpacaFormat(
instruction=instruction,
input=input_text,
output=json.dumps(output, ensure_ascii=False, indent=2)
)
def convert_to_chatml(self, record: Dict) -> ChatMLFormat:
"""Convert to ChatML format."""
input_text = record.get("input", record.get("text", ""))
output = record.get("output", record.get("ground_truth", {}))
if isinstance(output, str):
try:
output = json.loads(output)
except json.JSONDecodeError:
output = {"raw": output}
output = {k: v for k, v in output.items() if v is not None}
instruction = random.choice(self.instructions)
return ChatMLFormat(messages=[
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": f"{instruction}\n\n{input_text}"},
{"role": "assistant", "content": json.dumps(output, ensure_ascii=False)}
])
def convert_to_llama3(self, record: Dict) -> str:
"""Convert to Llama 3 native format."""
input_text = record.get("input", record.get("text", ""))
output = record.get("output", record.get("ground_truth", {}))
if isinstance(output, str):
try:
output = json.loads(output)
except json.JSONDecodeError:
output = {"raw": output}
output = {k: v for k, v in output.items() if v is not None}
instruction = random.choice(self.instructions)
user_content = f"{instruction}\n\n{input_text}"
return Llama3Format.format(
system=self.system_prompt,
user=user_content,
assistant=json.dumps(output, ensure_ascii=False)
)
def convert_file(
self,
input_path: Path,
output_path: Path,
format_type: str = "chatml",
max_samples: Optional[int] = None
) -> int:
"""Convert entire file."""
print(f"Converting {input_path} to {format_type} format...")
# Load input
records = []
with open(input_path, 'r', encoding='utf-8') as f:
for line in f:
try:
records.append(json.loads(line))
except json.JSONDecodeError:
continue
print(f" Loaded {len(records):,} records")
# Shuffle and limit
random.shuffle(records)
if max_samples:
records = records[:max_samples]
# Convert
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
for record in records:
if format_type == "alpaca":
converted = self.convert_to_alpaca(record).to_dict()
elif format_type == "chatml":
converted = self.convert_to_chatml(record).to_dict()
elif format_type == "llama3":
# For Llama3, just store the formatted text
converted = {"text": self.convert_to_llama3(record)}
else:
raise ValueError(f"Unknown format: {format_type}")
f.write(json.dumps(converted, ensure_ascii=False) + '\n')
print(f" Saved {len(records):,} records to {output_path}")
return len(records)
def create_train_val_test_split(
input_path: Path,
output_dir: Path,
format_type: str = "chatml",
train_ratio: float = 0.9,
val_ratio: float = 0.05,
test_ratio: float = 0.05
):
"""Create train/val/test splits."""
print(f"Creating data splits from {input_path}...")
# Load all records
records = []
with open(input_path, 'r', encoding='utf-8') as f:
for line in f:
try:
records.append(json.loads(line))
except json.JSONDecodeError:
continue
print(f" Total records: {len(records):,}")
# Shuffle
random.shuffle(records)
# Split
n = len(records)
train_end = int(n * train_ratio)
val_end = train_end + int(n * val_ratio)
train_records = records[:train_end]
val_records = records[train_end:val_end]
test_records = records[val_end:]
print(f" Train: {len(train_records):,}")
print(f" Val: {len(val_records):,}")
print(f" Test: {len(test_records):,}")
# Convert and save each split
converter = InstructionConverter()
output_dir.mkdir(parents=True, exist_ok=True)
for name, split_records in [
("train", train_records),
("valid", val_records), # MLX uses "valid"
("test", test_records)
]:
output_path = output_dir / f"{name}.jsonl"
with open(output_path, 'w', encoding='utf-8') as f:
for record in split_records:
if format_type == "alpaca":
converted = converter.convert_to_alpaca(record).to_dict()
elif format_type == "chatml":
converted = converter.convert_to_chatml(record).to_dict()
elif format_type == "llama3":
converted = {"text": converter.convert_to_llama3(record)}
else:
raise ValueError(f"Unknown format: {format_type}")
f.write(json.dumps(converted, ensure_ascii=False) + '\n')
print(f" Saved {output_path}")
def main():
parser = argparse.ArgumentParser(description="Convert training data to instruction format")
parser.add_argument("input", help="Input JSONL file")
parser.add_argument("-o", "--output", help="Output directory")
parser.add_argument("-f", "--format", choices=["alpaca", "chatml", "llama3"],
default="chatml", help="Output format")
parser.add_argument("-n", "--max-samples", type=int, help="Max samples")
parser.add_argument("--split", action="store_true", help="Create train/val/test splits")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
args = parser.parse_args()
random.seed(args.seed)
input_path = Path(args.input)
if args.split:
output_dir = Path(args.output) if args.output else Path("data/instruction")
create_train_val_test_split(input_path, output_dir, args.format)
else:
output_path = Path(args.output) if args.output else input_path.with_suffix(f".{args.format}.jsonl")
converter = InstructionConverter()
converter.convert_file(input_path, output_path, args.format, args.max_samples)
if __name__ == "__main__":
main()