| """HuggingFace Spaces demo for Image Rotation Angle Estimation.
|
|
|
| Two-step interactive demo:
|
| 1. Upload an image and click "Random Rotate" to apply a random rotation
|
| 2. Click "Correct Orientation" to see the model predict and correct the angle
|
| """
|
|
|
| import json
|
|
|
| import gradio as gr
|
| import torch
|
| from PIL import Image
|
| import os
|
| import random
|
| import numpy as np
|
| from loguru import logger
|
| from huggingface_hub import hf_hub_download
|
|
|
| from model_cgd import CGDAngleEstimation
|
| from architectures import get_default_input_size
|
| from rotation_utils import rotate_image_crop_max_area
|
|
|
|
|
| HF_REPO_ID = os.environ.get("HF_MODEL_REPO", "maxwoe/image-rotation-angle-estimation")
|
|
|
|
|
| _config_path = hf_hub_download(repo_id=HF_REPO_ID, filename="config.json")
|
| with open(_config_path) as _f:
|
| _config = json.load(_f)
|
| HF_MODELS = _config["models"]
|
| HF_DEFAULT_MODEL = _config["default_model"]
|
|
|
|
|
| model = None
|
| current_model_name = None
|
|
|
|
|
| def get_device():
|
| return "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
| def load_model(name):
|
| """Download and load a model from HuggingFace Hub."""
|
| global model, current_model_name
|
| if name == current_model_name and model is not None:
|
| return gr.Info(f"Model already loaded: {name}")
|
|
|
| if name not in HF_MODELS:
|
| return gr.Warning(f"Unknown model: {name}")
|
|
|
| info = HF_MODELS[name]
|
| logger.info(f"Downloading {info['filename']} from {HF_REPO_ID}...")
|
| local_path = hf_hub_download(repo_id=HF_REPO_ID, filename=info["filename"])
|
|
|
| architecture = info["architecture"]
|
| image_size = get_default_input_size(architecture)
|
|
|
| logger.info(f"Loading model from {local_path}...")
|
| new_model = CGDAngleEstimation.try_load(checkpoint_path=local_path, image_size=image_size)
|
| new_model.eval()
|
|
|
| device = get_device()
|
| if device.startswith("cuda"):
|
| new_model = new_model.to(device)
|
|
|
| model = new_model
|
| current_model_name = name
|
| logger.info(f"Model loaded: {name} on {device}")
|
| return gr.Info(f"Loaded: {name} ({device})")
|
|
|
|
|
| def store_original(image):
|
| """Store the uploaded image as the original for rotation."""
|
| if image is None:
|
| return None, ""
|
| if isinstance(image, np.ndarray):
|
| image = Image.fromarray(image)
|
| return image, ""
|
|
|
|
|
| def random_rotate(original):
|
| """Apply a random rotation to the original uploaded image."""
|
| if original is None:
|
| return None, None, ""
|
|
|
| angle = random.uniform(0, 360)
|
| img_array = np.array(original)
|
| rotated_array = rotate_image_crop_max_area(img_array, angle)
|
| rotated = Image.fromarray(rotated_array)
|
| return rotated, angle, f"{angle:.1f}°"
|
|
|
|
|
| def correct_orientation(image):
|
| """Predict the rotation angle and correct the image."""
|
| if image is None:
|
| return None, "Please upload and rotate an image first."
|
| if model is None:
|
| return None, "Model is still loading, please wait..."
|
|
|
| if isinstance(image, np.ndarray):
|
| image = Image.fromarray(image)
|
|
|
| predicted_angle = model.predict_angle(image)
|
|
|
| corrected = image.rotate(-predicted_angle, expand=True, fillcolor=(255, 255, 255))
|
|
|
| return corrected, f"Predicted rotation: {predicted_angle:.2f}°"
|
|
|
|
|
|
|
| app = gr.Blocks(title="Image Rotation Angle Estimation")
|
| with app:
|
| gr.HTML("<h1>Image Rotation Angle Estimation</h1>")
|
| gr.Markdown(
|
| "Upload an image, apply a random rotation, and see the model predict and correct the angle.\n\n"
|
| "Uses the **CGD** (Circular Gaussian Distribution) method with **MambaOut Base** architecture. "
|
| )
|
|
|
| original_image_state = gr.State(value=None)
|
| actual_angle_state = gr.State(value=None)
|
|
|
| model_dropdown = gr.Dropdown(
|
| choices=list(HF_MODELS.keys()),
|
| value=HF_DEFAULT_MODEL,
|
| label="Model",
|
| )
|
| model_dropdown.change(load_model, inputs=[model_dropdown])
|
|
|
| with gr.Row():
|
| with gr.Column():
|
| input_image = gr.Image(type="pil", label="Upload Image", height=400, format="png")
|
| rotate_btn = gr.Button("Random Rotate", variant="secondary", size="lg")
|
| with gr.Column():
|
| corrected_image = gr.Image(label="Corrected Image", height=400, interactive=False, format="png")
|
| correct_btn = gr.Button("Correct Orientation", variant="primary", size="lg")
|
|
|
| with gr.Row():
|
| rotation_info = gr.Textbox(label="Applied Rotation", lines=1, interactive=False)
|
| result_text = gr.Textbox(label="Predicted Rotation", lines=1, interactive=False)
|
|
|
| gr.Examples(
|
| examples=[
|
| "examples/COCO_val2014_000000168337.jpg",
|
| "examples/COCO_val2014_000000122166.jpg",
|
| "examples/COCO_val2014_000000446053.jpg",
|
| "examples/COCO_val2014_000000477919.jpg",
|
| "examples/COCO_val2014_000000016995.jpg",
|
| ],
|
| inputs=input_image,
|
| fn=store_original,
|
| outputs=[original_image_state, rotation_info],
|
| cache_examples=False,
|
| run_on_click=True,
|
| )
|
|
|
| input_image.upload(
|
| store_original,
|
| inputs=[input_image],
|
| outputs=[original_image_state, rotation_info],
|
| )
|
| rotate_btn.click(
|
| random_rotate,
|
| inputs=[original_image_state],
|
| outputs=[input_image, actual_angle_state, rotation_info],
|
| )
|
| correct_btn.click(
|
| correct_orientation,
|
| inputs=[input_image],
|
| outputs=[corrected_image, result_text],
|
| )
|
|
|
| app.load(lambda: load_model(HF_DEFAULT_MODEL))
|
|
|
| app.launch()
|
|
|