"""
Streamlit Application for Automated Tablet Defect Detection
"""
import streamlit as st
import torch
import numpy as np
from PIL import Image
import sys
from pathlib import Path
import io
# Add parent directory to path
sys.path.append(str(Path(__file__).parent.parent))
import config
from src.feature_extractor import FeatureExtractor, extract_embeddings
from src.padim import PaDiM
from src.visualize import apply_heatmap
@st.cache_resource
def load_model():
"""Load PaDiM model and feature extractor (cached)"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load PaDiM model
model_path = config.MODEL_DIR / "padim_model.pkl"
if not model_path.exists():
st.error("â Model file not found. Please train the model first.")
st.info("To train the model, run: `python train.py` in your terminal")
st.stop()
padim_model = PaDiM()
padim_model.load(model_path)
# Load feature extractor
extractor = FeatureExtractor(
backbone=config.BACKBONE,
layers=config.FEATURE_LAYERS
).to(device)
return padim_model, extractor, device
def preprocess_image(image: Image.Image) -> torch.Tensor:
"""Preprocess uploaded image"""
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(config.IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=config.MEAN, std=config.STD)
])
return transform(image).unsqueeze(0) # Add batch dimension
def predict_defect(image: Image.Image, padim_model, extractor, device):
"""Run inference on uploaded image"""
# Preprocess
img_tensor = preprocess_image(image).to(device)
# Extract embeddings
with torch.no_grad():
embeddings = extract_embeddings(extractor, img_tensor)
# Predict
embeddings_np = embeddings.cpu().numpy()
anomaly_score, anomaly_map = padim_model.predict(embeddings_np)
return anomaly_score, anomaly_map
def main():
"""Main Streamlit app"""
# Page configuration
st.set_page_config(
page_title="Tablet Defect Detection",
page_icon="đ",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS
st.markdown("""
""", unsafe_allow_html=True)
# Header
st.markdown('
đ Automated Tablet Defect Detection
',
unsafe_allow_html=True)
st.markdown('Unsupervised Computer Vision Quality Inspection System
',
unsafe_allow_html=True)
# Sidebar
with st.sidebar:
st.image("https://img.icons8.com/fluency/96/pill.png", width=80)
st.title("âī¸ Settings")
threshold = st.slider(
"Anomaly Threshold",
min_value=0.0,
max_value=30.0,
value=15.0,
step=0.5,
help="Adjust sensitivity: lower = more sensitive to defects (typical range: 10-20)"
)
show_heatmap = st.checkbox("Show Anomaly Heatmap", value=True)
heatmap_alpha = st.slider("Heatmap Opacity", 0.0, 1.0, 0.4, 0.05)
st.divider()
st.subheader("đ Model Info")
st.markdown(f"""
- **Method:** PaDiM
- **Backbone:** ResNet-18
- **Layers:** {', '.join(config.FEATURE_LAYERS)}
- **Device:** {'GPU' if torch.cuda.is_available() else 'CPU'}
""")
st.divider()
st.subheader("âšī¸ About")
st.markdown("""
This system uses **PaDiM** (Patch Distribution Modeling) for
unsupervised anomaly detection in pharmaceutical tablets.
**Features:**
- â
Image-level defect classification
- đ¯ Pixel-level defect localization
- đ Anomaly score quantification
- đ CPU-friendly inference
""")
st.divider()
st.warning("â ī¸ **Model Limitation:** This model is trained specifically on the Actavis 500mg capsule dataset. It will NOT work accurately on other tablet/capsule types, shapes, or colors.")
# Load model
with st.spinner("Loading model..."):
padim_model, extractor, device = load_model()
# Main content
st.divider()
# File uploader
uploaded_file = st.file_uploader(
"Upload a tablet image for inspection",
type=["png", "jpg", "jpeg"],
help="Supported formats: PNG, JPG, JPEG"
)
# Demo images section
col1, col2 = st.columns([3, 1])
with col2:
use_demo = st.button("đ˛ Try Demo Image")
if use_demo:
# Load a random test image from specific defect types only
demo_categories = ["squeeze", "poke", "crack"]
demo_category = np.random.choice(demo_categories)
demo_dir = config.TEST_DIR / demo_category
if demo_dir.exists():
demo_images = list(demo_dir.glob("*.png"))
if demo_images:
demo_path = np.random.choice(demo_images)
uploaded_file = demo_path
else:
st.error(f"Demo category '{demo_category}' not found.")
if uploaded_file is not None:
# Load image
if isinstance(uploaded_file, Path):
image = Image.open(uploaded_file).convert("RGB")
else:
image = Image.open(uploaded_file).convert("RGB")
# Display original image
st.subheader("đ¸ Uploaded Image")
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
st.image(image, use_column_width=True)
# Run inference
with st.spinner("đ Analyzing image..."):
anomaly_score, anomaly_map = predict_defect(
image, padim_model, extractor, device
)
# Display results
st.divider()
st.subheader("đ¯ Inspection Results")
# Prediction
is_defective = anomaly_score > threshold
if is_defective:
st.markdown(f"""
â ī¸ DEFECTIVE TABLET DETECTED
""", unsafe_allow_html=True)
else:
st.markdown(f"""
â
NORMAL TABLET (No Defects)
""", unsafe_allow_html=True)
# Metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric(
label="Anomaly Score",
value=f"{anomaly_score:.4f}",
delta="Defect" if is_defective else "Normal",
delta_color="inverse"
)
with col2:
st.metric(
label="Threshold",
value=f"{threshold:.3f}",
delta=f"{(anomaly_score/threshold - 1)*100:+.1f}%" if threshold > 0 else "N/A"
)
with col3:
confidence = abs(anomaly_score - threshold) / threshold if threshold > 0 else 0
st.metric(
label="Confidence",
value=f"{min(confidence * 100, 100):.1f}%"
)
# Heatmap visualization
if show_heatmap:
st.divider()
st.subheader("đĨ Anomaly Heatmap")
st.markdown("*Highlighted regions indicate potential defects*")
# Create heatmap overlay
img_np = np.array(image)
heatmap_overlay = apply_heatmap(
img_np,
anomaly_map,
alpha=heatmap_alpha,
colormap=config.HEATMAP_COLORMAP
)
# Display side by side
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Original", use_column_width=True)
with col2:
st.image(heatmap_overlay, caption="Defect Localization",
use_column_width=True)
# Download results
st.divider()
if st.button("đž Download Results"):
# Create annotated image
img_np = np.array(image)
result_img = apply_heatmap(img_np, anomaly_map, alpha=heatmap_alpha)
# Add text annotation
import cv2
prediction_text = "DEFECTIVE" if is_defective else "NORMAL"
color = (255, 0, 0) if is_defective else (0, 255, 0)
cv2.putText(result_img, f"{prediction_text} ({anomaly_score:.3f})",
(10, 30), cv2.FONT_HERSHEY_SIMPLEX,
1, color, 2, cv2.LINE_AA)
# Convert to bytes
result_pil = Image.fromarray(result_img)
buf = io.BytesIO()
result_pil.save(buf, format="PNG")
st.download_button(
label="âŦī¸ Download Annotated Image",
data=buf.getvalue(),
file_name="defect_detection_result.png",
mime="image/png"
)
else:
# Instructions when no image uploaded
st.info("đ Please upload an image or click 'Try Demo Image' to start inspection.")
# Example gallery
st.divider()
st.subheader("đ Example Defect Types")
cols = st.columns(3)
defect_examples = {
"Squeeze": config.TEST_DIR / "squeeze",
"Poke": config.TEST_DIR / "poke",
"Crack": config.TEST_DIR / "crack"
}
for idx, (defect_name, defect_dir) in enumerate(defect_examples.items()):
if defect_dir.exists():
images = list(defect_dir.glob("*.png"))
if images:
with cols[idx % 3]:
example_img = Image.open(images[0])
st.image(example_img, caption=defect_name, use_column_width=True)
if __name__ == "__main__":
main()