XiaoBai1221 commited on
Commit
3bf3b96
·
1 Parent(s): d68547d

🔧 使用 pydantic==2.10.6 修復 schema 錯誤

Browse files
Files changed (2) hide show
  1. app.py +236 -68
  2. requirements.txt +2 -2
app.py CHANGED
@@ -1,90 +1,258 @@
1
  import os
2
  import cv2
3
  import numpy as np
 
 
4
  import gradio as gr
 
 
 
5
 
6
- # 檢查檔案是否存在
7
- model_path = "tsflow/models/best_model.pt"
8
- config_path = "tsflow/results/test_results.json"
 
9
 
10
- # 如果是在SignView2.0目錄下運行,調整路徑
11
- if not os.path.exists(model_path):
12
- model_path = "../tsflow/models/best_model.pt"
13
- config_path = "../tsflow/results/test_results.json"
14
 
15
- try:
16
- from realtime_sign_prediction import RealtimeSignPredictor
17
-
18
- # 初始化預測器
19
- print("🚀 正在初始化手語辨識系統...")
20
- predictor = RealtimeSignPredictor(
21
- model_path=model_path,
22
- config_path=config_path,
23
- sequence_length=50,
24
- use_segmentation=True
25
- )
26
- print("✅ 手語辨識系統初始化完成!")
27
- MODEL_LOADED = True
28
-
29
- except Exception as e:
30
- print(f"⚠️ 模型載入失敗: {e}")
31
- print("🔄 使用模擬模式運行...")
32
- MODEL_LOADED = False
33
 
34
- def predict_sign_language(image):
35
- """簡化的預測函數,返回單一字串結果"""
36
- if image is None:
37
- return "請上傳影像"
38
-
39
- if not MODEL_LOADED:
40
- return "⚠️ 模型未載入,無法進行預測"
41
-
42
- try:
43
- # 處理畫面
44
- results, keypoints, flow_features = predictor.process_frame(image)
45
 
46
- # 獲取預測結果
47
- top_predictions = predictor.get_top_predictions(top_k=3)
 
48
 
49
- # 格式化預測結果為簡單字串
50
- if top_predictions:
51
- result_text = "🎯 手語辨識結果:\n\n"
52
- for i, (label, confidence) in enumerate(top_predictions, 1):
53
- result_text += f"{i}. {label}: {confidence:.2%}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- result_text += f"\n📊 序列進度: {len(predictor.keypoint_sequence)}/{predictor.sequence_length}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  else:
57
- result_text = "📡 正在收集動作序列...\n請確保手語動作清晰可見"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- return result_text
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  except Exception as e:
62
- return f"處理錯誤: {str(e)}"
63
 
64
- # 使用最簡單的Interface避開所有schema問題
65
  demo = gr.Interface(
66
- fn=predict_sign_language,
67
- inputs=gr.Image(),
68
- outputs=gr.Textbox(lines=10),
 
 
 
69
  title="🤟 SignView2.0 - 手語辨識系統",
70
- description="支援34種手語詞彙的即時辨識系統,準確率達94.25%\n\n上傳影像即可開始辨識手語動作",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  flagging_mode="never"
72
  )
73
 
74
  if __name__ == "__main__":
75
- print("🎉 SignView2.0 手語辨識系統已啟動!")
76
-
77
- # 根據環境自動選擇最佳配置
78
- import os
79
- try:
80
- # 嘗試最簡單的launch,讓Gradio自己處理
81
- demo.launch()
82
- except Exception as e:
83
- print(f"預設啟動失敗,嘗試備用方案: {e}")
84
- try:
85
- # 如果在Spaces環境,強制使用share=True
86
- demo.launch(share=True)
87
- except Exception as e2:
88
- print(f"備用方案也失敗: {e2}")
89
- # 最後嘗試基本配置
90
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
1
  import os
2
  import cv2
3
  import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
  import gradio as gr
7
+ from pathlib import Path
8
+ import mediapipe as mp
9
+ import pickle
10
 
11
+ # MediaPipe設定
12
+ mp_pose = mp.solutions.pose
13
+ mp_hands = mp.solutions.hands
14
+ mp_face_mesh = mp.solutions.face_mesh
15
 
16
+ # 設定設備
17
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+ print(f"使用設備: {device}")
 
19
 
20
+ # 載入標籤映射
21
+ label_to_idx = {'again': 0, 'all': 1, 'apple': 2, 'bad': 3, 'bathroom': 4, 'beautiful': 5, 'bird': 6, 'black': 7, 'blue': 8, 'book': 9, 'bored': 10, 'boy': 11, 'brother': 12, 'brown': 13, 'but': 14, 'computer': 15, 'cousin': 16, 'dance': 17, 'day': 18, 'deaf': 19, 'doctor': 20, 'dog': 21, 'draw': 22, 'drink': 23, 'eat': 24, 'english': 25, 'family': 26, 'father': 27, 'fine': 28, 'finish': 29, 'fish': 30, 'forget': 31, 'friend': 32, 'girl': 33}
22
+ idx_to_label = {v: k for k, v in label_to_idx.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ class BiLSTMWithAttention(nn.Module):
25
+ def __init__(self, input_size, hidden_size, num_layers, num_classes, dropout=0.5):
26
+ super(BiLSTMWithAttention, self).__init__()
27
+ self.hidden_size = hidden_size
28
+ self.num_layers = num_layers
29
+
30
+ self.bilstm = nn.LSTM(input_size, hidden_size, num_layers,
31
+ batch_first=True, bidirectional=True, dropout=dropout)
32
+
33
+ # 注意力機制
34
+ self.attention = nn.Linear(hidden_size * 2, 1)
35
 
36
+ # 分類層
37
+ self.classifier = nn.Linear(hidden_size * 2, num_classes)
38
+ self.dropout = nn.Dropout(dropout)
39
 
40
+ def forward(self, x):
41
+ batch_size = x.size(0)
42
+
43
+ # LSTM前向傳播
44
+ lstm_out, _ = self.bilstm(x)
45
+
46
+ # 注意力權重計算
47
+ attention_weights = torch.softmax(self.attention(lstm_out), dim=1)
48
+
49
+ # 加權平均
50
+ context_vector = torch.sum(attention_weights * lstm_out, dim=1)
51
+
52
+ # 分類
53
+ output = self.classifier(self.dropout(context_vector))
54
+
55
+ return output
56
+
57
+ # 初始化模型
58
+ input_size = 258 # keypoints (75*2) + optical_flow (108)
59
+ hidden_size = 256
60
+ num_layers = 3
61
+ num_classes = len(label_to_idx)
62
+
63
+ model = BiLSTMWithAttention(input_size, hidden_size, num_layers, num_classes)
64
+ model = model.to(device)
65
+
66
+ # 載入模型權重
67
+ model_path = Path("tsflow/models/best_model.pt")
68
+ if model_path.exists():
69
+ try:
70
+ checkpoint = torch.load(model_path, map_location=device)
71
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
72
+ model.load_state_dict(checkpoint['model_state_dict'])
73
+ else:
74
+ model.load_state_dict(checkpoint)
75
+ model.eval()
76
+ print("✅ 模型載入成功")
77
+ except Exception as e:
78
+ print(f"❌ 模型載入失敗: {e}")
79
+ raise
80
+ else:
81
+ print(f"❌ 找不到模型檔案: {model_path}")
82
+ raise FileNotFoundError(f"模型檔案不存在: {model_path}")
83
+
84
+ def extract_keypoints_from_frame(frame):
85
+ """從單個frame提取關鍵點"""
86
+ try:
87
+ with mp_pose.Pose(static_image_mode=True, model_complexity=1) as pose, \
88
+ mp_hands.Hands(static_image_mode=True, max_num_hands=2) as hands:
89
+
90
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
91
 
92
+ keypoints = []
93
+
94
+ # 提取姿勢關鍵點
95
+ pose_results = pose.process(rgb_frame)
96
+ if pose_results.pose_landmarks:
97
+ pose_points = []
98
+ for landmark in pose_results.pose_landmarks.landmark:
99
+ pose_points.extend([landmark.x, landmark.y])
100
+ keypoints.extend(pose_points)
101
+ else:
102
+ keypoints.extend([0.0] * 66) # 33個姿勢點 * 2
103
+
104
+ # 提取手部關鍵點
105
+ hands_results = hands.process(rgb_frame)
106
+ if hands_results.multi_hand_landmarks:
107
+ hand_points = []
108
+ for hand_landmarks in hands_results.multi_hand_landmarks:
109
+ for landmark in hand_landmarks.landmark:
110
+ hand_points.extend([landmark.x, landmark.y])
111
+ if len(hand_points) >= 42: # 至少一隻手
112
+ keypoints.extend(hand_points[:42])
113
+ else:
114
+ keypoints.extend(hand_points + [0.0] * (42 - len(hand_points)))
115
+ else:
116
+ keypoints.extend([0.0] * 42) # 21個手部點 * 2
117
+
118
+ return np.array(keypoints, dtype=np.float32)
119
+ except Exception as e:
120
+ print(f"關鍵點提取錯誤: {e}")
121
+ return np.zeros(150, dtype=np.float32)
122
+
123
+ def calculate_optical_flow_features(frames):
124
+ """計算光流特徵"""
125
+ try:
126
+ if len(frames) < 2:
127
+ return np.zeros(108, dtype=np.float32)
128
+
129
+ flow_features = []
130
+ for i in range(len(frames) - 1):
131
+ gray1 = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
132
+ gray2 = cv2.cvtColor(frames[i + 1], cv2.COLOR_BGR2GRAY)
133
+
134
+ flow = cv2.calcOpticalFlowPyrLK(
135
+ gray1, gray2, None, None,
136
+ winSize=(15, 15),
137
+ maxLevel=2,
138
+ criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 0.03)
139
+ )
140
+
141
+ if flow[0] is not None and len(flow[0]) > 0:
142
+ flow_features.extend(flow[0].flatten()[:54])
143
+ else:
144
+ flow_features.extend([0.0] * 54)
145
+
146
+ if len(flow_features) >= 108:
147
+ return np.array(flow_features[:108], dtype=np.float32)
148
  else:
149
+ return np.array(flow_features + [0.0] * (108 - len(flow_features)), dtype=np.float32)
150
+ except Exception as e:
151
+ print(f"光流計算錯誤: {e}")
152
+ return np.zeros(108, dtype=np.float32)
153
+
154
+ def predict_sign_language(video_path):
155
+ """預測手語影片"""
156
+ try:
157
+ cap = cv2.VideoCapture(video_path)
158
+ frames = []
159
+
160
+ while True:
161
+ ret, frame = cap.read()
162
+ if not ret:
163
+ break
164
+ frames.append(frame)
165
+
166
+ cap.release()
167
+
168
+ if len(frames) == 0:
169
+ return "錯誤:無法讀取影片幀", 0.0
170
+
171
+ # 提取特徵
172
+ keypoints_sequence = []
173
+ for frame in frames:
174
+ keypoints = extract_keypoints_from_frame(frame)
175
+ keypoints_sequence.append(keypoints)
176
+
177
+ optical_flow = calculate_optical_flow_features(frames)
178
+
179
+ # 確保序列長度為104
180
+ target_length = 104
181
+ if len(keypoints_sequence) > target_length:
182
+ keypoints_sequence = keypoints_sequence[:target_length]
183
+ elif len(keypoints_sequence) < target_length:
184
+ last_frame = keypoints_sequence[-1] if keypoints_sequence else np.zeros(150)
185
+ while len(keypoints_sequence) < target_length:
186
+ keypoints_sequence.append(last_frame)
187
+
188
+ # 組合特徵
189
+ features_sequence = []
190
+ for i, keypoints in enumerate(keypoints_sequence):
191
+ if i < len(optical_flow) // 54:
192
+ flow_feature = optical_flow[i*54:(i+1)*54]
193
+ else:
194
+ flow_feature = np.zeros(54)
195
+
196
+ combined_features = np.concatenate([keypoints, flow_feature, np.zeros(54)])
197
+ features_sequence.append(combined_features)
198
 
199
+ # 轉換為tensor並預測
200
+ features_tensor = torch.tensor([features_sequence], dtype=torch.float32).to(device)
201
 
202
+ with torch.no_grad():
203
+ outputs = model(features_tensor)
204
+ probabilities = torch.softmax(outputs, dim=1)
205
+ predicted_class = torch.argmax(probabilities, dim=1).item()
206
+ confidence = probabilities[0][predicted_class].item()
207
+
208
+ predicted_label = idx_to_label.get(predicted_class, "未知")
209
+
210
+ return f"預測結果: {predicted_label}", confidence
211
+
212
+ except Exception as e:
213
+ print(f"預測錯誤: {e}")
214
+ return f"預測失敗: {str(e)}", 0.0
215
+
216
+ def gradio_predict(video):
217
+ """Gradio介面的預測函數"""
218
+ if video is None:
219
+ return "請上傳影片", "信心度: 0%"
220
+
221
+ try:
222
+ result, confidence = predict_sign_language(video)
223
+ confidence_text = f"信心度: {confidence:.2%}"
224
+ return result, confidence_text
225
  except Exception as e:
226
+ return f"處理錯誤: {str(e)}", "信心度: 0%"
227
 
228
+ # 建立Gradio介面
229
  demo = gr.Interface(
230
+ fn=gradio_predict,
231
+ inputs=gr.Video(label="上傳手語影片"),
232
+ outputs=[
233
+ gr.Textbox(label="預測結果"),
234
+ gr.Textbox(label="信心度")
235
+ ],
236
  title="🤟 SignView2.0 - 手語辨識系統",
237
+ description="""
238
+ ### 歡迎使用 SignView2.0 手語辨識系統!
239
+
240
+ **系統特色:**
241
+ - 🎯 準確率:94.25%
242
+ - 📚 支援34種手語詞彙
243
+ - 🧠 使用BiLSTM + 注意力機制
244
+ - 👁️ MediaPipe + 光流特徵融合
245
+
246
+ **使用方法:**
247
+ 1. 上傳手語影片(建議3-4秒)
248
+ 2. 點擊提交進行辨識
249
+ 3. 查看預測結果和信心度
250
+
251
+ **支援詞彙:** again, all, apple, bad, bathroom, beautiful, bird, black, blue, book, bored, boy, brother, brown, but, computer, cousin, dance, day, deaf, doctor, dog, draw, drink, eat, english, family, father, fine, finish, fish, forget, friend, girl
252
+ """,
253
+ examples=[],
254
  flagging_mode="never"
255
  )
256
 
257
  if __name__ == "__main__":
258
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
- gradio==4.38.1
2
  torch>=2.0.0
3
  torchvision>=0.15.0
4
  opencv-python>=4.8.0
5
  mediapipe>=0.10.0
6
  numpy>=1.24.0
7
  Pillow>=9.5.0
8
- scipy>=1.10.0
 
1
+ gradio==4.44.0
2
  torch>=2.0.0
3
  torchvision>=0.15.0
4
  opencv-python>=4.8.0
5
  mediapipe>=0.10.0
6
  numpy>=1.24.0
7
  Pillow>=9.5.0
8
+ scipy>=1.10.0