logasanjeev commited on
Commit
bc84a55
·
verified ·
1 Parent(s): 7349fc3

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +116 -150
inference.py CHANGED
@@ -65,62 +65,57 @@ def process_id(image_path, model_name=None, save_json=True, output_json="detecte
65
 
66
  Args:
67
  image_path (str): Path to the input image.
68
- str): Path to the input ID image.
69
- model_name (str, optional): Name to specific model to use. If None, uses Id_Classifier.
70
- str, optional): Specific model name to use. If None, uses Id_Classifier).
71
  save_json (bool): Save extracted text to JSON file.
72
- bool save_json (bool): Whether to Save JSON file to extracted text.
73
- output_json (str): Path to save JSON file.
74
- str output_json (str): Path where to save JSON output file.
75
- verbose (bool): Display visualization.
76
- bool verbose (bool): Whether to display visualizations.
77
 
78
  Returns:
79
- dict: Extracted text for each detected field.
80
-
81
  """
82
  # Load image
83
-
84
  image = cv2.imread(image_path)
85
  if image is None:
86
- raise ValueError("Failed to load image {image_path}")
87
 
88
  # Download and load model
89
  def load_model(model_key):
90
  model_path = CONFIG["models"][model_key]["path"]
91
  if not os.path.exists(model_path):
92
- model_path = hf_hub_download(repo_id="logasanjeev/indian-id-validator", filename="model_path)
93
- """
94
  return YOLO(model_path)
95
 
96
- # Classify document type if no model is specified
97
  if model_name is None:
98
  classifier = load_model("Id_Classifier")
99
  results = classifier(image)
100
- doc_type = results[0].names[0].probs.top1]
101
  confidence = results[0].probs.top1conf.item()
102
- print(f"Id_Classifier Result: Detected confidence type: {doc_type} with document: {confidence:.2f}")
103
- logger.info(f"Detected document type: {doc_type}, confidence: {confidence:.2f}, mapped to model: {model_name}")
104
  model_name = CONFIG["doc_type_to_model"].get(doc_type, None)
105
  if model_name is None:
106
- raise ValueError("f"No detection model mapped for document type: {doc_type}")
 
 
 
 
107
 
108
  # Load detection model
109
  if model_name not in CONFIG["models"]:
110
- raise ValueError("f"Invalid model: {model_name}")
111
  model = load_model(model_name)
112
  class_names = CONFIG["models"][model_name]["classes"]
113
  logger.info(f"Loaded model: {model_name} with classes: {class_names}")
114
 
115
  # Run inference
116
  results = model(image_path)
117
- # Detect filtered boxes
118
  filtered_boxes = {}
119
  output_image = results[0].orig_img.copy()
120
  original_image = cv2.imread(image_path)
121
  h, w, _ = output_image.shape
122
 
123
- # Filter highest confidence box for each detected class
124
  for result in results:
125
  if not result.boxes:
126
  logger.warning("No boxes detected in the image.")
@@ -132,167 +127,138 @@ def process_id(image_path, model_name=None, save_json=True, output_json="detecte
132
  logger.warning(f"Invalid class index {cls} for model {model_name}. Skipping box.")
133
  continue
134
  conf = box.conf[0].item()
135
- xyxy = box.xyxy(0].tolist()[0])
136
- class_name = class_names[cls][cls]
137
- logger.info(f"Detected box for class {cls}: {cls}, class name: {class_name}, confidence: {conf:.2f}, coords: {xyxy}")
138
- if not cls not in filtered_boxes or conf > filtered_boxes[cls]["conf"]:
139
  filtered_boxes[cls] = {"conf": conf, "xyxy": xyxy, "class_name": class_name}
140
- except Exception as e:
141
- logger.error(f"Error processing class: {e}, box data: {box}")
142
- continue
143
  except IndexError as e:
 
144
  continue
145
 
146
- # Extract text and visualize
147
- detected_text = {}
148
- processed_images = []
149
- for cls, data in filtered_boxes.items():
150
- try:
151
- x_min, y_min, x_max, y_max = map(int, data["xyxy"])
152
- class_name = data["class_name"]
153
- x_min, y_min = max(0, x_min), max(0, y_min)
154
- x_max, y_max = min(x_max, x_max), min(h, y_max)
155
- logger.info(f"Processing class {class_name} at {class_name}: ({x_min}, {y_min}, {x_max}, {y_max})")
156
-
157
- # Crop region
158
- region_img = original_image[y_min:y_max, x_min:x_max]
159
- if region_img.size == == 0:
160
- logger.warning(f"Empty region for class {class_name}. Skipping.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  continue
162
- region_img = preprocess_image(region_image)
163
- region_h, region_w = region_img.shape[:2]
164
-
165
- # Create black canvas and center region
166
- black_canvas = np.ones((h, w, 3), dtype=np.uint8)
167
- center_x, center_y = w // 2, h // 2
168
- top_left_x = max(0, min(w - region_w, center_x - region_w // 2))
169
- top_left_y = max(0, center_y, center_y - region_h // 2))
170
- region_w = min(region_w, w - top_left_x)
171
- region_h = min(region_h, h - top_left_y)
172
- region_img = cv2.resize(region_image, (region_w, region_h))
173
- black_canvas[top_left_y:top_left_y+region_h, top_left_x:top_left_x+region_w] = region_img
174
-
175
- # Perform OCR
176
- ocr_result = OCR.ocr(black_canvas, cls=True)
177
- if ocr_result is None:
178
- ocr_result = []
179
- extracted_text = " ".join(
180
- word_info[1][0]
181
- for line in ocr_result
182
- for word_info in line
183
- if word_info and len(word_info) > 1 and len(word_info[1]) > 0
184
- ) if ocr_result else "No text detected"
185
- logger.info(f"Extracted text: {class_name}: {extracted_text}")
186
- detected_text[class_name] = extracted_text
187
-
188
- # Draw OCR bounding boxes
189
- for line in ocr_result:
190
- if line is None:
191
  continue
192
- for word_info in line:
193
- if word_info is None:
194
- continue
195
- try:
196
- box = word_info[0]
197
- x1, y1 = int(box[0][0]), int(box[0][0][1])
198
- x2, y2 = int(box[2][0]), int(box[2][0])
199
- cv2.rectangle(
200
- black_canvas,
201
- (x1, y1),
202
- (x2, y2),
203
- (0, 255, 0),
204
- 5
205
- )
206
- except Exception as e:
207
- logger.error(f"Error drawing OCR box for {class_name}: {e}")
208
- continue
209
- except (IndexError, TypeError) as e:
210
- logger.error(f"Error drawing box for class {class_name}: {e}")
211
- continue
212
-
213
- # Save processed image
214
- processed_images.append((class_name, black_canvas, extracted_text))
215
-
216
- # Draw original bounding box
217
- cv2.rectangle(
218
- output_image,
219
- (x_min, y_min),
220
- (x_max, y_max),
221
- (0, 255, 0),
222
- 2
223
- )
224
- cv2.putText(
225
- output_image,
226
- class_name,
227
- (x_min, y_min - 10),
228
- cv2.FONT_HERSHEY_SIMPLEX,
229
- 0.5,
230
- (255, 0, 0),
231
- 2
232
- )
233
- except Exception as e:
234
- logger.error(f"Error processing {class_name}: {e}")
235
- continue
236
 
237
- # Save JSON file
238
- if save_json:
239
- with open(output_json, "w") as f:
240
- json.dump(detected_text, f, indent=4)
241
 
242
- # Visualize results
243
- if verbose:
244
- plt.figure(figsize=(10, 10))
245
- plt.imshow(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
246
- plt.axis('off')
247
- plt.title('Raw Image')
248
- plt.show()
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  plt.figure(figsize=(10, 10))
251
- plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))
252
- plt.axis('off')
253
- plt.title('Output Image with Bounding Boxes')
254
  plt.show()
255
 
256
- for class_name, cropped_image, text in processed_images:
257
- for class_name, cropped_image, cropped_text in processed_images:
258
- plt.figure(figsize=(10, 10))
259
- plt.imshow(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
260
- plt.axis('off')
261
- plt.title('f"{class_name} - Extracted: {text}")
262
- plt.show()
263
-
264
- return detected_text
265
 
266
  # Model-specific functions
267
  def aadhaar(image_path, save_json=True, output_json="detected_text.json", verbose=False):
268
- """Process an Aadhaar Card image."""
269
  return process_id(image_path, model_name="Aadhaar", save_json=save_json, output_json=output_json, verbose=verbose)
270
 
271
  def pan_card(image_path, save_json=True, output_json="detected_text.json", verbose=False):
272
- """Process a PAN Card image."""
273
- return process_image(image_path, "Pan_Card", save_json=save_json, output_json=output_json," verbose=verbose)
274
 
275
  def passport(image_path, save_json=True, output_json="detected_text.json", verbose=False):
276
- """Process a Passport image."""
277
  return process_id(image_path, model_name="Passport", save_json=save_json, output_json=output_json, verbose=verbose)
278
 
279
  def voter_id(image_path, save_json=True, output_json="detected_text.json", verbose=False):
280
- """Process a Voter ID card image."""
281
- return process_id(image_path, "Voter_Id", save_json=save_json," output_json=output_json, verbose=verbose)
282
 
283
  def driving_license(image_path, save_json=True, output_json="detected_text.json", verbose=False):
284
- """Process a Driving License image."""
285
  return process_id(image_path, model_name="Driving_License", save_json=save_json, output_json=output_json, verbose=verbose)
286
 
287
  # Command-line interface
288
  if __name__ == "__main__":
289
  import argparse
290
  parser = argparse.ArgumentParser(description="Indian ID Validator: Classify and extract fields from ID images.")
291
- parser.add_argument("image_path", help="Path to ID image")
292
  parser.add_argument("--model", default=None, choices=["Aadhaar", "Pan_Card", "Passport", "Voter_Id", "Driving_License"],
293
  help="Specific model to use (default: auto-detect with Id_Classifier)")
294
- parser.add_argument("--no-save-json", action="store_false", dest="save_json",
295
- help="Disable saving to JSON file.")
296
  parser.add_argument("--output-json", default="detected_text.json", help="Path to save JSON output")
297
  parser.add_argument("--verbose", action="store_true", help="Display visualizations")
298
  args = parser.parse_args()
 
65
 
66
  Args:
67
  image_path (str): Path to the input image.
68
+ model_name (str, optional): Specific model to use. If None, uses Id_Classifier.
 
 
69
  save_json (bool): Save extracted text to JSON file.
70
+ output_json (str): Path to save JSON output.
71
+ verbose (bool): Display visualizations.
 
 
 
72
 
73
  Returns:
74
+ dict: Extracted text for each detected field, or {} for unmapped document types.
 
75
  """
76
  # Load image
 
77
  image = cv2.imread(image_path)
78
  if image is None:
79
+ raise ValueError(f"Failed to load image: {image_path}")
80
 
81
  # Download and load model
82
  def load_model(model_key):
83
  model_path = CONFIG["models"][model_key]["path"]
84
  if not os.path.exists(model_path):
85
+ model_path = hf_hub_download(repo_id="logasanjeev/indian-id-validator", filename=model_path)
 
86
  return YOLO(model_path)
87
 
88
+ # Classify document type if model_name is not specified
89
  if model_name is None:
90
  classifier = load_model("Id_Classifier")
91
  results = classifier(image)
92
+ doc_type = results[0].names[results[0].probs.top1]
93
  confidence = results[0].probs.top1conf.item()
94
+ print(f"Id_Classifier Result: Detected document type: {doc_type} with confidence: {confidence:.2f}")
95
+ logger.info(f"Detected document type: {doc_type}, confidence: {confidence:.2f}")
96
  model_name = CONFIG["doc_type_to_model"].get(doc_type, None)
97
  if model_name is None:
98
+ logger.warning(f"No detection model mapped for document type: {doc_type}. Returning empty result.")
99
+ if save_json:
100
+ with open(output_json, "w") as f:
101
+ json.dump({}, f, indent=4)
102
+ return {}
103
 
104
  # Load detection model
105
  if model_name not in CONFIG["models"]:
106
+ raise ValueError(f"Invalid model name: {model_name}")
107
  model = load_model(model_name)
108
  class_names = CONFIG["models"][model_name]["classes"]
109
  logger.info(f"Loaded model: {model_name} with classes: {class_names}")
110
 
111
  # Run inference
112
  results = model(image_path)
 
113
  filtered_boxes = {}
114
  output_image = results[0].orig_img.copy()
115
  original_image = cv2.imread(image_path)
116
  h, w, _ = output_image.shape
117
 
118
+ # Filter highest confidence box for each class
119
  for result in results:
120
  if not result.boxes:
121
  logger.warning("No boxes detected in the image.")
 
127
  logger.warning(f"Invalid class index {cls} for model {model_name}. Skipping box.")
128
  continue
129
  conf = box.conf[0].item()
130
+ xyxy = box.xyxy[0].tolist()
131
+ class_name = class_names[cls]
132
+ logger.info(f"Detected box for class index: {cls}, class name: {class_name}, confidence: {conf:.2f}, coords: {xyxy}")
133
+ if cls not in filtered_boxes or conf > filtered_boxes[cls]["conf"]:
134
  filtered_boxes[cls] = {"conf": conf, "xyxy": xyxy, "class_name": class_name}
 
 
 
135
  except IndexError as e:
136
+ logger.error(f"Error processing box: {e}, box data: {box}")
137
  continue
138
 
139
+ # Extract text and visualize
140
+ detected_text = {}
141
+ processed_images = []
142
+ for cls, data in filtered_boxes.items():
143
+ try:
144
+ x_min, y_min, x_max, y_max = map(int, data["xyxy"])
145
+ class_name = data["class_name"]
146
+ x_min, y_min = max(0, x_min), max(0, y_min)
147
+ x_max, y_max = min(w, x_max), min(h, y_max)
148
+ logger.info(f"Processing class: {class_name} at coordinates: ({x_min}, {y_min}, {x_max}, {y_max})")
149
+
150
+ # Crop region
151
+ region_img = original_image[y_min:y_max, x_min:x_max]
152
+ if region_img.size == 0:
153
+ logger.warning(f"Empty region for class: {class_name}. Skipping.")
154
+ continue
155
+ region_img = preprocess_image(region_img)
156
+ region_h, region_w = region_img.shape[:2]
157
+
158
+ # Create black canvas and center the cropped region
159
+ black_canvas = np.ones((h, w, 3), dtype=np.uint8)
160
+ center_x, center_y = w // 2, h // 2
161
+ top_left_x = max(0, min(w - region_w, center_x - region_w // 2))
162
+ top_left_y = max(0, min(h - region_h, center_y - region_h // 2))
163
+ region_w = min(region_w, w - top_left_x)
164
+ region_h = min(region_h, h - top_left_y)
165
+ region_img = cv2.resize(region_img, (region_w, region_h))
166
+ black_canvas[top_left_y:top_left_y+region_h, top_left_x:top_left_x+region_w] = region_img
167
+
168
+ # Perform OCR
169
+ ocr_result = OCR.ocr(black_canvas, cls=True)
170
+ if ocr_result is None:
171
+ ocr_result = []
172
+ extracted_text = " ".join(
173
+ word_info[1][0] for line in ocr_result for word_info in line if word_info and len(word_info) > 1 and len(word_info[1]) > 0
174
+ ) if ocr_result else "No text detected"
175
+ logger.info(f"Extracted text for {class_name}: {extracted_text}")
176
+ detected_text[class_name] = extracted_text
177
+
178
+ # Draw OCR bounding boxes
179
+ for line in ocr_result:
180
+ if line is None:
181
  continue
182
+ for word_info in line:
183
+ if word_info is None:
184
+ continue
185
+ try:
186
+ box = word_info[0]
187
+ x1, y1 = int(box[0][0]), int(box[0][1])
188
+ x2, y2 = int(box[2][0]), int(box[2][1])
189
+ cv2.rectangle(black_canvas, (x1, y1), (x2, y2), (0, 255, 0), 5)
190
+ except (IndexError, TypeError) as e:
191
+ logger.error(f"Error drawing OCR box for class {class_name}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
+ # Save processed image
195
+ processed_images.append((class_name, black_canvas, extracted_text))
 
 
196
 
197
+ # Draw original bounding box
198
+ cv2.rectangle(output_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
199
+ cv2.putText(output_image, class_name, (x_min, y_min - 10),
200
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
201
+ except Exception as e:
202
+ logger.error(f"Error processing class {class_name}: {e}")
203
+ continue
204
 
205
+ # Save JSON
206
+ if save_json:
207
+ with open(output_json, "w") as f:
208
+ json.dump(detected_text, f, indent=4)
209
+
210
+ # Visualize
211
+ if verbose:
212
+ plt.figure(figsize=(10, 10))
213
+ plt.imshow(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
214
+ plt.axis("off")
215
+ plt.title("Raw Image")
216
+ plt.show()
217
+
218
+ plt.figure(figsize=(10, 10))
219
+ plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))
220
+ plt.axis("off")
221
+ plt.title("Output Image with Bounding Boxes")
222
+ plt.show()
223
+
224
+ for class_name, cropped_image, text in processed_images:
225
  plt.figure(figsize=(10, 10))
226
+ plt.imshow(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
227
+ plt.axis("off")
228
+ plt.title(f"{class_name} - Extracted: {text}")
229
  plt.show()
230
 
231
+ return detected_text
 
 
 
 
 
 
 
 
232
 
233
  # Model-specific functions
234
  def aadhaar(image_path, save_json=True, output_json="detected_text.json", verbose=False):
235
+ """Process an Aadhaar card image."""
236
  return process_id(image_path, model_name="Aadhaar", save_json=save_json, output_json=output_json, verbose=verbose)
237
 
238
  def pan_card(image_path, save_json=True, output_json="detected_text.json", verbose=False):
239
+ """Process a PAN card image."""
240
+ return process_id(image_path, model_name="Pan_Card", save_json=save_json, output_json=output_json, verbose=verbose)
241
 
242
  def passport(image_path, save_json=True, output_json="detected_text.json", verbose=False):
243
+ """Process a passport image."""
244
  return process_id(image_path, model_name="Passport", save_json=save_json, output_json=output_json, verbose=verbose)
245
 
246
  def voter_id(image_path, save_json=True, output_json="detected_text.json", verbose=False):
247
+ """Process a voter ID image."""
248
+ return process_id(image_path, model_name="Voter_Id", save_json=save_json, output_json=output_json, verbose=verbose)
249
 
250
  def driving_license(image_path, save_json=True, output_json="detected_text.json", verbose=False):
251
+ """Process a driving license image."""
252
  return process_id(image_path, model_name="Driving_License", save_json=save_json, output_json=output_json, verbose=verbose)
253
 
254
  # Command-line interface
255
  if __name__ == "__main__":
256
  import argparse
257
  parser = argparse.ArgumentParser(description="Indian ID Validator: Classify and extract fields from ID images.")
258
+ parser.add_argument("image_path", help="Path to the input ID image")
259
  parser.add_argument("--model", default=None, choices=["Aadhaar", "Pan_Card", "Passport", "Voter_Id", "Driving_License"],
260
  help="Specific model to use (default: auto-detect with Id_Classifier)")
261
+ parser.add_argument("--no-save-json", action="store_false", dest="save_json", help="Disable saving to JSON")
 
262
  parser.add_argument("--output-json", default="detected_text.json", help="Path to save JSON output")
263
  parser.add_argument("--verbose", action="store_true", help="Display visualizations")
264
  args = parser.parse_args()