|
|
|
|
|
""" |
|
|
UVDoc Grid-Output Document Unwarping Example |
|
|
|
|
|
This script demonstrates how to use the UVDoc ONNX model with grid output |
|
|
for high-resolution document unwarping. |
|
|
|
|
|
The key advantage of this grid-output model over image-output models is that |
|
|
the coordinate grid can be upscaled to any resolution, preserving document |
|
|
quality when applied via cv2.remap(). |
|
|
|
|
|
Usage: |
|
|
python example.py input_image.jpg output_image.jpg |
|
|
python example.py input_image.jpg output_image.jpg --model path/to/UVDoc_grid.onnx |
|
|
|
|
|
Requirements: |
|
|
pip install onnxruntime opencv-python numpy |
|
|
|
|
|
Optional (for automatic model download): |
|
|
pip install huggingface_hub |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
MODEL_INPUT_HEIGHT = 720 |
|
|
MODEL_INPUT_WIDTH = 496 |
|
|
|
|
|
|
|
|
def load_model(model_path: str = None): |
|
|
""" |
|
|
Load the ONNX model. |
|
|
|
|
|
Args: |
|
|
model_path: Path to the ONNX model file. If None, attempts to download |
|
|
from HuggingFace Hub. |
|
|
|
|
|
Returns: |
|
|
ONNX Runtime InferenceSession |
|
|
""" |
|
|
import onnxruntime as ort |
|
|
|
|
|
if model_path is None: |
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
print("Downloading model from HuggingFace Hub...") |
|
|
model_path = hf_hub_download( |
|
|
repo_id="YOUR_USERNAME/uvdoc-grid-onnx", |
|
|
filename="UVDoc_grid.onnx" |
|
|
) |
|
|
print(f"Model downloaded to: {model_path}") |
|
|
except ImportError: |
|
|
print("Error: huggingface_hub not installed. Install it or provide --model path.") |
|
|
print(" pip install huggingface_hub") |
|
|
sys.exit(1) |
|
|
|
|
|
print(f"Loading model from: {model_path}") |
|
|
session = ort.InferenceSession( |
|
|
model_path, |
|
|
providers=['CPUExecutionProvider'] |
|
|
) |
|
|
|
|
|
return session |
|
|
|
|
|
|
|
|
def preprocess_image(image: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Preprocess image for UVDoc model input. |
|
|
|
|
|
Args: |
|
|
image: BGR image from cv2.imread() |
|
|
|
|
|
Returns: |
|
|
Preprocessed tensor of shape (1, 3, 720, 496) |
|
|
""" |
|
|
|
|
|
img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
resized = cv2.resize(img_rgb, (MODEL_INPUT_WIDTH, MODEL_INPUT_HEIGHT)) |
|
|
|
|
|
|
|
|
normalized = resized.astype(np.float32) / 255.0 |
|
|
|
|
|
|
|
|
transposed = np.transpose(normalized, (2, 0, 1)) |
|
|
|
|
|
|
|
|
batched = np.expand_dims(transposed, axis=0) |
|
|
|
|
|
return batched |
|
|
|
|
|
|
|
|
def apply_grid_unwarping( |
|
|
image: np.ndarray, |
|
|
grid: np.ndarray, |
|
|
interpolation: int = cv2.INTER_CUBIC |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Apply the coordinate grid to unwarp the image. |
|
|
|
|
|
Args: |
|
|
image: Original BGR image (any resolution) |
|
|
grid: Model output grid of shape (1, 2, 45, 31) |
|
|
interpolation: OpenCV interpolation method |
|
|
|
|
|
Returns: |
|
|
Unwarped image at original resolution |
|
|
""" |
|
|
h_orig, w_orig = image.shape[:2] |
|
|
|
|
|
|
|
|
grid_2d = np.transpose(grid[0], (1, 2, 0)) |
|
|
|
|
|
|
|
|
grid_upscaled = cv2.resize( |
|
|
grid_2d, |
|
|
(w_orig, h_orig), |
|
|
interpolation=cv2.INTER_LINEAR |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
map_x = ((grid_upscaled[..., 0] + 1) / 2) * (w_orig - 1) |
|
|
map_y = ((grid_upscaled[..., 1] + 1) / 2) * (h_orig - 1) |
|
|
|
|
|
|
|
|
unwarped = cv2.remap( |
|
|
image, |
|
|
map_x.astype(np.float32), |
|
|
map_y.astype(np.float32), |
|
|
interpolation=interpolation, |
|
|
borderMode=cv2.BORDER_REPLICATE |
|
|
) |
|
|
|
|
|
return unwarped |
|
|
|
|
|
|
|
|
def unwarp_document( |
|
|
image_path: str, |
|
|
output_path: str, |
|
|
model_path: str = None |
|
|
) -> None: |
|
|
""" |
|
|
Main function to unwarp a document image. |
|
|
|
|
|
Args: |
|
|
image_path: Path to input warped document image |
|
|
output_path: Path to save unwarped result |
|
|
model_path: Optional path to ONNX model file |
|
|
""" |
|
|
|
|
|
print(f"Loading image: {image_path}") |
|
|
image = cv2.imread(image_path) |
|
|
if image is None: |
|
|
print(f"Error: Could not load image from {image_path}") |
|
|
sys.exit(1) |
|
|
|
|
|
h, w = image.shape[:2] |
|
|
print(f"Image size: {w}x{h}") |
|
|
|
|
|
|
|
|
session = load_model(model_path) |
|
|
|
|
|
|
|
|
input_name = session.get_inputs()[0].name |
|
|
print(f"Model input name: {input_name}") |
|
|
|
|
|
|
|
|
print("Preprocessing image...") |
|
|
input_tensor = preprocess_image(image) |
|
|
print(f"Input tensor shape: {input_tensor.shape}") |
|
|
|
|
|
|
|
|
print("Running inference...") |
|
|
result = session.run(None, {input_name: input_tensor})[0] |
|
|
print(f"Output grid shape: {result.shape}") |
|
|
print(f"Output grid range: [{result.min():.4f}, {result.max():.4f}]") |
|
|
|
|
|
|
|
|
print("Applying grid-based unwarping...") |
|
|
unwarped = apply_grid_unwarping(image, result) |
|
|
|
|
|
|
|
|
print(f"Saving result to: {output_path}") |
|
|
cv2.imwrite(output_path, unwarped) |
|
|
|
|
|
print("Done!") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Unwarp document images using UVDoc grid-output ONNX model", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Examples: |
|
|
python example.py warped_doc.jpg unwarped_doc.jpg |
|
|
python example.py warped_doc.jpg unwarped_doc.jpg --model UVDoc_grid.onnx |
|
|
""" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"input", |
|
|
help="Path to input warped document image" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"output", |
|
|
help="Path to save unwarped output image" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--model", "-m", |
|
|
default=None, |
|
|
help="Path to UVDoc_grid.onnx model file (downloads from HuggingFace if not provided)" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if not Path(args.input).exists(): |
|
|
print(f"Error: Input file not found: {args.input}") |
|
|
sys.exit(1) |
|
|
|
|
|
unwarp_document(args.input, args.output, args.model) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|