OCR_Dashboard / calculate_ocr_accuracy.py
tyasnk's picture
Upload folder using huggingface_hub
34844c6 verified
#!/usr/bin/env python3
"""
Calculate OCR accuracy using edit distance (Levenshtein distance)
between label and predict data from JSON files or BigQuery SQL queries.
Supports two modes:
1. Separate files: Provide --label and --predict files (records matched by keys)
2. Combined data: Provide --data file with label fields having a suffix (e.g., _label)
Example with combined data (data.sql):
python calculate_ocr_accuracy.py --data scripts/data.sql --normalize
Example with separate files:
python calculate_ocr_accuracy.py --label labels.json --predict predicts.json --normalize
"""
import json
import sys
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
from collections import defaultdict
try:
from google.cloud import bigquery
BIGQUERY_AVAILABLE = True
except ImportError:
BIGQUERY_AVAILABLE = False
print("Warning: google-cloud-bigquery not installed. SQL query support disabled.")
def levenshtein_distance(s1: str, s2: str) -> int:
"""
Calculate Levenshtein distance between two strings.
Args:
s1: First string
s2: Second string
Returns:
Edit distance (minimum number of single-character edits)
"""
if len(s1) < len(s2):
return levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
def normalize_string(s: str) -> str:
"""
Normalize string for comparison (convert to lowercase, strip whitespace).
Args:
s: Input string
Returns:
Normalized string
"""
return s.lower().strip()
def calculate_field_accuracy(
label_value: str,
predict_value: str,
normalize: bool = False
) -> Tuple[int, float, bool]:
"""
Calculate accuracy metrics for a single field.
Args:
label_value: Ground truth value
predict_value: Predicted value
normalize: Whether to normalize strings before comparison
Returns:
Tuple of (edit_distance, normalized_accuracy, exact_match)
"""
if normalize:
label_norm = normalize_string(label_value)
predict_norm = normalize_string(predict_value)
else:
label_norm = label_value
predict_norm = predict_value
# Exact match check
exact_match = label_norm == predict_norm
# Calculate edit distance
edit_dist = levenshtein_distance(label_norm, predict_norm)
# Calculate normalized accuracy
# Accuracy = 1 - (edit_distance / max_length)
max_len = max(len(label_norm), len(predict_norm), 1)
normalized_accuracy = 1.0 - (edit_dist / max_len) if max_len > 0 else 1.0
return edit_dist, normalized_accuracy, exact_match
def execute_bigquery_query(query_file: str, project_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""
Execute a BigQuery SQL query from a file and return results as list of dictionaries.
Args:
query_file: Path to SQL query file
project_id: GCP project ID (if None, uses default from environment)
Returns:
List of dictionaries, one per row
"""
if not BIGQUERY_AVAILABLE:
raise ImportError("google-cloud-bigquery is required for SQL query support. "
"Install it with: poetry add google-cloud-bigquery")
# Read SQL query from file
query_path = Path(query_file)
if not query_path.exists():
raise FileNotFoundError(f"Query file not found: {query_file}")
with open(query_path, 'r', encoding='utf-8') as f:
query = f.read()
# Initialize BigQuery client
if project_id:
client = bigquery.Client(project=project_id)
else:
client = bigquery.Client()
# Execute query
print(f"Executing BigQuery query from: {query_file}")
query_job = client.query(query)
results = query_job.result()
# Convert to list of dictionaries
records = []
for row in results:
record = {}
for key, value in row.items():
# Convert None to empty string for consistency
record[key] = str(value) if value is not None else ''
records.append(record)
print(f"Retrieved {len(records)} records from BigQuery")
return records
def load_data_from_file(file_path: str, project_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""
Load data from either JSON file or SQL query file.
Args:
file_path: Path to JSON file or SQL query file
project_id: GCP project ID for BigQuery queries (optional)
Returns:
List of dictionaries
"""
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
# Check if it's a SQL file
if path.suffix.lower() == '.sql':
return execute_bigquery_query(file_path, project_id)
else:
# Assume JSON file
with open(path, 'r', encoding='utf-8') as f:
return json.load(f)
def match_records(
labels: List[Dict[str, Any]],
predicts: List[Dict[str, Any]],
match_keys: List[str] = None
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Match label and predict records by common keys (e.g., filename, project_id_mother).
Args:
labels: List of label records
predicts: List of predict records
match_keys: List of field names to use for matching (default: ['filename', 'project_id_mother'])
Returns:
Tuple of (matched_labels, matched_predicts) with same length and order
"""
if match_keys is None:
match_keys = ['filename', 'project_id_mother']
# Find keys that exist in both datasets
available_keys = []
if labels and predicts:
label_keys = set(labels[0].keys())
predict_keys = set(predicts[0].keys())
for key in match_keys:
if key in label_keys and key in predict_keys:
available_keys.append(key)
break
if not available_keys:
print("Warning: No matching keys found. Matching by index order.")
min_len = min(len(labels), len(predicts))
return labels[:min_len], predicts[:min_len]
match_key = available_keys[0]
print(f"Matching records by key: {match_key}")
# Create index of predict records by match key
predict_index = {}
for predict_record in predicts:
key_value = str(predict_record.get(match_key, ''))
if key_value:
predict_index[key_value] = predict_record
# Match labels with predicts
matched_labels = []
matched_predicts = []
for label_record in labels:
key_value = str(label_record.get(match_key, ''))
if key_value in predict_index:
matched_labels.append(label_record)
matched_predicts.append(predict_index[key_value])
print(f"Matched {len(matched_labels)} records out of {len(labels)} labels and {len(predicts)} predicts")
return matched_labels, matched_predicts
def calculate_ocr_accuracy_from_combined_data(
data_file: str,
normalize: bool = False,
project_id: Optional[str] = None,
label_suffix: str = '_label'
) -> Dict:
"""
Calculate OCR accuracy from a single data source where label and predict
fields are in the same records (labels have a suffix like '_label').
This function is designed for queries like data.sql that return combined
label and predict data in a single result set.
Args:
data_file: Path to JSON file or SQL query file containing combined data
normalize: Whether to normalize strings before comparison
project_id: GCP project ID for BigQuery queries (optional)
label_suffix: Suffix used to identify label fields (default: '_label')
Returns:
Dictionary containing accuracy metrics per field
"""
# Load combined data
records = load_data_from_file(data_file, project_id)
if len(records) == 0:
print("Warning: No records found in data file")
return {}
# Identify all fields and separate label/predict pairs
all_fields = set(records[0].keys())
# Find predict fields (those without label suffix)
# and their corresponding label fields
predict_fields = set()
field_pairs = {} # predict_field_name -> label_field_name
for field in all_fields:
if field.endswith(label_suffix):
# This is a label field
predict_field = field[:-len(label_suffix)]
if predict_field in all_fields:
# Both label and predict fields exist
predict_fields.add(predict_field)
field_pairs[predict_field] = field
elif field + label_suffix in all_fields:
# This is a predict field with a corresponding label
predict_fields.add(field)
field_pairs[field] = field + label_suffix
# Exclude metadata columns
exclude_fields = {'filename', 'project_id_mother', 'id', 'rn'}
predict_fields = {f for f in predict_fields if f not in exclude_fields}
if not predict_fields:
print("Warning: No matching field pairs found. "
f"Looking for fields with '{label_suffix}' suffix.")
return {}
print(f"Found {len(predict_fields)} field pairs to compare")
print(f"Fields: {', '.join(sorted(predict_fields))}")
# Initialize statistics per field
field_stats = defaultdict(lambda: {
'total_records': 0,
'exact_matches': 0,
'total_edit_distance': 0,
'total_normalized_accuracy': 0.0,
'min_accuracy': 1.0,
'max_edit_distance': 0,
'examples': [] # Store examples of mismatches
})
# Process each record
for record_idx, record in enumerate(records):
for predict_field in predict_fields:
label_field = field_pairs[predict_field]
label_value = str(record.get(label_field, ''))
predict_value = str(record.get(predict_field, ''))
# Skip if both label and predict are empty
label_stripped = label_value.strip()
predict_stripped = predict_value.strip()
if not label_stripped and not predict_stripped:
continue
edit_dist, normalized_acc, exact_match = calculate_field_accuracy(
label_value, predict_value, normalize
)
stats = field_stats[predict_field]
stats['total_records'] += 1
stats['total_edit_distance'] += edit_dist
stats['total_normalized_accuracy'] += normalized_acc
stats['max_edit_distance'] = max(stats['max_edit_distance'], edit_dist)
stats['min_accuracy'] = min(stats['min_accuracy'], normalized_acc)
if exact_match:
stats['exact_matches'] += 1
# Store example if there's a mismatch
if not exact_match and len(stats['examples']) < 3:
stats['examples'].append({
'record_idx': record_idx,
'label': label_value,
'predict': predict_value,
'edit_distance': edit_dist,
'accuracy': normalized_acc
})
# Calculate final statistics
results = {}
for predict_field in sorted(predict_fields):
stats = field_stats[predict_field]
total = stats['total_records']
if total == 0:
continue
results[predict_field] = {
'exact_match_rate': stats['exact_matches'] / total,
'average_edit_distance': stats['total_edit_distance'] / total,
'average_normalized_accuracy': stats['total_normalized_accuracy'] / total,
'min_accuracy': stats['min_accuracy'],
'max_edit_distance': stats['max_edit_distance'],
'exact_matches': stats['exact_matches'],
'total_records': total,
'examples': stats['examples']
}
return results
def calculate_ocr_accuracy(
label_file: str,
predict_file: str,
normalize: bool = False,
match_keys: List[str] = None,
project_id: Optional[str] = None
) -> Dict:
"""
Calculate OCR accuracy per field between label and predict data.
Supports both JSON files and BigQuery SQL query files.
Args:
label_file: Path to label JSON file or SQL query file
predict_file: Path to predict JSON file or SQL query file
normalize: Whether to normalize strings before comparison
match_keys: List of field names to use for matching records
project_id: GCP project ID for BigQuery queries (optional)
Returns:
Dictionary containing accuracy metrics per field
"""
# Load data from files (JSON or SQL)
labels = load_data_from_file(label_file, project_id)
predicts = load_data_from_file(predict_file, project_id)
# Match records by common keys
labels, predicts = match_records(labels, predicts, match_keys)
# Validate that both files have the same number of records
if len(labels) != len(predicts):
print(f"Warning: Label file has {len(labels)} records, "
f"predict file has {len(predicts)} records")
min_records = min(len(labels), len(predicts))
labels = labels[:min_records]
predicts = predicts[:min_records]
# Get all field names from the first record
if len(labels) == 0:
return {}
# Exclude metadata columns from accuracy calculation
exclude_fields = {'filename', 'project_id_mother'}
field_names = [f for f in labels[0].keys() if f not in exclude_fields]
if not field_names:
print("Warning: No fields to calculate accuracy for (all fields excluded)")
return {}
print(f"Calculating accuracy for {len(field_names)} fields (excluding: {', '.join(exclude_fields)})")
# Initialize statistics per field
field_stats = defaultdict(lambda: {
'total_records': 0,
'exact_matches': 0,
'total_edit_distance': 0,
'total_normalized_accuracy': 0.0,
'min_accuracy': 1.0,
'max_edit_distance': 0,
'examples': [] # Store examples of mismatches
})
# Process each record
for record_idx, (label_record, predict_record) in enumerate(zip(labels, predicts)):
for field_name in field_names:
label_value = str(label_record.get(field_name, ''))
predict_value = str(predict_record.get(field_name, ''))
# Skip if both label and predict are empty
label_stripped = label_value.strip()
predict_stripped = predict_value.strip()
if not label_stripped and not predict_stripped:
continue
edit_dist, normalized_acc, exact_match = calculate_field_accuracy(
label_value, predict_value, normalize
)
stats = field_stats[field_name]
stats['total_records'] += 1
stats['total_edit_distance'] += edit_dist
stats['total_normalized_accuracy'] += normalized_acc
stats['max_edit_distance'] = max(stats['max_edit_distance'], edit_dist)
stats['min_accuracy'] = min(stats['min_accuracy'], normalized_acc)
if exact_match:
stats['exact_matches'] += 1
# Store example if there's a mismatch
if not exact_match and len(stats['examples']) < 3:
stats['examples'].append({
'record_idx': record_idx,
'label': label_value,
'predict': predict_value,
'edit_distance': edit_dist,
'accuracy': normalized_acc
})
# Calculate final statistics
results = {}
for field_name in field_names:
stats = field_stats[field_name]
total = stats['total_records']
if total == 0:
continue
results[field_name] = {
'exact_match_rate': stats['exact_matches'] / total,
'average_edit_distance': stats['total_edit_distance'] / total,
'average_normalized_accuracy': stats['total_normalized_accuracy'] / total,
'min_accuracy': stats['min_accuracy'],
'max_edit_distance': stats['max_edit_distance'],
'exact_matches': stats['exact_matches'],
'total_records': total,
'examples': stats['examples']
}
return results
def print_results(results: Dict, output_file: str = None):
"""
Print accuracy results in a formatted way.
Args:
results: Dictionary containing accuracy metrics per field
output_file: Optional file path to save results
"""
output_lines = []
# Header
header = "=" * 100
output_lines.append(header)
output_lines.append("OCR ACCURACY REPORT (Edit Distance Analysis)")
output_lines.append(header)
output_lines.append("")
# Sort fields by average normalized accuracy (descending)
sorted_fields = sorted(
results.items(),
key=lambda x: x[1]['average_normalized_accuracy'],
reverse=True
)
# Summary statistics
total_fields = len(results)
avg_exact_match = sum(r['exact_match_rate'] for r in results.values()) / total_fields
avg_normalized_acc = sum(r['average_normalized_accuracy'] for r in results.values()) / total_fields
output_lines.append("SUMMARY STATISTICS")
output_lines.append("-" * 100)
output_lines.append(f"Total Fields Analyzed: {total_fields}")
output_lines.append(f"Overall Exact Match Rate: {avg_exact_match:.2%}")
output_lines.append(f"Overall Average Normalized Accuracy: {avg_normalized_acc:.2%}")
output_lines.append("")
# Per-field statistics
output_lines.append("PER-FIELD STATISTICS")
output_lines.append("-" * 100)
output_lines.append(f"{'Field Name':<50} {'Exact Match':<15} {'Avg Accuracy':<15} {'Avg Edit Dist':<15}")
output_lines.append("-" * 100)
for field_name, stats in sorted_fields:
exact_match_pct = stats['exact_match_rate'] * 100
avg_acc = stats['average_normalized_accuracy'] * 100
avg_edit = stats['average_edit_distance']
output_lines.append(
f"{field_name:<50} {exact_match_pct:>6.2f}% ({stats['exact_matches']}/{stats['total_records']}) "
f"{avg_acc:>6.2f}% {avg_edit:>6.2f}"
)
output_lines.append("")
output_lines.append("")
# Detailed examples for fields with errors
output_lines.append("EXAMPLES OF MISMATCHES (Top 3 per field)")
output_lines.append("-" * 100)
for field_name, stats in sorted_fields:
if stats['exact_matches'] < stats['total_records'] and stats['examples']:
output_lines.append(f"\nField: {field_name}")
output_lines.append(f" Exact Match Rate: {stats['exact_match_rate']:.2%}")
output_lines.append(f" Average Accuracy: {stats['average_normalized_accuracy']:.2%}")
for example in stats['examples']:
output_lines.append(f" Record {example['record_idx']}:")
output_lines.append(f" Label: '{example['label']}'")
output_lines.append(f" Predict: '{example['predict']}'")
output_lines.append(f" Edit Distance: {example['edit_distance']}, "
f"Accuracy: {example['accuracy']:.2%}")
# Print to console
output_text = "\n".join(output_lines)
print(output_text)
# Save to file if specified
if output_file:
with open(output_file, 'w', encoding='utf-8') as f:
f.write(output_text)
print(f"\nResults saved to: {output_file}")
def main():
"""Main function to run the OCR accuracy calculation."""
import argparse
parser = argparse.ArgumentParser(
description='Calculate OCR accuracy using edit distance between label and predict data. '
'Supports JSON files and BigQuery SQL query files (.sql). '
'Can use either two separate files (--label and --predict) or a single '
'combined file (--data) with label fields having a suffix (e.g., _label).'
)
# Mode selection: either combined data or separate label/predict files
input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument(
'--data',
type=str,
help='Path to combined data file (JSON or SQL) with label and predict fields in same records. '
'Label fields should have a suffix (default: _label). Use this for queries like data.sql'
)
input_group.add_argument(
'--label',
type=str,
help='Path to label JSON file or SQL query file (use with --predict)'
)
parser.add_argument(
'--predict',
type=str,
help='Path to predict JSON file or SQL query file (use with --label)'
)
parser.add_argument(
'--output',
type=str,
default=None,
help='Path to output file for results (optional)'
)
parser.add_argument(
'--normalize',
action='store_true',
help='Normalize strings (lowercase, strip) before comparison'
)
parser.add_argument(
'--json-output',
type=str,
default=None,
help='Path to save results as JSON (optional)'
)
parser.add_argument(
'--match-keys',
type=str,
nargs='+',
default=['filename', 'project_id_mother'],
help='Field names to use for matching records (default: filename project_id_mother). '
'Only used with --label/--predict mode.'
)
parser.add_argument(
'--project-id',
type=str,
default=None,
help='GCP project ID for BigQuery (default: uses default from environment)'
)
parser.add_argument(
'--label-suffix',
type=str,
default='_label',
help='Suffix used to identify label fields in combined data mode (default: _label)'
)
args = parser.parse_args()
# Validate arguments based on mode
if args.data:
# Combined data mode
data_path = Path(args.data)
if not data_path.exists():
print(f"Error: Data file not found: {data_path}")
sys.exit(1)
print(f"Calculating OCR accuracy from combined data...")
print(f"Data file: {data_path}")
print(f"Label suffix: {args.label_suffix}")
print(f"Normalize: {args.normalize}")
print()
try:
results = calculate_ocr_accuracy_from_combined_data(
str(data_path),
normalize=args.normalize,
project_id=args.project_id,
label_suffix=args.label_suffix
)
# Print results
print_results(results, args.output)
# Save JSON output if requested
if args.json_output:
with open(args.json_output, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"\nJSON results saved to: {args.json_output}")
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
else:
# Separate label/predict files mode
if not args.label or not args.predict:
parser.error("--label and --predict are required when not using --data mode")
label_path = Path(args.label)
predict_path = Path(args.predict)
if not label_path.exists():
print(f"Error: Label file not found: {label_path}")
sys.exit(1)
if not predict_path.exists():
print(f"Error: Predict file not found: {predict_path}")
sys.exit(1)
# Calculate accuracy
print(f"Calculating OCR accuracy...")
print(f"Label file: {label_path}")
print(f"Predict file: {predict_path}")
print(f"Normalize: {args.normalize}")
print(f"Match keys: {args.match_keys}")
print()
try:
results = calculate_ocr_accuracy(
str(label_path),
str(predict_path),
normalize=args.normalize,
match_keys=args.match_keys,
project_id=args.project_id
)
# Print results
print_results(results, args.output)
# Save JSON output if requested
if args.json_output:
with open(args.json_output, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"\nJSON results saved to: {args.json_output}")
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == '__main__':
main()