Update app.py
Browse files
app.py
CHANGED
|
@@ -52,6 +52,10 @@ for _mod in ["OpenGL", "OpenGL.GL", "OpenGL.GL.framebufferobjects",
|
|
| 52 |
|
| 53 |
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
# --- Hugging Face Model Integration ---
|
| 56 |
REPO_ID = "SondosM/api_GP"
|
| 57 |
|
|
@@ -76,14 +80,11 @@ get_hf_file("mano_data/mano_data/mano_mean_params.npz", is_mano=True)
|
|
| 76 |
get_hf_file("mano_data/mano_data/MANO_LEFT.pkl", is_mano=True)
|
| 77 |
get_hf_file("mano_data/mano_data/MANO_RIGHT.pkl", is_mano=True)
|
| 78 |
|
| 79 |
-
WILOR_REPO_PATH
|
| 80 |
-
WILOR_CKPT
|
| 81 |
-
WILOR_CFG
|
| 82 |
-
DETECTOR_PATH
|
| 83 |
-
|
| 84 |
-
# ─── الفرق الأساسي: الكود الأول كان بيحمّل classifier.pkl من مسار محلي ثابت
|
| 85 |
-
# بدل ما يحمّله من HF زي باقي الملفات ─────────────────────────────────────────
|
| 86 |
-
CLASSIFIER_PATH = get_hf_file("classifier.pkl")
|
| 87 |
MLP_LETTERS_PATH = get_hf_file("MLP_letters.pkl")
|
| 88 |
MLP_NUMBERS_PATH = get_hf_file("MLP_numbers.pkl")
|
| 89 |
|
|
@@ -94,19 +95,20 @@ WILOR_TRANSFORM = transforms.Compose([
|
|
| 94 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 95 |
])
|
| 96 |
|
| 97 |
-
wilor_model
|
| 98 |
-
yolo_detector
|
| 99 |
-
|
| 100 |
-
mlp_letters
|
| 101 |
-
mlp_numbers
|
| 102 |
|
| 103 |
|
| 104 |
def load_models():
|
| 105 |
-
global wilor_model, yolo_detector,
|
| 106 |
|
| 107 |
sys.path.insert(0, WILOR_REPO_PATH)
|
| 108 |
from wilor.models import load_wilor
|
| 109 |
from ultralytics import YOLO
|
|
|
|
| 110 |
|
| 111 |
print(f"Loading WiLoR on {DEVICE}...")
|
| 112 |
wilor_model, _ = load_wilor(checkpoint_path=WILOR_CKPT, cfg_path=WILOR_CFG)
|
|
@@ -116,8 +118,10 @@ def load_models():
|
|
| 116 |
print("Loading YOLO detector...")
|
| 117 |
yolo_detector = YOLO(DETECTOR_PATH)
|
| 118 |
|
| 119 |
-
print("Loading
|
| 120 |
-
|
|
|
|
|
|
|
| 121 |
mlp_letters = joblib.load(MLP_LETTERS_PATH)
|
| 122 |
mlp_numbers = joblib.load(MLP_NUMBERS_PATH)
|
| 123 |
|
|
@@ -189,8 +193,6 @@ def read_image_from_upload(file_bytes: bytes) -> np.ndarray:
|
|
| 189 |
|
| 190 |
|
| 191 |
def _align_features(model, features: np.ndarray) -> np.ndarray:
|
| 192 |
-
# بعض الـ models (Pipeline مع StandardScaler) مش بيحفظوا feature_names_in_
|
| 193 |
-
# فبنستخدم n_features_in_ بدلها ونرجع array عادية مش DataFrame
|
| 194 |
if hasattr(model, "feature_names_in_"):
|
| 195 |
expected_cols = model.feature_names_in_
|
| 196 |
vec = np.zeros(len(expected_cols))
|
|
@@ -198,33 +200,25 @@ def _align_features(model, features: np.ndarray) -> np.ndarray:
|
|
| 198 |
vec[:limit] = features[:limit]
|
| 199 |
return pd.DataFrame([vec], columns=expected_cols)
|
| 200 |
else:
|
| 201 |
-
n
|
| 202 |
-
vec
|
| 203 |
limit = min(len(features), n)
|
| 204 |
vec[:limit] = features[:limit]
|
| 205 |
return vec.reshape(1, -1)
|
| 206 |
|
| 207 |
|
| 208 |
-
def run_two_stage(features: np.ndarray) -> dict:
|
| 209 |
-
# Stage 1:
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
cat_conf = float(classifier.predict_proba(feat_df)[0].max())
|
| 213 |
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
elif cat in ("number", "numbers", "digit", "digits", "رقم", "أرقام", "ارقام"):
|
| 219 |
-
model = mlp_numbers
|
| 220 |
-
else:
|
| 221 |
-
# fallback: pick whichever is more confident
|
| 222 |
-
feat_l = _align_features(mlp_letters, features)
|
| 223 |
-
feat_n = _align_features(mlp_numbers, features)
|
| 224 |
-
prob_l = float(mlp_letters.predict_proba(feat_l)[0].max())
|
| 225 |
-
prob_n = float(mlp_numbers.predict_proba(feat_n)[0].max())
|
| 226 |
-
model = mlp_letters if prob_l >= prob_n else mlp_numbers
|
| 227 |
|
|
|
|
|
|
|
| 228 |
feat_df = _align_features(model, features)
|
| 229 |
label = str(model.predict(feat_df)[0])
|
| 230 |
conf = float(model.predict_proba(feat_df)[0].max())
|
|
@@ -256,8 +250,8 @@ async def predict(file: UploadFile = File(...)):
|
|
| 256 |
if not results[0].boxes:
|
| 257 |
raise HTTPException(status_code=422, detail="No hand detected.")
|
| 258 |
|
| 259 |
-
box
|
| 260 |
-
label_id
|
| 261 |
hand_side = "left" if label_id == 0 else "right"
|
| 262 |
|
| 263 |
h, w = img_rgb.shape[:2]
|
|
@@ -271,7 +265,7 @@ async def predict(file: UploadFile = File(...)):
|
|
| 271 |
if features is None:
|
| 272 |
raise HTTPException(status_code=500, detail="Feature extraction failed.")
|
| 273 |
|
| 274 |
-
result = run_two_stage(features)
|
| 275 |
return JSONResponse({**result, "hand_side": hand_side, "bbox": [int(x1), int(y1), int(x2), int(y2)]})
|
| 276 |
|
| 277 |
|
|
@@ -285,8 +279,8 @@ async def predict_with_skeleton(file: UploadFile = File(...)):
|
|
| 285 |
if not results[0].boxes:
|
| 286 |
raise HTTPException(status_code=422, detail="No hand detected.")
|
| 287 |
|
| 288 |
-
box
|
| 289 |
-
label_id
|
| 290 |
hand_side = "left" if label_id == 0 else "right"
|
| 291 |
|
| 292 |
h, w = img_rgb.shape[:2]
|
|
@@ -296,7 +290,7 @@ async def predict_with_skeleton(file: UploadFile = File(...)):
|
|
| 296 |
features = extract_features(crop)
|
| 297 |
joints = get_3d_joints(crop)
|
| 298 |
|
| 299 |
-
result = run_two_stage(features)
|
| 300 |
_, buf = cv2.imencode(".png", cv2.cvtColor(crop, cv2.COLOR_RGB2BGR))
|
| 301 |
crop_b64 = base64.b64encode(buf).decode("utf-8")
|
| 302 |
|
|
|
|
| 52 |
|
| 53 |
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
| 54 |
|
| 55 |
+
# --- Router Model Classes ---
|
| 56 |
+
CLASSES = {0: "letter", 1: "number"}
|
| 57 |
+
IMG_SIZE = 64
|
| 58 |
+
|
| 59 |
# --- Hugging Face Model Integration ---
|
| 60 |
REPO_ID = "SondosM/api_GP"
|
| 61 |
|
|
|
|
| 80 |
get_hf_file("mano_data/mano_data/MANO_LEFT.pkl", is_mano=True)
|
| 81 |
get_hf_file("mano_data/mano_data/MANO_RIGHT.pkl", is_mano=True)
|
| 82 |
|
| 83 |
+
WILOR_REPO_PATH = "./WiLoR"
|
| 84 |
+
WILOR_CKPT = get_hf_file("pretrained_models/pretrained_models/wilor_final.ckpt")
|
| 85 |
+
WILOR_CFG = get_hf_file("pretrained_models/pretrained_models/model_config.yaml")
|
| 86 |
+
DETECTOR_PATH = get_hf_file("pretrained_models/pretrained_models/detector.pt")
|
| 87 |
+
ROUTER_MODEL_PATH = get_hf_file("router_model.keras")
|
|
|
|
|
|
|
|
|
|
| 88 |
MLP_LETTERS_PATH = get_hf_file("MLP_letters.pkl")
|
| 89 |
MLP_NUMBERS_PATH = get_hf_file("MLP_numbers.pkl")
|
| 90 |
|
|
|
|
| 95 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 96 |
])
|
| 97 |
|
| 98 |
+
wilor_model = None
|
| 99 |
+
yolo_detector = None
|
| 100 |
+
router_model_keras = None
|
| 101 |
+
mlp_letters = None
|
| 102 |
+
mlp_numbers = None
|
| 103 |
|
| 104 |
|
| 105 |
def load_models():
|
| 106 |
+
global wilor_model, yolo_detector, router_model_keras, mlp_letters, mlp_numbers
|
| 107 |
|
| 108 |
sys.path.insert(0, WILOR_REPO_PATH)
|
| 109 |
from wilor.models import load_wilor
|
| 110 |
from ultralytics import YOLO
|
| 111 |
+
from tensorflow.keras.models import load_model
|
| 112 |
|
| 113 |
print(f"Loading WiLoR on {DEVICE}...")
|
| 114 |
wilor_model, _ = load_wilor(checkpoint_path=WILOR_CKPT, cfg_path=WILOR_CFG)
|
|
|
|
| 118 |
print("Loading YOLO detector...")
|
| 119 |
yolo_detector = YOLO(DETECTOR_PATH)
|
| 120 |
|
| 121 |
+
print("Loading router model (Keras)...")
|
| 122 |
+
router_model_keras = load_model(ROUTER_MODEL_PATH)
|
| 123 |
+
|
| 124 |
+
print("Loading MLP classifiers...")
|
| 125 |
mlp_letters = joblib.load(MLP_LETTERS_PATH)
|
| 126 |
mlp_numbers = joblib.load(MLP_NUMBERS_PATH)
|
| 127 |
|
|
|
|
| 193 |
|
| 194 |
|
| 195 |
def _align_features(model, features: np.ndarray) -> np.ndarray:
|
|
|
|
|
|
|
| 196 |
if hasattr(model, "feature_names_in_"):
|
| 197 |
expected_cols = model.feature_names_in_
|
| 198 |
vec = np.zeros(len(expected_cols))
|
|
|
|
| 200 |
vec[:limit] = features[:limit]
|
| 201 |
return pd.DataFrame([vec], columns=expected_cols)
|
| 202 |
else:
|
| 203 |
+
n = model.n_features_in_
|
| 204 |
+
vec = np.zeros(n)
|
| 205 |
limit = min(len(features), n)
|
| 206 |
vec[:limit] = features[:limit]
|
| 207 |
return vec.reshape(1, -1)
|
| 208 |
|
| 209 |
|
| 210 |
+
def run_two_stage(features: np.ndarray, crop_rgb: np.ndarray) -> dict:
|
| 211 |
+
# Stage 1: router_model.keras يحدد حرف (0) أو رقم (1)
|
| 212 |
+
img_resized = cv2.resize(crop_rgb, (IMG_SIZE, IMG_SIZE))
|
| 213 |
+
img_array = np.expand_dims(img_resized, axis=0).astype("float32") / 255.0
|
|
|
|
| 214 |
|
| 215 |
+
prob = float(router_model_keras.predict(img_array, verbose=0)[0][0])
|
| 216 |
+
cls_idx = 1 if prob >= 0.5 else 0
|
| 217 |
+
category = CLASSES[cls_idx]
|
| 218 |
+
cat_conf = prob if cls_idx == 1 else 1.0 - prob
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
+
# Stage 2: اختار الموديل الصح بناءً على النتيجة
|
| 221 |
+
model = mlp_letters if category == "letter" else mlp_numbers
|
| 222 |
feat_df = _align_features(model, features)
|
| 223 |
label = str(model.predict(feat_df)[0])
|
| 224 |
conf = float(model.predict_proba(feat_df)[0].max())
|
|
|
|
| 250 |
if not results[0].boxes:
|
| 251 |
raise HTTPException(status_code=422, detail="No hand detected.")
|
| 252 |
|
| 253 |
+
box = results[0].boxes.xyxy[0].cpu().numpy().astype(int)
|
| 254 |
+
label_id = int(results[0].boxes.cls[0].cpu().item())
|
| 255 |
hand_side = "left" if label_id == 0 else "right"
|
| 256 |
|
| 257 |
h, w = img_rgb.shape[:2]
|
|
|
|
| 265 |
if features is None:
|
| 266 |
raise HTTPException(status_code=500, detail="Feature extraction failed.")
|
| 267 |
|
| 268 |
+
result = run_two_stage(features, crop)
|
| 269 |
return JSONResponse({**result, "hand_side": hand_side, "bbox": [int(x1), int(y1), int(x2), int(y2)]})
|
| 270 |
|
| 271 |
|
|
|
|
| 279 |
if not results[0].boxes:
|
| 280 |
raise HTTPException(status_code=422, detail="No hand detected.")
|
| 281 |
|
| 282 |
+
box = results[0].boxes.xyxy[0].cpu().numpy().astype(int)
|
| 283 |
+
label_id = int(results[0].boxes.cls[0].cpu().item())
|
| 284 |
hand_side = "left" if label_id == 0 else "right"
|
| 285 |
|
| 286 |
h, w = img_rgb.shape[:2]
|
|
|
|
| 290 |
features = extract_features(crop)
|
| 291 |
joints = get_3d_joints(crop)
|
| 292 |
|
| 293 |
+
result = run_two_stage(features, crop)
|
| 294 |
_, buf = cv2.imencode(".png", cv2.cvtColor(crop, cv2.COLOR_RGB2BGR))
|
| 295 |
crop_b64 = base64.b64encode(buf).decode("utf-8")
|
| 296 |
|