SondosM commited on
Commit
2d5e733
·
verified ·
1 Parent(s): 634b247

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -44
app.py CHANGED
@@ -52,6 +52,10 @@ for _mod in ["OpenGL", "OpenGL.GL", "OpenGL.GL.framebufferobjects",
52
 
53
  os.environ["PYOPENGL_PLATFORM"] = "osmesa"
54
 
 
 
 
 
55
  # --- Hugging Face Model Integration ---
56
  REPO_ID = "SondosM/api_GP"
57
 
@@ -76,14 +80,11 @@ get_hf_file("mano_data/mano_data/mano_mean_params.npz", is_mano=True)
76
  get_hf_file("mano_data/mano_data/MANO_LEFT.pkl", is_mano=True)
77
  get_hf_file("mano_data/mano_data/MANO_RIGHT.pkl", is_mano=True)
78
 
79
- WILOR_REPO_PATH = "./WiLoR"
80
- WILOR_CKPT = get_hf_file("pretrained_models/pretrained_models/wilor_final.ckpt")
81
- WILOR_CFG = get_hf_file("pretrained_models/pretrained_models/model_config.yaml")
82
- DETECTOR_PATH = get_hf_file("pretrained_models/pretrained_models/detector.pt")
83
-
84
- # ─── الفرق الأساسي: الكود الأول كان بيحمّل classifier.pkl من مسار محلي ثابت
85
- # بدل ما يحمّله من HF زي باقي الملفات ─────────────────────────────────────────
86
- CLASSIFIER_PATH = get_hf_file("classifier.pkl")
87
  MLP_LETTERS_PATH = get_hf_file("MLP_letters.pkl")
88
  MLP_NUMBERS_PATH = get_hf_file("MLP_numbers.pkl")
89
 
@@ -94,19 +95,20 @@ WILOR_TRANSFORM = transforms.Compose([
94
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
95
  ])
96
 
97
- wilor_model = None
98
- yolo_detector = None
99
- classifier = None
100
- mlp_letters = None
101
- mlp_numbers = None
102
 
103
 
104
  def load_models():
105
- global wilor_model, yolo_detector, classifier, mlp_letters, mlp_numbers
106
 
107
  sys.path.insert(0, WILOR_REPO_PATH)
108
  from wilor.models import load_wilor
109
  from ultralytics import YOLO
 
110
 
111
  print(f"Loading WiLoR on {DEVICE}...")
112
  wilor_model, _ = load_wilor(checkpoint_path=WILOR_CKPT, cfg_path=WILOR_CFG)
@@ -116,8 +118,10 @@ def load_models():
116
  print("Loading YOLO detector...")
117
  yolo_detector = YOLO(DETECTOR_PATH)
118
 
119
- print("Loading classifiers...")
120
- classifier = joblib.load(CLASSIFIER_PATH)
 
 
121
  mlp_letters = joblib.load(MLP_LETTERS_PATH)
122
  mlp_numbers = joblib.load(MLP_NUMBERS_PATH)
123
 
@@ -189,8 +193,6 @@ def read_image_from_upload(file_bytes: bytes) -> np.ndarray:
189
 
190
 
191
  def _align_features(model, features: np.ndarray) -> np.ndarray:
192
- # بعض الـ models (Pipeline مع StandardScaler) مش بيحفظوا feature_names_in_
193
- # فبنستخدم n_features_in_ بدلها ونرجع array عادية مش DataFrame
194
  if hasattr(model, "feature_names_in_"):
195
  expected_cols = model.feature_names_in_
196
  vec = np.zeros(len(expected_cols))
@@ -198,33 +200,25 @@ def _align_features(model, features: np.ndarray) -> np.ndarray:
198
  vec[:limit] = features[:limit]
199
  return pd.DataFrame([vec], columns=expected_cols)
200
  else:
201
- n = model.n_features_in_
202
- vec = np.zeros(n)
203
  limit = min(len(features), n)
204
  vec[:limit] = features[:limit]
205
  return vec.reshape(1, -1)
206
 
207
 
208
- def run_two_stage(features: np.ndarray) -> dict:
209
- # Stage 1: letter or number?
210
- feat_df = _align_features(classifier, features)
211
- category = str(classifier.predict(feat_df)[0])
212
- cat_conf = float(classifier.predict_proba(feat_df)[0].max())
213
 
214
- # Stage 2: which sign exactly?
215
- cat = category.lower().strip()
216
- if cat in ("letter", "letters", "حرف", "حروف"):
217
- model = mlp_letters
218
- elif cat in ("number", "numbers", "digit", "digits", "رقم", "أرقام", "ارقام"):
219
- model = mlp_numbers
220
- else:
221
- # fallback: pick whichever is more confident
222
- feat_l = _align_features(mlp_letters, features)
223
- feat_n = _align_features(mlp_numbers, features)
224
- prob_l = float(mlp_letters.predict_proba(feat_l)[0].max())
225
- prob_n = float(mlp_numbers.predict_proba(feat_n)[0].max())
226
- model = mlp_letters if prob_l >= prob_n else mlp_numbers
227
 
 
 
228
  feat_df = _align_features(model, features)
229
  label = str(model.predict(feat_df)[0])
230
  conf = float(model.predict_proba(feat_df)[0].max())
@@ -256,8 +250,8 @@ async def predict(file: UploadFile = File(...)):
256
  if not results[0].boxes:
257
  raise HTTPException(status_code=422, detail="No hand detected.")
258
 
259
- box = results[0].boxes.xyxy[0].cpu().numpy().astype(int)
260
- label_id = int(results[0].boxes.cls[0].cpu().item())
261
  hand_side = "left" if label_id == 0 else "right"
262
 
263
  h, w = img_rgb.shape[:2]
@@ -271,7 +265,7 @@ async def predict(file: UploadFile = File(...)):
271
  if features is None:
272
  raise HTTPException(status_code=500, detail="Feature extraction failed.")
273
 
274
- result = run_two_stage(features)
275
  return JSONResponse({**result, "hand_side": hand_side, "bbox": [int(x1), int(y1), int(x2), int(y2)]})
276
 
277
 
@@ -285,8 +279,8 @@ async def predict_with_skeleton(file: UploadFile = File(...)):
285
  if not results[0].boxes:
286
  raise HTTPException(status_code=422, detail="No hand detected.")
287
 
288
- box = results[0].boxes.xyxy[0].cpu().numpy().astype(int)
289
- label_id = int(results[0].boxes.cls[0].cpu().item())
290
  hand_side = "left" if label_id == 0 else "right"
291
 
292
  h, w = img_rgb.shape[:2]
@@ -296,7 +290,7 @@ async def predict_with_skeleton(file: UploadFile = File(...)):
296
  features = extract_features(crop)
297
  joints = get_3d_joints(crop)
298
 
299
- result = run_two_stage(features)
300
  _, buf = cv2.imencode(".png", cv2.cvtColor(crop, cv2.COLOR_RGB2BGR))
301
  crop_b64 = base64.b64encode(buf).decode("utf-8")
302
 
 
52
 
53
  os.environ["PYOPENGL_PLATFORM"] = "osmesa"
54
 
55
+ # --- Router Model Classes ---
56
+ CLASSES = {0: "letter", 1: "number"}
57
+ IMG_SIZE = 64
58
+
59
  # --- Hugging Face Model Integration ---
60
  REPO_ID = "SondosM/api_GP"
61
 
 
80
  get_hf_file("mano_data/mano_data/MANO_LEFT.pkl", is_mano=True)
81
  get_hf_file("mano_data/mano_data/MANO_RIGHT.pkl", is_mano=True)
82
 
83
+ WILOR_REPO_PATH = "./WiLoR"
84
+ WILOR_CKPT = get_hf_file("pretrained_models/pretrained_models/wilor_final.ckpt")
85
+ WILOR_CFG = get_hf_file("pretrained_models/pretrained_models/model_config.yaml")
86
+ DETECTOR_PATH = get_hf_file("pretrained_models/pretrained_models/detector.pt")
87
+ ROUTER_MODEL_PATH = get_hf_file("router_model.keras")
 
 
 
88
  MLP_LETTERS_PATH = get_hf_file("MLP_letters.pkl")
89
  MLP_NUMBERS_PATH = get_hf_file("MLP_numbers.pkl")
90
 
 
95
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
96
  ])
97
 
98
+ wilor_model = None
99
+ yolo_detector = None
100
+ router_model_keras = None
101
+ mlp_letters = None
102
+ mlp_numbers = None
103
 
104
 
105
  def load_models():
106
+ global wilor_model, yolo_detector, router_model_keras, mlp_letters, mlp_numbers
107
 
108
  sys.path.insert(0, WILOR_REPO_PATH)
109
  from wilor.models import load_wilor
110
  from ultralytics import YOLO
111
+ from tensorflow.keras.models import load_model
112
 
113
  print(f"Loading WiLoR on {DEVICE}...")
114
  wilor_model, _ = load_wilor(checkpoint_path=WILOR_CKPT, cfg_path=WILOR_CFG)
 
118
  print("Loading YOLO detector...")
119
  yolo_detector = YOLO(DETECTOR_PATH)
120
 
121
+ print("Loading router model (Keras)...")
122
+ router_model_keras = load_model(ROUTER_MODEL_PATH)
123
+
124
+ print("Loading MLP classifiers...")
125
  mlp_letters = joblib.load(MLP_LETTERS_PATH)
126
  mlp_numbers = joblib.load(MLP_NUMBERS_PATH)
127
 
 
193
 
194
 
195
  def _align_features(model, features: np.ndarray) -> np.ndarray:
 
 
196
  if hasattr(model, "feature_names_in_"):
197
  expected_cols = model.feature_names_in_
198
  vec = np.zeros(len(expected_cols))
 
200
  vec[:limit] = features[:limit]
201
  return pd.DataFrame([vec], columns=expected_cols)
202
  else:
203
+ n = model.n_features_in_
204
+ vec = np.zeros(n)
205
  limit = min(len(features), n)
206
  vec[:limit] = features[:limit]
207
  return vec.reshape(1, -1)
208
 
209
 
210
+ def run_two_stage(features: np.ndarray, crop_rgb: np.ndarray) -> dict:
211
+ # Stage 1: router_model.keras يحدد حرف (0) أو رقم (1)
212
+ img_resized = cv2.resize(crop_rgb, (IMG_SIZE, IMG_SIZE))
213
+ img_array = np.expand_dims(img_resized, axis=0).astype("float32") / 255.0
 
214
 
215
+ prob = float(router_model_keras.predict(img_array, verbose=0)[0][0])
216
+ cls_idx = 1 if prob >= 0.5 else 0
217
+ category = CLASSES[cls_idx]
218
+ cat_conf = prob if cls_idx == 1 else 1.0 - prob
 
 
 
 
 
 
 
 
 
219
 
220
+ # Stage 2: اختار الموديل الصح بناءً على النتيجة
221
+ model = mlp_letters if category == "letter" else mlp_numbers
222
  feat_df = _align_features(model, features)
223
  label = str(model.predict(feat_df)[0])
224
  conf = float(model.predict_proba(feat_df)[0].max())
 
250
  if not results[0].boxes:
251
  raise HTTPException(status_code=422, detail="No hand detected.")
252
 
253
+ box = results[0].boxes.xyxy[0].cpu().numpy().astype(int)
254
+ label_id = int(results[0].boxes.cls[0].cpu().item())
255
  hand_side = "left" if label_id == 0 else "right"
256
 
257
  h, w = img_rgb.shape[:2]
 
265
  if features is None:
266
  raise HTTPException(status_code=500, detail="Feature extraction failed.")
267
 
268
+ result = run_two_stage(features, crop)
269
  return JSONResponse({**result, "hand_side": hand_side, "bbox": [int(x1), int(y1), int(x2), int(y2)]})
270
 
271
 
 
279
  if not results[0].boxes:
280
  raise HTTPException(status_code=422, detail="No hand detected.")
281
 
282
+ box = results[0].boxes.xyxy[0].cpu().numpy().astype(int)
283
+ label_id = int(results[0].boxes.cls[0].cpu().item())
284
  hand_side = "left" if label_id == 0 else "right"
285
 
286
  h, w = img_rgb.shape[:2]
 
290
  features = extract_features(crop)
291
  joints = get_3d_joints(crop)
292
 
293
+ result = run_two_stage(features, crop)
294
  _, buf = cv2.imencode(".png", cv2.cvtColor(crop, cv2.COLOR_RGB2BGR))
295
  crop_b64 = base64.b64encode(buf).decode("utf-8")
296