aziraarshad commited on
Commit
bf588cb
·
verified ·
1 Parent(s): b972955

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +274 -0
app.py CHANGED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import cv2
4
+ import gradio as gr
5
+ import torch
6
+ import torch.nn as nn
7
+ import mediapipe as mp
8
+
9
+ # ----------------------------
10
+ # Load labels (label.json)
11
+ # Supports:
12
+ # 1) ["label1","label2",...]
13
+ # 2) {"0":"label1","1":"label2",...}
14
+ # ----------------------------
15
+ def load_labels(path="label.json"):
16
+ with open(path, "r", encoding="utf-8") as f:
17
+ obj = json.load(f)
18
+ if isinstance(obj, list):
19
+ return obj
20
+ if isinstance(obj, dict):
21
+ items = sorted(obj.items(), key=lambda kv: int(kv[0]))
22
+ return [v for _, v in items]
23
+ raise ValueError("label.json must be a list or a dict mapping index -> label.")
24
+
25
+ LABELS = load_labels("label.json")
26
+ NUM_CLASSES = len(LABELS)
27
+
28
+ # ----------------------------
29
+ # MediaPipe helpers (from your notebook)
30
+ # ----------------------------
31
+ mp_holistic = mp.solutions.holistic
32
+ mp_drawing = mp.solutions.drawing_utils
33
+
34
+ def mediapipe_detection(image, model):
35
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
36
+ image.flags.writeable = False
37
+ results = model.process(image)
38
+ image.flags.writeable = True
39
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
40
+ return image, results
41
+
42
+ def draw_styled_landmarks(image, results):
43
+ mp_drawing.draw_landmarks(
44
+ image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS,
45
+ mp_drawing.DrawingSpec(color=(0, 0, 255), thickness=1, circle_radius=1),
46
+ mp_drawing.DrawingSpec(color=(80, 110, 10), thickness=1, circle_radius=1)
47
+ )
48
+ mp_drawing.draw_landmarks(
49
+ image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS,
50
+ mp_drawing.DrawingSpec(color=(0, 0, 255), thickness=1, circle_radius=2),
51
+ mp_drawing.DrawingSpec(color=(80, 110, 10), thickness=1, circle_radius=1)
52
+ )
53
+ mp_drawing.draw_landmarks(
54
+ image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS,
55
+ mp_drawing.DrawingSpec(color=(0, 0, 255), thickness=1, circle_radius=2),
56
+ mp_drawing.DrawingSpec(color=(80, 110, 10), thickness=1, circle_radius=1)
57
+ )
58
+
59
+ def extract_keypoints(results):
60
+ pose = np.array([[res.x, res.y, res.z, res.visibility] for res in results.pose_landmarks.landmark]).flatten() \
61
+ if results.pose_landmarks else np.zeros(33 * 4)
62
+ lh = np.array([[res.x, res.y, res.z] for res in results.left_hand_landmarks.landmark]).flatten() \
63
+ if results.left_hand_landmarks else np.zeros(21 * 3)
64
+ rh = np.array([[res.x, res.y, res.z] for res in results.right_hand_landmarks.landmark]).flatten() \
65
+ if results.right_hand_landmarks else np.zeros(21 * 3)
66
+ return np.concatenate([pose, lh, rh]) # 258 dims
67
+
68
+ # ----------------------------
69
+ # Model code (from your notebook)
70
+ # ----------------------------
71
+ class MultiHeadSelfAttention(nn.Module):
72
+ def __init__(self, embed_dim, num_heads=8, dropout=0.1):
73
+ super().__init__()
74
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
75
+ self.embed_dim = embed_dim
76
+ self.num_heads = num_heads
77
+ self.head_dim = embed_dim // num_heads
78
+ self.query = nn.Linear(embed_dim, embed_dim)
79
+ self.key = nn.Linear(embed_dim, embed_dim)
80
+ self.value = nn.Linear(embed_dim, embed_dim)
81
+ self.dropout = nn.Dropout(dropout)
82
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
83
+ self.norm = nn.LayerNorm(embed_dim)
84
+
85
+ def forward(self, x):
86
+ batch_size, seq_len, _ = x.size()
87
+ residual = x
88
+ Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
89
+ K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
90
+ V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
91
+
92
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
93
+ attn_weights = torch.softmax(scores, dim=-1)
94
+ attn_weights = self.dropout(attn_weights)
95
+
96
+ attn_output = torch.matmul(attn_weights, V)
97
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
98
+ output = self.out_proj(attn_output)
99
+ output = self.norm(output + residual)
100
+ return output, attn_weights
101
+
102
+ class AttentionEnhancedLSTM(nn.Module):
103
+ def __init__(self, input_size, hidden_size, num_layers=1, bidirectional=True, dropout=0.1):
104
+ super().__init__()
105
+ self.hidden_size = hidden_size
106
+ self.num_layers = num_layers
107
+ self.bidirectional = bidirectional
108
+ self.lstm = nn.LSTM(
109
+ input_size, hidden_size, num_layers,
110
+ batch_first=True, bidirectional=bidirectional,
111
+ dropout=dropout if num_layers > 1 else 0
112
+ )
113
+ lstm_output_dim = hidden_size * 2 if bidirectional else hidden_size
114
+ self.attention = MultiHeadSelfAttention(embed_dim=lstm_output_dim, num_heads=8, dropout=dropout)
115
+
116
+ def forward(self, x):
117
+ lstm_out, (h_n, c_n) = self.lstm(x)
118
+ attn_out, attn_weights = self.attention(lstm_out)
119
+ return attn_out, (h_n, c_n), attn_weights
120
+
121
+ class CNNLSTMAttention(nn.Module):
122
+ def __init__(self, input_size, num_classes, dropout=0.4, num_attention_heads=8):
123
+ super().__init__()
124
+ self.conv1 = nn.Conv1d(in_channels=input_size, out_channels=128, kernel_size=3, padding=1)
125
+ self.bn1 = nn.BatchNorm1d(128)
126
+ self.conv2 = nn.Conv1d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
127
+ self.bn2 = nn.BatchNorm1d(256)
128
+ self.conv3 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, padding=1)
129
+ self.bn3 = nn.BatchNorm1d(128)
130
+ self.dropout_cnn = nn.Dropout(dropout)
131
+
132
+ self.ae_lstm1 = AttentionEnhancedLSTM(128, 256, num_layers=1, bidirectional=True, dropout=dropout)
133
+ self.ae_lstm2 = AttentionEnhancedLSTM(512, 128, num_layers=1, bidirectional=True, dropout=dropout)
134
+ self.dropout_lstm = nn.Dropout(dropout)
135
+
136
+ self.temporal_attention = MultiHeadSelfAttention(embed_dim=256, num_heads=num_attention_heads, dropout=dropout)
137
+ self.attention_pool = nn.Linear(256, 1)
138
+
139
+ self.fc1 = nn.Linear(256, 128)
140
+ self.bn_fc = nn.BatchNorm1d(128)
141
+ self.fc2 = nn.Linear(128, 64)
142
+ self.dropout_fc = nn.Dropout(dropout)
143
+ self.output_layer = nn.Linear(64, num_classes)
144
+
145
+ def forward(self, x):
146
+ # x: (batch, seq_len, features=258)
147
+ x = x.permute(0, 2, 1) # (batch, features, seq_len)
148
+
149
+ x = torch.relu(self.bn1(self.conv1(x)))
150
+ x = self.dropout_cnn(x)
151
+ x = torch.relu(self.bn2(self.conv2(x)))
152
+ x = self.dropout_cnn(x)
153
+ x = torch.relu(self.bn3(self.conv3(x)))
154
+ x = self.dropout_cnn(x)
155
+
156
+ x = x.permute(0, 2, 1) # (batch, seq_len, channels=128)
157
+
158
+ x, _, _ = self.ae_lstm1(x) # -> (batch, seq_len, 512)
159
+ x = self.dropout_lstm(x)
160
+ x, _, _ = self.ae_lstm2(x) # -> (batch, seq_len, 256)
161
+ x = self.dropout_lstm(x)
162
+
163
+ attn_output, _ = self.temporal_attention(x) # (batch, seq_len, 256)
164
+ attention_scores = torch.softmax(self.attention_pool(attn_output), dim=1) # (batch, seq_len, 1)
165
+ pooled_output = torch.sum(attention_scores * attn_output, dim=1) # (batch, 256)
166
+
167
+ x = torch.relu(self.bn_fc(self.fc1(pooled_output)))
168
+ x = self.dropout_fc(x)
169
+ x = torch.relu(self.fc2(x))
170
+ x = self.dropout_fc(x)
171
+ x = self.output_layer(x)
172
+ return x
173
+
174
+ # ----------------------------
175
+ # Load trained weights
176
+ # ----------------------------
177
+ DEVICE = "cpu"
178
+ INPUT_SIZE = 258
179
+ SEQ_LEN = 30
180
+
181
+ model = CNNLSTMAttention(INPUT_SIZE, NUM_CLASSES, dropout=0.4, num_attention_heads=8)
182
+ state = torch.load("trained_model.pth", map_location=DEVICE)
183
+ model.load_state_dict(state, strict=True)
184
+ model.eval()
185
+
186
+ # One MediaPipe instance for the whole app (faster)
187
+ holistic = mp_holistic.Holistic(
188
+ min_detection_confidence=0.5,
189
+ min_tracking_confidence=0.5
190
+ )
191
+
192
+ # ----------------------------
193
+ # Gradio inference with state
194
+ # ----------------------------
195
+ def run(frame, sequence_state):
196
+ """
197
+ frame: numpy array from webcam (RGB)
198
+ sequence_state: list of last keypoint vectors
199
+ returns: annotated_frame (RGB), label dict, updated sequence_state
200
+ """
201
+ if sequence_state is None:
202
+ sequence_state = []
203
+
204
+ # Gradio gives RGB; MediaPipe helper expects BGR for cv2 conversions
205
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
206
+
207
+ image_bgr, results = mediapipe_detection(frame_bgr, holistic)
208
+ draw_styled_landmarks(image_bgr, results)
209
+
210
+ keypoints = extract_keypoints(results)
211
+ sequence_state.append(keypoints)
212
+ sequence_state = sequence_state[-SEQ_LEN:]
213
+
214
+ probs_dict = {}
215
+ pred_text = "Waiting..."
216
+ conf = 0.0
217
+
218
+ hands_present = (results.left_hand_landmarks is not None) or (results.right_hand_landmarks is not None)
219
+
220
+ if not hands_present:
221
+ pred_text = "No hands detected"
222
+ elif len(sequence_state) == SEQ_LEN:
223
+ x = torch.tensor(np.expand_dims(sequence_state, axis=0), dtype=torch.float32) # (1, 30, 258)
224
+ with torch.no_grad():
225
+ logits = model(x)
226
+ probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
227
+
228
+ top_idx = int(np.argmax(probs))
229
+ conf = float(probs[top_idx])
230
+ pred_text = f"{LABELS[top_idx]} ({conf:.2%})"
231
+ probs_dict = {LABELS[i]: float(probs[i]) for i in range(NUM_CLASSES)}
232
+
233
+ # Overlay prediction text
234
+ cv2.rectangle(image_bgr, (0, 0), (640, 45), (245, 117, 16), -1)
235
+ cv2.putText(
236
+ image_bgr,
237
+ pred_text,
238
+ (10, 30),
239
+ cv2.FONT_HERSHEY_SIMPLEX,
240
+ 0.9,
241
+ (255, 255, 255),
242
+ 2,
243
+ cv2.LINE_AA
244
+ )
245
+
246
+ # Back to RGB for Gradio display
247
+ out_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
248
+
249
+ # If probs_dict is empty (e.g., still warming up), show something stable
250
+ if not probs_dict:
251
+ probs_dict = {"(warming up)": 1.0}
252
+
253
+ return out_rgb, probs_dict, sequence_state
254
+
255
+ with gr.Blocks() as demo:
256
+ gr.Markdown("# Live Sign Language Gesture Demo (CNN-LSTM + Multi-Head Attention)")
257
+ gr.Markdown("Show your hand gesture to the webcam. Prediction starts after 30 frames are collected.")
258
+
259
+ seq_state = gr.State([])
260
+
261
+ with gr.Row():
262
+ cam = gr.Image(source="webcam", streaming=True, type="numpy", label="Webcam")
263
+ out_img = gr.Image(type="numpy", label="Output (Annotated)")
264
+
265
+ out_label = gr.Label(num_top_classes=5, label="Probabilities (Top 5)")
266
+
267
+ cam.stream(
268
+ fn=run,
269
+ inputs=[cam, seq_state],
270
+ outputs=[out_img, out_label, seq_state],
271
+ )
272
+
273
+ if __name__ == "__main__":
274
+ demo.launch()