mediaportal commited on
Commit
b322655
Β·
verified Β·
1 Parent(s): 9a44f00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -41
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
- # Force Keras 2 logic to prevent the 'recursion depth' and 'quantization' errors
3
- # common when loading Kaggle-trained .h5 files in new environments.
4
  os.environ["TF_USE_LEGACY_KERAS"] = "1"
5
 
6
  import gradio as gr
@@ -15,79 +14,85 @@ from huggingface_hub import hf_hub_download
15
  REPO_ID = "mediaportal/Roadsegmentation"
16
  MODEL_FILENAME = "trained_model_33_cpu.h5"
17
 
18
- # If your repo is PRIVATE, add your token to Space Secrets as 'HF_TOKEN'
 
 
 
 
 
 
 
 
 
 
19
  hf_token = os.getenv("HF_TOKEN")
20
  model = None
21
 
22
  def load_model():
23
  global model
24
  try:
25
- # Download the model file
26
  path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, token=hf_token)
27
-
28
- # Load using the Classic Keras engine
29
- # compile=False is required because segmentation models often use
30
- # custom Loss functions (like IoU or Dice) that are hard to reload.
31
  model = keras.models.load_model(path, compile=False)
32
  return "βœ… Road Segmentation Model Loaded"
33
  except Exception as e:
34
  return f"❌ Error: {str(e)}"
35
 
36
- def predict_segmentation(img):
37
  if model is None:
38
- return None
39
 
40
- # 1. Store original dimensions
41
- h, w = img.shape[:2]
42
 
43
- # 2. Preprocessing
44
- # BDD100K segmentation models typically use 256x256 or 512x512.
45
- # We resize to 256x256 based on common CPU-optimized configurations.
46
- input_size = (256, 256)
47
- img_resized = cv2.resize(img, input_size)
48
  img_array = img_resized.astype('float32') / 255.0
49
  img_array = np.expand_dims(img_array, axis=0)
50
 
51
- # 3. Predict the Mask
52
  prediction = model.predict(img_array)[0]
53
 
54
- # 4. Process the Mask
55
- # The output is a probability map. Threshold at 0.5 to get binary road/not-road.
56
- mask = (prediction > 0.5).astype(np.uint8) * 255
57
- if len(mask.shape) == 3:
58
- mask = np.squeeze(mask, axis=-1)
59
-
60
- # Resize mask back to original image size
61
- mask_full = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
 
 
 
 
62
 
63
- # 5. Create the Green Overlay
64
- # We create a green version of the original image
65
  overlay = img.copy()
66
- overlay[mask_full > 0] = [0, 255, 0] # Apply green color (RGB)
67
 
68
- # Blend original and green overlay (0.6 original + 0.4 green)
69
- output = cv2.addWeighted(img, 0.6, overlay, 0.4, 0)
70
 
71
- return output
 
 
 
72
 
73
  # --- GRADIO INTERFACE ---
74
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
75
- gr.Markdown("# πŸš— ADAS: Road Surface Segmentation")
76
- gr.Markdown("Upload a dashboard camera image to visualize the drivable road area detection.")
77
 
78
  status = gr.Markdown("⏳ Initializing system...")
79
 
80
  with gr.Row():
81
- input_img = gr.Image(label="Original View", type="numpy")
82
- output_img = gr.Image(label="Detected Road (Green Overlay)")
 
83
 
84
- btn = gr.Button("Analyze Road", variant="primary")
85
 
86
- # Automatically load model on start
87
  demo.load(load_model, outputs=status)
88
-
89
- # Connect button to prediction
90
- btn.click(fn=predict_segmentation, inputs=input_img, outputs=output_img)
91
 
92
  if __name__ == "__main__":
93
  demo.queue().launch()
 
1
  import os
2
+ # Force Keras 2 logic to prevent recursion/quantization errors from Kaggle .h5 files
 
3
  os.environ["TF_USE_LEGACY_KERAS"] = "1"
4
 
5
  import gradio as gr
 
14
  REPO_ID = "mediaportal/Roadsegmentation"
15
  MODEL_FILENAME = "trained_model_33_cpu.h5"
16
 
17
+ # BDD100K Color Dictionary from your notebook
18
+ COLOR_DICT = {
19
+ 0: (128, 128, 128), # road - gray
20
+ 1: (230, 230, 50), # sidewalk - yellow
21
+ 8: (50, 150, 50), # vegetation - green
22
+ 10: (128, 180, 255), # sky - blue
23
+ 11: (255, 0, 0), # person - red
24
+ 13: (0, 0, 255), # car - blue
25
+ 19: (0, 0, 0) # unknown - black
26
+ }
27
+
28
  hf_token = os.getenv("HF_TOKEN")
29
  model = None
30
 
31
  def load_model():
32
  global model
33
  try:
 
34
  path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, token=hf_token)
35
+ # compile=False is used because the notebook uses SparseCategoricalCrossentropy
 
 
 
36
  model = keras.models.load_model(path, compile=False)
37
  return "βœ… Road Segmentation Model Loaded"
38
  except Exception as e:
39
  return f"❌ Error: {str(e)}"
40
 
41
+ def segment_road(img):
42
  if model is None:
43
+ return None, None
44
 
45
+ # 1. Store original size for scaling back
46
+ h_orig, w_orig = img.shape[:2]
47
 
48
+ # 2. Preprocessing (Notebook uses 192 height, 256 width)
49
+ img_resized = cv2.resize(img, (256, 192))
 
 
 
50
  img_array = img_resized.astype('float32') / 255.0
51
  img_array = np.expand_dims(img_array, axis=0)
52
 
53
+ # 3. Predict (Returns 20 channels for 20 classes)
54
  prediction = model.predict(img_array)[0]
55
 
56
+ # Get the class index with the highest probability for each pixel
57
+ mask = np.argmax(prediction, axis=-1).astype(np.uint8)
58
+
59
+ # 4. Create Outputs
60
+ # A. Full Semantic Map (Colorizing all classes)
61
+ full_mask_color = np.zeros((192, 256, 3), dtype=np.uint8)
62
+ for class_idx, color in COLOR_DICT.items():
63
+ full_mask_color[mask == class_idx] = color
64
+
65
+ # B. Road Highlight Overlay (Class 0 is Road)
66
+ road_mask = (mask == 0).astype(np.uint8) * 255
67
+ road_mask_resized = cv2.resize(road_mask, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST)
68
 
 
 
69
  overlay = img.copy()
70
+ overlay[road_mask_resized > 0] = [0, 255, 0] # Highlight road in green
71
 
72
+ # Blend: 70% original image, 30% green highlight
73
+ highlighted_road = cv2.addWeighted(img, 0.7, overlay, 0.3, 0)
74
 
75
+ # Resize full mask back to original aspect ratio for display
76
+ full_mask_resized = cv2.resize(full_mask_color, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST)
77
+
78
+ return highlighted_road, full_mask_resized
79
 
80
  # --- GRADIO INTERFACE ---
81
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
82
+ gr.Markdown("# πŸš— ADAS Road & Scene Segmentation")
83
+ gr.Markdown("Upload a dashboard image to identify the drivable road surface and other objects.")
84
 
85
  status = gr.Markdown("⏳ Initializing system...")
86
 
87
  with gr.Row():
88
+ input_img = gr.Image(label="Input Dashboard Image", type="numpy")
89
+ output_overlay = gr.Image(label="Drivable Road (Green Highlight)")
90
+ output_full = gr.Image(label="Full Semantic Map")
91
 
92
+ btn = gr.Button("Analyze Scene", variant="primary")
93
 
 
94
  demo.load(load_model, outputs=status)
95
+ btn.click(fn=segment_road, inputs=input_img, outputs=[output_overlay, output_full])
 
 
96
 
97
  if __name__ == "__main__":
98
  demo.queue().launch()