| import streamlit as st |
| import torch |
| import requests |
| from io import BytesIO |
| from PIL import Image |
| from torchvision import transforms |
| from transformers import ResNetForImageClassification |
|
|
| |
| |
| st.set_page_config(page_title="GyroScope Rotation Corrector", layout="centered", page_icon="🔄") |
|
|
| |
| |
| @st.cache_resource |
| def load_model(): |
| model = ResNetForImageClassification.from_pretrained("LH-Tech-AI/GyroScope") |
| model.eval() |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
| return model, device |
|
|
| model, device = load_model() |
|
|
| |
| preprocess = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ]) |
|
|
| ANGLES = [0, 90, 180, 270] |
|
|
| def predict_and_correct(img): |
| |
| img = img.convert("RGB") |
| tensor = preprocess(img).unsqueeze(0).to(device) |
| |
| with torch.no_grad(): |
| logits = model(pixel_values=tensor).logits |
| probs = torch.softmax(logits, dim=1)[0] |
| pred = probs.argmax().item() |
|
|
| detected = ANGLES[pred] |
| correction = (360 - detected) % 360 |
| |
| |
| corrected_img = img.rotate(correction, expand=True) |
| |
| |
| prob_dict = {f"{a}°": f"{p:.4f}" for a, p in zip(ANGLES, probs)} |
| |
| return corrected_img, detected, correction, prob_dict |
|
|
| |
| st.title("🔄 Auto Rotation Corrector") |
| st.markdown("Upload an image or provide a URL to automatically fix its orientation.") |
|
|
| st.divider() |
|
|
| |
| input_method = st.radio("Select Image Source:", ["Upload a File", "Enter Image URL"], horizontal=True) |
|
|
| img = None |
|
|
| |
| if input_method == "Upload a File": |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
| if uploaded_file: |
| img = Image.open(uploaded_file) |
| else: |
| url = st.text_input("Enter Image URL:", placeholder="https://example.com/image.jpg") |
| if url: |
| try: |
| response = requests.get(url, timeout=5) |
| img = Image.open(BytesIO(response.content)) |
| except Exception as e: |
| st.error(f"Could not load image from URL. Error: {e}") |
|
|
| |
| if img: |
| st.divider() |
| |
| manual_angle = st.slider("Manual Pre-rotation", min_value=0, max_value=360, value=0, step=90) |
| if manual_angle != 0: |
| img = img.rotate(manual_angle, expand=True) |
| |
| |
| col_left, col_right = st.columns(2) |
| |
| with col_left: |
| st.subheader("Input Preview") |
| st.image(img, use_container_width=True) |
| |
| |
| process_btn = st.button("✨ Correct Rotation", type="primary", use_container_width=True) |
| |
| with col_right: |
| st.subheader("Output Preview") |
| |
| if process_btn: |
| with st.spinner("Analyzing..."): |
| corrected_img, detected, correction, prob_dict = predict_and_correct(img) |
| |
| |
| st.image(corrected_img, use_container_width=True) |
| |
| |
| st.success(f"✅ Detected: **{detected}°** | Correction: **{correction}°**") |
| |
| |
| with st.expander("📊 View Probability Details"): |
| st.json(prob_dict) |
| else: |
| |
| st.info("Waiting for processing... Click the button on the left to correct the rotation.") |