logasanjeev commited on
Commit
0dc0b49
·
verified ·
1 Parent(s): cfa1cb9

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +206 -158
inference.py CHANGED
@@ -1,177 +1,225 @@
1
- import json
2
- import argparse
3
- import os
4
  import cv2
 
5
  import numpy as np
6
  import matplotlib.pyplot as plt
7
  from ultralytics import YOLO
8
  from paddleocr import PaddleOCR
9
- from pathlib import Path
 
10
 
 
11
  def load_config(config_path="config.json"):
12
- """Load configuration from JSON file."""
13
  if not os.path.exists(config_path):
14
- raise FileNotFoundError(f"Config file {config_path} not found.")
15
- with open(config_path, 'r') as f:
16
  return json.load(f)
17
 
18
- def preprocess_image(image):
19
- """Apply preprocessing steps to enhance OCR accuracy."""
20
- scale_factor = 2
21
- image = cv2.resize(image, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_CUBIC)
22
-
23
- image = cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21)
24
-
 
 
 
 
 
25
  kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
26
- image = cv2.filter2D(image, -1, kernel)
27
-
28
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
29
- clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
30
- image = clahe.apply(gray)
31
-
32
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
33
- return image
34
 
35
- def run_ocr(cropped_image, ocr):
36
- """Run PaddleOCR on a cropped image and return extracted text with confidence."""
37
- result = ocr.ocr(cropped_image, cls=True)
38
- if not result or not result[0]:
39
- return None, 0.0
40
- text = result[0][0][1][0]
41
- confidence = result[0][0][1][1]
42
- return text, confidence
43
-
44
- def visualize_yolo_output(image, boxes, class_names, save_path=None, show=False):
45
- """Visualize YOLO bounding boxes on the image."""
46
- img = image.copy()
47
- for box in boxes:
48
- x1, y1, x2, y2 = box.xyxy[0].numpy().astype(int)
49
- label = class_names[int(box.cls)]
50
- conf = box.conf[0].numpy()
51
- cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
52
- cv2.putText(img, f"{label}: {conf:.2f}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
53
- if save_path:
54
- cv2.imwrite(save_path, img)
55
- if show:
56
- plt.imshow(img[:, :, ::-1])
57
- plt.axis('off')
58
- plt.show()
59
- return img
60
-
61
- def visualize_ocr_output(cropped_image, ocr_result, text, confidence, save_path=None, show=False):
62
- """Visualize OCR bounding boxes and text on the cropped image."""
63
- img = cropped_image.copy()
64
- if ocr_result and ocr_result[0]:
65
- for line in ocr_result[0]:
66
- box = line[0]
67
- x1, y1 = int(box[0][0]), int(box[0][1])
68
- x2, y2 = int(box[2][0]), int(box[2][1])
69
- cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
70
- cv2.putText(img, f"{text} ({confidence:.2f})", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
71
- if save_path:
72
- cv2.imwrite(save_path, img)
73
- if show:
74
- plt.imshow(img[:, :, ::-1])
75
- plt.axis('off')
76
- plt.show()
77
- return img
78
 
79
- def process_image(image_path, config, model_choice=None, show_yolo=False, show_ocr=False, save_json=True, verbose=False):
80
- """Process an input image to classify document type, detect fields, and extract text."""
81
- if not os.path.exists(image_path):
82
- raise FileNotFoundError(f"Image {image_path} not found.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
 
 
 
 
 
 
 
 
 
 
 
84
  image = cv2.imread(image_path)
85
  if image is None:
86
- raise ValueError(f"Failed to load image {image_path}.")
87
-
88
- ocr = PaddleOCR(use_angle_cls=True, lang="en", show_log=False)
89
-
90
- doc_type = model_choice
91
- if model_choice is None:
92
- classifier = YOLO(config["models"]["id_classifier"]["path"])
93
- results = classifier(image, verbose=verbose)
94
- top_class_idx = results[0].probs.top1
95
- doc_type = config["models"]["id_classifier"]["classes"][str(top_class_idx)]
96
- if verbose:
97
- print(f"Classified document as: {doc_type} (confidence: {results[0].probs.top1conf:.2f})")
98
-
99
- if doc_type not in config["doc_type_to_model"]:
100
- raise ValueError(f"Document type {doc_type} not supported.")
101
- model_name = config["doc_type_to_model"][doc_type]
102
- if model_name not in config["models"]:
103
- raise ValueError(f"Model {model_name} not found in config.")
104
-
105
- detector = YOLO(config["models"][model_name]["path"])
106
- class_names = config["models"][model_name]["classes"]
107
- results = detector(image, verbose=verbose)
108
-
109
- output = {}
110
-
111
- for i, box in enumerate(results[0].boxes):
112
- x1, y1, x2, y2 = box.xyxy[0].numpy().astype(int)
113
- label = class_names[int(box.cls)]
114
- conf = box.conf[0].numpy()
115
-
116
- cropped = image[y1:y2, x1:x2]
117
- if cropped.size == 0:
118
- continue
119
-
120
- preprocessed = preprocess_image(cropped)
121
-
122
- text, ocr_conf = run_ocr(preprocessed, ocr)
123
- if text:
124
- output[label] = {"text": text, "yolo_conf": float(conf), "ocr_conf": float(ocr_conf)}
125
- if verbose:
126
- print(f"Field: {label}, Text: {text}, YOLO Conf: {conf:.2f}, OCR Conf: {ocr_conf:.2f}")
127
-
128
- if show_ocr or (save_json and show_ocr):
129
- ocr_result = ocr.ocr(preprocessed, cls=True)
130
- save_path = f"ocr_output_{label}_{i}.jpg" if save_json else None
131
- visualize_ocr_output(preprocessed, ocr_result, text, ocr_conf, save_path=save_path, show=show_ocr)
132
-
133
- if show_yolo or (save_json and show_yolo):
134
- save_path = "yolo_output.jpg" if save_json else None
135
- visualize_yolo_output(image, results[0].boxes, class_names, save_path=save_path, show=show_yolo)
136
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  if save_json:
138
- output_path = "detected_text.json"
139
- with open(output_path, 'w') as f:
140
- json.dump(output, f, indent=2)
141
- if verbose:
142
- print(f"Saved results to {output_path}")
143
-
144
- return output
145
-
146
- def main():
147
- """Command-line interface for inference."""
148
- parser = argparse.ArgumentParser(description="Indian ID Validator Inference Script")
149
- parser.add_argument("--image", required=True, help="Path to input image")
150
- parser.add_argument("--model", default=None, choices=["aadhaar", "pan_card", "passport", "voter_id", "driving_license"],
151
- help="Specify detection model (default: auto via id_classifier)")
152
- parser.add_argument("--show-yolo", action="store_true", help="Display/save YOLO bounding box image")
153
- parser.add_argument("--show-ocr", action="store_true", help="Display/save OCR results for each field")
154
- parser.add_argument("--no-save-json", action="store_true", help="Disable saving detected_text.json")
155
- parser.add_argument("--verbose", action="store_true", help="Print detailed inference results")
156
- args = parser.parse_args()
157
-
158
- config = load_config()
159
- try:
160
- output = process_image(
161
- image_path=args.image,
162
- config=config,
163
- model_choice=args.model,
164
- show_yolo=args.show_yolo,
165
- show_ocr=args.show_ocr,
166
- save_json=not args.no_save_json,
167
- verbose=args.verbose
168
- )
169
- if not args.verbose:
170
- print("Detected Fields:")
171
- for label, data in output.items():
172
- print(f"{label}: {data['text']} (YOLO Conf: {data['yolo_conf']:.2f}, OCR Conf: {data['ocr_conf']:.2f})")
173
- except Exception as e:
174
- print(f"Error: {str(e)}")
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  if __name__ == "__main__":
177
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import cv2
2
+ import json
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
  from ultralytics import YOLO
6
  from paddleocr import PaddleOCR
7
+ from huggingface_hub import hf_hub_download
8
+ import os
9
 
10
+ # Load configuration
11
  def load_config(config_path="config.json"):
 
12
  if not os.path.exists(config_path):
13
+ config_path = hf_hub_download(repo_id="logasanjeev/indian-id-validator", filename="config.json")
14
+ with open(config_path, "r") as f:
15
  return json.load(f)
16
 
17
+ CONFIG = load_config()
18
+
19
+ # Initialize PaddleOCR
20
+ OCR = PaddleOCR(use_angle_cls=True, lang="en")
21
+
22
+ # Preprocessing functions
23
+ def upscale_image(image, scale=2):
24
+ """Upscales the image to improve OCR accuracy."""
25
+ return cv2.resize(image, (image.shape[1] * scale, image.shape[0] * scale), interpolation=cv2.INTER_CUBIC)
26
+
27
+ def unblur_image(image):
28
+ """Sharpens the image to reduce blurriness."""
29
  kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
30
+ return cv2.filter2D(image, -1, kernel)
 
 
 
 
 
 
 
31
 
32
+ def denoise_image(image):
33
+ """Removes noise using Non-Local Means Denoising."""
34
+ return cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ def enhance_contrast(image):
37
+ """Enhances contrast using CLAHE."""
38
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
39
+ l, a, b = cv2.split(lab)
40
+ clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
41
+ l = clahe.apply(l)
42
+ return cv2.cvtColor(cv2.merge((l, a, b)), cv2.COLOR_LAB2BGR)
43
+
44
+ def preprocess_image(image):
45
+ """Applies all preprocessing steps."""
46
+ if isinstance(image, str):
47
+ image = cv2.imread(image)
48
+ if image is None or not isinstance(image, np.ndarray):
49
+ raise ValueError("Invalid image input. Provide a valid file path or numpy array.")
50
+ image = upscale_image(image, scale=2)
51
+ image = unblur_image(image)
52
+ image = denoise_image(image)
53
+ image = enhance_contrast(image)
54
+ return image
55
+
56
+ # Core inference function
57
+ def process_id(image_path, model_name=None, save_json=True, output_json="detected_text.json", verbose=False):
58
+ """
59
+ Process an ID image to classify document type, detect fields, and extract text.
60
 
61
+ Args:
62
+ image_path (str): Path to the input image.
63
+ model_name (str, optional): Specific model to use (e.g., 'Aadhaar', 'Pan_Card'). If None, uses Id_Classifier.
64
+ save_json (bool): Save extracted text to JSON file.
65
+ output_json (str): Path to save JSON output.
66
+ verbose (bool): Display visualizations (bounding boxes, cropped images).
67
+
68
+ Returns:
69
+ dict: Extracted text for each detected field.
70
+ """
71
+ # Load image
72
  image = cv2.imread(image_path)
73
  if image is None:
74
+ raise ValueError(f"Failed to load image: {image_path}")
75
+
76
+ # Download and load model
77
+ def load_model(model_key):
78
+ model_path = CONFIG["models"][model_key]["path"]
79
+ if not os.path.exists(model_path):
80
+ model_path = hf_hub_download(repo_id="logasanjeev/indian-id-validator", filename=model_path)
81
+ return YOLO(model_path)
82
+
83
+ # Classify document type if model_name is not specified
84
+ if model_name is None:
85
+ classifier = load_model("Id_Classifier")
86
+ results = classifier(image)
87
+ doc_type = results[0].names[results[0].probs.top1]
88
+ model_name = CONFIG["doc_type_to_model"].get(doc_type, None)
89
+ if model_name is None:
90
+ raise ValueError(f"No detection model mapped for document type: {doc_type}")
91
+ else:
92
+ model_name = model_name.capitalize()
93
+
94
+ # Load detection model
95
+ if model_name not in CONFIG["models"]:
96
+ raise ValueError(f"Invalid model name: {model_name}")
97
+ model = load_model(model_name)
98
+ class_names = CONFIG["models"][model_name]["classes"]
99
+
100
+ # Run inference
101
+ results = model(image_path)
102
+ filtered_boxes = {}
103
+ output_image = results[0].orig_img.copy()
104
+ original_image = cv2.imread(image_path)
105
+ h, w, _ = output_image.shape
106
+
107
+ # Filter highest confidence box for each class
108
+ for result in results:
109
+ for box in result.boxes:
110
+ cls = int(box.cls[0].item())
111
+ conf = box.conf[0].item()
112
+ xyxy = box.xyxy[0].tolist()
113
+ if cls not in filtered_boxes or conf > filtered_boxes[cls]["conf"]:
114
+ filtered_boxes[cls] = {"conf": conf, "xyxy": xyxy, "class_name": class_names[cls]}
115
+
116
+ # Extract text and visualize
117
+ detected_text = {}
118
+ processed_images = []
119
+ for cls, data in filtered_boxes.items():
120
+ x_min, y_min, x_max, y_max = map(int, data["xyxy"])
121
+ class_name = data["class_name"]
122
+ x_min, y_min = max(0, x_min), max(0, y_min)
123
+ x_max, y_max = min(w, x_max), min(h, y_max)
124
+
125
+ # Crop region
126
+ region_img = original_image[y_min:y_max, x_min:x_max]
127
+ region_img = preprocess_image(region_img)
128
+ region_h, region_w = region_img.shape[:2]
129
+
130
+ # Create black canvas and center the cropped region
131
+ black_canvas = np.ones((h, w, 3), dtype=np.uint8)
132
+ center_x, center_y = w // 2, h // 2
133
+ top_left_x = max(0, min(w - region_w, center_x - region_w // 2))
134
+ top_left_y = max(0, min(h - region_h, center_y - region_h // 2))
135
+ region_w = min(region_w, w - top_left_x)
136
+ region_h = min(region_h, h - top_left_y)
137
+ region_img = cv2.resize(region_img, (region_w, region_h))
138
+ black_canvas[top_left_y:top_left_y+region_h, top_left_x:top_left_x+region_w] = region_img
139
+
140
+ # Perform OCR
141
+ ocr_result = OCR.ocr(black_canvas, cls=True) or []
142
+ extracted_text = " ".join(word_info[1][0] for line in ocr_result for word_info in line if word_info)
143
+ detected_text[class_name] = extracted_text
144
+
145
+ # Draw OCR bounding boxes
146
+ for line in ocr_result:
147
+ for word_info in line:
148
+ if word_info:
149
+ box = word_info[0]
150
+ x1, y1 = int(box[0][0]), int(box[0][1])
151
+ x2, y2 = int(box[2][0]), int(box[2][1])
152
+ cv2.rectangle(black_canvas, (x1, y1), (x2, y2), (0, 255, 0), 5)
153
+
154
+ # Save processed image
155
+ processed_images.append((class_name, black_canvas, extracted_text))
156
+
157
+ # Draw original bounding box
158
+ cv2.rectangle(output_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
159
+ cv2.putText(output_image, class_name, (x_min, y_min - 10),
160
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
161
+
162
+ # Save JSON
163
  if save_json:
164
+ with open(output_json, "w") as f:
165
+ json.dump(detected_text, f, indent=4)
166
+
167
+ # Visualize
168
+ if verbose:
169
+ plt.figure(figsize=(10, 10))
170
+ plt.imshow(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
171
+ plt.axis("off")
172
+ plt.title("Raw Image")
173
+ plt.show()
174
+
175
+ plt.figure(figsize=(10, 10))
176
+ plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))
177
+ plt.axis("off")
178
+ plt.title("Output Image with Bounding Boxes")
179
+ plt.show()
180
+
181
+ for class_name, cropped_image, text in processed_images:
182
+ plt.figure(figsize=(10, 10))
183
+ plt.imshow(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
184
+ plt.axis("off")
185
+ plt.title(f"{class_name} - Extracted: {text}")
186
+ plt.show()
187
+
188
+ return detected_text
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ # Model-specific functions
191
+ def aadhaar(image_path, save_json=True, output_json="detected_text.json", verbose=False):
192
+ """Process an Aadhaar card image."""
193
+ return process_id(image_path, model_name="Aadhaar", save_json=save_json, output_json=output_json, verbose=verbose)
194
+
195
+ def pan_card(image_path, save_json=True, output_json="detected_text.json", verbose=False):
196
+ """Process a PAN card image."""
197
+ return process_id(image_path, model_name="Pan_Card", save_json=save_json, output_json=output_json, verbose=verbose)
198
+
199
+ def passport(image_path, save_json=True, output_json="detected_text.json", verbose=False):
200
+ """Process a passport image."""
201
+ return process_id(image_path, model_name="Passport", save_json=save_json, output_json=output_json, verbose=verbose)
202
+
203
+ def voter_id(image_path, save_json=True, output_json="detected_text.json", verbose=False):
204
+ """Process a voter ID image."""
205
+ return process_id(image_path, model_name="Voter_Id", save_json=save_json, output_json=output_json, verbose=verbose)
206
+
207
+ def driving_license(image_path, save_json=True, output_json="detected_text.json", verbose=False):
208
+ """Process a driving license image."""
209
+ return process_id(image_path, model_name="Driving_License", save_json=save_json, output_json=output_json, verbose=verbose)
210
+
211
+ # Command-line interface
212
  if __name__ == "__main__":
213
+ import argparse
214
+ parser = argparse.ArgumentParser(description="Indian ID Validator: Classify and extract fields from ID images.")
215
+ parser.add_argument("image_path", help="Path to the input ID image")
216
+ parser.add_argument("--model", default=None, choices=["Aadhaar", "Pan_Card", "Passport", "Voter_Id", "Driving_License"],
217
+ help="Specific model to use (default: auto-detect with Id_Classifier)")
218
+ parser.add_argument("--no-save-json", action="store_false", dest="save_json", help="Disable saving to JSON")
219
+ parser.add_argument("--output-json", default="detected_text.json", help="Path to save JSON output")
220
+ parser.add_argument("--verbose", action="store_true", help="Display visualizations")
221
+ args = parser.parse_args()
222
+
223
+ result = process_id(args.image_path, args.model, args.save_json, args.output_json, args.verbose)
224
+ print("Extracted Text:")
225
+ print(json.dumps(result, indent=4))