tree-counter / app.py
Mawube's picture
Increase width of Images and add label description
5d00887 verified
import streamlit as st
import wandb
from utils import preprocess_image, draw_boxes_on_image
import logging
import torch
import tempfile
from PIL import Image
import os
import io
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize Weights & Biases
wandb.login(key="4462a61f98f7ea72e8dfa5f96ff9798cc1c08c60")
run = wandb.init(project="counting-trees", job_type="inference")
# Download the model
artifact = run.use_artifact('ahiamadzormaxwell7/counting-trees/model:v64', type='model')
artifact_dir = artifact.download()
logger.info(f"Model downloaded to directory: {artifact_dir}")
# Set page to wide mode
st.set_page_config(layout="wide")
# Streamlit UI
st.title("Detect Trees")
uploaded_file = st.file_uploader("Upload Aerial Image")
if uploaded_file is not None:
# Add captions
st.markdown("""
<style>
.big-font {
font-size:20px !important;
}
</style>
""", unsafe_allow_html=True)
st.markdown('<p class="big-font">1 - Palm Tree (displayed in <span style="color:red">red</span>)<br>2 - Tree (displayed in <span style="color:blue">blue</span>)</p>', unsafe_allow_html=True)
col1, col2 = st.columns(2)
# Create a temporary file
file_extension = os.path.splitext(uploaded_file.name)[1]
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
tmp_file.write(uploaded_file.getvalue())
tmp_file_path = tmp_file.name
# Open and display original image
image = Image.open(uploaded_file)
col1.image(image, caption="Original Image", width=600)
# Preprocess image
img = preprocess_image(tmp_file_path)
logger.info("Image preprocessed.")
# Load model and run prediction
model_path = os.path.join(artifact_dir, "best_model.pth")
model = torch.load(model_path, map_location=torch.device('cpu'))
model.eval()
with torch.no_grad():
predictions = model(img)
logger.info("Trees detection done.")
# Process predictions
preds = predictions[0]
# Draw boxes on image
image_bytes = draw_boxes_on_image(img[0], preds)
col2.image(image_bytes, caption="Detected Trees", width=600)
os.unlink(tmp_file_path)
wandb.finish()