| | """ |
| | Small demo application to explore Gradio. |
| | """ |
| |
|
| | import argparse |
| | import os |
| | from functools import partial |
| |
|
| | import gradio as gr |
| | from PIL import Image |
| | from huggingface_hub import hf_hub_download |
| |
|
| | from die_model import UNetDIEModel |
| | from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, \ |
| | remove_square_padding |
| |
|
| |
|
| | def die_inference( |
| | image_raw, |
| | num_of_die_iterations, |
| | die_model, |
| | device |
| | ): |
| | """ |
| | Function to run the DIE model. |
| | :param image_raw: raw image |
| | :param num_of_die_iterations: number of DIE iterations |
| | :param die_model: DIE model |
| | :param device: device |
| | :return: cleaned image |
| | """ |
| |
|
| | |
| | image_raw_resized = resize_image(image_raw, 1500) |
| | image_raw_resized_square = make_image_square(image_raw_resized) |
| | image_raw_resized_square_tensor = cast_pil_image_to_torch_tensor_with_4_channel_dim(image_raw_resized_square) |
| | image_raw_resized_square_tensor = image_raw_resized_square_tensor.to(device) |
| |
|
| | |
| | num_of_die_iterations = int(num_of_die_iterations) |
| |
|
| | |
| | image_die = die_model.enhance_document_image( |
| | image_raw_list=[image_raw_resized_square_tensor], |
| | num_of_die_iterations=num_of_die_iterations |
| | )[0] |
| |
|
| | |
| | image_die_resized = remove_square_padding( |
| | original_image=image_raw, |
| | square_image=image_die, |
| | resize_back_to_original=True |
| | ) |
| |
|
| |
|
| | return image_die_resized |
| |
|
| |
|
| | def main(): |
| | """ |
| | Main function to run the Gradio demo. |
| | :return: |
| | """ |
| |
|
| | args = parse_arguments() |
| |
|
| | description = "Welcome to the Document Image Enhancement (DIE) model demo on Hugging Face!\n\n" \ |
| | "" \ |
| | "This interactive application showcases a specialized AI model developed by " \ |
| | "the [Artificial Intelligence group](https://ai.renyi.hu) at the [Alfréd Rényi Institute of Mathematics](https://renyi.hu).\n\n" \ |
| | "" \ |
| | "Our DIE model is designed to enhance and restore archival and aged document images " \ |
| | "by removing various types of degradation, thereby making historical documents more legible " \ |
| | "and suitable for Optical Character Recognition (OCR) processing.\n\n" \ |
| | "" \ |
| | "The model effectively tackles 20-30 types of domain-specific noise found in historical records, " \ |
| | "such as scribbles, bleed-through text, faded or worn text, blurriness, textured noise, " \ |
| | "and unwanted background elements. " \ |
| | "By applying deep learning techniques, specifically a U-Net-based architecture, " \ |
| | "the model accurately cleans and clarifies text while preserving original details. " \ |
| | "This improved clarity dramatically boosts OCR accuracy, making it an ideal " \ |
| | "pre-processing tool in digitization workflows.\n\n" \ |
| | "" \ |
| | "If you’re interested in learning more about the model’s capabilities or potential applications, " \ |
| | "please contact us at: gabar92@renyi.hu.\n\n" |
| |
|
| | |
| |
|
| | num_of_die_iterations_list = [1, 2, 3] |
| |
|
| | die_token = os.getenv("DIE_TOKEN") |
| |
|
| | |
| | example_image_list = [ |
| | [Image.open(os.path.join(args.example_image_path, image_path))] |
| | for image_path in os.listdir(args.example_image_path) |
| | ] |
| |
|
| | |
| | args.die_model_path = hf_hub_download( |
| | repo_id="gabar92/die", |
| | filename=args.die_model_path, |
| | use_auth_token=die_token |
| | ) |
| | |
| | die_model = UNetDIEModel(args=args) |
| |
|
| | |
| | partial_die_inference = partial(die_inference, device=args.device, die_model=die_model) |
| |
|
| | demo = gr.Interface( |
| | fn=partial_die_inference, |
| | inputs=[ |
| | gr.Image(type="pil", label="Degraded Document Image"), |
| | gr.Dropdown(num_of_die_iterations_list, label="Number of DIE iterations", value=1), |
| | ], |
| | outputs=gr.Image(type="pil", label="Clean Document Image"), |
| | title="Document Image Enhancement (DIE) model", |
| | description=description, |
| | examples=example_image_list |
| | ) |
| |
|
| | demo.launch(server_name="0.0.0.0", server_port=7860) |
| |
|
| |
|
| | def parse_arguments(): |
| | """ |
| | Parse arguments. |
| | :return: argument namespace |
| | """ |
| |
|
| | parser = argparse.ArgumentParser() |
| |
|
| | parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt") |
| | parser.add_argument("--device", default="cpu") |
| |
|
| | parser.add_argument("--example_image_path", default="example_images") |
| |
|
| | return parser.parse_args() |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | main() |
| |
|