Spaces:
Sleeping
Sleeping
| """ | |
| Streamlit Web UI for the Diatom Classifier Pipeline. | |
| """ | |
| import os | |
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| from pathlib import Path | |
| from PIL import Image, ImageDraw, ImageFont | |
| from dotenv import load_dotenv | |
| # Load environment variables from .env file (if it exists) | |
| load_dotenv() | |
| # -- Setup Environment & Paths -- | |
| APP_ENV = os.environ.get("APP_ENV", "production") | |
| REPO_ID = "kemalbsoylu/diatom-models" | |
| BASE_DIR = Path(__file__).resolve().parent.parent | |
| st.set_page_config(page_title="Diatom AI", page_icon="🔬", layout="wide") | |
| def load_models(): | |
| from ultralytics import YOLO | |
| from fastai.vision.all import load_learner | |
| if APP_ENV == "development": | |
| # LOCAL: Use models stored on your hard drive | |
| yolo_path = BASE_DIR / "models" / "yolo_diatom_detector.pt" | |
| resnet_path = BASE_DIR / "models" / "v2_resnet18_weighted.pkl" | |
| else: | |
| # PRODUCTION: Download models from the Hugging Face Hub | |
| from huggingface_hub import hf_hub_download | |
| yolo_path = hf_hub_download(repo_id=REPO_ID, filename="yolo_diatom_detector.pt") | |
| resnet_path = hf_hub_download(repo_id=REPO_ID, filename="v2_resnet18_weighted.pkl") | |
| # Load them into memory | |
| yolo = YOLO(yolo_path) | |
| resnet = load_learner(resnet_path) | |
| return yolo, resnet | |
| st.title("🔬 Diatom Detection & Classification AI") | |
| st.markdown(""" | |
| Upload a microscope image. Use **Full Slide Analysis** to automatically detect and classify multiple diatoms, | |
| or use **Single Diatom Crop** if you already have a cropped image of a single diatom. | |
| """) | |
| with st.spinner("Loading AI Models into memory..."): | |
| yolo_model, resnet_model = load_models() | |
| if APP_ENV == "development": | |
| st.sidebar.success("🔧 Running in Development Mode (Local Models)") | |
| # -- Sidebar Controls -- | |
| st.sidebar.header("Configuration") | |
| app_mode = st.sidebar.radio("Select Analysis Mode:", ["Full Slide Analysis", "Single Diatom Crop"]) | |
| conf_threshold = 0.25 | |
| if app_mode == "Full Slide Analysis": | |
| conf_threshold = st.sidebar.slider("Detection Confidence", 0.1, 1.0, 0.25, 0.05) | |
| st.sidebar.markdown("*Lowering the threshold finds more diatoms but increases false positives. (Default: 0.25)*") | |
| # -- Sidebar Footer (Portfolio & License) -- | |
| st.sidebar.markdown("---") | |
| st.sidebar.markdown("### About") | |
| st.sidebar.markdown("Developed by **Kemal Soylu**") | |
| st.sidebar.markdown("[View Source Code on GitHub](https://github.com/kemalbsoylu/diatom-classifier)") | |
| st.sidebar.markdown(""" | |
| <small> | |
| <b>Licenses:</b> Code (MIT), Detector (AGPL-3.0).<br> | |
| <b>Data:</b> Trained on dataset by Gündüz et al. (CC BY-NC-SA 4.0). | |
| </small> | |
| """, unsafe_allow_html=True) | |
| # -- Main File Uploader -- | |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| # Silence PyCharm type warning | |
| assert not isinstance(uploaded_file, list) | |
| original_image = Image.open(uploaded_file).convert("RGB") | |
| st.markdown("---") | |
| st.subheader("Analysis Results") | |
| # ----------------------------------------------------- | |
| # MODE 1: SINGLE CROP CLASSIFICATION | |
| # ----------------------------------------------------- | |
| if app_mode == "Single Diatom Crop": | |
| with st.spinner("Classifying diatom..."): | |
| # Convert standard PIL Image directly to a numpy array | |
| img_array = np.array(original_image) | |
| # Predict directly on the array | |
| pred_class, pred_idx, probs = resnet_model.predict(img_array) | |
| conf = probs[pred_idx].item() * 100 | |
| st.success("Classification Complete!") | |
| st.metric(label="Predicted Genus", value=f"**{pred_class}**", delta=f"{conf:.2f}% Confidence", border=True) | |
| st.info("Note: This mode bypassed the automatic detector and evaluated the entire image as a single diatom. For best results, ensure your image is cropped tightly around the diatom with a maximum of 15% background margin.") | |
| st.markdown("---") | |
| st.subheader("Image Viewer") | |
| st.image(original_image, caption="Original Upload", use_container_width=False) | |
| # ----------------------------------------------------- | |
| # MODE 2: FULL SLIDE YOLO + RESNET | |
| # ----------------------------------------------------- | |
| else: | |
| display_image = original_image.copy() | |
| report_data = [] | |
| diatom_count = 0 | |
| with st.spinner("Scanning slide & Classifying..."): | |
| results = yolo_model(original_image, conf=conf_threshold, verbose=False)[0] | |
| draw = ImageDraw.Draw(display_image) | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 24) | |
| except IOError: | |
| font = ImageFont.load_default() | |
| for box in results.boxes.xyxy: | |
| diatom_count += 1 | |
| x1, y1, x2, y2 = map(int, box.tolist()) | |
| # Apply 15% margin for cropping | |
| box_w, box_h = x2 - x1, y2 - y1 | |
| margin_x, margin_y = int(box_w * 0.15), int(box_h * 0.15) | |
| crop_x1 = max(0, x1 - margin_x) | |
| crop_y1 = max(0, y1 - margin_y) | |
| crop_x2 = min(original_image.width, x2 + margin_x) | |
| crop_y2 = min(original_image.height, y2 + margin_y) | |
| cropped_img = original_image.crop((crop_x1, crop_y1, crop_x2, crop_y2)) | |
| # Convert cropped PIL Image to a numpy array | |
| img_array = np.array(cropped_img) | |
| # Classify directly with ResNet | |
| pred_class, pred_idx, probs = resnet_model.predict(img_array) | |
| conf = probs[pred_idx].item() * 100 | |
| # Draw bounding box on display image | |
| draw.rectangle([x1, y1, x2, y2], outline="red", width=3) | |
| # Format text with the ID matching the CSV report | |
| label_text = f"#{diatom_count} {pred_class} ({conf:.1f}%)" | |
| # Calculate text background size for readability | |
| left, top, right, bottom = font.getbbox(label_text) | |
| text_width = right - left | |
| text_height = bottom - top | |
| # Draw solid red background for text (positioned inside top-left of the box) | |
| draw.rectangle([x1, y1, x1 + text_width + 6, y1 + text_height + 6], fill="red") | |
| # Draw white text over the red background | |
| draw.text((x1 + 3, y1 + 3), label_text, fill="white", font=font, stroke_width=0.5, stroke_fill="white") | |
| # Save to report | |
| report_data.append({ | |
| "ID": diatom_count, | |
| "Genus": pred_class, | |
| "Confidence": f"{conf:.2f}%" | |
| }) | |
| # Render Full Slide Results | |
| if report_data: | |
| word = "diatom" if diatom_count == 1 else "diatoms" | |
| st.success(f"Successfully found {diatom_count} {word}!") | |
| df = pd.DataFrame(report_data) | |
| st.dataframe(df, use_container_width=True) | |
| csv = df.to_csv(index=False).encode('utf-8') | |
| st.download_button( | |
| label="📥 Download CSV Report", | |
| data=csv, | |
| file_name=f"analysis_{uploaded_file.name}.csv", | |
| mime="text/csv", | |
| ) | |
| else: | |
| st.warning("No diatoms detected. Try lowering the detection confidence threshold in the sidebar.") | |
| st.markdown("---") | |
| st.subheader("Image Viewer") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.image(original_image, caption="Original Upload", use_container_width=True) | |
| with col2: | |
| st.image(display_image, caption="Analyzed Image", use_container_width=True) | |
| # -- Footer AI Warning -- | |
| st.markdown("---") | |
| st.caption("**Disclaimer:** This application utilizes artificial intelligence and may produce inaccurate results. Always verify critical findings with a qualified domain expert.") | |