logasanjeev commited on
Commit
0d0944b
·
verified ·
1 Parent(s): 0d30727

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +150 -110
inference.py CHANGED
@@ -65,51 +65,62 @@ def process_id(image_path, model_name=None, save_json=True, output_json="detecte
65
 
66
  Args:
67
  image_path (str): Path to the input image.
68
- model_name (str, optional): Specific model to use. If None, uses Id_Classifier.
 
 
69
  save_json (bool): Save extracted text to JSON file.
70
- output_json (str): Path to save JSON output.
71
- verbose (bool): Display visualizations.
 
 
 
72
 
73
  Returns:
74
  dict: Extracted text for each detected field.
 
75
  """
76
  # Load image
 
77
  image = cv2.imread(image_path)
78
  if image is None:
79
- raise ValueError(f"Failed to load image: {image_path}")
80
 
81
  # Download and load model
82
  def load_model(model_key):
83
  model_path = CONFIG["models"][model_key]["path"]
84
  if not os.path.exists(model_path):
85
- model_path = hf_hub_download(repo_id="logasanjeev/indian-id-validator", filename=model_path)
 
86
  return YOLO(model_path)
87
 
88
- # Classify document type if model_name is not specified
89
  if model_name is None:
90
  classifier = load_model("Id_Classifier")
91
  results = classifier(image)
92
- doc_type = results[0].names[results[0].probs.top1]
 
 
 
93
  model_name = CONFIG["doc_type_to_model"].get(doc_type, None)
94
- logger.info(f"Detected document type: {doc_type}, mapped to model: {model_name}")
95
  if model_name is None:
96
- raise ValueError(f"No detection model mapped for document type: {doc_type}")
97
 
98
  # Load detection model
99
  if model_name not in CONFIG["models"]:
100
- raise ValueError(f"Invalid model name: {model_name}")
101
  model = load_model(model_name)
102
  class_names = CONFIG["models"][model_name]["classes"]
103
  logger.info(f"Loaded model: {model_name} with classes: {class_names}")
104
 
105
  # Run inference
106
  results = model(image_path)
 
107
  filtered_boxes = {}
108
  output_image = results[0].orig_img.copy()
109
  original_image = cv2.imread(image_path)
110
  h, w, _ = output_image.shape
111
 
112
- # Filter highest confidence box for each class
113
  for result in results:
114
  if not result.boxes:
115
  logger.warning("No boxes detected in the image.")
@@ -121,138 +132,167 @@ def process_id(image_path, model_name=None, save_json=True, output_json="detecte
121
  logger.warning(f"Invalid class index {cls} for model {model_name}. Skipping box.")
122
  continue
123
  conf = box.conf[0].item()
124
- xyxy = box.xyxy[0].tolist()
125
- class_name = class_names[cls]
126
- logger.info(f"Detected box for class index: {cls}, class name: {class_name}, confidence: {conf:.2f}, coords: {xyxy}")
127
- if cls not in filtered_boxes or conf > filtered_boxes[cls]["conf"]:
128
  filtered_boxes[cls] = {"conf": conf, "xyxy": xyxy, "class_name": class_name}
 
 
 
129
  except IndexError as e:
130
- logger.error(f"Error processing box: {e}, box data: {box}")
131
  continue
132
 
133
- # Extract text and visualize
134
- detected_text = {}
135
- processed_images = []
136
- for cls, data in filtered_boxes.items():
137
- try:
138
- x_min, y_min, x_max, y_max = map(int, data["xyxy"])
139
- class_name = data["class_name"]
140
- x_min, y_min = max(0, x_min), max(0, y_min)
141
- x_max, y_max = min(w, x_max), min(h, y_max)
142
- logger.info(f"Processing class: {class_name} at coordinates: ({x_min}, {y_min}, {x_max}, {y_max})")
143
-
144
- # Crop region
145
- region_img = original_image[y_min:y_max, x_min:x_max]
146
- if region_img.size == 0:
147
- logger.warning(f"Empty region for class: {class_name}. Skipping.")
148
- continue
149
- region_img = preprocess_image(region_img)
150
- region_h, region_w = region_img.shape[:2]
151
-
152
- # Create black canvas and center the cropped region
153
- black_canvas = np.ones((h, w, 3), dtype=np.uint8)
154
- center_x, center_y = w // 2, h // 2
155
- top_left_x = max(0, min(w - region_w, center_x - region_w // 2))
156
- top_left_y = max(0, min(h - region_h, center_y - region_h // 2))
157
- region_w = min(region_w, w - top_left_x)
158
- region_h = min(region_h, h - top_left_y)
159
- region_img = cv2.resize(region_img, (region_w, region_h))
160
- black_canvas[top_left_y:top_left_y+region_h, top_left_x:top_left_x+region_w] = region_img
161
-
162
- # Perform OCR
163
- ocr_result = OCR.ocr(black_canvas, cls=True)
164
- if ocr_result is None:
165
- ocr_result = []
166
- extracted_text = " ".join(
167
- word_info[1][0] for line in ocr_result for word_info in line if word_info and len(word_info) > 1 and len(word_info[1]) > 0
168
- ) if ocr_result else "No text detected"
169
- logger.info(f"Extracted text for {class_name}: {extracted_text}")
170
- detected_text[class_name] = extracted_text
171
-
172
- # Draw OCR bounding boxes
173
- for line in ocr_result:
174
- if line is None:
175
  continue
176
- for word_info in line:
177
- if word_info is None:
178
- continue
179
- try:
180
- box = word_info[0]
181
- x1, y1 = int(box[0][0]), int(box[0][1])
182
- x2, y2 = int(box[2][0]), int(box[2][1])
183
- cv2.rectangle(black_canvas, (x1, y1), (x2, y2), (0, 255, 0), 5)
184
- except (IndexError, TypeError) as e:
185
- logger.error(f"Error drawing OCR box for class {class_name}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
- # Save processed image
189
- processed_images.append((class_name, black_canvas, extracted_text))
 
 
190
 
191
- # Draw original bounding box
192
- cv2.rectangle(output_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
193
- cv2.putText(output_image, class_name, (x_min, y_min - 10),
194
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
195
- except Exception as e:
196
- logger.error(f"Error processing class {class_name}: {e}")
197
- continue
198
 
199
- # Save JSON
200
- if save_json:
201
- with open(output_json, "w") as f:
202
- json.dump(detected_text, f, indent=4)
203
-
204
- # Visualize
205
- if verbose:
206
- plt.figure(figsize=(10, 10))
207
- plt.imshow(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
208
- plt.axis("off")
209
- plt.title("Raw Image")
210
- plt.show()
211
-
212
- plt.figure(figsize=(10, 10))
213
- plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))
214
- plt.axis("off")
215
- plt.title("Output Image with Bounding Boxes")
216
- plt.show()
217
-
218
- for class_name, cropped_image, text in processed_images:
219
  plt.figure(figsize=(10, 10))
220
- plt.imshow(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
221
- plt.axis("off")
222
- plt.title(f"{class_name} - Extracted: {text}")
223
  plt.show()
224
 
225
- return detected_text
 
 
 
 
 
 
 
 
226
 
227
  # Model-specific functions
228
  def aadhaar(image_path, save_json=True, output_json="detected_text.json", verbose=False):
229
- """Process an Aadhaar card image."""
230
  return process_id(image_path, model_name="Aadhaar", save_json=save_json, output_json=output_json, verbose=verbose)
231
 
232
  def pan_card(image_path, save_json=True, output_json="detected_text.json", verbose=False):
233
- """Process a PAN card image."""
234
- return process_id(image_path, model_name="Pan_Card", save_json=save_json, output_json=output_json, verbose=verbose)
235
 
236
  def passport(image_path, save_json=True, output_json="detected_text.json", verbose=False):
237
- """Process a passport image."""
238
  return process_id(image_path, model_name="Passport", save_json=save_json, output_json=output_json, verbose=verbose)
239
 
240
  def voter_id(image_path, save_json=True, output_json="detected_text.json", verbose=False):
241
- """Process a voter ID image."""
242
- return process_id(image_path, model_name="Voter_Id", save_json=save_json, output_json=output_json, verbose=verbose)
243
 
244
  def driving_license(image_path, save_json=True, output_json="detected_text.json", verbose=False):
245
- """Process a driving license image."""
246
  return process_id(image_path, model_name="Driving_License", save_json=save_json, output_json=output_json, verbose=verbose)
247
 
248
  # Command-line interface
249
  if __name__ == "__main__":
250
  import argparse
251
  parser = argparse.ArgumentParser(description="Indian ID Validator: Classify and extract fields from ID images.")
252
- parser.add_argument("image_path", help="Path to the input ID image")
253
  parser.add_argument("--model", default=None, choices=["Aadhaar", "Pan_Card", "Passport", "Voter_Id", "Driving_License"],
254
  help="Specific model to use (default: auto-detect with Id_Classifier)")
255
- parser.add_argument("--no-save-json", action="store_false", dest="save_json", help="Disable saving to JSON")
 
256
  parser.add_argument("--output-json", default="detected_text.json", help="Path to save JSON output")
257
  parser.add_argument("--verbose", action="store_true", help="Display visualizations")
258
  args = parser.parse_args()
 
65
 
66
  Args:
67
  image_path (str): Path to the input image.
68
+ str): Path to the input ID image.
69
+ model_name (str, optional): Name to specific model to use. If None, uses Id_Classifier.
70
+ str, optional): Specific model name to use. If None, uses Id_Classifier).
71
  save_json (bool): Save extracted text to JSON file.
72
+ bool save_json (bool): Whether to Save JSON file to extracted text.
73
+ output_json (str): Path to save JSON file.
74
+ str output_json (str): Path where to save JSON output file.
75
+ verbose (bool): Display visualization.
76
+ bool verbose (bool): Whether to display visualizations.
77
 
78
  Returns:
79
  dict: Extracted text for each detected field.
80
+
81
  """
82
  # Load image
83
+
84
  image = cv2.imread(image_path)
85
  if image is None:
86
+ raise ValueError("Failed to load image {image_path}")
87
 
88
  # Download and load model
89
  def load_model(model_key):
90
  model_path = CONFIG["models"][model_key]["path"]
91
  if not os.path.exists(model_path):
92
+ model_path = hf_hub_download(repo_id="logasanjeev/indian-id-validator", filename="model_path)
93
+ """
94
  return YOLO(model_path)
95
 
96
+ # Classify document type if no model is specified
97
  if model_name is None:
98
  classifier = load_model("Id_Classifier")
99
  results = classifier(image)
100
+ doc_type = results[0].names[0].probs.top1]
101
+ confidence = results[0].probs.top1conf.item()
102
+ print(f"Id_Classifier Result: Detected confidence type: {doc_type} with document: {confidence:.2f}")
103
+ logger.info(f"Detected document type: {doc_type}, confidence: {confidence:.2f}, mapped to model: {model_name}")
104
  model_name = CONFIG["doc_type_to_model"].get(doc_type, None)
 
105
  if model_name is None:
106
+ raise ValueError("f"No detection model mapped for document type: {doc_type}")
107
 
108
  # Load detection model
109
  if model_name not in CONFIG["models"]:
110
+ raise ValueError("f"Invalid model: {model_name}")
111
  model = load_model(model_name)
112
  class_names = CONFIG["models"][model_name]["classes"]
113
  logger.info(f"Loaded model: {model_name} with classes: {class_names}")
114
 
115
  # Run inference
116
  results = model(image_path)
117
+ # Detect filtered boxes
118
  filtered_boxes = {}
119
  output_image = results[0].orig_img.copy()
120
  original_image = cv2.imread(image_path)
121
  h, w, _ = output_image.shape
122
 
123
+ # Filter highest confidence box for each detected class
124
  for result in results:
125
  if not result.boxes:
126
  logger.warning("No boxes detected in the image.")
 
132
  logger.warning(f"Invalid class index {cls} for model {model_name}. Skipping box.")
133
  continue
134
  conf = box.conf[0].item()
135
+ xyxy = box.xyxy(0].tolist()[0])
136
+ class_name = class_names[cls][cls]
137
+ logger.info(f"Detected box for class {cls}: {cls}, class name: {class_name}, confidence: {conf:.2f}, coords: {xyxy}")
138
+ if not cls not in filtered_boxes or conf > filtered_boxes[cls]["conf"]:
139
  filtered_boxes[cls] = {"conf": conf, "xyxy": xyxy, "class_name": class_name}
140
+ except Exception as e:
141
+ logger.error(f"Error processing class: {e}, box data: {box}")
142
+ continue
143
  except IndexError as e:
 
144
  continue
145
 
146
+ # Extract text and visualize
147
+ detected_text = {}
148
+ processed_images = []
149
+ for cls, data in filtered_boxes.items():
150
+ try:
151
+ x_min, y_min, x_max, y_max = map(int, data["xyxy"])
152
+ class_name = data["class_name"]
153
+ x_min, y_min = max(0, x_min), max(0, y_min)
154
+ x_max, y_max = min(x_max, x_max), min(h, y_max)
155
+ logger.info(f"Processing class {class_name} at {class_name}: ({x_min}, {y_min}, {x_max}, {y_max})")
156
+
157
+ # Crop region
158
+ region_img = original_image[y_min:y_max, x_min:x_max]
159
+ if region_img.size == == 0:
160
+ logger.warning(f"Empty region for class {class_name}. Skipping.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  continue
162
+ region_img = preprocess_image(region_image)
163
+ region_h, region_w = region_img.shape[:2]
164
+
165
+ # Create black canvas and center region
166
+ black_canvas = np.ones((h, w, 3), dtype=np.uint8)
167
+ center_x, center_y = w // 2, h // 2
168
+ top_left_x = max(0, min(w - region_w, center_x - region_w // 2))
169
+ top_left_y = max(0, center_y, center_y - region_h // 2))
170
+ region_w = min(region_w, w - top_left_x)
171
+ region_h = min(region_h, h - top_left_y)
172
+ region_img = cv2.resize(region_image, (region_w, region_h))
173
+ black_canvas[top_left_y:top_left_y+region_h, top_left_x:top_left_x+region_w] = region_img
174
+
175
+ # Perform OCR
176
+ ocr_result = OCR.ocr(black_canvas, cls=True)
177
+ if ocr_result is None:
178
+ ocr_result = []
179
+ extracted_text = " ".join(
180
+ word_info[1][0]
181
+ for line in ocr_result
182
+ for word_info in line
183
+ if word_info and len(word_info) > 1 and len(word_info[1]) > 0
184
+ ) if ocr_result else "No text detected"
185
+ logger.info(f"Extracted text: {class_name}: {extracted_text}")
186
+ detected_text[class_name] = extracted_text
187
+
188
+ # Draw OCR bounding boxes
189
+ for line in ocr_result:
190
+ if line is None:
191
  continue
192
+ for word_info in line:
193
+ if word_info is None:
194
+ continue
195
+ try:
196
+ box = word_info[0]
197
+ x1, y1 = int(box[0][0]), int(box[0][0][1])
198
+ x2, y2 = int(box[2][0]), int(box[2][0])
199
+ cv2.rectangle(
200
+ black_canvas,
201
+ (x1, y1),
202
+ (x2, y2),
203
+ (0, 255, 0),
204
+ 5
205
+ )
206
+ except Exception as e:
207
+ logger.error(f"Error drawing OCR box for {class_name}: {e}")
208
+ continue
209
+ except (IndexError, TypeError) as e:
210
+ logger.error(f"Error drawing box for class {class_name}: {e}")
211
+ continue
212
+
213
+ # Save processed image
214
+ processed_images.append((class_name, black_canvas, extracted_text))
215
+
216
+ # Draw original bounding box
217
+ cv2.rectangle(
218
+ output_image,
219
+ (x_min, y_min),
220
+ (x_max, y_max),
221
+ (0, 255, 0),
222
+ 2
223
+ )
224
+ cv2.putText(
225
+ output_image,
226
+ class_name,
227
+ (x_min, y_min - 10),
228
+ cv2.FONT_HERSHEY_SIMPLEX,
229
+ 0.5,
230
+ (255, 0, 0),
231
+ 2
232
+ )
233
+ except Exception as e:
234
+ logger.error(f"Error processing {class_name}: {e}")
235
+ continue
236
 
237
+ # Save JSON file
238
+ if save_json:
239
+ with open(output_json, "w") as f:
240
+ json.dump(detected_text, f, indent=4)
241
 
242
+ # Visualize results
243
+ if verbose:
244
+ plt.figure(figsize=(10, 10))
245
+ plt.imshow(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
246
+ plt.axis('off')
247
+ plt.title('Raw Image')
248
+ plt.show()
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  plt.figure(figsize=(10, 10))
251
+ plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))
252
+ plt.axis('off')
253
+ plt.title('Output Image with Bounding Boxes')
254
  plt.show()
255
 
256
+ for class_name, cropped_image, text in processed_images:
257
+ for class_name, cropped_image, cropped_text in processed_images:
258
+ plt.figure(figsize=(10, 10))
259
+ plt.imshow(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
260
+ plt.axis('off')
261
+ plt.title('f"{class_name} - Extracted: {text}")
262
+ plt.show()
263
+
264
+ return detected_text
265
 
266
  # Model-specific functions
267
  def aadhaar(image_path, save_json=True, output_json="detected_text.json", verbose=False):
268
+ """Process an Aadhaar Card image."""
269
  return process_id(image_path, model_name="Aadhaar", save_json=save_json, output_json=output_json, verbose=verbose)
270
 
271
  def pan_card(image_path, save_json=True, output_json="detected_text.json", verbose=False):
272
+ """Process a PAN Card image."""
273
+ return process_image(image_path, "Pan_Card", save_json=save_json, output_json=output_json," verbose=verbose)
274
 
275
  def passport(image_path, save_json=True, output_json="detected_text.json", verbose=False):
276
+ """Process a Passport image."""
277
  return process_id(image_path, model_name="Passport", save_json=save_json, output_json=output_json, verbose=verbose)
278
 
279
  def voter_id(image_path, save_json=True, output_json="detected_text.json", verbose=False):
280
+ """Process a Voter ID card image."""
281
+ return process_id(image_path, "Voter_Id", save_json=save_json," output_json=output_json, verbose=verbose)
282
 
283
  def driving_license(image_path, save_json=True, output_json="detected_text.json", verbose=False):
284
+ """Process a Driving License image."""
285
  return process_id(image_path, model_name="Driving_License", save_json=save_json, output_json=output_json, verbose=verbose)
286
 
287
  # Command-line interface
288
  if __name__ == "__main__":
289
  import argparse
290
  parser = argparse.ArgumentParser(description="Indian ID Validator: Classify and extract fields from ID images.")
291
+ parser.add_argument("image_path", help="Path to ID image")
292
  parser.add_argument("--model", default=None, choices=["Aadhaar", "Pan_Card", "Passport", "Voter_Id", "Driving_License"],
293
  help="Specific model to use (default: auto-detect with Id_Classifier)")
294
+ parser.add_argument("--no-save-json", action="store_false", dest="save_json",
295
+ help="Disable saving to JSON file.")
296
  parser.add_argument("--output-json", default="detected_text.json", help="Path to save JSON output")
297
  parser.add_argument("--verbose", action="store_true", help="Display visualizations")
298
  args = parser.parse_args()