Financial_RAG / data_loader.py
amaherovskyi's picture
Update data_loader.py
b27716d verified
from datasets import load_dataset
from typing import List, Dict, Any
def preprocess_context(text: str) -> str:
"""Text cleaning: remove line breaks and extra spaces."""
if text is None:
return ""
return " ".join(text.split())
def preprocess_table(table: Any) -> str:
"""Convert the table to plain text for retriever."""
if table is None:
return ""
return str(table)
def create_documents(dataset: List[Dict[str, Any]], chunk_size: int = 500) -> List[Dict[str, Any]]:
"""
Creates documents from context chunks and a table.
Args:
dataset: list of examples from the dataset.
chunk_size: chunk size in characters.
Returns:
List of documents in the format:
[{"id": ..., "text": ..., "metadata": {...}}, ...]
"""
documents = []
for example in dataset:
context_text = preprocess_context(example.get('context'))
table_text = preprocess_table(example.get('table'))
full_text = context_text
if table_text:
full_text += "\nTable:\n" + table_text
for i in range(0, len(full_text), chunk_size):
chunk_text = full_text[i:i+chunk_size]
documents.append({
"id": f"{example['id']}_{i // chunk_size}",
"text": chunk_text,
"metadata": {
"question": example.get('question'),
"original_answer": example.get('original_answer'),
"file_name": example.get('file_name'),
"company_name": example.get('company_name'),
"report_year": example.get('report_year')
}
})
return documents
def load_finqa_dataset(split: str = "train") -> List[Dict[str, Any]]:
"""Loads the FinQA dataset and returns a list of documents."""
ds = load_dataset("G4KMU/t2-ragbench", "FinQA")[split]
return create_documents(ds)