File size: 5,782 Bytes
c2f46fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | """HuggingFace Spaces demo for Image Rotation Angle Estimation.
Two-step interactive demo:
1. Upload an image and click "Random Rotate" to apply a random rotation
2. Click "Correct Orientation" to see the model predict and correct the angle
"""
import json
import gradio as gr
import torch
from PIL import Image
import os
import random
import numpy as np
from loguru import logger
from huggingface_hub import hf_hub_download
from model_cgd import CGDAngleEstimation
from architectures import get_default_input_size
from rotation_utils import rotate_image_crop_max_area
# HuggingFace Hub configuration
HF_REPO_ID = os.environ.get("HF_MODEL_REPO", "maxwoe/image-rotation-angle-estimation")
# Fetch config.json from model repo (this is the HF download-tracking query file)
_config_path = hf_hub_download(repo_id=HF_REPO_ID, filename="config.json")
with open(_config_path) as _f:
_config = json.load(_f)
HF_MODELS = _config["models"]
HF_DEFAULT_MODEL = _config["default_model"]
# Global model state
model = None
current_model_name = None
def get_device():
return "cuda:0" if torch.cuda.is_available() else "cpu"
def load_model(name):
"""Download and load a model from HuggingFace Hub."""
global model, current_model_name
if name == current_model_name and model is not None:
return gr.Info(f"Model already loaded: {name}")
if name not in HF_MODELS:
return gr.Warning(f"Unknown model: {name}")
info = HF_MODELS[name]
logger.info(f"Downloading {info['filename']} from {HF_REPO_ID}...")
local_path = hf_hub_download(repo_id=HF_REPO_ID, filename=info["filename"])
architecture = info["architecture"]
image_size = get_default_input_size(architecture)
logger.info(f"Loading model from {local_path}...")
new_model = CGDAngleEstimation.try_load(checkpoint_path=local_path, image_size=image_size)
new_model.eval()
device = get_device()
if device.startswith("cuda"):
new_model = new_model.to(device)
model = new_model
current_model_name = name
logger.info(f"Model loaded: {name} on {device}")
return gr.Info(f"Loaded: {name} ({device})")
def store_original(image):
"""Store the uploaded image as the original for rotation."""
if image is None:
return None, ""
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
return image, ""
def random_rotate(original):
"""Apply a random rotation to the original uploaded image."""
if original is None:
return None, None, ""
angle = random.uniform(0, 360)
img_array = np.array(original)
rotated_array = rotate_image_crop_max_area(img_array, angle)
rotated = Image.fromarray(rotated_array)
return rotated, angle, f"{angle:.1f}°"
def correct_orientation(image):
"""Predict the rotation angle and correct the image."""
if image is None:
return None, "Please upload and rotate an image first."
if model is None:
return None, "Model is still loading, please wait..."
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
predicted_angle = model.predict_angle(image)
corrected = image.rotate(-predicted_angle, expand=True, fillcolor=(255, 255, 255))
return corrected, f"Predicted rotation: {predicted_angle:.2f}°"
# Build UI
app = gr.Blocks(title="Image Rotation Angle Estimation")
with app:
gr.HTML("<h1>Image Rotation Angle Estimation</h1>")
gr.Markdown(
"Upload an image, apply a random rotation, and see the model predict and correct the angle.\n\n"
"Uses the **CGD** (Circular Gaussian Distribution) method with **MambaOut Base** architecture. "
)
original_image_state = gr.State(value=None)
actual_angle_state = gr.State(value=None)
model_dropdown = gr.Dropdown(
choices=list(HF_MODELS.keys()),
value=HF_DEFAULT_MODEL,
label="Model",
)
model_dropdown.change(load_model, inputs=[model_dropdown])
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Image", height=400, format="png")
rotate_btn = gr.Button("Random Rotate", variant="secondary", size="lg")
with gr.Column():
corrected_image = gr.Image(label="Corrected Image", height=400, interactive=False, format="png")
correct_btn = gr.Button("Correct Orientation", variant="primary", size="lg")
with gr.Row():
rotation_info = gr.Textbox(label="Applied Rotation", lines=1, interactive=False)
result_text = gr.Textbox(label="Predicted Rotation", lines=1, interactive=False)
gr.Examples(
examples=[
"examples/COCO_val2014_000000168337.jpg",
"examples/COCO_val2014_000000122166.jpg",
"examples/COCO_val2014_000000446053.jpg",
"examples/COCO_val2014_000000477919.jpg",
"examples/COCO_val2014_000000016995.jpg",
],
inputs=input_image,
fn=store_original,
outputs=[original_image_state, rotation_info],
cache_examples=False,
run_on_click=True,
)
input_image.upload(
store_original,
inputs=[input_image],
outputs=[original_image_state, rotation_info],
)
rotate_btn.click(
random_rotate,
inputs=[original_image_state],
outputs=[input_image, actual_angle_state, rotation_info],
)
correct_btn.click(
correct_orientation,
inputs=[input_image],
outputs=[corrected_image, result_text],
)
app.load(lambda: load_model(HF_DEFAULT_MODEL))
app.launch()
|