afshin-dini's picture
Add a sample image to demo as default
cc8568c
"""This is a demo for running the egg segmentation and sizing using streamlit library"""
from dataclasses import dataclass, field
from pathlib import Path
import tempfile
import streamlit as st
import pandas as pd
from PIL import Image
from src.deep_package_detection.detector import PackageDetectorInference
@dataclass
class DemoPackageDetection:
"""Class for running the egg segmentation and sizing app using Streamlit."""
image: str = field(init=False)
def upload_image(self) -> None:
"""Upload an image from the streamlit page"""
uploaded_file = st.file_uploader(
"Upload an image or use the default one...", type=["jpg", "png", "jpeg"]
)
if uploaded_file is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file:
tmp_file.write(uploaded_file.getbuffer())
self.image = tmp_file.name
else:
self.image = "tests/test_data/5.jpg"
st.image(
Image.open(self.image),
caption="Original/Uploaded Image",
width="stretch",
)
def process_image(self) -> None:
"""Process the image for the egg segmentation and sizing"""
if st.button("Detect/Count Packages"):
inferer = PackageDetectorInference(
model_path=Path(
"./src/deep_package_detection/model/package_detection.pt"
),
)
segmentations = inferer.inference(data_path=self.image)
result_image = inferer.single_inference(segmentations)
if result_image is None:
return
st.markdown("<h3>Segmented Results</h3>", unsafe_allow_html=True)
st.image(result_image, caption="Detected Packages", width="stretch")
counts = inferer.count_packages(segmentations)
extracted_data = []
if counts:
for key, val in counts.items():
for detection in val:
extracted_data.append(
{
"Image": key,
"Type": detection["class"],
"Count": detection["count"],
}
)
extracted_data = pd.DataFrame(extracted_data).round(2)
st.markdown('<div class="center-container">', unsafe_allow_html=True)
st.markdown(
"<h3>Detailed Information of Detections</h3>", unsafe_allow_html=True
)
st.markdown(
"""
<style>
table {
width: 100%;
}
th, td {
text-align: center !important;
}
</style>
""",
unsafe_allow_html=True,
)
st.table(extracted_data)
st.markdown("</div>", unsafe_allow_html=True)
def design_page(self) -> None:
"""Design the streamlit page for package detection and counting"""
st.title("Package Detection and Counting")
self.upload_image()
self.process_image()
demo = DemoPackageDetection()
demo.design_page()