| |
| import os |
|
|
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
|
|
| |
| import io |
| import streamlit as st |
| import requests |
| from PIL import Image, ImageDraw |
| from ultralytics import YOLO |
| import supervision as sv |
| from huggingface_hub import hf_hub_download |
| import matplotlib.pyplot as plt |
|
|
| |
| state = st.session_state |
|
|
|
|
| |
| def get_static_map_image(lat_lon, api): |
| |
| base_url = "https://maps.googleapis.com/maps/api/staticmap" |
| params = { |
| "center": lat_lon, |
| "zoom": 16, |
| "size": "640x640", |
| "maptype": "satellite", |
| "key": api, |
| "scale": 2, |
| } |
| response = requests.get(base_url, params=params) |
| return response.content |
|
|
|
|
| @st.cache_resource(show_spinner=False) |
| def load_model(repo_id, filename): |
| model_path = hf_hub_download(repo_id=repo_id, filename=filename) |
| model = YOLO(model_path, task="obb") |
| return model |
|
|
|
|
| |
| st.header("Brick Kiln Detection (Single Location)") |
|
|
|
|
| |
| if "GMS_KEY" in state and state.GMS_KEY != "": |
| st.write("Using cached Google Maps API Key") |
| api_key = state.GMS_KEY |
| else: |
| api_key = st.text_input("Enter Google Maps API Key", type="password") |
| if api_key == "": |
| st.stop() |
| else: |
| state.GMS_KEY = api_key |
| st.rerun() |
|
|
| |
| col1, col2 = st.columns(2) |
| repo = col1.text_input("HF Repo", "Vannsh/v8x-obb") |
| path = col2.text_input("Model Path", "obb3.pt") |
|
|
| |
| lat_lon = st.text_input("Enter Latitude, Longitude (comma separated)", placeholder="28.00,77.00") |
| if lat_lon == "": |
| st.stop() |
|
|
| SAME_LOCATION = False |
| if "lat_lon" in state: |
| if state.lat_lon == lat_lon: |
| SAME_LOCATION = True |
| else: |
| state.lat_lon = lat_lon |
| else: |
| state.lat_lon = lat_lon |
|
|
| |
| if SAME_LOCATION and "image" in state: |
| image = state.image |
| else: |
| image_data = get_static_map_image(lat_lon, api_key) |
| image = Image.open(io.BytesIO(image_data)) |
| state.image = image |
|
|
| buffer = io.BytesIO() |
| image.save(buffer, format="PNG") |
| buffer.seek(0) |
|
|
| |
| st.image(image, caption=lat_lon, use_column_width=True) |
|
|
| |
| model = load_model(repo, path) |
| result = model(image)[0] |
| detections = sv.Detections.from_ultralytics(result) |
|
|
| |
| buffer = io.BytesIO() |
| image.save(buffer, format="PNG") |
| buffer.seek(0) |
|
|
| img = plt.imread(buffer) |
| fig, ax = plt.subplots() |
| ax.axis("off") |
| ax.imshow(img) |
| for detection in detections: |
| xyxyxyxy = detection[-1]["xyxyxyxy"] |
| class_name = detection[-1]["class_name"] |
|
|
| x1, y1, x2, y2, x3, y3, x4, y4 = xyxyxyxy.ravel() |
| |
| color = "red" if class_name == "FCBK" else "blue" |
| ax.plot([x1, x2], [y1, y2], color=color, linewidth=1) |
| ax.plot([x2, x3], [y2, y3], color=color, linewidth=1) |
| ax.plot([x3, x4], [y3, y4], color=color, linewidth=1) |
| ax.plot([x4, x1], [y4, y1], color=color, linewidth=1) |
|
|
|
|
| buffer = io.BytesIO() |
| plt.savefig(buffer, format="PNG", bbox_inches="tight", pad_inches=0, dpi=300) |
| buffer.seek(0) |
|
|
| img = Image.open(buffer) |
| st.image(img, use_column_width=True) |
|
|
| |
| st.write("<span style='color:red'>Red: FCBK (Fixed Chimney Bull's Trench Kiln)</span>", unsafe_allow_html=True) |
| st.write("<span style='color:blue'>Blue: Zigzag Kiln</span>", unsafe_allow_html=True) |
|
|