jableable's picture
Update new_app.py
e85428d verified
import streamlit as st
import keras
import numpy as np
from PIL import Image
import io, urllib.request, os
st.set_page_config(layout="wide")
@st.cache_data(show_spinner=False, ttl=600)
def fetch_satellite_tile(lat, lng, zoom=16, size="640x640", api_key=""):
if not api_key:
raise RuntimeError("Missing Google Static Maps API key in env var 'goog_api'.")
url = (
f"https://maps.googleapis.com/maps/api/staticmap?"
f"center={lat},{lng}&zoom={zoom}&size={size}&maptype=satellite&key={api_key}"
)
buffer = io.BytesIO(urllib.request.urlopen(url, timeout=10).read())
return Image.open(buffer).convert("RGB")
@st.cache_resource
def get_model():
# compile=False skips optimizer/state rebuild, which saves time
return keras.models.load_model("0.0008-0.92.keras", compile=False)
st.markdown("""
<style>
.block-container {
padding-top: 2.2rem;
padding-bottom: 0rem;
padding-left: 5rem;
padding-right: 5rem;
}
hr {
margin-top: 0.0cm;
margin-bottom: 0.25cm;
}
[data-testid="stImage"] {
display: flex !important;
justify-content: center !important;
}
[data-testid="stImage"] img {
margin: auto;
display: block;
}
</style>
""", unsafe_allow_html=True)
#title
col1, col2 = st.columns(2)
with col1:
_, col = st.columns(2)
with col:
st.markdown(
"<h1 style='text-align: center; margin-bottom: 0.1em;'>Overpass Detector</h1>",
unsafe_allow_html=True
)
st.markdown(
"""
<div style="
text-align:center;
font-size:1.25rem;
margin-top:-0.3rem;
margin-bottom:0.6rem;
opacity:0.9;
">
This page runs a trained machine learning model to detect road overpasses in satellite imagery.
</div>
""",
unsafe_allow_html=True
)
with col2:
_, col, _ = st.columns(3)
with col:
st.image('overpass.png')
st.write("---")
#load model and initialize image size required by model. uploaded images are resized to indicated size
img_height = 640
img_width = 640
state = st.session_state
state.loaded_model = get_model()
if "lat" not in state:
state.lat = 39.11
if "lng" not in state:
state.lng = -86.56
if "coords_submitted" not in state:
state.coords_submitted = False
if "img" not in state:
state.img = None
# Preload default image once
if state.img is None:
try:
api_key = os.getenv("goog_api", "")
state.img = fetch_satellite_tile(state.lat, state.lng, api_key=api_key)
except Exception as e:
st.info(f"Couldn’t fetch default tile: {e}")
col1, col2, col3 = st.columns([0.8, 1.8, 1.0]) # adjust ratios to taste
with col3:
st.subheader('Enter latitude/longitude coordinates:')
with st.form("coords_form"):
c1, c2 = st.columns(2)
with c1:
st.number_input('Latitude', key="lat", min_value=-90.0, max_value=90.0, step=0.01, format="%.2f")
st.write('The current lat/long are:')
with c2:
st.number_input('Longitude', key="lng", min_value=-180.0, max_value=180.0, step=0.01, format="%.2f")
st.write(f"{st.session_state.lat:.2f}, {st.session_state.lng:.2f}")
submitted = st.form_submit_button("Get Image and Prediction")
if submitted:
try:
api_key = os.getenv("goog_api", "")
state.img = fetch_satellite_tile(st.session_state.lat, st.session_state.lng, api_key=api_key)
st.rerun() # ensures the updated lat/lng render immediately
except Exception as e:
st.error(f"Error fetching image: {e}")
with col2:
if state.coords_submitted:
state.coords_submitted = False
try:
api_key = os.getenv("goog_api", "")
state.img = fetch_satellite_tile(state.lat, state.lng, api_key=api_key)
except Exception as e:
st.error(f"Error fetching image: {e}")
if state.img is not None:
st.markdown(
"<div style='display:flex; justify-content:center;'>",
unsafe_allow_html=True
)
st.image(state.img, width=640)
st.markdown("</div>", unsafe_allow_html=True)
with col1:
if state.img is not None:
img_array = np.array(state.img)
batch_size = 1
img_array = np.reshape(img_array,[batch_size,img_height,img_width,3])
with st.spinner("Running inference..."):
result = state.loaded_model.predict(img_array)
crossing_chance = result[0][1]*100
pct = np.round(crossing_chance,decimals=2)
label = "Overpass Likely" if pct >= 50 else "Overpass Unlikely"
st.markdown(
f"""
<div style="text-align:center;">
<div style="font-size:2.2rem; font-weight:700; margin-bottom:0.4rem;">
{label}
</div>
<div style="font-size:1.35rem;">
Model probability: <b>{pct}%</b>
</div>
</div>
""",
unsafe_allow_html=True
)