maxwoe's picture
Deploy inference-only Gradio demo
c2f46fa
"""HuggingFace Spaces demo for Image Rotation Angle Estimation.
Two-step interactive demo:
1. Upload an image and click "Random Rotate" to apply a random rotation
2. Click "Correct Orientation" to see the model predict and correct the angle
"""
import json
import gradio as gr
import torch
from PIL import Image
import os
import random
import numpy as np
from loguru import logger
from huggingface_hub import hf_hub_download
from model_cgd import CGDAngleEstimation
from architectures import get_default_input_size
from rotation_utils import rotate_image_crop_max_area
# HuggingFace Hub configuration
HF_REPO_ID = os.environ.get("HF_MODEL_REPO", "maxwoe/image-rotation-angle-estimation")
# Fetch config.json from model repo (this is the HF download-tracking query file)
_config_path = hf_hub_download(repo_id=HF_REPO_ID, filename="config.json")
with open(_config_path) as _f:
_config = json.load(_f)
HF_MODELS = _config["models"]
HF_DEFAULT_MODEL = _config["default_model"]
# Global model state
model = None
current_model_name = None
def get_device():
return "cuda:0" if torch.cuda.is_available() else "cpu"
def load_model(name):
"""Download and load a model from HuggingFace Hub."""
global model, current_model_name
if name == current_model_name and model is not None:
return gr.Info(f"Model already loaded: {name}")
if name not in HF_MODELS:
return gr.Warning(f"Unknown model: {name}")
info = HF_MODELS[name]
logger.info(f"Downloading {info['filename']} from {HF_REPO_ID}...")
local_path = hf_hub_download(repo_id=HF_REPO_ID, filename=info["filename"])
architecture = info["architecture"]
image_size = get_default_input_size(architecture)
logger.info(f"Loading model from {local_path}...")
new_model = CGDAngleEstimation.try_load(checkpoint_path=local_path, image_size=image_size)
new_model.eval()
device = get_device()
if device.startswith("cuda"):
new_model = new_model.to(device)
model = new_model
current_model_name = name
logger.info(f"Model loaded: {name} on {device}")
return gr.Info(f"Loaded: {name} ({device})")
def store_original(image):
"""Store the uploaded image as the original for rotation."""
if image is None:
return None, ""
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
return image, ""
def random_rotate(original):
"""Apply a random rotation to the original uploaded image."""
if original is None:
return None, None, ""
angle = random.uniform(0, 360)
img_array = np.array(original)
rotated_array = rotate_image_crop_max_area(img_array, angle)
rotated = Image.fromarray(rotated_array)
return rotated, angle, f"{angle:.1f}°"
def correct_orientation(image):
"""Predict the rotation angle and correct the image."""
if image is None:
return None, "Please upload and rotate an image first."
if model is None:
return None, "Model is still loading, please wait..."
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
predicted_angle = model.predict_angle(image)
corrected = image.rotate(-predicted_angle, expand=True, fillcolor=(255, 255, 255))
return corrected, f"Predicted rotation: {predicted_angle:.2f}°"
# Build UI
app = gr.Blocks(title="Image Rotation Angle Estimation")
with app:
gr.HTML("<h1>Image Rotation Angle Estimation</h1>")
gr.Markdown(
"Upload an image, apply a random rotation, and see the model predict and correct the angle.\n\n"
"Uses the **CGD** (Circular Gaussian Distribution) method with **MambaOut Base** architecture. "
)
original_image_state = gr.State(value=None)
actual_angle_state = gr.State(value=None)
model_dropdown = gr.Dropdown(
choices=list(HF_MODELS.keys()),
value=HF_DEFAULT_MODEL,
label="Model",
)
model_dropdown.change(load_model, inputs=[model_dropdown])
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Image", height=400, format="png")
rotate_btn = gr.Button("Random Rotate", variant="secondary", size="lg")
with gr.Column():
corrected_image = gr.Image(label="Corrected Image", height=400, interactive=False, format="png")
correct_btn = gr.Button("Correct Orientation", variant="primary", size="lg")
with gr.Row():
rotation_info = gr.Textbox(label="Applied Rotation", lines=1, interactive=False)
result_text = gr.Textbox(label="Predicted Rotation", lines=1, interactive=False)
gr.Examples(
examples=[
"examples/COCO_val2014_000000168337.jpg",
"examples/COCO_val2014_000000122166.jpg",
"examples/COCO_val2014_000000446053.jpg",
"examples/COCO_val2014_000000477919.jpg",
"examples/COCO_val2014_000000016995.jpg",
],
inputs=input_image,
fn=store_original,
outputs=[original_image_state, rotation_info],
cache_examples=False,
run_on_click=True,
)
input_image.upload(
store_original,
inputs=[input_image],
outputs=[original_image_state, rotation_info],
)
rotate_btn.click(
random_rotate,
inputs=[original_image_state],
outputs=[input_image, actual_angle_state, rotation_info],
)
correct_btn.click(
correct_orientation,
inputs=[input_image],
outputs=[corrected_image, result_text],
)
app.load(lambda: load_model(HF_DEFAULT_MODEL))
app.launch()