# src/data_loader.py import json import ast import numpy as np from datasets import load_dataset from difflib import SequenceMatcher # --- CONFIGURATION --- LABEL_MAPPING = { # Vendor/Company "seller": "COMPANY", "store_name": "COMPANY", # Address "store_addr": "ADDRESS", # Date "date": "DATE", "invoice_date": "DATE", # Total "total": "TOTAL", "total_gross_worth": "TOTAL", # Receipt Number / Invoice No "invoice_no": "INVOICE_NO", # Bill To / Client "client": "BILL_TO" } def safe_parse(content): """Robustly parses input that might be a list, a JSON string, or a Python string literal.""" if isinstance(content, list): return content if isinstance(content, str): try: return json.loads(content) except json.JSONDecodeError: pass try: return ast.literal_eval(content) except (ValueError, SyntaxError): pass return [] def normalize_box(box, width, height): """Converts 8-point polygons to 4-point normalized [0-1000] bbox.""" try: # Handle nested format variations if isinstance(box, list) and len(box) == 2 and isinstance(box[0], list): polygon = box[0] elif isinstance(box, list) and len(box) == 4 and isinstance(box[0], list): polygon = box else: return None xs = [point[0] for point in polygon] ys = [point[1] for point in polygon] return [ int(max(0, min(1000 * (min(xs) / width), 1000))), int(max(0, min(1000 * (min(ys) / height), 1000))), int(max(0, min(1000 * (max(xs) / width), 1000))), int(max(0, min(1000 * (max(ys) / height), 1000))) ] except Exception: return None def tokenize_and_spread_boxes(words, boxes): """ Splits phrases into individual words and duplicates the bounding box. Input: ['Invoice #123'], [BOX_A] Output: ['Invoice', '#123'], [BOX_A, BOX_A] """ tokenized_words = [] tokenized_boxes = [] for word, box in zip(words, boxes): # Split by whitespace sub_words = str(word).split() for sw in sub_words: tokenized_words.append(sw) tokenized_boxes.append(box) return tokenized_words, tokenized_boxes def align_labels(ocr_words, label_map): """Matches OCR words to Ground Truth values using Sub-sequence Matching.""" tags = ["O"] * len(ocr_words) for target_text, label_class in label_map.items(): if not target_text: continue target_tokens = str(target_text).split() if not target_tokens: continue n_target = len(target_tokens) # Sliding window search for i in range(len(ocr_words) - n_target + 1): window = ocr_words[i : i + n_target] # Check match match = True for j in range(n_target): # Clean punctuation for comparison w_clean = window[j].strip(".,-:") t_clean = target_tokens[j].strip(".,-:") if w_clean not in t_clean and t_clean not in w_clean: match = False break if match: tags[i] = f"B-{label_class}" for k in range(1, n_target): tags[i + k] = f"I-{label_class}" return tags def load_unified_dataset(split="train", sample_size=None): print(f"🔄 Loading dataset 'mychen76/invoices-and-receipts_ocr_v1' ({split})...") dataset = load_dataset("mychen76/invoices-and-receipts_ocr_v1", split=split) if sample_size: dataset = dataset.select(range(sample_size)) processed_data = [] print("⚙️ Processing, Tokenizing, and Aligning...") for example in dataset: try: image = example['image'] if image.mode != "RGB": image = image.convert("RGB") width, height = image.size # 1. Parse Raw OCR raw_words = safe_parse(json.loads(example['raw_data']).get('ocr_words')) raw_boxes = safe_parse(json.loads(example['raw_data']).get('ocr_boxes')) if not raw_words or not raw_boxes or len(raw_words) != len(raw_boxes): continue # 2. Normalize Boxes first norm_boxes = [] valid_words = [] for i, box in enumerate(raw_boxes): nb = normalize_box(box, width, height) if nb: norm_boxes.append(nb) valid_words.append(raw_words[i]) # 3. TOKENIZE (The Fix) final_words, final_boxes = tokenize_and_spread_boxes(valid_words, norm_boxes) # 4. Map Labels parsed_json = json.loads(example['parsed_data']) fields = safe_parse(parsed_json.get('json', {})) label_value_map = {} if isinstance(fields, dict): for k, v in fields.items(): if k in LABEL_MAPPING and v: label_value_map[v] = LABEL_MAPPING[k] # 5. Align Labels final_tags = align_labels(final_words, label_value_map) # Only keep if we found at least one entity (cleaner training data) unique_tags = set(final_tags) if len(unique_tags) > 1: processed_data.append({ "image": image, "words": final_words, "bboxes": final_boxes, "ner_tags": final_tags }) except Exception: continue print(f"✅ Successfully processed {len(processed_data)} examples.") return processed_data if __name__ == "__main__": # Test run data = load_unified_dataset(sample_size=20) if len(data) > 0: print(f"\nSample 0 Words: {data[0]['words'][:10]}...") print(f"Sample 0 Tags: {data[0]['ner_tags'][:10]}...") all_tags = [t for item in data for t in item['ner_tags']] unique_tags = set(all_tags) print(f"\nUnique Tags Found in Sample: {unique_tags}") else: print("No valid examples found in sample.")