import streamlit as st import wandb from utils import preprocess_image, draw_boxes_on_image import logging import torch import tempfile from PIL import Image import os import io # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize Weights & Biases wandb.login(key="4462a61f98f7ea72e8dfa5f96ff9798cc1c08c60") run = wandb.init(project="counting-trees", job_type="inference") # Download the model artifact = run.use_artifact('ahiamadzormaxwell7/counting-trees/model:v64', type='model') artifact_dir = artifact.download() logger.info(f"Model downloaded to directory: {artifact_dir}") # Set page to wide mode st.set_page_config(layout="wide") # Streamlit UI st.title("Detect Trees") uploaded_file = st.file_uploader("Upload Aerial Image") if uploaded_file is not None: # Add captions st.markdown(""" """, unsafe_allow_html=True) st.markdown('

1 - Palm Tree (displayed in red)
2 - Tree (displayed in blue)

', unsafe_allow_html=True) col1, col2 = st.columns(2) # Create a temporary file file_extension = os.path.splitext(uploaded_file.name)[1] with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file: tmp_file.write(uploaded_file.getvalue()) tmp_file_path = tmp_file.name # Open and display original image image = Image.open(uploaded_file) col1.image(image, caption="Original Image", width=600) # Preprocess image img = preprocess_image(tmp_file_path) logger.info("Image preprocessed.") # Load model and run prediction model_path = os.path.join(artifact_dir, "best_model.pth") model = torch.load(model_path, map_location=torch.device('cpu')) model.eval() with torch.no_grad(): predictions = model(img) logger.info("Trees detection done.") # Process predictions preds = predictions[0] # Draw boxes on image image_bytes = draw_boxes_on_image(img[0], preds) col2.image(image_bytes, caption="Detected Trees", width=600) os.unlink(tmp_file_path) wandb.finish()