logasanjeev commited on
Commit
f93fee0
·
verified ·
1 Parent(s): 221f146

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +31 -45
inference.py CHANGED
@@ -46,22 +46,12 @@ def enhance_contrast(image):
46
  l = clahe.apply(l)
47
  return cv2.cvtColor(cv2.merge((l, a, b)), cv2.COLOR_LAB2BGR)
48
 
49
- def preprocess_mrz(image):
50
- """Special preprocessing for MRZ regions."""
51
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
52
- clahe = cv2.createCLAHE(clipLimit=5.0, tileGridSize=(8,8))
53
- gray = clahe.apply(gray)
54
- _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
55
- return cv2.cvtColor(thresh, cv2.COLOR_GRAY2BGR)
56
-
57
- def preprocess_image(image, is_mrz=False):
58
- """Applies preprocessing steps, with special handling for MRZ."""
59
  if isinstance(image, str):
60
  image = cv2.imread(image)
61
  if image is None or not isinstance(image, np.ndarray):
62
  raise ValueError("Invalid image input. Provide a valid file path or numpy array.")
63
- if is_mrz:
64
- return preprocess_mrz(image)
65
  image = upscale_image(image, scale=2)
66
  image = unblur_image(image)
67
  image = denoise_image(image)
@@ -75,10 +65,10 @@ def process_id(image_path, model_name=None, save_json=True, output_json="detecte
75
 
76
  Args:
77
  image_path (str): Path to the input image.
78
- model_name (str, optional): Specific model to use (e.g., 'Aadhaar', 'Pan_Card'). If None, uses Id_Classifier.
79
  save_json (bool): Save extracted text to JSON file.
80
  output_json (str): Path to save JSON output.
81
- verbose (bool): Display visualizations (bounding boxes, cropped images).
82
 
83
  Returns:
84
  dict: Extracted text for each detected field.
@@ -115,12 +105,11 @@ def process_id(image_path, model_name=None, save_json=True, output_json="detecte
115
  # Run inference
116
  results = model(image_path)
117
  filtered_boxes = {}
118
- class_counts = {} # Track multiple instances of the same class
119
  output_image = results[0].orig_img.copy()
120
  original_image = cv2.imread(image_path)
121
  h, w, _ = output_image.shape
122
 
123
- # Filter boxes, allowing multiple instances of the same class
124
  for result in results:
125
  if not result.boxes:
126
  logger.warning("No boxes detected in the image.")
@@ -128,13 +117,14 @@ def process_id(image_path, model_name=None, save_json=True, output_json="detecte
128
  for box in result.boxes:
129
  try:
130
  cls = int(box.cls[0].item())
 
 
 
131
  conf = box.conf[0].item()
132
  xyxy = box.xyxy[0].tolist()
133
- class_name = class_names[cls]
134
- class_counts[class_name] = class_counts.get(class_name, 0) + 1
135
- unique_class_name = f"{class_name}_{class_counts[class_name]}" if class_counts[class_name] > 1 else class_name
136
- filtered_boxes[unique_class_name] = {"conf": conf, "xyxy": xyxy, "class_name": unique_class_name}
137
- logger.info(f"Detected box for class: {unique_class_name}, confidence: {conf:.2f}")
138
  except IndexError as e:
139
  logger.error(f"Error processing box: {e}, box data: {box}")
140
  continue
@@ -142,7 +132,7 @@ def process_id(image_path, model_name=None, save_json=True, output_json="detecte
142
  # Extract text and visualize
143
  detected_text = {}
144
  processed_images = []
145
- for unique_class_name, data in filtered_boxes.items():
146
  try:
147
  x_min, y_min, x_max, y_max = map(int, data["xyxy"])
148
  class_name = data["class_name"]
@@ -155,8 +145,7 @@ def process_id(image_path, model_name=None, save_json=True, output_json="detecte
155
  if region_img.size == 0:
156
  logger.warning(f"Empty region for class: {class_name}. Skipping.")
157
  continue
158
- is_mrz = "MRZ" in class_name.upper()
159
- region_img = preprocess_image(region_img, is_mrz=is_mrz)
160
  region_h, region_w = region_img.shape[:2]
161
 
162
  # Create black canvas and center the cropped region
@@ -170,32 +159,29 @@ def process_id(image_path, model_name=None, save_json=True, output_json="detecte
170
  black_canvas[top_left_y:top_left_y+region_h, top_left_x:top_left_x+region_w] = region_img
171
 
172
  # Perform OCR
173
- ocr_result = OCR.ocr(black_canvas, cls=True) or []
174
- extracted_text = ""
175
- if ocr_result:
176
- try:
177
- extracted_text = " ".join(word_info[1][0] for line in ocr_result for word_info in line if word_info and len(word_info) > 1 and len(word_info[1]) > 0)
178
- except (IndexError, TypeError) as e:
179
- logger.error(f"Error processing OCR result for class {class_name}: {e}")
180
- extracted_text = "OCR failed"
181
- else:
182
- logger.warning(f"No OCR results for class: {class_name}")
183
- extracted_text = "No text detected"
184
-
185
  detected_text[class_name] = extracted_text
186
 
187
  # Draw OCR bounding boxes
188
  for line in ocr_result:
 
 
189
  for word_info in line:
190
- if word_info and len(word_info) > 0:
191
- try:
192
- box = word_info[0]
193
- x1, y1 = int(box[0][0]), int(box[0][1])
194
- x2, y2 = int(box[2][0]), int(box[2][1])
195
- cv2.rectangle(black_canvas, (x1, y1), (x2, y2), (0, 255, 0), 5)
196
- except (IndexError, TypeError) as e:
197
- logger.error(f"Error drawing OCR box for class {class_name}: {e}")
198
- continue
 
199
 
200
  # Save processed image
201
  processed_images.append((class_name, black_canvas, extracted_text))
 
46
  l = clahe.apply(l)
47
  return cv2.cvtColor(cv2.merge((l, a, b)), cv2.COLOR_LAB2BGR)
48
 
49
+ def preprocess_image(image):
50
+ """Applies all preprocessing steps."""
 
 
 
 
 
 
 
 
51
  if isinstance(image, str):
52
  image = cv2.imread(image)
53
  if image is None or not isinstance(image, np.ndarray):
54
  raise ValueError("Invalid image input. Provide a valid file path or numpy array.")
 
 
55
  image = upscale_image(image, scale=2)
56
  image = unblur_image(image)
57
  image = denoise_image(image)
 
65
 
66
  Args:
67
  image_path (str): Path to the input image.
68
+ model_name (str, optional): Specific model to use. If None, uses Id_Classifier.
69
  save_json (bool): Save extracted text to JSON file.
70
  output_json (str): Path to save JSON output.
71
+ verbose (bool): Display visualizations.
72
 
73
  Returns:
74
  dict: Extracted text for each detected field.
 
105
  # Run inference
106
  results = model(image_path)
107
  filtered_boxes = {}
 
108
  output_image = results[0].orig_img.copy()
109
  original_image = cv2.imread(image_path)
110
  h, w, _ = output_image.shape
111
 
112
+ # Filter highest confidence box for each class
113
  for result in results:
114
  if not result.boxes:
115
  logger.warning("No boxes detected in the image.")
 
117
  for box in result.boxes:
118
  try:
119
  cls = int(box.cls[0].item())
120
+ if cls >= len(class_names):
121
+ logger.warning(f"Invalid class index {cls} for model {model_name}. Skipping box.")
122
+ continue
123
  conf = box.conf[0].item()
124
  xyxy = box.xyxy[0].tolist()
125
+ if cls not in filtered_boxes or conf > filtered_boxes[cls]["conf"]:
126
+ filtered_boxes[cls] = {"conf": conf, "xyxy": xyxy, "class_name": class_names[cls]}
127
+ logger.info(f"Detected box for class: {class_names[cls]}, confidence: {conf:.2f}")
 
 
128
  except IndexError as e:
129
  logger.error(f"Error processing box: {e}, box data: {box}")
130
  continue
 
132
  # Extract text and visualize
133
  detected_text = {}
134
  processed_images = []
135
+ for cls, data in filtered_boxes.items():
136
  try:
137
  x_min, y_min, x_max, y_max = map(int, data["xyxy"])
138
  class_name = data["class_name"]
 
145
  if region_img.size == 0:
146
  logger.warning(f"Empty region for class: {class_name}. Skipping.")
147
  continue
148
+ region_img = preprocess_image(region_img)
 
149
  region_h, region_w = region_img.shape[:2]
150
 
151
  # Create black canvas and center the cropped region
 
159
  black_canvas[top_left_y:top_left_y+region_h, top_left_x:top_left_x+region_w] = region_img
160
 
161
  # Perform OCR
162
+ ocr_result = OCR.ocr(black_canvas, cls=True)
163
+ if ocr_result is None:
164
+ ocr_result = []
165
+ extracted_text = " ".join(
166
+ word_info[1][0] for line in ocr_result for word_info in line if word_info and len(word_info) > 1 and len(word_info[1]) > 0
167
+ ) if ocr_result else "No text detected"
 
 
 
 
 
 
168
  detected_text[class_name] = extracted_text
169
 
170
  # Draw OCR bounding boxes
171
  for line in ocr_result:
172
+ if line is None:
173
+ continue
174
  for word_info in line:
175
+ if word_info is None:
176
+ continue
177
+ try:
178
+ box = word_info[0]
179
+ x1, y1 = int(box[0][0]), int(box[0][1])
180
+ x2, y2 = int(box[2][0]), int(box[2][1])
181
+ cv2.rectangle(black_canvas, (x1, y1), (x2, y2), (0, 255, 0), 5)
182
+ except (IndexError, TypeError) as e:
183
+ logger.error(f"Error drawing OCR box for class {class_name}: {e}")
184
+ continue
185
 
186
  # Save processed image
187
  processed_images.append((class_name, black_canvas, extracted_text))