Ash2749's picture
Upload 17 files
befccc3 verified
# eval.py - OCR Evaluation Methods
# Comprehensive accuracy evaluation for OCR text extraction
import re
import difflib
from typing import Dict, List, Any
from collections import defaultdict
import unicodedata
def clean_control_characters(text: str) -> str:
"""
Remove or replace control characters that can cause JSON encoding issues.
Properly handles Bangla and other Unicode characters.
"""
if not text:
return text
# First, ensure the text is properly encoded
if isinstance(text, bytes):
try:
text = text.decode("utf-8", errors="replace")
except Exception:
text = str(text)
cleaned = ""
for char in text:
# Get Unicode category
category = unicodedata.category(char)
# Remove control characters except for common whitespace
if category.startswith("C") and char not in "\t\n\r":
# Replace with space for control characters
cleaned += " "
# Keep printable characters including Bangla unicode range
elif (
char.isprintable()
or char in "\t\n\r"
or "\u0980" <= char <= "\u09ff" # Bangla
or "\u0900" <= char <= "\u097f" # Devanagari
or "\u0600" <= char <= "\u06ff"
): # Arabic
cleaned += char
else:
# Replace unprintable characters with space
cleaned += " "
# Clean up multiple spaces and normalize
cleaned = re.sub(r"\s+", " ", cleaned)
return cleaned.strip()
def safe_json_serialize(data: Dict[str, Any]) -> Dict[str, Any]:
"""
Ensure all string values in the dictionary are safe for JSON serialization.
Handles Unicode characters properly for JSON encoding.
"""
if isinstance(data, dict):
return {key: safe_json_serialize(value) for key, value in data.items()}
elif isinstance(data, list):
return [safe_json_serialize(item) for item in data]
elif isinstance(data, str):
# Clean control characters first
cleaned = clean_control_characters(data)
# Ensure the string is JSON-safe by encoding to UTF-8 and back
try:
# Test if it can be JSON serialized
import json
json.dumps(cleaned, ensure_ascii=False)
return cleaned
except Exception:
# If there are still issues, use ASCII encoding with escape sequences
return cleaned.encode("ascii", errors="replace").decode("ascii")
else:
return data
def edit_distance(s1: str, s2: str) -> int:
"""
Calculate edit distance (Levenshtein distance) between two strings.
"""
if len(s1) < len(s2):
return edit_distance(s2, s1)
if len(s2) == 0:
return len(s1)
previous_row = list(range(len(s2) + 1))
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
def normalize_text(text: str) -> str:
"""
Normalize text for better comparison by:
- Converting to lowercase
- Removing extra whitespace
- Normalizing Unicode characters
"""
# Normalize Unicode (handles accents, special characters)
text = unicodedata.normalize("NFKD", text)
# Convert to lowercase
text = text.lower()
# Remove extra whitespace and normalize line breaks
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def calculate_character_accuracy(extracted: str, baseline: str) -> Dict[str, float]:
"""
Calculate character-level accuracy metrics.
"""
extracted_norm = normalize_text(extracted)
baseline_norm = normalize_text(baseline)
# Character-level metrics
total_chars = len(baseline_norm)
if total_chars == 0:
return {"character_accuracy": 0.0, "character_error_rate": 100.0}
# Calculate edit distance (Levenshtein distance)
edit_dist = edit_distance(extracted_norm, baseline_norm)
# Character accuracy = (total_chars - edit_distance) / total_chars
char_accuracy = max(0, (total_chars - edit_dist) / total_chars) * 100
char_error_rate = (edit_dist / total_chars) * 100
return {
"character_accuracy": round(char_accuracy, 2),
"character_error_rate": round(char_error_rate, 2),
"edit_distance": edit_dist,
"total_characters": total_chars,
}
def calculate_word_accuracy(extracted: str, baseline: str) -> Dict[str, float]:
"""
Calculate word-level accuracy metrics.
"""
extracted_words = normalize_text(extracted).split()
baseline_words = normalize_text(baseline).split()
total_words = len(baseline_words)
if total_words == 0:
return {"word_accuracy": 0.0, "word_error_rate": 100.0}
# Calculate word-level edit distance
word_edit_dist = edit_distance(" ".join(extracted_words), " ".join(baseline_words))
# Count exact word matches
extracted_set = set(extracted_words)
baseline_set = set(baseline_words)
correct_words = len(extracted_set.intersection(baseline_set))
word_accuracy = (correct_words / total_words) * 100
# Word Error Rate (WER)
word_error_rate = (word_edit_dist / total_words) * 100
return {
"word_accuracy": round(word_accuracy, 2),
"word_error_rate": round(word_error_rate, 2),
"correct_words": correct_words,
"total_words": total_words,
"missing_words": len(baseline_set - extracted_set),
"extra_words": len(extracted_set - baseline_set),
}
def calculate_line_accuracy(extracted: str, baseline: str) -> Dict[str, float]:
"""
Calculate line-level accuracy metrics.
"""
extracted_lines = [line.strip() for line in extracted.split("\n") if line.strip()]
baseline_lines = [line.strip() for line in baseline.split("\n") if line.strip()]
total_lines = len(baseline_lines)
if total_lines == 0:
return {"line_accuracy": 0.0, "lines_matched": 0}
# Calculate similarity for each line
matched_lines = 0
line_similarities = []
for i, baseline_line in enumerate(baseline_lines):
best_similarity = 0
for extracted_line in extracted_lines:
similarity = difflib.SequenceMatcher(
None, normalize_text(baseline_line), normalize_text(extracted_line)
).ratio()
best_similarity = max(best_similarity, similarity)
line_similarities.append(best_similarity)
if best_similarity > 0.8: # 80% similarity threshold
matched_lines += 1
line_accuracy = (matched_lines / total_lines) * 100
avg_line_similarity = (sum(line_similarities) / len(line_similarities)) * 100
return {
"line_accuracy": round(line_accuracy, 2),
"average_line_similarity": round(avg_line_similarity, 2),
"lines_matched": matched_lines,
"total_lines": total_lines,
}
def calculate_language_specific_accuracy(
extracted: str, baseline: str
) -> Dict[str, Any]:
"""
Calculate accuracy for different language components (English, Bangla, Math).
"""
def classify_char(char):
if "\u0980" <= char <= "\u09ff": # Bangla unicode range
return "bangla"
elif char.isascii() and char.isalpha():
return "english"
elif char.isdigit():
return "number"
elif char in "=+-×÷∑∫√π∞∂→≤≥∝∴∵∠∆∇∀∃∈∉⊂⊃⊆⊇∪∩∧∨¬αβγδεζηθικλμνξοπρστυφχψω":
return "math"
else:
return "other"
# Analyze character distribution
extracted_chars = defaultdict(list)
baseline_chars = defaultdict(list)
for char in extracted:
char_type = classify_char(char)
extracted_chars[char_type].append(char)
for char in baseline:
char_type = classify_char(char)
baseline_chars[char_type].append(char)
language_accuracy = {}
for lang_type in ["english", "bangla", "math", "number"]:
extracted_text = "".join(extracted_chars.get(lang_type, []))
baseline_text = "".join(baseline_chars.get(lang_type, []))
if baseline_text:
char_metrics = calculate_character_accuracy(extracted_text, baseline_text)
language_accuracy[f"{lang_type}_accuracy"] = char_metrics[
"character_accuracy"
]
else:
language_accuracy[f"{lang_type}_accuracy"] = (
100.0 if not extracted_text else 0.0
)
return language_accuracy
def calculate_similarity_score(extracted: str, baseline: str) -> float:
"""
Calculate overall similarity score using sequence matcher.
"""
similarity = difflib.SequenceMatcher(
None, normalize_text(extracted), normalize_text(baseline)
).ratio()
return round(similarity * 100, 2)
def generate_detailed_diff(extracted: str, baseline: str) -> List[Dict[str, str]]:
"""
Generate a detailed diff showing insertions, deletions, and matches.
"""
extracted_norm = normalize_text(extracted)
baseline_norm = normalize_text(baseline)
differ = difflib.unified_diff(
baseline_norm.splitlines(keepends=True),
extracted_norm.splitlines(keepends=True),
fromfile="baseline",
tofile="extracted",
lineterm="",
)
diff_result = []
for line in differ:
if line.startswith("---") or line.startswith("+++") or line.startswith("@@"):
continue
elif line.startswith("-"):
content = clean_control_characters(line[1:])
diff_result.append({"type": "deletion", "content": content})
elif line.startswith("+"):
content = clean_control_characters(line[1:])
diff_result.append({"type": "insertion", "content": content})
else:
content = clean_control_characters(line)
diff_result.append({"type": "match", "content": content})
return diff_result
def evaluate_ocr_accuracy(extracted_text: str, baseline_text: str) -> Dict[str, Any]:
"""
Comprehensive OCR accuracy evaluation.
Args:
extracted_text: The text extracted by OCR
baseline_text: The ground truth text
Returns:
Dictionary containing various accuracy metrics
"""
if not extracted_text and not baseline_text:
return {"error": "Both texts are empty"}
if not baseline_text:
return {"error": "Baseline text is empty"}
# Clean control characters from input texts
extracted_text = clean_control_characters(extracted_text)
baseline_text = clean_control_characters(baseline_text)
# Calculate all metrics
char_metrics = calculate_character_accuracy(extracted_text, baseline_text)
word_metrics = calculate_word_accuracy(extracted_text, baseline_text)
line_metrics = calculate_line_accuracy(extracted_text, baseline_text)
lang_metrics = calculate_language_specific_accuracy(extracted_text, baseline_text)
similarity_score = calculate_similarity_score(extracted_text, baseline_text)
detailed_diff = generate_detailed_diff(extracted_text, baseline_text)
# Calculate overall score (weighted average)
overall_score = (
char_metrics["character_accuracy"] * 0.4
+ word_metrics["word_accuracy"] * 0.3
+ line_metrics["line_accuracy"] * 0.2
+ similarity_score * 0.1
)
result = {
"overall_accuracy": round(overall_score, 2),
"similarity_score": similarity_score,
"character_metrics": char_metrics,
"word_metrics": word_metrics,
"line_metrics": line_metrics,
"language_specific": lang_metrics,
"text_statistics": {
"extracted_length": len(extracted_text),
"baseline_length": len(baseline_text),
"extracted_words": len(extracted_text.split()),
"baseline_words": len(baseline_text.split()),
"extracted_lines": len(extracted_text.split("\n")),
"baseline_lines": len(baseline_text.split("\n")),
},
"detailed_diff": detailed_diff[:50], # Limit to first 50 diff items
"evaluation_summary": {
"grade": get_accuracy_grade(overall_score),
"recommendations": get_recommendations(
char_metrics, word_metrics, lang_metrics
),
},
}
# Clean all string values to ensure JSON safety
return safe_json_serialize(result)
def get_accuracy_grade(score: float) -> str:
"""Convert accuracy score to letter grade."""
if score >= 95:
return "A+ (Excellent)"
elif score >= 90:
return "A (Very Good)"
elif score >= 80:
return "B (Good)"
elif score >= 70:
return "C (Fair)"
elif score >= 60:
return "D (Poor)"
else:
return "F (Very Poor)"
def get_recommendations(
char_metrics: Dict, word_metrics: Dict, lang_metrics: Dict
) -> List[str]:
"""Generate recommendations based on accuracy metrics."""
recommendations = []
if char_metrics["character_accuracy"] < 80:
recommendations.append(
"Consider improving image preprocessing (noise reduction, contrast adjustment)"
)
if word_metrics["word_accuracy"] < 70:
recommendations.append(
"Word-level accuracy is low - check language model configuration"
)
if lang_metrics.get("bangla_accuracy", 100) < 80:
recommendations.append(
"Bangla text accuracy is low - ensure Bengali language pack is installed"
)
if lang_metrics.get("math_accuracy", 100) < 70:
recommendations.append(
"Mathematical expression accuracy is low - consider tuning Pix2Text parameters"
)
if lang_metrics.get("english_accuracy", 100) < 85:
recommendations.append(
"English text accuracy could be improved - check OCR engine settings"
)
if not recommendations:
recommendations.append("Excellent accuracy! No specific improvements needed.")
return recommendations