Create inference.py
Browse files- inference.py +177 -0
inference.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
from ultralytics import YOLO
|
| 8 |
+
from paddleocr import PaddleOCR
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
def load_config(config_path="config.json"):
|
| 12 |
+
"""Load configuration from JSON file."""
|
| 13 |
+
if not os.path.exists(config_path):
|
| 14 |
+
raise FileNotFoundError(f"Config file {config_path} not found.")
|
| 15 |
+
with open(config_path, 'r') as f:
|
| 16 |
+
return json.load(f)
|
| 17 |
+
|
| 18 |
+
def preprocess_image(image):
|
| 19 |
+
"""Apply preprocessing steps to enhance OCR accuracy."""
|
| 20 |
+
scale_factor = 2
|
| 21 |
+
image = cv2.resize(image, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_CUBIC)
|
| 22 |
+
|
| 23 |
+
image = cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21)
|
| 24 |
+
|
| 25 |
+
kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
|
| 26 |
+
image = cv2.filter2D(image, -1, kernel)
|
| 27 |
+
|
| 28 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 29 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 30 |
+
image = clahe.apply(gray)
|
| 31 |
+
|
| 32 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
| 33 |
+
return image
|
| 34 |
+
|
| 35 |
+
def run_ocr(cropped_image, ocr):
|
| 36 |
+
"""Run PaddleOCR on a cropped image and return extracted text with confidence."""
|
| 37 |
+
result = ocr.ocr(cropped_image, cls=True)
|
| 38 |
+
if not result or not result[0]:
|
| 39 |
+
return None, 0.0
|
| 40 |
+
text = result[0][0][1][0]
|
| 41 |
+
confidence = result[0][0][1][1]
|
| 42 |
+
return text, confidence
|
| 43 |
+
|
| 44 |
+
def visualize_yolo_output(image, boxes, class_names, save_path=None, show=False):
|
| 45 |
+
"""Visualize YOLO bounding boxes on the image."""
|
| 46 |
+
img = image.copy()
|
| 47 |
+
for box in boxes:
|
| 48 |
+
x1, y1, x2, y2 = box.xyxy[0].numpy().astype(int)
|
| 49 |
+
label = class_names[int(box.cls)]
|
| 50 |
+
conf = box.conf[0].numpy()
|
| 51 |
+
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 52 |
+
cv2.putText(img, f"{label}: {conf:.2f}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
| 53 |
+
if save_path:
|
| 54 |
+
cv2.imwrite(save_path, img)
|
| 55 |
+
if show:
|
| 56 |
+
plt.imshow(img[:, :, ::-1])
|
| 57 |
+
plt.axis('off')
|
| 58 |
+
plt.show()
|
| 59 |
+
return img
|
| 60 |
+
|
| 61 |
+
def visualize_ocr_output(cropped_image, ocr_result, text, confidence, save_path=None, show=False):
|
| 62 |
+
"""Visualize OCR bounding boxes and text on the cropped image."""
|
| 63 |
+
img = cropped_image.copy()
|
| 64 |
+
if ocr_result and ocr_result[0]:
|
| 65 |
+
for line in ocr_result[0]:
|
| 66 |
+
box = line[0]
|
| 67 |
+
x1, y1 = int(box[0][0]), int(box[0][1])
|
| 68 |
+
x2, y2 = int(box[2][0]), int(box[2][1])
|
| 69 |
+
cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
|
| 70 |
+
cv2.putText(img, f"{text} ({confidence:.2f})", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
|
| 71 |
+
if save_path:
|
| 72 |
+
cv2.imwrite(save_path, img)
|
| 73 |
+
if show:
|
| 74 |
+
plt.imshow(img[:, :, ::-1])
|
| 75 |
+
plt.axis('off')
|
| 76 |
+
plt.show()
|
| 77 |
+
return img
|
| 78 |
+
|
| 79 |
+
def process_image(image_path, config, model_choice=None, show_yolo=False, show_ocr=False, save_json=True, verbose=False):
|
| 80 |
+
"""Process an input image to classify document type, detect fields, and extract text."""
|
| 81 |
+
if not os.path.exists(image_path):
|
| 82 |
+
raise FileNotFoundError(f"Image {image_path} not found.")
|
| 83 |
+
|
| 84 |
+
image = cv2.imread(image_path)
|
| 85 |
+
if image is None:
|
| 86 |
+
raise ValueError(f"Failed to load image {image_path}.")
|
| 87 |
+
|
| 88 |
+
ocr = PaddleOCR(use_angle_cls=True, lang="en", show_log=False)
|
| 89 |
+
|
| 90 |
+
doc_type = model_choice
|
| 91 |
+
if model_choice is None:
|
| 92 |
+
classifier = YOLO(config["models"]["id_classifier"]["path"])
|
| 93 |
+
results = classifier(image, verbose=verbose)
|
| 94 |
+
top_class_idx = results[0].probs.top1
|
| 95 |
+
doc_type = config["models"]["id_classifier"]["classes"][str(top_class_idx)]
|
| 96 |
+
if verbose:
|
| 97 |
+
print(f"Classified document as: {doc_type} (confidence: {results[0].probs.top1conf:.2f})")
|
| 98 |
+
|
| 99 |
+
if doc_type not in config["doc_type_to_model"]:
|
| 100 |
+
raise ValueError(f"Document type {doc_type} not supported.")
|
| 101 |
+
model_name = config["doc_type_to_model"][doc_type]
|
| 102 |
+
if model_name not in config["models"]:
|
| 103 |
+
raise ValueError(f"Model {model_name} not found in config.")
|
| 104 |
+
|
| 105 |
+
detector = YOLO(config["models"][model_name]["path"])
|
| 106 |
+
class_names = config["models"][model_name]["classes"]
|
| 107 |
+
results = detector(image, verbose=verbose)
|
| 108 |
+
|
| 109 |
+
output = {}
|
| 110 |
+
|
| 111 |
+
for i, box in enumerate(results[0].boxes):
|
| 112 |
+
x1, y1, x2, y2 = box.xyxy[0].numpy().astype(int)
|
| 113 |
+
label = class_names[int(box.cls)]
|
| 114 |
+
conf = box.conf[0].numpy()
|
| 115 |
+
|
| 116 |
+
cropped = image[y1:y2, x1:x2]
|
| 117 |
+
if cropped.size == 0:
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
preprocessed = preprocess_image(cropped)
|
| 121 |
+
|
| 122 |
+
text, ocr_conf = run_ocr(preprocessed, ocr)
|
| 123 |
+
if text:
|
| 124 |
+
output[label] = {"text": text, "yolo_conf": float(conf), "ocr_conf": float(ocr_conf)}
|
| 125 |
+
if verbose:
|
| 126 |
+
print(f"Field: {label}, Text: {text}, YOLO Conf: {conf:.2f}, OCR Conf: {ocr_conf:.2f}")
|
| 127 |
+
|
| 128 |
+
if show_ocr or (save_json and show_ocr):
|
| 129 |
+
ocr_result = ocr.ocr(preprocessed, cls=True)
|
| 130 |
+
save_path = f"ocr_output_{label}_{i}.jpg" if save_json else None
|
| 131 |
+
visualize_ocr_output(preprocessed, ocr_result, text, ocr_conf, save_path=save_path, show=show_ocr)
|
| 132 |
+
|
| 133 |
+
if show_yolo or (save_json and show_yolo):
|
| 134 |
+
save_path = "yolo_output.jpg" if save_json else None
|
| 135 |
+
visualize_yolo_output(image, results[0].boxes, class_names, save_path=save_path, show=show_yolo)
|
| 136 |
+
|
| 137 |
+
if save_json:
|
| 138 |
+
output_path = "detected_text.json"
|
| 139 |
+
with open(output_path, 'w') as f:
|
| 140 |
+
json.dump(output, f, indent=2)
|
| 141 |
+
if verbose:
|
| 142 |
+
print(f"Saved results to {output_path}")
|
| 143 |
+
|
| 144 |
+
return output
|
| 145 |
+
|
| 146 |
+
def main():
|
| 147 |
+
"""Command-line interface for inference."""
|
| 148 |
+
parser = argparse.ArgumentParser(description="Indian ID Validator Inference Script")
|
| 149 |
+
parser.add_argument("--image", required=True, help="Path to input image")
|
| 150 |
+
parser.add_argument("--model", default=None, choices=["aadhaar", "pan_card", "passport", "voter_id", "driving_license"],
|
| 151 |
+
help="Specify detection model (default: auto via id_classifier)")
|
| 152 |
+
parser.add_argument("--show-yolo", action="store_true", help="Display/save YOLO bounding box image")
|
| 153 |
+
parser.add_argument("--show-ocr", action="store_true", help="Display/save OCR results for each field")
|
| 154 |
+
parser.add_argument("--no-save-json", action="store_true", help="Disable saving detected_text.json")
|
| 155 |
+
parser.add_argument("--verbose", action="store_true", help="Print detailed inference results")
|
| 156 |
+
args = parser.parse_args()
|
| 157 |
+
|
| 158 |
+
config = load_config()
|
| 159 |
+
try:
|
| 160 |
+
output = process_image(
|
| 161 |
+
image_path=args.image,
|
| 162 |
+
config=config,
|
| 163 |
+
model_choice=args.model,
|
| 164 |
+
show_yolo=args.show_yolo,
|
| 165 |
+
show_ocr=args.show_ocr,
|
| 166 |
+
save_json=not args.no_save_json,
|
| 167 |
+
verbose=args.verbose
|
| 168 |
+
)
|
| 169 |
+
if not args.verbose:
|
| 170 |
+
print("Detected Fields:")
|
| 171 |
+
for label, data in output.items():
|
| 172 |
+
print(f"{label}: {data['text']} (YOLO Conf: {data['yolo_conf']:.2f}, OCR Conf: {data['ocr_conf']:.2f})")
|
| 173 |
+
except Exception as e:
|
| 174 |
+
print(f"Error: {str(e)}")
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
main()
|