File size: 4,061 Bytes
e95a6c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97805b4
 
 
 
e95a6c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import streamlit as st
import torch
import requests
from io import BytesIO
from PIL import Image
from torchvision import transforms
from transformers import ResNetForImageClassification

# --- 1. UI Configuration ---
# 'centered' ensures the app doesn't stretch across massive screens
st.set_page_config(page_title="GyroScope Rotation Corrector", layout="centered", page_icon="🔄")

# --- 2. Model Caching ---
# @st.cache_resource prevents reloading the model every time the user interacts with the UI
@st.cache_resource
def load_model():
    model = ResNetForImageClassification.from_pretrained("LH-Tech-AI/GyroScope")
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    return model, device

model, device = load_model()

# --- 3. Preprocessing & Logic ---
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

ANGLES = [0, 90, 180, 270]

def predict_and_correct(img):
    # Ensure image is RGB
    img = img.convert("RGB")
    tensor = preprocess(img).unsqueeze(0).to(device)
    
    with torch.no_grad():
        logits = model(pixel_values=tensor).logits
        probs = torch.softmax(logits, dim=1)[0]
        pred = probs.argmax().item()

    detected = ANGLES[pred]
    correction = (360 - detected) % 360
    
    # Apply correction (PIL rotate is counter-clockwise)
    corrected_img = img.rotate(correction, expand=True)
    
    # Format probabilities for the UI
    prob_dict = {f"{a}°": f"{p:.4f}" for a, p in zip(ANGLES, probs)}
    
    return corrected_img, detected, correction, prob_dict

# --- 4. Frontend Layout ---
st.title("🔄 Auto Rotation Corrector")
st.markdown("Upload an image or provide a URL to automatically fix its orientation.")

st.divider()

# Input Selection
input_method = st.radio("Select Image Source:", ["Upload a File", "Enter Image URL"], horizontal=True)

img = None

# Input Handling
if input_method == "Upload a File":
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
    if uploaded_file:
        img = Image.open(uploaded_file)
else:
    url = st.text_input("Enter Image URL:", placeholder="https://example.com/image.jpg")
    if url:
        try:
            response = requests.get(url, timeout=5)
            img = Image.open(BytesIO(response.content))
        except Exception as e:
            st.error(f"Could not load image from URL. Error: {e}")

# Preview & Processing Section
if img:
    st.divider()
    
    manual_angle = st.slider("Manual Pre-rotation", min_value=0, max_value=360, value=0, step=90)
    if manual_angle != 0:
        img = img.rotate(manual_angle, expand=True) # expand=True prevents cropping
    
    # Use columns to keep the UI compact and side-by-side
    col_left, col_right = st.columns(2)
    
    with col_left:
        st.subheader("Input Preview")
        st.image(img, use_container_width=True)
        
        # The primary action button
        process_btn = st.button("✨ Correct Rotation", type="primary", use_container_width=True)
        
    with col_right:
        st.subheader("Output Preview")
        
        if process_btn:
            with st.spinner("Analyzing..."):
                corrected_img, detected, correction, prob_dict = predict_and_correct(img)
                
                # Show result
                st.image(corrected_img, use_container_width=True)
                
                # Show stats
                st.success(f"✅ Detected: **{detected}°** | Correction: **{correction}°**")
                
                # Hidden expander for clean UI, but available if the user wants details
                with st.expander("📊 View Probability Details"):
                    st.json(prob_dict)
        else:
            # Placeholder container before the button is clicked
            st.info("Waiting for processing... Click the button on the left to correct the rotation.")