Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import wandb | |
| from utils import preprocess_image, draw_boxes_on_image | |
| import logging | |
| import torch | |
| import tempfile | |
| from PIL import Image | |
| import os | |
| import io | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize Weights & Biases | |
| wandb.login(key="4462a61f98f7ea72e8dfa5f96ff9798cc1c08c60") | |
| run = wandb.init(project="counting-trees", job_type="inference") | |
| # Download the model | |
| artifact = run.use_artifact('ahiamadzormaxwell7/counting-trees/model:v64', type='model') | |
| artifact_dir = artifact.download() | |
| logger.info(f"Model downloaded to directory: {artifact_dir}") | |
| # Set page to wide mode | |
| st.set_page_config(layout="wide") | |
| # Streamlit UI | |
| st.title("Detect Trees") | |
| uploaded_file = st.file_uploader("Upload Aerial Image") | |
| if uploaded_file is not None: | |
| # Add captions | |
| st.markdown(""" | |
| <style> | |
| .big-font { | |
| font-size:20px !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.markdown('<p class="big-font">1 - Palm Tree (displayed in <span style="color:red">red</span>)<br>2 - Tree (displayed in <span style="color:blue">blue</span>)</p>', unsafe_allow_html=True) | |
| col1, col2 = st.columns(2) | |
| # Create a temporary file | |
| file_extension = os.path.splitext(uploaded_file.name)[1] | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file: | |
| tmp_file.write(uploaded_file.getvalue()) | |
| tmp_file_path = tmp_file.name | |
| # Open and display original image | |
| image = Image.open(uploaded_file) | |
| col1.image(image, caption="Original Image", width=600) | |
| # Preprocess image | |
| img = preprocess_image(tmp_file_path) | |
| logger.info("Image preprocessed.") | |
| # Load model and run prediction | |
| model_path = os.path.join(artifact_dir, "best_model.pth") | |
| model = torch.load(model_path, map_location=torch.device('cpu')) | |
| model.eval() | |
| with torch.no_grad(): | |
| predictions = model(img) | |
| logger.info("Trees detection done.") | |
| # Process predictions | |
| preds = predictions[0] | |
| # Draw boxes on image | |
| image_bytes = draw_boxes_on_image(img[0], preds) | |
| col2.image(image_bytes, caption="Detected Trees", width=600) | |
| os.unlink(tmp_file_path) | |
| wandb.finish() |