math-ocr / main6_pix2text.py
Ash2749's picture
Upload 5 files
1a3e965 verified
import cv2
import pytesseract
from pytesseract import Output
from pdf2image import convert_from_path
import numpy as np
import json
from tqdm import tqdm
import unicodedata
from collections import defaultdict
from PIL import Image
import logging
try:
from pix2text import Pix2Text
PIX2TEXT_AVAILABLE = True
print("Pix2Text imported successfully for advanced math extraction")
except ImportError:
PIX2TEXT_AVAILABLE = False
print("Pix2Text not available. Install with: pip install pix2text")
print(" Falling back to traditional OCR for math expressions")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ----------------------------
# STEP 1: Enhanced Character Classification
# ----------------------------
def classify_character(char):
"""
Classify a single character as English, Bangla, Math, or Other.
Enhanced for better math detection.
"""
if not char or char.isspace():
return "space"
# Unicode ranges for Bangla
if "\u0980" <= char <= "\u09ff": # Bangla unicode range
return "bangla"
# Enhanced mathematical symbols and operators
math_chars = set(
"=+-×÷∑∫√π∞∂→≤≥∝∴∵∠∆∇∀∃∈∉⊂⊃⊆⊇∪∩∧∨¬"
"αβγδεζηθικλμνξοπρστυφχψωΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩ"
"±≈≠≡⇒⇔∘∗⊕⊗⊙⊥∥∦∝∞"
)
# Extended math ranges
math_ranges = [
("\u2200", "\u22ff"), # Mathematical Operators
("\u2190", "\u21ff"), # Arrows
("\u0370", "\u03ff"), # Greek and Coptic
("\u2070", "\u209f"), # Superscripts and Subscripts
("\u27c0", "\u27ef"), # Miscellaneous Mathematical Symbols-A
("\u2980", "\u29ff"), # Miscellaneous Mathematical Symbols-B
]
if char in math_chars:
return "math"
for start, end in math_ranges:
if start <= char <= end:
return "math"
# Numbers (also often mathematical)
if char.isdigit():
return "number"
# English letters
if char.isascii() and char.isalpha():
return "english"
# Mathematical punctuation
if char in ".,;:!?()[]{}\"'-_/\\^":
return "punctuation"
return "other"
def classify_text_region(text):
"""
Enhanced text region classification with better math detection.
"""
if not text.strip():
return "empty"
char_counts = defaultdict(int)
for char in text:
char_type = classify_character(char)
char_counts[char_type] += 1
# Remove spaces from consideration
significant_chars = {k: v for k, v in char_counts.items() if k not in ["space"]}
if not significant_chars:
return "empty"
total_significant = sum(significant_chars.values())
percentages = {k: v / total_significant for k, v in significant_chars.items()}
# Enhanced classification logic
math_indicators = percentages.get("math", 0) + percentages.get("number", 0) * 0.5
if percentages.get("bangla", 0) > 0.5:
return "bangla"
elif math_indicators > 0.3 or has_math_patterns(text):
return "math"
elif percentages.get("english", 0) > 0.5:
return "english"
else:
return "mixed"
def has_math_patterns(text):
"""
Detect mathematical patterns in text using regex and heuristics.
"""
import re
# Common mathematical patterns
math_patterns = [
r"\d+[\+\-\*/=]\d+", # Simple arithmetic
r"[xy]\^?\d+", # Variables with powers
r"\\[a-zA-Z]+", # LaTeX commands
r"\$.*?\$", # LaTeX inline math
r"[a-zA-Z]\([a-zA-Z,\d\s]+\)", # Functions like f(x)
r"\b(sin|cos|tan|log|ln|exp|sqrt|int|sum|lim)\b", # Math functions
r"[≤≥≠≈∫∑∂∞]", # Math symbols
]
for pattern in math_patterns:
if re.search(pattern, text, re.IGNORECASE):
return True
return False
# ----------------------------
# STEP 2: Initialize Pix2Text
# ----------------------------
def initialize_pix2text():
"""Initialize Pix2Text model for mathematical expression extraction."""
if not PIX2TEXT_AVAILABLE:
return None
try:
# Initialize Pix2Text with specific configuration for math
# Try different initialization methods
logger.info("Initializing Pix2Text...")
# Method 1: Default initialization
try:
p2t = Pix2Text.from_config()
logger.info("✅ Pix2Text initialized with default config")
return p2t
except Exception as e1:
logger.warning(f"Default Pix2Text init failed: {e1}")
# Method 2: Try with specific config
try:
p2t = Pix2Text()
logger.info("✅ Pix2Text initialized with basic constructor")
return p2t
except Exception as e2:
logger.warning(f"Basic Pix2Text init failed: {e2}")
# Method 3: Try with minimal config
try:
config = {"device": "cpu"} # Force CPU to avoid CUDA issues
p2t = Pix2Text.from_config(config)
logger.info("✅ Pix2Text initialized with CPU config")
return p2t
except Exception as e3:
logger.error(f"All Pix2Text initialization methods failed: {e3}")
return None
except Exception as e:
logger.error(f"❌ Failed to initialize Pix2Text: {e}")
return None
# ----------------------------
# STEP 3: Enhanced Image Preprocessing
# ----------------------------
def preprocess_image_advanced(pil_image):
"""Enhanced image preprocessing with multiple techniques."""
img = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# Noise reduction
gray = cv2.fastNlMeansDenoising(gray, h=15)
# Adaptive thresholding for better text separation
binary = cv2.adaptiveThreshold(
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 15, 5
)
# Enhance contrast
enhanced = cv2.convertScaleAbs(binary, alpha=1.2, beta=10)
# Scale up for better OCR accuracy
height, width = enhanced.shape
scaled = cv2.resize(
enhanced, (width * 2, height * 2), interpolation=cv2.INTER_CUBIC
)
return scaled
def preprocess_for_pix2text(pil_image, region):
"""
Special preprocessing for Pix2Text mathematical expression extraction.
"""
# Convert PIL to numpy array
img = np.array(pil_image)
# Crop the specific region
x, y, w, h = region["left"], region["top"], region["width"], region["height"]
# Validate region dimensions
if w <= 0 or h <= 0:
logger.warning(f"Invalid region dimensions: w={w}, h={h}. Skipping Pix2Text.")
return None
# Add padding around the math region for better recognition
padding = 10
x_start = max(0, x - padding)
y_start = max(0, y - padding)
x_end = min(img.shape[1], x + w + padding)
y_end = min(img.shape[0], y + h + padding)
# Validate cropping bounds
if x_end <= x_start or y_end <= y_start:
logger.warning(
f"Invalid crop bounds: x({x_start}:{x_end}), y({y_start}:{y_end}). Skipping Pix2Text."
)
return None
cropped = img[y_start:y_end, x_start:x_end]
# Check if crop resulted in empty image
if cropped.size == 0:
logger.warning("Cropped image is empty. Skipping Pix2Text.")
return None
# Convert back to PIL Image
try:
cropped_pil = Image.fromarray(cropped)
except Exception as e:
logger.error(f"Failed to create PIL image from cropped array: {e}")
return None
# Ensure minimum size for Pix2Text
min_size = 32
if cropped_pil.width <= 0 or cropped_pil.height <= 0:
logger.warning(
f"Invalid PIL image dimensions: {cropped_pil.width}x{cropped_pil.height}"
)
return None
if cropped_pil.width < min_size or cropped_pil.height < min_size:
# Resize maintaining aspect ratio
try:
ratio = max(min_size / cropped_pil.width, min_size / cropped_pil.height)
new_width = int(cropped_pil.width * ratio)
new_height = int(cropped_pil.height * ratio)
# Ensure new dimensions are valid
if new_width <= 0 or new_height <= 0:
logger.warning(f"Invalid resized dimensions: {new_width}x{new_height}")
return None
cropped_pil = cropped_pil.resize((new_width, new_height), Image.LANCZOS)
except Exception as e:
logger.error(f"Failed to resize image: {e}")
return None
return cropped_pil
# ----------------------------
# STEP 4: Text Detection and Line Segmentation
# ----------------------------
def detect_text_regions(image):
"""Detect text regions and classify them by line and character type."""
data = pytesseract.image_to_data(image, output_type=Output.DICT, lang="eng+ben")
text_regions = []
for i in range(len(data["text"])):
text = data["text"][i].strip()
if text and int(data["conf"][i]) > 25: # Lowered threshold for math
# Validate region dimensions
width = int(data["width"][i])
height = int(data["height"][i])
left = int(data["left"][i])
top = int(data["top"][i])
# Skip regions with invalid dimensions
if width <= 0 or height <= 0:
logger.debug(
f"Skipping region with invalid dimensions: {width}x{height}"
)
continue
# Skip regions that are too small to be meaningful
if width < 3 or height < 3:
logger.debug(f"Skipping tiny region: {width}x{height}")
continue
region = {
"text": text,
"left": left,
"top": top,
"width": width,
"height": height,
"confidence": int(data["conf"][i]),
"type": classify_text_region(text),
}
text_regions.append(region)
logger.info(f"Detected {len(text_regions)} valid text regions")
return text_regions
def group_regions_by_line(regions, line_tolerance=15):
"""Group text regions into lines with better tolerance for math expressions."""
if not regions:
return []
regions_sorted = sorted(regions, key=lambda x: x["top"])
lines = []
current_line = [regions_sorted[0]]
current_top = regions_sorted[0]["top"]
for region in regions_sorted[1:]:
# More flexible line grouping for mathematical expressions
# Handle zero heights safely
current_height = max(1, current_line[0]["height"]) # Avoid division by zero
region_height = max(1, region["height"]) # Avoid division by zero
height_avg = (current_height + region_height) / 2
tolerance = max(line_tolerance, height_avg * 0.3)
if abs(region["top"] - current_top) <= tolerance:
current_line.append(region)
else:
current_line.sort(key=lambda x: x["left"])
lines.append(current_line)
current_line = [region]
current_top = region["top"]
if current_line:
current_line.sort(key=lambda x: x["left"])
lines.append(current_line)
return lines
# ----------------------------
# STEP 5: Advanced OCR Extractors
# ----------------------------
def extract_english_region(image, region):
"""Extract English text from a specific region with optimized settings."""
x, y, w, h = region["left"], region["top"], region["width"], region["height"]
roi = image[y : y + h, x : x + w]
if roi.size == 0:
return region["text"]
config = r"--oem 3 --psm 8 -l eng"
try:
result = pytesseract.image_to_string(roi, config=config).strip()
return result if result else region["text"]
except Exception:
return region["text"]
def extract_bangla_region(image, region):
"""Extract Bangla text from a specific region with optimized settings."""
x, y, w, h = region["left"], region["top"], region["width"], region["height"]
roi = image[y : y + h, x : x + w]
if roi.size == 0:
return region["text"]
config = r"--oem 3 --psm 8 -l ben"
try:
result = pytesseract.image_to_string(roi, config=config).strip()
return result if result else region["text"]
except Exception:
return region["text"]
def extract_math_region_pix2text(pil_image, region, p2t_model):
"""
Extract mathematical expressions using Pix2Text with fallback to traditional OCR.
"""
if not p2t_model:
return extract_math_region_traditional(pil_image, region)
try:
# Preprocess image for Pix2Text
math_image = preprocess_for_pix2text(pil_image, region)
# If preprocessing failed, fall back to traditional OCR
if math_image is None:
logger.warning(
"Pix2Text preprocessing failed, falling back to traditional OCR"
)
return extract_math_region_traditional(pil_image, region)
# Use Pix2Text to extract mathematical expressions
result = p2t_model(math_image)
# Enhanced result parsing to handle different Pix2Text response formats
extracted_text = parse_pix2text_result(result)
if extracted_text and extracted_text.strip():
# Filter out invalid responses
if not is_valid_pix2text_result(extracted_text):
logger.warning(f"Invalid Pix2Text result: {extracted_text[:100]}...")
return extract_math_region_traditional(pil_image, region)
logger.info(f"✅ Pix2Text extracted: {extracted_text[:50]}...")
return extracted_text.strip()
else:
logger.warning(
"⚠️ Pix2Text returned empty result, falling back to traditional OCR"
)
return extract_math_region_traditional(pil_image, region)
except Exception as e:
logger.error(f"❌ Pix2Text extraction failed: {e}")
return extract_math_region_traditional(pil_image, region)
def parse_pix2text_result(result):
"""
Parse Pix2Text result handling various response formats.
"""
try:
if isinstance(result, dict):
# Handle different Pix2Text response formats
# Try common keys for mathematical content
for key in ["text", "formula", "latex", "content", "output"]:
if key in result and result[key]:
return str(result[key])
# If no specific key found, convert entire dict to string
# but filter out obviously bad content
result_str = str(result)
if len(result_str) > 1000: # Too long, likely debug info
return ""
return result_str
elif isinstance(result, list):
# Handle list responses
if not result:
return ""
# Join list elements that look like mathematical content
valid_items = []
for item in result:
item_str = str(item).strip()
if item_str and not is_debug_content(item_str):
valid_items.append(item_str)
return " ".join(valid_items)
elif isinstance(result, str):
return result
else:
return str(result)
except Exception as e:
logger.error(f"Error parsing Pix2Text result: {e}")
return ""
def is_valid_pix2text_result(text):
"""
Check if the Pix2Text result is valid mathematical content.
"""
if not text or not text.strip():
return False
text = text.strip()
# Filter out obvious debug/error content
invalid_patterns = [
"Page(id=",
"elements=[]",
"number=0",
"Error:",
"Exception:",
"Traceback:",
"DEBUG:",
"INFO:",
"WARNING:",
"ERROR:",
]
for pattern in invalid_patterns:
if pattern in text:
return False
# Must have some reasonable length for math content
if len(text) < 1:
return False
# Should contain some mathematical or textual content
# Allow mathematical symbols, letters, numbers, basic punctuation
import re
if re.search(r"[a-zA-Z0-9=+\-*/(){}[\]^_√∫∑∂πθαβγδλμΩ]", text):
return True
return False
def is_debug_content(text):
"""
Check if text appears to be debug/logging content rather than actual content.
"""
debug_indicators = [
"Page(",
"id=",
"number=",
"elements=",
"[])",
"DEBUG",
"INFO",
"WARNING",
"ERROR",
"Exception",
"Traceback",
'File "',
"line ",
" at 0x",
]
for indicator in debug_indicators:
if indicator in text:
return True
return False
def extract_math_region_traditional(pil_image, region):
"""
Fallback traditional OCR for mathematical expressions.
"""
# Convert PIL to OpenCV format
img = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
x, y, w, h = region["left"], region["top"], region["width"], region["height"]
roi = gray[y : y + h, x : x + w]
if roi.size == 0:
return region["text"]
# Math-optimized OCR with expanded symbol whitelist
math_chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz=+-×÷∑∫√π∞∂→≤≥∝∴∵∠∆∇()[]{}.,;:^_αβγδλμθΩ±≈≠≡⇒⇔"
config = f"--oem 3 --psm 6 -c tessedit_char_whitelist={math_chars}"
try:
result = pytesseract.image_to_string(roi, config=config).strip()
return result if result else region["text"]
except Exception:
return region["text"]
def extract_mixed_region(pil_image, region, p2t_model):
"""Extract mixed content using multiple approaches."""
# Convert PIL to OpenCV for traditional OCR
img = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
eng_result = extract_english_region(gray, region)
bangla_result = extract_bangla_region(gray, region)
# If it might contain math, try Pix2Text too
if has_math_patterns(region["text"]):
math_result = extract_math_region_pix2text(pil_image, region, p2t_model)
# Choose the longest non-empty result
results = [r for r in [eng_result, bangla_result, math_result] if r.strip()]
return max(results, key=len) if results else region["text"]
# Choose between English and Bangla
return bangla_result if len(bangla_result) > len(eng_result) else eng_result
# ----------------------------
# STEP 6: Character Analysis (unchanged)
# ----------------------------
def analyze_character_by_character(text):
"""Analyze text character by character to identify language patterns."""
analysis = {
"characters": [],
"language_segments": [],
"total_chars": len(text),
"language_distribution": defaultdict(int),
}
for i, char in enumerate(text):
char_type = classify_character(char)
analysis["characters"].append(
{
"char": char,
"position": i,
"type": char_type,
"unicode_name": unicodedata.name(char, "UNKNOWN"),
}
)
analysis["language_distribution"][char_type] += 1
# Create language segments
current_segment = None
for char_info in analysis["characters"]:
if char_info["type"] in ["space", "punctuation"]:
continue
if current_segment is None or current_segment["type"] != char_info["type"]:
if current_segment:
analysis["language_segments"].append(current_segment)
current_segment = {
"type": char_info["type"],
"start": char_info["position"],
"end": char_info["position"],
"text": char_info["char"],
}
else:
current_segment["end"] = char_info["position"]
current_segment["text"] += char_info["char"]
if current_segment:
analysis["language_segments"].append(current_segment)
return analysis
# ----------------------------
# STEP 7: Main Processing Pipeline
# ----------------------------
def process_page_advanced(page_image, page_num, p2t_model):
"""
Advanced page processing with Pix2Text integration.
"""
print(f"Processing page {page_num + 1}...")
# Preprocess image
processed_image = preprocess_image_advanced(page_image)
# Detect text regions
regions = detect_text_regions(processed_image)
# Group regions by lines
lines = group_regions_by_line(regions)
page_results = []
for line_num, line in enumerate(lines):
line_text_parts = []
for region in line:
# Choose appropriate extractor based on region type
if region["type"] == "english":
extracted_text = extract_english_region(processed_image, region)
elif region["type"] == "bangla":
extracted_text = extract_bangla_region(processed_image, region)
elif region["type"] == "math":
extracted_text = extract_math_region_pix2text(
page_image, region, p2t_model
)
elif region["type"] == "mixed":
extracted_text = extract_mixed_region(page_image, region, p2t_model)
else:
extracted_text = region["text"]
# Character-by-character analysis
char_analysis = analyze_character_by_character(extracted_text)
region_result = {
"page": page_num,
"line": line_num,
"text": extracted_text,
"original_text": region["text"],
"position": {
"left": region["left"],
"top": region["top"],
"width": region["width"],
"height": region["height"],
},
"confidence": region["confidence"],
"detected_type": region["type"],
"extraction_method": "pix2text"
if region["type"] == "math" and p2t_model
else "tesseract",
"character_analysis": char_analysis,
}
page_results.append(region_result)
line_text_parts.append(extracted_text)
# Log line information
if line_text_parts:
line_text = " ".join(line_text_parts)
print(f" Line {line_num + 1}: {line_text[:100]}...")
return page_results
def extract_all_text_advanced_pix2text(
pdf_path, output_text_file, output_json_file, output_analysis_file
):
"""
Advanced text extraction with Pix2Text integration.
"""
print("[INFO] Initializing Pix2Text for mathematical expression extraction...")
p2t_model = initialize_pix2text()
if p2t_model:
print("✅ Pix2Text ready for advanced math extraction")
else:
print("⚠️ Using traditional OCR for math expressions")
print("[INFO] Converting PDF to images...")
pages = convert_from_path(pdf_path, dpi=300)
all_results = []
combined_text_parts = []
for page_num, page_image in enumerate(tqdm(pages, desc="Processing pages")):
page_results = process_page_advanced(page_image, page_num, p2t_model)
all_results.extend(page_results)
# Build page text
page_text_parts = [result["text"] for result in page_results]
page_text = " ".join(page_text_parts)
combined_text_parts.append(page_text)
# Combine all text
final_text = "\n\n".join(combined_text_parts)
# Save text file
with open(output_text_file, "w", encoding="utf-8") as f:
f.write(final_text)
# Save detailed JSON results
with open(output_json_file, "w", encoding="utf-8") as f:
json.dump(all_results, f, ensure_ascii=False, indent=2)
# Create summary analysis
summary_analysis = create_extraction_summary(all_results)
with open(output_analysis_file, "w", encoding="utf-8") as f:
json.dump(summary_analysis, f, ensure_ascii=False, indent=2)
print("\n[✅] Advanced Pix2Text extraction complete!")
print(f"→ Text file saved to: {output_text_file}")
print(f"→ Detailed JSON saved to: {output_json_file}")
print(f"→ Analysis report saved to: {output_analysis_file}")
# Print summary
print("\n📊 Extraction Summary:")
print(f" Total text regions: {len(all_results)}")
print(f" English regions: {summary_analysis['type_distribution']['english']}")
print(f" Bangla regions: {summary_analysis['type_distribution']['bangla']}")
print(f" Math regions: {summary_analysis['type_distribution']['math']}")
print(f" Mixed regions: {summary_analysis['type_distribution']['mixed']}")
# Show extraction method statistics
method_stats = defaultdict(int)
for result in all_results:
method_stats[result.get("extraction_method", "unknown")] += 1
print("\n🔧 Extraction Methods Used:")
for method, count in method_stats.items():
print(f" {method}: {count} regions")
def create_extraction_summary(results):
"""Create a comprehensive summary of the extraction results."""
summary = {
"total_regions": len(results),
"total_pages": len(set(r["page"] for r in results)),
"type_distribution": defaultdict(int),
"character_distribution": defaultdict(int),
"confidence_stats": {"min": 100, "max": 0, "avg": 0},
"language_segments_summary": defaultdict(int),
"extraction_methods": defaultdict(int),
}
total_confidence = 0
for result in results:
summary["type_distribution"][result["detected_type"]] += 1
summary["extraction_methods"][result.get("extraction_method", "unknown")] += 1
conf = result["confidence"]
total_confidence += conf
summary["confidence_stats"]["min"] = min(
summary["confidence_stats"]["min"], conf
)
summary["confidence_stats"]["max"] = max(
summary["confidence_stats"]["max"], conf
)
# Character distribution
char_analysis = result["character_analysis"]
for char_type, count in char_analysis["language_distribution"].items():
summary["character_distribution"][char_type] += count
# Language segments
for segment in char_analysis["language_segments"]:
summary["language_segments_summary"][segment["type"]] += 1
if results:
summary["confidence_stats"]["avg"] = total_confidence / len(results)
return summary
# ----------------------
# MAIN EXECUTION SECTION
# ----------------------
if __name__ == "__main__":
pdf_path = r"math102.pdf"
output_text_file = "math102_pix2text.txt"
output_json_file = "math102_pix2text.json"
output_analysis_file = "math102_pix2text_analysis.json"
extract_all_text_advanced_pix2text(
pdf_path, output_text_file, output_json_file, output_analysis_file
)