File size: 8,166 Bytes
83d04a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
"""
Optimized feature extractor for document classification.
Contains 20 most effective features including contextual patterns from neighboring lines.
"""
import numpy as np
import pandas as pd
import re
class OptimizedFeatureExtractor:
"""Extract 20 optimized features for document line classification with contextual information."""
def __init__(self):
# Keywords that suggest different document types
self.form_keywords = [
'name', 'date', 'address', 'phone', 'email', 'signature',
'number', 'ssn', 'dob', 'zip', ':', '_____'
]
self.table_keywords = [
'total', 'qty', 'quantity', 'price', 'amount', 'item',
'cost', 'subtotal', 'tax', '%', '$'
]
# Selected features (in order of importance)
self.selected_features = ['word_count', 'line_position_ratio', 'line_length', 'avg_word_length', 'column_count', 'prev_line_length', 'digit_ratio', 'next_line_length', 'uppercase_ratio', 'next_line_digit_ratio', 'next_line_word_count', 'surrounded_by_form_pattern', 'prev_line_word_count', 'prev_line_digit_ratio', 'form_keyword_count', 'special_char_count', 'next_line_form_keyword_count', 'next_line_special_char_count', 'prev_line_form_keyword_count', 'prev_line_special_char_count']
def _extract_basic_features(self, line):
"""Extract core text features for a single line."""
# Handle NaN or None values
if not line or pd.isna(line):
line = ""
else:
line = str(line) # Ensure it's a string
words = line.split()
line_lower = line.lower()
# Only compute features that are in our selected set
basic_features = {}
if 'line_length' in self.selected_features:
basic_features['line_length'] = len(line)
if 'word_count' in self.selected_features:
basic_features['word_count'] = len(words)
if 'avg_word_length' in self.selected_features:
basic_features['avg_word_length'] = len(line) / max(len(words), 1)
if 'starts_with_whitespace' in self.selected_features:
basic_features['starts_with_whitespace'] = 1 if line.startswith(' ') else 0
if 'digit_ratio' in self.selected_features:
basic_features['digit_ratio'] = sum(c.isdigit() for c in line) / max(len(line), 1)
if 'uppercase_ratio' in self.selected_features:
basic_features['uppercase_ratio'] = sum(c.isupper() for c in line) / max(sum(c.isalpha() for c in line), 1)
if 'special_char_count' in self.selected_features:
basic_features['special_char_count'] = sum(not c.isalnum() and not c.isspace() for c in line)
if 'ends_with_colon' in self.selected_features:
basic_features['ends_with_colon'] = 1 if line.strip().endswith(':') else 0
if 'has_underscore_field' in self.selected_features:
basic_features['has_underscore_field'] = 1 if '___' in line else 0
if 'is_all_caps' in self.selected_features:
basic_features['is_all_caps'] = 1 if line.isupper() and len(line.strip()) > 1 else 0
if 'has_currency' in self.selected_features:
basic_features['has_currency'] = 1 if '$' in line else 0
if 'has_percentage' in self.selected_features:
basic_features['has_percentage'] = 1 if '%' in line else 0
if 'has_email_pattern' in self.selected_features:
basic_features['has_email_pattern'] = 1 if '@' in line and '.' in line else 0
if 'has_phone_pattern' in self.selected_features:
basic_features['has_phone_pattern'] = 1 if re.search(r'\d{3}[-.\s]?\d{3}[-.\s]?\d{4}', line) else 0
if 'column_count' in self.selected_features:
basic_features['column_count'] = len(re.split(r'\s{2,}|\t', line.strip()))
if 'form_keyword_count' in self.selected_features:
basic_features['form_keyword_count'] = sum(1 for word in self.form_keywords if word in line_lower)
if 'table_keyword_count' in self.selected_features:
basic_features['table_keyword_count'] = sum(1 for word in self.table_keywords if word in line_lower)
return basic_features
def extract_features_for_line(self, line, all_lines=None, line_index=0):
"""Extract features for a line including previous/next line context."""
# Get basic features for current line
features = self._extract_basic_features(line)
# Add positional features if selected
if 'line_position_ratio' in self.selected_features:
features['line_position_ratio'] = line_index / max(len(all_lines), 1) if all_lines else 0
if 'is_near_start' in self.selected_features:
features['is_near_start'] = 1 if all_lines and (line_index / max(len(all_lines), 1)) < 0.1 else 0
if 'is_near_end' in self.selected_features:
features['is_near_end'] = 1 if all_lines and (line_index / max(len(all_lines), 1)) > 0.9 else 0
# Add contextual features if selected and available
if all_lines and len(all_lines) > 1:
# Previous line features
if line_index > 0:
prev_line = all_lines[line_index - 1]
prev_features = self._extract_basic_features(prev_line)
for feat_name, feat_value in prev_features.items():
prev_feat_name = f'prev_{feat_name}'
if prev_feat_name in self.selected_features:
features[prev_feat_name] = feat_value
# Next line features
if line_index < len(all_lines) - 1:
next_line = all_lines[line_index + 1]
next_features = self._extract_basic_features(next_line)
for feat_name, feat_value in next_features.items():
next_feat_name = f'next_{feat_name}'
if next_feat_name in self.selected_features:
features[next_feat_name] = feat_value
# Contextual pattern features
if 'follows_label_pattern' in self.selected_features:
features['follows_label_pattern'] = 1 if line_index > 0 and \
self._extract_basic_features(all_lines[line_index - 1]).get('ends_with_colon', 0) and \
features.get('line_length', 0) < 50 else 0
if 'precedes_input_pattern' in self.selected_features:
features['precedes_input_pattern'] = 1 if line_index < len(all_lines) - 1 and \
features.get('ends_with_colon', 0) and \
self._extract_basic_features(all_lines[line_index + 1]).get('has_underscore_field', 0) else 0
if 'surrounded_by_form_pattern' in self.selected_features:
features['surrounded_by_form_pattern'] = 1 if line_index > 0 and line_index < len(all_lines) - 1 and \
(self._extract_basic_features(all_lines[line_index - 1]).get('form_keyword_count', 0) > 0 or \
self._extract_basic_features(all_lines[line_index + 1]).get('form_keyword_count', 0) > 0) else 0
# Fill missing features with 0
for feat_name in self.selected_features:
if feat_name not in features:
features[feat_name] = 0
return features
def extract_features_for_document(self, lines):
"""Extract feature matrix for all lines in a document."""
if not lines:
return np.array([]), []
all_features = []
for i, line in enumerate(lines):
features = self.extract_features_for_line(line, lines, i)
# Convert to list in consistent order
feature_vector = [features[key] for key in sorted(self.selected_features)]
all_features.append(feature_vector)
feature_names = sorted(self.selected_features)
return np.array(all_features), feature_names
|