Zeel's picture
manual repo input
98722f4
#### Set environment
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
#### Imports
import io
import streamlit as st
import requests
from PIL import Image, ImageDraw
from ultralytics import YOLO
import supervision as sv
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
#### State
state = st.session_state
#### Functions
def get_static_map_image(lat_lon, api):
# Replace with your Google Maps API Key
base_url = "https://maps.googleapis.com/maps/api/staticmap"
params = {
"center": lat_lon,
"zoom": 16, # You can adjust the zoom level as per your requirement
"size": "640x640", # You can adjust the size of the image as per your requirement
"maptype": "satellite",
"key": api,
"scale": 2,
}
response = requests.get(base_url, params=params)
return response.content
@st.cache_resource(show_spinner=False)
def load_model(repo_id, filename):
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
model = YOLO(model_path, task="obb")
return model
#### Title
st.header("Brick Kiln Detection (Single Location)")
#### Set API Key
if "GMS_KEY" in state and state.GMS_KEY != "":
st.write("Using cached Google Maps API Key")
api_key = state.GMS_KEY
else:
api_key = st.text_input("Enter Google Maps API Key", type="password")
if api_key == "":
st.stop()
else:
state.GMS_KEY = api_key
st.rerun()
# get two columns
col1, col2 = st.columns(2)
repo = col1.text_input("HF Repo", "Vannsh/v8x-obb")
path = col2.text_input("Model Path", "obb3.pt")
#### Get Latitude, Longitude
lat_lon = st.text_input("Enter Latitude, Longitude (comma separated)", placeholder="28.00,77.00")
if lat_lon == "":
st.stop()
SAME_LOCATION = False
if "lat_lon" in state:
if state.lat_lon == lat_lon:
SAME_LOCATION = True
else:
state.lat_lon = lat_lon
else:
state.lat_lon = lat_lon
#### Get and load image
if SAME_LOCATION and "image" in state:
image = state.image
else:
image_data = get_static_map_image(lat_lon, api_key)
image = Image.open(io.BytesIO(image_data))
state.image = image
buffer = io.BytesIO()
image.save(buffer, format="PNG")
buffer.seek(0)
### Show raw image
st.image(image, caption=lat_lon, use_column_width=True)
### Predict
model = load_model(repo, path)
result = model(image)[0]
detections = sv.Detections.from_ultralytics(result)
# save image
buffer = io.BytesIO()
image.save(buffer, format="PNG")
buffer.seek(0)
img = plt.imread(buffer)
fig, ax = plt.subplots()
ax.axis("off")
ax.imshow(img)
for detection in detections:
xyxyxyxy = detection[-1]["xyxyxyxy"]
class_name = detection[-1]["class_name"]
x1, y1, x2, y2, x3, y3, x4, y4 = xyxyxyxy.ravel()
# draw the rectangle
color = "red" if class_name == "FCBK" else "blue"
ax.plot([x1, x2], [y1, y2], color=color, linewidth=1)
ax.plot([x2, x3], [y2, y3], color=color, linewidth=1)
ax.plot([x3, x4], [y3, y4], color=color, linewidth=1)
ax.plot([x4, x1], [y4, y1], color=color, linewidth=1)
buffer = io.BytesIO()
plt.savefig(buffer, format="PNG", bbox_inches="tight", pad_inches=0, dpi=300)
buffer.seek(0)
img = Image.open(buffer)
st.image(img, use_column_width=True)
# Mention colors
st.write("<span style='color:red'>Red: FCBK (Fixed Chimney Bull's Trench Kiln)</span>", unsafe_allow_html=True)
st.write("<span style='color:blue'>Blue: Zigzag Kiln</span>", unsafe_allow_html=True)