File size: 3,499 Bytes
292bfa3
b322655
292bfa3
 
 
 
 
 
 
 
 
 
 
9a44f00
292bfa3
 
b322655
 
 
 
 
 
 
 
 
 
 
292bfa3
 
 
 
 
 
 
b322655
292bfa3
9a44f00
292bfa3
 
 
b322655
292bfa3
b322655
292bfa3
b322655
 
292bfa3
b322655
 
292bfa3
 
 
b322655
292bfa3
 
b322655
 
 
 
 
 
 
 
 
 
 
 
292bfa3
 
b322655
292bfa3
b322655
 
292bfa3
b322655
 
 
 
292bfa3
 
 
b322655
 
292bfa3
 
 
 
b322655
 
 
292bfa3
b322655
292bfa3
 
b322655
292bfa3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
# Force Keras 2 logic to prevent recursion/quantization errors from Kaggle .h5 files
os.environ["TF_USE_LEGACY_KERAS"] = "1" 

import gradio as gr
import tensorflow as tf
import tf_keras as keras 
import numpy as np
import cv2
from PIL import Image
from huggingface_hub import hf_hub_download

# --- CONFIGURATION ---
REPO_ID = "mediaportal/Roadsegmentation" 
MODEL_FILENAME = "trained_model_33_cpu.h5"

# BDD100K Color Dictionary from your notebook
COLOR_DICT = {
    0: (128, 128, 128),  # road - gray
    1: (230, 230, 50),   # sidewalk - yellow
    8: (50, 150, 50),    # vegetation - green
    10: (128, 180, 255), # sky - blue
    11: (255, 0, 0),     # person - red
    13: (0, 0, 255),     # car - blue
    19: (0, 0, 0)        # unknown - black
}

hf_token = os.getenv("HF_TOKEN")
model = None

def load_model():
    global model
    try:
        path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, token=hf_token)
        # compile=False is used because the notebook uses SparseCategoricalCrossentropy
        model = keras.models.load_model(path, compile=False)
        return "✅ Road Segmentation Model Loaded"
    except Exception as e:
        return f"❌ Error: {str(e)}"

def segment_road(img):
    if model is None:
        return None, None
    
    # 1. Store original size for scaling back
    h_orig, w_orig = img.shape[:2]
    
    # 2. Preprocessing (Notebook uses 192 height, 256 width)
    img_resized = cv2.resize(img, (256, 192)) 
    img_array = img_resized.astype('float32') / 255.0
    img_array = np.expand_dims(img_array, axis=0)
    
    # 3. Predict (Returns 20 channels for 20 classes)
    prediction = model.predict(img_array)[0]
    
    # Get the class index with the highest probability for each pixel
    mask = np.argmax(prediction, axis=-1).astype(np.uint8)
    
    # 4. Create Outputs
    # A. Full Semantic Map (Colorizing all classes)
    full_mask_color = np.zeros((192, 256, 3), dtype=np.uint8)
    for class_idx, color in COLOR_DICT.items():
        full_mask_color[mask == class_idx] = color
    
    # B. Road Highlight Overlay (Class 0 is Road)
    road_mask = (mask == 0).astype(np.uint8) * 255
    road_mask_resized = cv2.resize(road_mask, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST)
    
    overlay = img.copy()
    overlay[road_mask_resized > 0] = [0, 255, 0] # Highlight road in green
    
    # Blend: 70% original image, 30% green highlight
    highlighted_road = cv2.addWeighted(img, 0.7, overlay, 0.3, 0)
    
    # Resize full mask back to original aspect ratio for display
    full_mask_resized = cv2.resize(full_mask_color, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST)
    
    return highlighted_road, full_mask_resized

# --- GRADIO INTERFACE ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🚗 ADAS Road & Scene Segmentation")
    gr.Markdown("Upload a dashboard image to identify the drivable road surface and other objects.")
    
    status = gr.Markdown("⏳ Initializing system...")
    
    with gr.Row():
        input_img = gr.Image(label="Input Dashboard Image", type="numpy")
        output_overlay = gr.Image(label="Drivable Road (Green Highlight)")
        output_full = gr.Image(label="Full Semantic Map")
    
    btn = gr.Button("Analyze Scene", variant="primary")
    
    demo.load(load_model, outputs=status)
    btn.click(fn=segment_road, inputs=input_img, outputs=[output_overlay, output_full])

if __name__ == "__main__":
    demo.queue().launch()