German_Traffic_Sign / src /streamlit_app.py
Osmanerendgn's picture
Update src/streamlit_app.py
95de39d verified
import os
import tempfile
import numpy as np
import streamlit as st
from PIL import Image
import tensorflow as tf
import cv2
from pathlib import Path
# ─────────────────────────────────────────────
# PAGE CONFIG
# ─────────────────────────────────────────────
st.set_page_config(
page_title="German Traffic Sign Recognition",
page_icon="🚦",
layout="wide",
initial_sidebar_state="expanded",
)
# ─────────────────────────────────────────────
# CUSTOM CSS β€” dark industrial / road aesthetic
# ─────────────────────────────────────────────
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Bebas+Neue&family=DM+Sans:wght@300;400;500;600&display=swap');
/* Base */
html, body, [class*="css"] {
font-family: 'DM Sans', sans-serif;
background-color: #0d0d0d;
color: #e8e8e0;
}
/* Hero header */
.hero-title {
font-family: 'Bebas Neue', sans-serif;
font-size: 3.6rem;
letter-spacing: 0.08em;
color: #f5c518;
line-height: 1.0;
margin-bottom: 0;
}
.hero-sub {
font-size: 0.95rem;
color: #888;
letter-spacing: 0.12em;
text-transform: uppercase;
margin-top: 4px;
}
/* Metric cards */
.metric-row { display: flex; gap: 16px; margin: 18px 0; }
.metric-card {
flex: 1;
background: #1a1a1a;
border: 1px solid #2a2a2a;
border-radius: 10px;
padding: 16px 20px;
text-align: center;
}
.metric-card .val {
font-family: 'Bebas Neue', sans-serif;
font-size: 2.2rem;
color: #f5c518;
line-height: 1;
}
.metric-card .lbl {
font-size: 0.72rem;
color: #666;
letter-spacing: 0.1em;
text-transform: uppercase;
margin-top: 4px;
}
/* Prediction result box */
.pred-box {
background: linear-gradient(135deg, #1a1a1a 0%, #141414 100%);
border: 2px solid #f5c518;
border-radius: 14px;
padding: 24px 28px;
margin: 16px 0;
}
.pred-label {
font-family: 'Bebas Neue', sans-serif;
font-size: 2.0rem;
color: #f5c518;
letter-spacing: 0.06em;
}
.pred-conf {
font-size: 1.1rem;
color: #aaa;
margin-top: 4px;
}
.pred-conf span { color: #4cff91; font-weight: 600; }
/* Confidence bar custom */
.conf-bar-wrap { margin: 6px 0; }
.conf-bar-label { font-size: 0.8rem; color: #aaa; margin-bottom: 2px; display: flex; justify-content: space-between; }
.conf-bar-bg { background: #1e1e1e; border-radius: 4px; height: 10px; overflow: hidden; }
.conf-bar-fill { height: 100%; border-radius: 4px; background: linear-gradient(90deg, #f5c518, #ff8c00); }
/* Section headers */
.section-head {
font-family: 'Bebas Neue', sans-serif;
font-size: 1.3rem;
letter-spacing: 0.1em;
color: #f5c518;
border-bottom: 1px solid #2a2a2a;
padding-bottom: 6px;
margin: 20px 0 12px 0;
}
/* Sidebar */
[data-testid="stSidebar"] {
background: #111 !important;
border-right: 1px solid #222;
}
/* Upload area */
[data-testid="stFileUploader"] {
background: #141414 !important;
border: 1px dashed #333 !important;
border-radius: 10px !important;
}
/* Tabs */
[data-testid="stTabs"] button {
font-family: 'Bebas Neue', sans-serif;
letter-spacing: 0.08em;
font-size: 1.05rem;
}
</style>
""", unsafe_allow_html=True)
# ─────────────────────────────────────────────
# CLASS NAMES (43 classes β€” correct & verified)
# ─────────────────────────────────────────────
CLASS_NAMES = {
0: 'Speed limit (20km/h)',
1: 'Speed limit (30km/h)',
2: 'Speed limit (50km/h)',
3: 'Speed limit (60km/h)',
4: 'Speed limit (70km/h)',
5: 'Speed limit (80km/h)',
6: 'End of speed limit (80km/h)',
7: 'Speed limit (100km/h)',
8: 'Speed limit (120km/h)',
9: 'No passing',
10: 'No passing for vehicles over 3.5t',
11: 'Right-of-way at next intersection',
12: 'Priority road',
13: 'Yield',
14: 'Stop',
15: 'No vehicles',
16: 'Vehicles over 3.5t prohibited',
17: 'No entry',
18: 'General caution',
19: 'Dangerous curve to the left',
20: 'Dangerous curve to the right',
21: 'Double curve',
22: 'Bumpy road',
23: 'Slippery road',
24: 'Road narrows on the right',
25: 'Road work',
26: 'Traffic signals',
27: 'Pedestrians',
28: 'Children crossing',
29: 'Bicycles crossing',
30: 'Beware of ice/snow',
31: 'Wild animals crossing',
32: 'End of all speed and passing limits',
33: 'Turn right ahead',
34: 'Turn left ahead',
35: 'Ahead only',
36: 'Go straight or right',
37: 'Go straight or left',
38: 'Keep right',
39: 'Keep left',
40: 'Roundabout mandatory',
41: 'End of no passing',
42: 'End of no passing for vehicles over 3.5t',
}
# Category groupings for sidebar info
CATEGORIES = {
"πŸ”΄ Speed Limits": list(range(0, 9)),
"β›” Prohibition": [9,10,15,16,17],
"⚠️ Warning": list(range(11,12)) + list(range(18,32)),
"πŸ”΅ Mandatory": list(range(33,41)),
"⬜ End of restriction": [6,32,41,42],
"πŸ›‘ Priority": [12,13,14],
}
# ─────────────────────────────────────────────
# MODEL LOADING
# ─────────────────────────────────────────────
BASE_DIR = Path(__file__).resolve().parent
MODEL_PATH = BASE_DIR / "GermanTraffic-sΔ±fΔ±rdanCNN-OS.keras"
IMG_SIZE = (64, 64) # matches your training config
@st.cache_resource(show_spinner=False)
def load_model():
try:
return tf.keras.models.load_model(MODEL_PATH)
except Exception as e:
st.error(f"Model yüklenemedi: {e}\n\nModel dosyasının app.py ile aynı klasârde olduğundan emin ol: `{MODEL_PATH}`")
st.stop()
# ─────────────────────────────────────────────
# PREPROCESSING
# ─────────────────────────────────────────────
def preprocess(pil_img: Image.Image) -> np.ndarray:
"""Matches training pipeline exactly: RGB resize 64x64, /255.0"""
img = pil_img.convert("RGB")
img = img.resize(IMG_SIZE, Image.LANCZOS)
x = np.array(img, dtype=np.float32) / 255.0
return np.expand_dims(x, axis=0) # (1, 64, 64, 3)
def predict(pil_img: Image.Image):
model = load_model()
x = preprocess(pil_img)
probs = model.predict(x, verbose=0)[0] # (43,)
top5 = np.argsort(probs)[::-1][:5]
return probs, top5
# ─────────────────────────────────────────────
# VIDEO PROCESSING (from video_processor.py logic)
# ─────────────────────────────────────────────
def build_candidate_mask(frame: np.ndarray) -> np.ndarray:
"""
Combines THREE detection techniques from computer_vision.ipynb:
1. HSV color segmentation (Cell 18, 22) β€” isolates red & blue sign colors
2. Canny edge detection (Cell 27) β€” catches shape boundaries of any sign
3. Noise cleanup via dilate (Cell 41) β€” fills small gaps, merges nearby blobs
Red β†’ speed limits, prohibitions, warnings (majority of GTSRB classes)
Blue β†’ mandatory signs (keep right/left, roundabout, ahead only β€” classes 33-40)
Canny→ catches yellow/white signs that HSV color mask misses
"""
hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
# ── 1. Red mask (two HSV ranges because red wraps around 0Β°/180Β°) ──────────
# Technique from computer_vision.ipynb Cell 18 & 22
red_mask = (
cv2.inRange(hsv, np.array([0, 70, 50]), np.array([10, 255, 255])) |
cv2.inRange(hsv, np.array([170, 70, 50]), np.array([180,255, 255]))
)
# ── 2. Blue mask (mandatory signs) ──────────────────────────────────────────
# computer_vision.ipynb Cell 22 β€” channel isolation by hue
blue_mask = cv2.inRange(hsv, np.array([100, 80, 50]), np.array([130, 255, 255]))
# ── 3. Canny edge detection on grayscale ────────────────────────────────────
# computer_vision.ipynb Cell 27: Sobel/Laplacian/Canny comparison
# GaussianBlur first (Cell 28 sketch function pattern) to reduce noise
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
canny = cv2.Canny(blurred, 50, 120) # same thresholds as Cell 27
# ── 4. Combine all three masks ───────────────────────────────────────────────
combined = cv2.bitwise_or(red_mask, blue_mask)
combined = cv2.bitwise_or(combined, canny)
# ── 5. Dilate to fill gaps (Cell 41 pattern: dilate + thresh) ───────────────
kernel = np.ones((3, 3), np.uint8)
combined = cv2.dilate(combined, kernel, iterations=2)
return combined
def process_video_frames(video_path: str, conf_thresh: float = 0.60) -> str:
"""
Frame-by-frame traffic sign detection.
Detection pipeline (techniques from computer_vision.ipynb):
HSV red mask + HSV blue mask + Canny edges β†’ dilate β†’ findContours β†’ CNN classify
Returns path to processed output video.
"""
model = load_model()
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None
fps = cap.get(cv2.CAP_PROP_FPS) or 20.0
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
out_path = video_path.replace(".", "_processed.")
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
# ── Motion detection state (computer_vision.ipynb Cell 38-41) ───────────────
# absdiff between consecutive frames β€” only process regions with motion
ret_prev, prev_frame = cap.read()
if not ret_prev:
cap.release()
out.release()
return None
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
prev_blur = cv2.GaussianBlur(prev_gray, (21, 21), 0)
while True:
ret, frame = cap.read()
if not ret:
break
# ── Motion filter (Cell 39 pattern) ─────────────────────────────────────
curr_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
curr_blur = cv2.GaussianBlur(curr_gray, (21, 21), 0)
diff = cv2.absdiff(prev_blur, curr_blur)
motion_thresh = cv2.threshold(diff, 20, 255, cv2.THRESH_BINARY)[1]
motion_pixels = np.sum(motion_thresh) / 255
prev_blur = curr_blur # advance frame
# Skip static frames β€” if very few pixels changed, no new sign appeared
# Threshold: at least 0.3% of frame must show motion
if motion_pixels < (width * height * 0.003):
out.write(frame)
continue
# ── Combined mask (HSV + Canny) ──────────────────────────────────────────
candidate_mask = build_candidate_mask(frame)
# ── Contour detection (computer_vision.ipynb Cell 42) ───────────────────
contours, _ = cv2.findContours(
candidate_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
for cnt in contours:
x, y, w, h = cv2.boundingRect(cnt)
# Size filter: ignore tiny noise and full-frame false positives
if w < 25 or h < 25 or w > width * 0.7 or h > height * 0.7:
continue
# Aspect ratio filter: traffic signs are roughly square (0.5 – 2.0)
aspect = w / h
if aspect < 0.4 or aspect > 2.5:
continue
roi = frame[y:y+h, x:x+w]
roi_rgb = cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)
roi_pil = Image.fromarray(roi_rgb)
roi_inp = preprocess(roi_pil)
preds = model.predict(roi_inp, verbose=0)[0]
class_id = int(np.argmax(preds))
confidence = float(np.max(preds))
if confidence >= conf_thresh:
label = f"{CLASS_NAMES.get(class_id, '?')} {confidence*100:.0f}%"
# Bounding box color: green for high conf, yellow for moderate
box_color = (0, 230, 50) if confidence >= 0.80 else (0, 200, 255)
cv2.rectangle(frame, (x, y), (x+w, y+h), box_color, 2)
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1)
cv2.rectangle(frame, (x, y - th - 8), (x + tw + 4, y), (0, 0, 0), -1)
cv2.putText(frame, label, (x + 2, y - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.55, box_color, 1, cv2.LINE_AA)
out.write(frame)
cap.release()
out.release()
return out_path
# ─────────────────────────────────────────────
# UI HELPERS
# ─────────────────────────────────────────────
def render_confidence_bars(probs: np.ndarray, top5: np.ndarray):
bars_html = '<div>'
for rank, idx in enumerate(top5):
p = float(probs[idx]) * 100
name = CLASS_NAMES.get(int(idx), "?")
color = "#f5c518" if rank == 0 else ("#888" if p < 5 else "#aaa")
width = max(p, 1)
bars_html += f"""
<div class="conf-bar-wrap">
<div class="conf-bar-label">
<span style="color:{color}">{'β˜… ' if rank==0 else ''}{name}</span>
<span style="color:{color};font-weight:600">{p:.1f}%</span>
</div>
<div class="conf-bar-bg">
<div class="conf-bar-fill" style="width:{width}%;background:{'linear-gradient(90deg,#f5c518,#ff8c00)' if rank==0 else '#333'}"></div>
</div>
</div>"""
bars_html += "</div>"
st.markdown(bars_html, unsafe_allow_html=True)
def render_prediction(probs, top5):
best_id = int(top5[0])
best_name = CLASS_NAMES.get(best_id, "Unknown")
best_conf = float(probs[best_id]) * 100
conf_color = "#4cff91" if best_conf >= 80 else ("#ffb347" if best_conf >= 50 else "#ff6b6b")
st.markdown(f"""
<div class="pred-box">
<div style="font-size:0.75rem;color:#666;letter-spacing:0.12em;text-transform:uppercase;margin-bottom:6px">Predicted Sign</div>
<div class="pred-label">{best_name}</div>
<div class="pred-conf">Confidence: <span style="color:{conf_color}">{best_conf:.1f}%</span> &nbsp;Β·&nbsp; Class ID: {best_id}</div>
</div>
""", unsafe_allow_html=True)
st.markdown('<div class="section-head">Top 5 Predictions</div>', unsafe_allow_html=True)
render_confidence_bars(probs, top5)
# ─────────────────────────────────────────────
# SIDEBAR
# ─────────────────────────────────────────────
with st.sidebar:
st.markdown("""
<div style="text-align:center;padding:12px 0 20px">
<div style="font-family:'Bebas Neue',sans-serif;font-size:1.8rem;color:#f5c518;letter-spacing:0.1em">🚦 GTSRB</div>
<div style="font-size:0.7rem;color:#555;letter-spacing:0.12em;text-transform:uppercase">German Traffic Sign Recognition</div>
</div>
""", unsafe_allow_html=True)
st.markdown('<div class="section-head">Model Info</div>', unsafe_allow_html=True)
st.markdown("""
<div class="metric-row">
<div class="metric-card"><div class="val">97.8%</div><div class="lbl">Test Acc</div></div>
<div class="metric-card"><div class="val">43</div><div class="lbl">Classes</div></div>
</div>
<div class="metric-row">
<div class="metric-card"><div class="val">64px</div><div class="lbl">Input Size</div></div>
<div class="metric-card"><div class="val">CNN</div><div class="lbl">Architecture</div></div>
</div>
""", unsafe_allow_html=True)
st.markdown('<div class="section-head">Sign Categories</div>', unsafe_allow_html=True)
for cat, ids in CATEGORIES.items():
with st.expander(f"{cat} ({len(ids)} signs)"):
for i in ids:
st.markdown(f"<small style='color:#888'>**{i}** β€” {CLASS_NAMES[i]}</small>", unsafe_allow_html=True)
st.markdown('<div class="section-head">Video Detection</div>', unsafe_allow_html=True)
conf_thresh = st.slider("Confidence threshold", 0.30, 0.95, 0.60, 0.05,
help="Minimum confidence to draw a bounding box on video frames")
st.markdown("""
<div style="font-size:0.72rem;color:#444;margin-top:24px;line-height:1.7">
Dataset: GTSRB (Kaggle)<br>
39,209 training images<br>
Architecture: Custom CNN<br>
Trained locally on GPU
</div>
""", unsafe_allow_html=True)
# ─────────────────────────────────────────────
# MAIN HEADER
# ─────────────────────────────────────────────
st.markdown("""
<div class="hero-title">German Traffic Sign<br>Recognition</div>
<div class="hero-sub">Deep Learning Β· 43 Classes Β· 97.82% Test Accuracy</div>
<hr style="border-color:#1e1e1e;margin:18px 0">
""", unsafe_allow_html=True)
# ─────────────────────────────────────────────
# MAIN TABS
# ─────────────────────────────────────────────
tab_upload, tab_camera, tab_video = st.tabs(["πŸ“ IMAGE UPLOAD", "πŸ“· CAMERA", "🎬 VIDEO DETECTION"])
# ══════════════════════════════════════════════
# TAB 1 β€” IMAGE UPLOAD
# ══════════════════════════════════════════════
with tab_upload:
st.markdown('<div class="section-head">Upload a Traffic Sign Image</div>', unsafe_allow_html=True)
uploaded = st.file_uploader(
"Drag & drop or browse β€” JPG, JPEG, PNG supported",
type=["jpg", "jpeg", "png", "webp"],
label_visibility="collapsed"
)
if uploaded:
pil_img = Image.open(uploaded)
col_img, col_pred = st.columns([1, 1.4], gap="large")
with col_img:
st.markdown('<div class="section-head">Uploaded Image</div>', unsafe_allow_html=True)
st.image(pil_img.convert("RGB"), use_container_width=True)
st.caption(f"Size: {pil_img.size[0]}Γ—{pil_img.size[1]} px")
with col_pred:
with st.spinner("Running inference..."):
probs, top5 = predict(pil_img)
render_prediction(probs, top5)
else:
st.markdown("""
<div style="text-align:center;padding:60px 0;color:#333">
<div style="font-size:3rem">🚧</div>
<div style="font-size:1rem;letter-spacing:0.06em;margin-top:8px">Upload an image to begin</div>
</div>
""", unsafe_allow_html=True)
# ══════════════════════════════════════════════
# TAB 2 β€” CAMERA INPUT
# ══════════════════════════════════════════════
with tab_camera:
st.markdown('<div class="section-head">Take a Photo</div>', unsafe_allow_html=True)
st.markdown('<small style="color:#666">Point your camera at a German traffic sign and capture</small>', unsafe_allow_html=True)
cam_img = st.camera_input("", label_visibility="collapsed")
if cam_img:
pil_img = Image.open(cam_img)
col_c1, col_c2 = st.columns([1, 1.4], gap="large")
with col_c1:
st.markdown('<div class="section-head">Captured Image</div>', unsafe_allow_html=True)
st.image(pil_img.convert("RGB"), use_container_width=True)
with col_c2:
with st.spinner("Classifying..."):
probs, top5 = predict(pil_img)
render_prediction(probs, top5)
# ══════════════════════════════════════════════
# TAB 3 β€” VIDEO DETECTION
# ══════════════════════════════════════════════
with tab_video:
st.markdown('<div class="section-head">Upload a Video for Traffic Sign Detection</div>', unsafe_allow_html=True)
st.markdown("""
<small style="color:#666">
The system detects red-colored regions in each frame using HSV segmentation,
crops candidate regions, classifies them with the CNN, and draws bounding boxes
when confidence β‰₯ threshold set in sidebar.
</small>
""", unsafe_allow_html=True)
vid_file = st.file_uploader(
"Upload MP4 / AVI / MOV",
type=["mp4", "avi", "mov"],
label_visibility="collapsed",
key="vid_uploader"
)
if vid_file:
st.markdown('<div class="section-head">Original Video</div>', unsafe_allow_html=True)
st.video(vid_file)
if st.button("πŸ” Run Detection", use_container_width=True):
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_in:
tmp_in.write(vid_file.read())
tmp_in_path = tmp_in.name
with st.spinner("Processing frames... this may take a moment ⏳"):
out_path = process_video_frames(tmp_in_path, conf_thresh=conf_thresh)
if out_path and os.path.exists(out_path):
st.markdown('<div class="section-head">Processed Video</div>', unsafe_allow_html=True)
with open(out_path, "rb") as f:
st.download_button(
"⬇️ Download Processed Video",
data=f,
file_name="gtsrb_detection_output.mp4",
mime="video/mp4",
use_container_width=True
)
st.video(out_path)
else:
st.error("Video işlenemedi. Codec veya dosya formatı sorunu olabilir.")
else:
st.markdown("""
<div style="text-align:center;padding:60px 0;color:#333">
<div style="font-size:3rem">🎬</div>
<div style="font-size:1rem;letter-spacing:0.06em;margin-top:8px">Upload a video to detect traffic signs</div>
</div>
""", unsafe_allow_html=True)