Create use_with_UI.py
Browse files- use_with_UI.py +113 -0
use_with_UI.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
import requests
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
from transformers import ResNetForImageClassification
|
| 8 |
+
|
| 9 |
+
# --- 1. UI Configuration ---
|
| 10 |
+
# 'centered' ensures the app doesn't stretch across massive screens
|
| 11 |
+
st.set_page_config(page_title="GyroScope Rotation Corrector", layout="centered", page_icon="🔄")
|
| 12 |
+
|
| 13 |
+
# --- 2. Model Caching ---
|
| 14 |
+
# @st.cache_resource prevents reloading the model every time the user interacts with the UI
|
| 15 |
+
@st.cache_resource
|
| 16 |
+
def load_model():
|
| 17 |
+
model = ResNetForImageClassification.from_pretrained("LH-Tech-AI/GyroScope")
|
| 18 |
+
model.eval()
|
| 19 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
+
model.to(device)
|
| 21 |
+
return model, device
|
| 22 |
+
|
| 23 |
+
model, device = load_model()
|
| 24 |
+
|
| 25 |
+
# --- 3. Preprocessing & Logic ---
|
| 26 |
+
preprocess = transforms.Compose([
|
| 27 |
+
transforms.Resize(256),
|
| 28 |
+
transforms.CenterCrop(224),
|
| 29 |
+
transforms.ToTensor(),
|
| 30 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 31 |
+
])
|
| 32 |
+
|
| 33 |
+
ANGLES = [0, 90, 180, 270]
|
| 34 |
+
|
| 35 |
+
def predict_and_correct(img):
|
| 36 |
+
# Ensure image is RGB
|
| 37 |
+
img = img.convert("RGB")
|
| 38 |
+
tensor = preprocess(img).unsqueeze(0).to(device)
|
| 39 |
+
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
logits = model(pixel_values=tensor).logits
|
| 42 |
+
probs = torch.softmax(logits, dim=1)[0]
|
| 43 |
+
pred = probs.argmax().item()
|
| 44 |
+
|
| 45 |
+
detected = ANGLES[pred]
|
| 46 |
+
correction = (360 - detected) % 360
|
| 47 |
+
|
| 48 |
+
# Apply correction (PIL rotate is counter-clockwise)
|
| 49 |
+
corrected_img = img.rotate(correction, expand=True)
|
| 50 |
+
|
| 51 |
+
# Format probabilities for the UI
|
| 52 |
+
prob_dict = {f"{a}°": f"{p:.4f}" for a, p in zip(ANGLES, probs)}
|
| 53 |
+
|
| 54 |
+
return corrected_img, detected, correction, prob_dict
|
| 55 |
+
|
| 56 |
+
# --- 4. Frontend Layout ---
|
| 57 |
+
st.title("🔄 Auto Rotation Corrector")
|
| 58 |
+
st.markdown("Upload an image or provide a URL to automatically fix its orientation.")
|
| 59 |
+
|
| 60 |
+
st.divider()
|
| 61 |
+
|
| 62 |
+
# Input Selection
|
| 63 |
+
input_method = st.radio("Select Image Source:", ["Upload a File", "Enter Image URL"], horizontal=True)
|
| 64 |
+
|
| 65 |
+
img = None
|
| 66 |
+
|
| 67 |
+
# Input Handling
|
| 68 |
+
if input_method == "Upload a File":
|
| 69 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
| 70 |
+
if uploaded_file:
|
| 71 |
+
img = Image.open(uploaded_file)
|
| 72 |
+
else:
|
| 73 |
+
url = st.text_input("Enter Image URL:", placeholder="https://example.com/image.jpg")
|
| 74 |
+
if url:
|
| 75 |
+
try:
|
| 76 |
+
response = requests.get(url, timeout=5)
|
| 77 |
+
img = Image.open(BytesIO(response.content))
|
| 78 |
+
except Exception as e:
|
| 79 |
+
st.error(f"Could not load image from URL. Error: {e}")
|
| 80 |
+
|
| 81 |
+
# Preview & Processing Section
|
| 82 |
+
if img:
|
| 83 |
+
st.divider()
|
| 84 |
+
|
| 85 |
+
# Use columns to keep the UI compact and side-by-side
|
| 86 |
+
col_left, col_right = st.columns(2)
|
| 87 |
+
|
| 88 |
+
with col_left:
|
| 89 |
+
st.subheader("Input Preview")
|
| 90 |
+
st.image(img, use_container_width=True)
|
| 91 |
+
|
| 92 |
+
# The primary action button
|
| 93 |
+
process_btn = st.button("✨ Correct Rotation", type="primary", use_container_width=True)
|
| 94 |
+
|
| 95 |
+
with col_right:
|
| 96 |
+
st.subheader("Output Preview")
|
| 97 |
+
|
| 98 |
+
if process_btn:
|
| 99 |
+
with st.spinner("Analyzing..."):
|
| 100 |
+
corrected_img, detected, correction, prob_dict = predict_and_correct(img)
|
| 101 |
+
|
| 102 |
+
# Show result
|
| 103 |
+
st.image(corrected_img, use_container_width=True)
|
| 104 |
+
|
| 105 |
+
# Show stats
|
| 106 |
+
st.success(f"✅ Detected: **{detected}°** | Correction: **{correction}°**")
|
| 107 |
+
|
| 108 |
+
# Hidden expander for clean UI, but available if the user wants details
|
| 109 |
+
with st.expander("📊 View Probability Details"):
|
| 110 |
+
st.json(prob_dict)
|
| 111 |
+
else:
|
| 112 |
+
# Placeholder container before the button is clicked
|
| 113 |
+
st.info("Waiting for processing... Click the button on the left to correct the rotation.")
|