darthvader2603 commited on
Commit
d0bfb7a
·
verified ·
1 Parent(s): 8ef0245

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -37
app.py CHANGED
@@ -5,7 +5,6 @@ import torchvision.transforms as T
5
  import gradio as gr
6
  import numpy as np
7
  import cv2
8
- from PIL import Image
9
 
10
  # ==========================================
11
  # 1. Model Architecture
@@ -34,7 +33,7 @@ class SimpleUNet(nn.Module):
34
  return x
35
 
36
  # ==========================================
37
- # 2. Load Model
38
  # ==========================================
39
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
  model = SimpleUNet().to(device)
@@ -45,78 +44,126 @@ try:
45
  except FileNotFoundError:
46
  print("WARNING: 'iris_segmentation_model.pth' not found.")
47
 
 
 
 
48
  model.eval()
49
 
50
  # ==========================================
51
- # 3. Preprocessing & Logic
52
  # ==========================================
53
- transform = T.Compose([
54
- T.Resize((224, 224)),
55
- T.ToTensor(),
56
- T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
57
- ])
58
 
59
- def process_frame(frame):
 
 
 
 
60
  """
61
- Standard processing function.
62
- Used by both the Live Stream and the Snapshot button.
 
 
63
  """
64
  if frame is None:
65
- return None
66
 
67
  original_h, original_w = frame.shape[:2]
68
 
69
- # Enhance
 
 
70
  gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
71
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
72
  enhanced = clahe.apply(gray)
73
  enhanced_rgb = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2RGB)
74
- pil_img = Image.fromarray(enhanced_rgb)
 
 
 
 
 
 
75
 
76
- # Predict
77
- input_tensor = transform(pil_img).unsqueeze(0).to(device)
 
 
 
 
 
 
 
 
78
  with torch.no_grad():
79
  pred = model(input_tensor)
80
- pred_mask = pred.squeeze().cpu().numpy()
 
81
 
82
- # Mask
83
  binary_mask = (pred_mask > 0.5).astype(np.uint8)
84
  binary_mask_resized = cv2.resize(binary_mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
85
 
86
- # Colorize
87
  color_mask_bgr = cv2.applyColorMap(binary_mask_resized * 255, cv2.COLORMAP_JET)
88
  color_mask_rgb = cv2.cvtColor(color_mask_bgr, cv2.COLOR_BGR2RGB)
89
 
90
- # Blend
91
  blended = cv2.addWeighted(frame, 0.7, color_mask_rgb, 0.3, 0)
92
 
93
- # Flip for mirror effect (optional)
94
- return cv2.flip(blended, 1)
95
 
96
  # ==========================================
97
- # 4. Gradio Interface
98
  # ==========================================
 
 
 
99
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
100
- gr.Markdown("## 👁️ Live Iris Segmentation + Snapshot")
101
 
102
- # --- Live Stream Section ---
 
 
103
  with gr.Row():
104
- input_stream = gr.Image(sources=["webcam"], streaming=True, label="Live Webcam", mirror_webcam=True)
 
 
105
  output_stream = gr.Image(label="Live Segmentation", interactive=False)
 
 
106
 
107
- # --- Capture Controls ---
108
- # This button allows you to "Stop and See" the current result
109
- btn_snapshot = gr.Button("📸 Freeze Current Frame", variant="primary")
110
-
111
- # --- Static Result Section ---
112
- # This shows the single frozen frame
113
- static_output = gr.Image(label="Frozen Snapshot (Inspection View)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- # 1. Start the Live Stream
116
- input_stream.stream(fn=process_frame, inputs=input_stream, outputs=output_stream)
117
 
118
- # 2. Button Logic: Take current stream frame -> Process -> Show in Static Box
119
- btn_snapshot.click(fn=process_frame, inputs=input_stream, outputs=static_output)
120
 
121
  if __name__ == "__main__":
122
  demo.launch()
 
5
  import gradio as gr
6
  import numpy as np
7
  import cv2
 
8
 
9
  # ==========================================
10
  # 1. Model Architecture
 
33
  return x
34
 
35
  # ==========================================
36
+ # 2. Load Model & Optimize (FP16)
37
  # ==========================================
38
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
  model = SimpleUNet().to(device)
 
44
  except FileNotFoundError:
45
  print("WARNING: 'iris_segmentation_model.pth' not found.")
46
 
47
+ # OPTIMIZATION: Convert to Half Precision (FP16) for speed
48
+ if device.type == 'cuda':
49
+ model.half()
50
  model.eval()
51
 
52
  # ==========================================
53
+ # 3. High-Speed Processing Logic
54
  # ==========================================
55
+ # Pre-calculate normalization constants for speed
56
+ mean = torch.tensor([0.5, 0.5, 0.5], device=device).view(1, 3, 1, 1)
57
+ std = torch.tensor([0.5, 0.5, 0.5], device=device).view(1, 3, 1, 1)
 
 
58
 
59
+ if device.type == 'cuda':
60
+ mean = mean.half()
61
+ std = std.half()
62
+
63
+ def process_frame_fast(frame):
64
  """
65
+ Optimized pipeline:
66
+ 1. Direct Numpy -> Tensor (No PIL)
67
+ 2. GPU Resize
68
+ 3. FP16 Inference
69
  """
70
  if frame is None:
71
+ return None, None
72
 
73
  original_h, original_w = frame.shape[:2]
74
 
75
+ # 1. Preprocessing (OpenCV is faster than PIL for basic ops)
76
+ # CLAHE (CPU side is usually fine, but moving to GPU tensor first is an option if CPU is bottleneck)
77
+ # Keeping CLAHE on CPU for stability with OpenCV
78
  gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
79
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
80
  enhanced = clahe.apply(gray)
81
  enhanced_rgb = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2RGB)
82
+
83
+ # 2. To Tensor (Directly to GPU)
84
+ input_tensor = torch.from_numpy(enhanced_rgb).permute(2, 0, 1).to(device)
85
+
86
+ # 3. Normalize & Resize on GPU (Faster than CPU resize)
87
+ # We resize to 224x224
88
+ input_tensor = T.functional.resize(input_tensor, [224, 224], antialias=True)
89
 
90
+ # Convert to float (or half) and normalize
91
+ if device.type == 'cuda':
92
+ input_tensor = input_tensor.half()
93
+ else:
94
+ input_tensor = input_tensor.float()
95
+
96
+ input_tensor = input_tensor.div(255.0).unsqueeze(0)
97
+ input_tensor = (input_tensor - mean) / std
98
+
99
+ # 4. Inference
100
  with torch.no_grad():
101
  pred = model(input_tensor)
102
+ # Squeeze and bring mask back to CPU as float32 for OpenCV
103
+ pred_mask = pred.squeeze().float().cpu().numpy()
104
 
105
+ # 5. Post-Processing (Mask & Blend)
106
  binary_mask = (pred_mask > 0.5).astype(np.uint8)
107
  binary_mask_resized = cv2.resize(binary_mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
108
 
 
109
  color_mask_bgr = cv2.applyColorMap(binary_mask_resized * 255, cv2.COLORMAP_JET)
110
  color_mask_rgb = cv2.cvtColor(color_mask_bgr, cv2.COLOR_BGR2RGB)
111
 
 
112
  blended = cv2.addWeighted(frame, 0.7, color_mask_rgb, 0.3, 0)
113
 
114
+ # Return twice: one for display, one for 'latest_frame' state
115
+ return blended, blended
116
 
117
  # ==========================================
118
+ # 4. Gradio Interface (Generator Pattern)
119
  # ==========================================
120
+ def capture_logic(image):
121
+ return image
122
+
123
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
124
+ gr.Markdown("## Fast Iris Segmentation")
125
 
126
+ # State to hold the latest image for capturing
127
+ latest_frame_state = gr.State()
128
+
129
  with gr.Row():
130
+ # Input: Webcam (mirroring enabled)
131
+ input_stream = gr.Image(sources=["webcam"], streaming=True, label="Webcam", mirror_webcam=True)
132
+ # Output: Live Result
133
  output_stream = gr.Image(label="Live Segmentation", interactive=False)
134
+ # Output: Snapshot
135
+ snapshot_output = gr.Image(label="Snapshot")
136
 
137
+ with gr.Row():
138
+ # Three minimal buttons
139
+ btn_start = gr.Button("▶️ Start / Restart", variant="primary")
140
+ btn_stop = gr.Button("⏹️ Stop", variant="stop")
141
+ btn_capture = gr.Button("📸 Capture", variant="secondary")
142
+
143
+ # --- Event Logic ---
144
+
145
+ # 1. START: distinct event that triggers the stream
146
+ # using input_stream.stream allows Gradio to handle the webcam loop efficiently
147
+ stream_event = input_stream.stream(
148
+ fn=process_frame_fast,
149
+ inputs=input_stream,
150
+ outputs=[output_stream, latest_frame_state],
151
+ show_progress=False
152
+ )
153
+
154
+ # 2. RESTART: Clicking start simply re-triggers the stream event
155
+ btn_start.click(
156
+ fn=process_frame_fast,
157
+ inputs=input_stream,
158
+ outputs=[output_stream, latest_frame_state],
159
+ show_progress=False
160
+ )
161
 
162
+ # 3. STOP: Cancels the stream event. This kills the process cleanly.
163
+ btn_stop.click(fn=None, inputs=None, outputs=None, cancels=[stream_event])
164
 
165
+ # 4. CAPTURE: Grabs the last frame from State
166
+ btn_capture.click(fn=capture_logic, inputs=latest_frame_state, outputs=snapshot_output)
167
 
168
  if __name__ == "__main__":
169
  demo.launch()