logasanjeev commited on
Commit
97960a9
·
verified ·
1 Parent(s): fe82aa3

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +177 -0
inference.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from ultralytics import YOLO
8
+ from paddleocr import PaddleOCR
9
+ from pathlib import Path
10
+
11
+ def load_config(config_path="config.json"):
12
+ """Load configuration from JSON file."""
13
+ if not os.path.exists(config_path):
14
+ raise FileNotFoundError(f"Config file {config_path} not found.")
15
+ with open(config_path, 'r') as f:
16
+ return json.load(f)
17
+
18
+ def preprocess_image(image):
19
+ """Apply preprocessing steps to enhance OCR accuracy."""
20
+ scale_factor = 2
21
+ image = cv2.resize(image, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_CUBIC)
22
+
23
+ image = cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21)
24
+
25
+ kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
26
+ image = cv2.filter2D(image, -1, kernel)
27
+
28
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
29
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
30
+ image = clahe.apply(gray)
31
+
32
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
33
+ return image
34
+
35
+ def run_ocr(cropped_image, ocr):
36
+ """Run PaddleOCR on a cropped image and return extracted text with confidence."""
37
+ result = ocr.ocr(cropped_image, cls=True)
38
+ if not result or not result[0]:
39
+ return None, 0.0
40
+ text = result[0][0][1][0]
41
+ confidence = result[0][0][1][1]
42
+ return text, confidence
43
+
44
+ def visualize_yolo_output(image, boxes, class_names, save_path=None, show=False):
45
+ """Visualize YOLO bounding boxes on the image."""
46
+ img = image.copy()
47
+ for box in boxes:
48
+ x1, y1, x2, y2 = box.xyxy[0].numpy().astype(int)
49
+ label = class_names[int(box.cls)]
50
+ conf = box.conf[0].numpy()
51
+ cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
52
+ cv2.putText(img, f"{label}: {conf:.2f}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
53
+ if save_path:
54
+ cv2.imwrite(save_path, img)
55
+ if show:
56
+ plt.imshow(img[:, :, ::-1])
57
+ plt.axis('off')
58
+ plt.show()
59
+ return img
60
+
61
+ def visualize_ocr_output(cropped_image, ocr_result, text, confidence, save_path=None, show=False):
62
+ """Visualize OCR bounding boxes and text on the cropped image."""
63
+ img = cropped_image.copy()
64
+ if ocr_result and ocr_result[0]:
65
+ for line in ocr_result[0]:
66
+ box = line[0]
67
+ x1, y1 = int(box[0][0]), int(box[0][1])
68
+ x2, y2 = int(box[2][0]), int(box[2][1])
69
+ cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
70
+ cv2.putText(img, f"{text} ({confidence:.2f})", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
71
+ if save_path:
72
+ cv2.imwrite(save_path, img)
73
+ if show:
74
+ plt.imshow(img[:, :, ::-1])
75
+ plt.axis('off')
76
+ plt.show()
77
+ return img
78
+
79
+ def process_image(image_path, config, model_choice=None, show_yolo=False, show_ocr=False, save_json=True, verbose=False):
80
+ """Process an input image to classify document type, detect fields, and extract text."""
81
+ if not os.path.exists(image_path):
82
+ raise FileNotFoundError(f"Image {image_path} not found.")
83
+
84
+ image = cv2.imread(image_path)
85
+ if image is None:
86
+ raise ValueError(f"Failed to load image {image_path}.")
87
+
88
+ ocr = PaddleOCR(use_angle_cls=True, lang="en", show_log=False)
89
+
90
+ doc_type = model_choice
91
+ if model_choice is None:
92
+ classifier = YOLO(config["models"]["id_classifier"]["path"])
93
+ results = classifier(image, verbose=verbose)
94
+ top_class_idx = results[0].probs.top1
95
+ doc_type = config["models"]["id_classifier"]["classes"][str(top_class_idx)]
96
+ if verbose:
97
+ print(f"Classified document as: {doc_type} (confidence: {results[0].probs.top1conf:.2f})")
98
+
99
+ if doc_type not in config["doc_type_to_model"]:
100
+ raise ValueError(f"Document type {doc_type} not supported.")
101
+ model_name = config["doc_type_to_model"][doc_type]
102
+ if model_name not in config["models"]:
103
+ raise ValueError(f"Model {model_name} not found in config.")
104
+
105
+ detector = YOLO(config["models"][model_name]["path"])
106
+ class_names = config["models"][model_name]["classes"]
107
+ results = detector(image, verbose=verbose)
108
+
109
+ output = {}
110
+
111
+ for i, box in enumerate(results[0].boxes):
112
+ x1, y1, x2, y2 = box.xyxy[0].numpy().astype(int)
113
+ label = class_names[int(box.cls)]
114
+ conf = box.conf[0].numpy()
115
+
116
+ cropped = image[y1:y2, x1:x2]
117
+ if cropped.size == 0:
118
+ continue
119
+
120
+ preprocessed = preprocess_image(cropped)
121
+
122
+ text, ocr_conf = run_ocr(preprocessed, ocr)
123
+ if text:
124
+ output[label] = {"text": text, "yolo_conf": float(conf), "ocr_conf": float(ocr_conf)}
125
+ if verbose:
126
+ print(f"Field: {label}, Text: {text}, YOLO Conf: {conf:.2f}, OCR Conf: {ocr_conf:.2f}")
127
+
128
+ if show_ocr or (save_json and show_ocr):
129
+ ocr_result = ocr.ocr(preprocessed, cls=True)
130
+ save_path = f"ocr_output_{label}_{i}.jpg" if save_json else None
131
+ visualize_ocr_output(preprocessed, ocr_result, text, ocr_conf, save_path=save_path, show=show_ocr)
132
+
133
+ if show_yolo or (save_json and show_yolo):
134
+ save_path = "yolo_output.jpg" if save_json else None
135
+ visualize_yolo_output(image, results[0].boxes, class_names, save_path=save_path, show=show_yolo)
136
+
137
+ if save_json:
138
+ output_path = "detected_text.json"
139
+ with open(output_path, 'w') as f:
140
+ json.dump(output, f, indent=2)
141
+ if verbose:
142
+ print(f"Saved results to {output_path}")
143
+
144
+ return output
145
+
146
+ def main():
147
+ """Command-line interface for inference."""
148
+ parser = argparse.ArgumentParser(description="Indian ID Validator Inference Script")
149
+ parser.add_argument("--image", required=True, help="Path to input image")
150
+ parser.add_argument("--model", default=None, choices=["aadhaar", "pan_card", "passport", "voter_id", "driving_license"],
151
+ help="Specify detection model (default: auto via id_classifier)")
152
+ parser.add_argument("--show-yolo", action="store_true", help="Display/save YOLO bounding box image")
153
+ parser.add_argument("--show-ocr", action="store_true", help="Display/save OCR results for each field")
154
+ parser.add_argument("--no-save-json", action="store_true", help="Disable saving detected_text.json")
155
+ parser.add_argument("--verbose", action="store_true", help="Print detailed inference results")
156
+ args = parser.parse_args()
157
+
158
+ config = load_config()
159
+ try:
160
+ output = process_image(
161
+ image_path=args.image,
162
+ config=config,
163
+ model_choice=args.model,
164
+ show_yolo=args.show_yolo,
165
+ show_ocr=args.show_ocr,
166
+ save_json=not args.no_save_json,
167
+ verbose=args.verbose
168
+ )
169
+ if not args.verbose:
170
+ print("Detected Fields:")
171
+ for label, data in output.items():
172
+ print(f"{label}: {data['text']} (YOLO Conf: {data['yolo_conf']:.2f}, OCR Conf: {data['ocr_conf']:.2f})")
173
+ except Exception as e:
174
+ print(f"Error: {str(e)}")
175
+
176
+ if __name__ == "__main__":
177
+ main()