GyroScope / use_with_UI.py
LH-Tech-AI's picture
Update use_with_UI.py
97805b4 verified
import streamlit as st
import torch
import requests
from io import BytesIO
from PIL import Image
from torchvision import transforms
from transformers import ResNetForImageClassification
# --- 1. UI Configuration ---
# 'centered' ensures the app doesn't stretch across massive screens
st.set_page_config(page_title="GyroScope Rotation Corrector", layout="centered", page_icon="🔄")
# --- 2. Model Caching ---
# @st.cache_resource prevents reloading the model every time the user interacts with the UI
@st.cache_resource
def load_model():
model = ResNetForImageClassification.from_pretrained("LH-Tech-AI/GyroScope")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, device
model, device = load_model()
# --- 3. Preprocessing & Logic ---
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
ANGLES = [0, 90, 180, 270]
def predict_and_correct(img):
# Ensure image is RGB
img = img.convert("RGB")
tensor = preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(pixel_values=tensor).logits
probs = torch.softmax(logits, dim=1)[0]
pred = probs.argmax().item()
detected = ANGLES[pred]
correction = (360 - detected) % 360
# Apply correction (PIL rotate is counter-clockwise)
corrected_img = img.rotate(correction, expand=True)
# Format probabilities for the UI
prob_dict = {f"{a}°": f"{p:.4f}" for a, p in zip(ANGLES, probs)}
return corrected_img, detected, correction, prob_dict
# --- 4. Frontend Layout ---
st.title("🔄 Auto Rotation Corrector")
st.markdown("Upload an image or provide a URL to automatically fix its orientation.")
st.divider()
# Input Selection
input_method = st.radio("Select Image Source:", ["Upload a File", "Enter Image URL"], horizontal=True)
img = None
# Input Handling
if input_method == "Upload a File":
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file:
img = Image.open(uploaded_file)
else:
url = st.text_input("Enter Image URL:", placeholder="https://example.com/image.jpg")
if url:
try:
response = requests.get(url, timeout=5)
img = Image.open(BytesIO(response.content))
except Exception as e:
st.error(f"Could not load image from URL. Error: {e}")
# Preview & Processing Section
if img:
st.divider()
manual_angle = st.slider("Manual Pre-rotation", min_value=0, max_value=360, value=0, step=90)
if manual_angle != 0:
img = img.rotate(manual_angle, expand=True) # expand=True prevents cropping
# Use columns to keep the UI compact and side-by-side
col_left, col_right = st.columns(2)
with col_left:
st.subheader("Input Preview")
st.image(img, use_container_width=True)
# The primary action button
process_btn = st.button("✨ Correct Rotation", type="primary", use_container_width=True)
with col_right:
st.subheader("Output Preview")
if process_btn:
with st.spinner("Analyzing..."):
corrected_img, detected, correction, prob_dict = predict_and_correct(img)
# Show result
st.image(corrected_img, use_container_width=True)
# Show stats
st.success(f"✅ Detected: **{detected}°** | Correction: **{correction}°**")
# Hidden expander for clean UI, but available if the user wants details
with st.expander("📊 View Probability Details"):
st.json(prob_dict)
else:
# Placeholder container before the button is clicked
st.info("Waiting for processing... Click the button on the left to correct the rotation.")