nithishbasireddy's picture
Update app: auto-find model path, add spinner, improve landing page
03967a3 verified
"""
EL Defect Detection β€” Streamlit App (Production)
Runs with trained U-Net++ model. No mock inference.
Fixed grid detection: single cells stay single, full modules are properly segmented.
Usage:
streamlit run app.py
"""
import sys
import os
import json
import numpy as np
import cv2
import torch
import torch.nn.functional as F
from PIL import Image
from io import BytesIO
from pathlib import Path
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict
import streamlit as st
import segmentation_models_pytorch as smp
from scipy.signal import find_peaks
from scipy.ndimage import distance_transform_edt
try:
from skimage.morphology import skeletonize
from skimage.measure import label as sk_label, regionprops
SKIMAGE_OK = True
except ImportError:
SKIMAGE_OK = False
# ═══════════════════════════════════════════════════════════════
# LABEL REMAP (must match training)
# ═══════════════════════════════════════════════════════════════
LABEL_REMAP = np.zeros(30, dtype=np.uint8)
LABEL_REMAP[9] = 1 # busbars
LABEL_REMAP[10] = 2 # crack_rbn_edge
LABEL_REMAP[14] = 2 # crack
LABEL_REMAP[11] = 3 # inactive
LABEL_REMAP[17] = 3 # dead_cell
LABEL_REMAP[20] = 3 # edge_dark
for lbl in [12, 13, 15, 16, 18, 19, 25, 26, 27, 28]:
LABEL_REMAP[lbl] = 4 # other_defect
CLASS_NAMES = ["background", "busbar", "crack", "dark", "other_defect"]
CLASS_COLORS_RGB = {
"background": (0, 0, 0),
"busbar": (0, 200, 0), # Green
"crack": (0, 100, 255), # Blue
"dark": (255, 50, 50), # Red
"other_defect": (255, 200, 0), # Yellow
}
# ═══════════════════════════════════════════════════════════════
# FIND MODEL β€” check multiple locations
# ═══════════════════════════════════════════════════════════════
def find_model_path():
"""Search for best_model.pth in common locations."""
candidates = [
"best_model.pth", # same dir as Dockerfile WORKDIR /app
"/app/best_model.pth", # absolute
"output/best_model.pth", # training output dir
"/app/output/best_model.pth",
os.path.join(os.path.dirname(__file__), "..", "..", "best_model.pth"),
]
for p in candidates:
if os.path.exists(p):
return p
return "best_model.pth" # default
# ═══════════════════════════════════════════════════════════════
# MODEL LOADING
# ═══════════════════════════════════════════════════════════════
@st.cache_resource
def load_model(model_path: str):
"""Load trained model. Returns (model, device, metadata)."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not os.path.exists(model_path):
return None, device, {}
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
arch = checkpoint.get("architecture", "UnetPlusPlus")
encoder = checkpoint.get("encoder", "efficientnet-b4")
num_classes = checkpoint.get("num_classes", 5)
ModelClass = getattr(smp, arch)
model = ModelClass(
encoder_name=encoder,
encoder_weights=None,
in_channels=1,
classes=num_classes,
decoder_attention_type="scse",
)
state_dict = checkpoint.get("model_state_dict", checkpoint)
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
meta = {
"architecture": arch,
"encoder": encoder,
"val_dice": checkpoint.get("val_dice", 0),
"val_iou": checkpoint.get("val_iou", 0),
"epoch": checkpoint.get("epoch", 0),
}
return model, device, meta
# ═══════════════════════════════════════════════════════════════
# PREPROCESSING
# ═══════════════════════════════════════════════════════════════
def preprocess_image(img_np: np.ndarray, target_size: int = 512) -> Tuple[np.ndarray, np.ndarray]:
"""
Preprocess EL image for model input.
Returns: (model_input [1,1,H,W] float32, display_gray [H,W] uint8)
"""
# Convert to grayscale
if img_np.ndim == 3:
if img_np.shape[2] == 4:
gray = cv2.cvtColor(img_np, cv2.COLOR_RGBA2GRAY)
else:
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
else:
gray = img_np.copy()
if gray.dtype != np.uint8:
gray = (np.clip(gray, 0, 255)).astype(np.uint8)
orig_gray = gray.copy()
# CLAHE
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# Resize to model input
resized = cv2.resize(enhanced, (target_size, target_size), interpolation=cv2.INTER_LINEAR)
# Normalize: [0, 255] β†’ [0, 1]
normalized = resized.astype(np.float32) / 255.0
# To tensor shape: (1, 1, H, W)
tensor = normalized[np.newaxis, np.newaxis, ...]
return tensor, orig_gray
# ═══════════════════════════════════════════════════════════════
# INFERENCE
# ═══════════════════════════════════════════════════════════════
def predict(model, device, tensor_input: np.ndarray) -> np.ndarray:
"""Run model inference. Returns (H, W) class mask."""
x = torch.from_numpy(tensor_input).float().to(device)
with torch.no_grad():
with torch.amp.autocast(device_type=device.type, enabled=(device.type == "cuda")):
logits = model(x)
mask = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy().astype(np.uint8)
return mask
# ═══════════════════════════════════════════════════════════════
# GRID DETECTION β€” FIXED VERSION
# ═══════════════════════════════════════════════════════════════
@dataclass
class CellInfo:
cell_id: int
row: int
col: int
bbox: Tuple[int, int, int, int] # y1, x1, y2, x2
image: Optional[np.ndarray] = None
def detect_grid(gray: np.ndarray, min_cells: int = 4) -> List[CellInfo]:
"""
Detect cell grid in module image.
FIXED LOGIC:
- Only segment if we find a clear periodic grid with >= min_cells
- Single cells (no grid) β†’ return as one cell
- Requires BOTH row and column grid lines to segment
- Uses stricter periodicity validation
"""
h, w = gray.shape
# Apply CLAHE for better grid contrast
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray if gray.dtype == np.uint8 else (gray * 255).astype(np.uint8))
# Compute projections (inverted β€” dark gaps become peaks)
inv = 255 - enhanced
row_proj = inv.mean(axis=1).astype(np.float64) # horizontal gaps
col_proj = inv.mean(axis=0).astype(np.float64) # vertical gaps
# Smooth
from scipy.signal import medfilt
ks = max(3, h // 100) | 1 # ensure odd
row_proj = medfilt(row_proj, kernel_size=ks)
ks = max(3, w // 100) | 1
col_proj = medfilt(col_proj, kernel_size=ks)
# Find peaks β€” STRICT parameters
row_range = row_proj.max() - row_proj.min()
col_range = col_proj.max() - col_proj.min()
# Require prominent peaks (at least 20% of range)
row_peaks, _ = find_peaks(row_proj, prominence=row_range * 0.2, distance=h // 20)
col_peaks, _ = find_peaks(col_proj, prominence=col_range * 0.2, distance=w // 20)
# Validate periodicity β€” peaks must be roughly evenly spaced
row_peaks = _validate_periodic(row_peaks, min_count=2)
col_peaks = _validate_periodic(col_peaks, min_count=1)
# Need enough grid lines to form min_cells
n_row_cells = len(row_peaks) + 1
n_col_cells = len(col_peaks) + 1
total_cells = n_row_cells * n_col_cells
if total_cells < min_cells:
# Not enough grid β†’ treat as single cell
return [CellInfo(cell_id=1, row=0, col=0, bbox=(0, 0, h, w), image=gray)]
# Extract cells
row_bounds = np.concatenate([[0], row_peaks, [h]])
col_bounds = np.concatenate([[0], col_peaks, [w]])
cells = []
cell_id = 1
min_dim = max(20, min(h, w) // 30)
for i in range(len(row_bounds) - 1):
for j in range(len(col_bounds) - 1):
y1, y2 = int(row_bounds[i]), int(row_bounds[i+1])
x1, x2 = int(col_bounds[j]), int(col_bounds[j+1])
if y2 - y1 < min_dim or x2 - x1 < min_dim:
continue
cell_img = gray[y1:y2, x1:x2]
if cell_img.mean() < 5: # Skip pure black regions
continue
cells.append(CellInfo(
cell_id=cell_id, row=i, col=j,
bbox=(y1, x1, y2, x2), image=cell_img.copy()
))
cell_id += 1
if len(cells) == 0:
return [CellInfo(cell_id=1, row=0, col=0, bbox=(0, 0, h, w), image=gray)]
return cells
def _validate_periodic(peaks: np.ndarray, min_count: int = 2) -> np.ndarray:
"""Keep only peaks that form a roughly periodic pattern."""
if len(peaks) < min_count + 1:
return np.array([], dtype=int)
spacings = np.diff(peaks)
if len(spacings) == 0:
return np.array([], dtype=int)
median_sp = np.median(spacings)
if median_sp < 10:
return np.array([], dtype=int)
# Keep peaks where spacing is within 40% of median
good = [peaks[0]]
for i in range(len(spacings)):
if abs(spacings[i] - median_sp) < median_sp * 0.4:
good.append(peaks[i + 1])
# If spacing is ~2x median, it's a missing line β€” still valid
elif abs(spacings[i] - 2 * median_sp) < median_sp * 0.4:
good.append(peaks[i + 1])
if len(good) < min_count + 1:
return np.array([], dtype=int)
return np.array(good)
# ═══════════════════════════════════════════════════════════════
# DEFECT ANALYSIS
# ═══════════════════════════════════════════════════════════════
def analyze_cell(cell_img: np.ndarray, mask: np.ndarray, px_per_mm: float = 3.3) -> dict:
"""Analyze defects in one cell from its segmentation mask."""
h, w = mask.shape
total_px = h * w
# Class areas
busbar_pct = (mask == 1).sum() / total_px * 100
crack_pct = (mask == 2).sum() / total_px * 100
dark_pct = (mask == 3).sum() / total_px * 100
other_pct = (mask == 4).sum() / total_px * 100
# Crack length via skeletonization
crack_length_mm = 0.0
num_cracks = 0
if SKIMAGE_OK and (mask == 2).sum() > 5:
crack_binary = (mask == 2).astype(np.uint8)
try:
skeleton = skeletonize(crack_binary.astype(bool))
crack_length_px = skeleton.sum()
crack_length_mm = crack_length_px / px_per_mm
labeled = sk_label(skeleton.astype(np.uint8))
num_cracks = labeled.max()
except Exception:
pass
# Dark severity
if dark_pct > 50:
dark_severity = "critical"
elif dark_pct > 25:
dark_severity = "severe"
elif dark_pct > 10:
dark_severity = "moderate"
elif dark_pct > 2:
dark_severity = "minor"
else:
dark_severity = "none"
# Crack severity
if crack_length_mm > 30:
crack_severity = "critical"
elif crack_length_mm > 15:
crack_severity = "severe"
elif crack_length_mm > 5:
crack_severity = "moderate"
elif crack_length_mm > 0.5:
crack_severity = "minor"
else:
crack_severity = "none"
# Defect score (0-100)
score = min(100.0,
0.35 * min(crack_length_mm / 50 * 100, 100) +
0.35 * min(dark_pct * 2, 100) +
0.15 * min(num_cracks * 15, 100) +
0.15 * min(other_pct * 3, 100)
)
return {
"busbar_pct": round(busbar_pct, 2),
"crack_pct": round(crack_pct, 2),
"dark_pct": round(dark_pct, 2),
"other_defect_pct": round(other_pct, 2),
"crack_length_mm": round(crack_length_mm, 2),
"num_cracks": int(num_cracks),
"dark_severity": dark_severity,
"crack_severity": crack_severity,
"defect_score": round(score, 1),
}
def module_decision(cell_results: List[dict], thresholds: dict) -> dict:
"""PASS/FAIL decision from per-cell results."""
if not cell_results:
return {"decision": "PASS", "score": 0, "reasons": [], "cells": []}
reasons = []
defective = 0
for i, r in enumerate(cell_results):
fails = []
if r["defect_score"] > thresholds.get("max_score", 50):
fails.append(f"Cell {i+1}: score {r['defect_score']:.0f}")
if r["crack_length_mm"] > thresholds.get("max_crack_mm", 30):
fails.append(f"Cell {i+1}: crack {r['crack_length_mm']:.1f}mm")
if r["dark_pct"] > thresholds.get("max_dark_pct", 40):
fails.append(f"Cell {i+1}: dark {r['dark_pct']:.1f}%")
if fails:
defective += 1
reasons.extend(fails)
avg_score = np.mean([r["defect_score"] for r in cell_results])
decision = "FAIL" if reasons else "PASS"
return {
"decision": decision,
"score": round(avg_score, 1),
"num_defective": defective,
"num_cells": len(cell_results),
"reasons": reasons,
}
# ═══════════════════════════════════════════════════════════════
# VISUALIZATION
# ═══════════════════════════════════════════════════════════════
def create_overlay(gray: np.ndarray, mask: np.ndarray, alpha: float = 0.4) -> np.ndarray:
"""Create colored overlay of mask on grayscale image."""
if gray.ndim == 2:
vis = cv2.cvtColor(gray if gray.dtype == np.uint8 else (gray * 255).astype(np.uint8),
cv2.COLOR_GRAY2RGB)
else:
vis = gray.copy()
h, w = vis.shape[:2]
if mask.shape[:2] != (h, w):
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
overlay = vis.copy()
for idx, name in enumerate(CLASS_NAMES):
if idx == 0:
continue
color = CLASS_COLORS_RGB[name]
overlay[mask == idx] = color
return cv2.addWeighted(vis, 1 - alpha, overlay, alpha, 0)
# ═══════════════════════════════════════════════════════════════
# STREAMLIT APP
# ═══════════════════════════════════════════════════════════════
st.set_page_config(page_title="EL Defect Detection", page_icon="πŸ”¬", layout="wide")
st.title("πŸ”¬ EL Defect Detection System")
st.markdown("**U-Net++ with EfficientNet-B4 | Trained on E-SCDD | Val Dice: 0.6297**")
# ── Sidebar ──────────────────────────────────────────────────
with st.sidebar:
st.header("βš™οΈ Settings")
st.subheader("Quality Thresholds")
max_score = st.slider("Max defect score", 10, 90, 50, 5)
max_crack_mm = st.slider("Max crack length (mm)", 5, 100, 30, 5)
max_dark_pct = st.slider("Max dark area (%)", 5, 80, 40, 5)
overlay_alpha = st.slider("Overlay opacity", 0.1, 0.9, 0.4, 0.1)
st.subheader("Grid Detection")
min_cells_for_grid = st.slider("Min cells to segment", 2, 12, 4, 1,
help="Only segment into grid if at least this many cells detected")
st.markdown("---")
st.markdown("**Model Info**")
st.markdown("- Architecture: U-Net++")
st.markdown("- Encoder: EfficientNet-B4 + scSE")
st.markdown("- Dataset: E-SCDD (903 images)")
st.markdown("- Best Dice: 0.6297 (epoch 73)")
# ── Load model ───────────────────────────────────────────────
model_path = find_model_path()
model, device, meta = load_model(model_path)
if model is None:
st.warning(f"⚠️ Model not found. Searched: best_model.pth, /app/best_model.pth, output/best_model.pth. "
f"Falling back to heuristic analysis.")
HAS_MODEL = False
else:
st.success(f"βœ… Model loaded: {meta.get('architecture')} + {meta.get('encoder')} | "
f"Val Dice: {meta.get('val_dice', 0):.4f} | Epoch: {meta.get('epoch', 0)}")
HAS_MODEL = True
# ── Upload ───────────────────────────────────────────────────
uploaded = st.file_uploader("πŸ“€ Upload EL Image", type=["png", "jpg", "jpeg", "tif", "bmp"])
if uploaded:
pil_img = Image.open(uploaded)
img_np = np.array(pil_img)
# Preprocess
tensor_input, gray = preprocess_image(img_np, target_size=512)
st.markdown("---")
# ── Run inference ────────────────────────────────────────
with st.spinner("πŸ” Running defect detection..."):
if HAS_MODEL:
mask_512 = predict(model, device, tensor_input)
else:
# Fallback: simple thresholding
g = cv2.resize(gray, (512, 512))
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
g = clahe.apply(g)
mask_512 = np.zeros((512, 512), dtype=np.uint8)
mean_v = g.mean()
mask_512[g < mean_v * 0.4] = 3 # dark
edges = cv2.Canny(g, 30, 100)
mask_512[edges > 0] = 2 # crack approx
# Resize mask to original image size
mask_full = cv2.resize(mask_512, (gray.shape[1], gray.shape[0]),
interpolation=cv2.INTER_NEAREST)
# Create overlay on original
overlay_full = create_overlay(gray, mask_full, alpha=overlay_alpha)
# ── Display original + overlay ───────────────────────────
st.subheader("πŸ–ΌοΈ Results")
col1, col2 = st.columns(2)
with col1:
st.markdown("**Original**")
st.image(gray, use_container_width=True, clamp=True)
with col2:
st.markdown("**Defect Overlay**")
st.image(overlay_full, use_container_width=True, clamp=True)
# ── Grid detection + per-cell analysis ───────────────────
st.markdown("---")
cells = detect_grid(gray, min_cells=min_cells_for_grid)
st.subheader(f"πŸ“ {len(cells)} cell(s) detected")
# Estimate px/mm from cell size
if len(cells) > 1:
widths = [c.bbox[3] - c.bbox[1] for c in cells]
px_per_mm = np.median(widths) / 156.0 # standard 156mm cell
else:
px_per_mm = max(gray.shape) / 156.0
# Analyze each cell
cell_results = []
cell_overlays = []
for cell in cells:
y1, x1, y2, x2 = cell.bbox
cell_mask = mask_full[y1:y2, x1:x2]
cell_gray = gray[y1:y2, x1:x2]
result = analyze_cell(cell_gray, cell_mask, px_per_mm=max(px_per_mm, 0.5))
cell_results.append(result)
cell_ov = create_overlay(cell_gray, cell_mask, alpha=overlay_alpha)
cell_overlays.append(cell_ov)
# Display cells in grid
cols_per_row = min(6, len(cells))
if cols_per_row > 0:
for row_start in range(0, len(cells), cols_per_row):
row_end = min(row_start + cols_per_row, len(cells))
cols = st.columns(cols_per_row)
for i, col in enumerate(cols[:row_end - row_start]):
idx = row_start + i
r = cell_results[idx]
with col:
st.image(cell_overlays[idx], use_container_width=True, clamp=True)
score = r["defect_score"]
icon = "🟒" if score < 25 else ("🟑" if score < 50 else "πŸ”΄")
st.markdown(f"**Cell {idx+1}** {icon} {score:.0f}")
st.caption(f"Crack: {r['crack_length_mm']:.1f}mm | Dark: {r['dark_pct']:.1f}%")
# ── Module decision ──────────────────────────────────────
st.markdown("---")
thresholds = {"max_score": max_score, "max_crack_mm": max_crack_mm, "max_dark_pct": max_dark_pct}
decision = module_decision(cell_results, thresholds)
if decision["decision"] == "PASS":
st.success(f"βœ… **PASS** β€” Module Score: {decision['score']:.1f}/100")
else:
st.error(f"❌ **FAIL** β€” Module Score: {decision['score']:.1f}/100 β€” "
f"{decision['num_defective']}/{decision['num_cells']} cells defective")
with st.expander("Failure reasons"):
for reason in decision["reasons"]:
st.markdown(f"- {reason}")
# ── Summary metrics ──────────────────────────────────────
st.markdown("---")
st.subheader("πŸ“Š Summary")
c1, c2, c3, c4 = st.columns(4)
c1.metric("Cells", len(cell_results))
c2.metric("Avg Score", f"{decision['score']:.1f}")
c3.metric("Total Cracks", sum(r["num_cracks"] for r in cell_results))
c4.metric("Avg Dark %", f"{np.mean([r['dark_pct'] for r in cell_results]):.1f}%")
# ── Detailed table ───────────────────────────────────────
with st.expander("πŸ“‹ Detailed Results"):
import pandas as pd
rows = []
for i, r in enumerate(cell_results):
rows.append({
"Cell": i + 1,
"Score": r["defect_score"],
"Cracks": r["num_cracks"],
"Crack mm": r["crack_length_mm"],
"Dark %": r["dark_pct"],
"Busbar %": r["busbar_pct"],
"Crack Severity": r["crack_severity"],
"Dark Severity": r["dark_severity"],
})
st.dataframe(pd.DataFrame(rows), use_container_width=True)
# ── Color legend ─────────────────────────────────────────
with st.expander("🎨 Color Legend"):
st.markdown("""
| Color | Class | Description |
|-------|-------|-------------|
| 🟒 Green | Busbar | Metal busbar (feature, not defect) |
| πŸ”΅ Blue | Crack | Micro-crack in silicon |
| πŸ”΄ Red | Dark/Inactive | Area disconnected from circuit |
| 🟑 Yellow | Other Defect | Rings, material, gridline, corrosion, etc. |
""")
# ── Download ─────────────────────────────────────────────
st.markdown("---")
col_d1, col_d2 = st.columns(2)
with col_d1:
report = {"decision": decision, "cells": cell_results}
st.download_button("πŸ“„ Download JSON Report",
json.dumps(report, indent=2),
"el_report.json", "application/json")
with col_d2:
buf = BytesIO()
Image.fromarray(overlay_full).save(buf, format="PNG")
st.download_button("πŸ–ΌοΈ Download Overlay",
buf.getvalue(), "el_overlay.png", "image/png")
else:
st.info("πŸ‘† Upload an EL image to start analysis")
st.markdown("""
### πŸ”¬ What This Does
1. **Upload** an EL (Electroluminescence) image of a solar cell or module
2. **AI segmentation** detects cracks, dark areas, busbars, and other defects
3. **Grid detection** automatically segments modules into individual cells
4. **Analysis** measures crack length, dark area percentage, and defect severity
5. **PASS/FAIL** decision based on configurable quality thresholds
### Supported Inputs
- **Full module** images (6Γ—10, 6Γ—12, etc.) β€” automatically segments into cells
- **Single cell** images β€” analyzed as-is (no grid segmentation)
- Any brightness, any size, PNG/JPG/TIFF/BMP
### Model Details
| Property | Value |
|----------|-------|
| Architecture | U-Net++ with scSE attention |
| Encoder | EfficientNet-B4 (ImageNet pretrained) |
| Dataset | E-SCDD (903 EL images, 512Γ—512) |
| Classes | Background, Busbar, Crack, Dark, Other Defect |
| Best Val Dice | 0.6297 |
| Training | 100 epochs, Dice+Focal loss, AMP |
""")
st.markdown("---")
st.caption("EL Defect Detection | U-Net++ + EfficientNet-B4 + scSE | Trained on E-SCDD | Val Dice: 0.6297")