SondosM commited on
Commit
e651a0f
·
verified ·
1 Parent(s): 3b7f1f4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +237 -0
app.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import base64
3
+ import inspect
4
+ import sys
5
+ import os
6
+ import types
7
+ from unittest.mock import MagicMock
8
+
9
+ import numpy as np
10
+ import cv2
11
+ import torch
12
+ import joblib
13
+ import pandas as pd
14
+ from pathlib import Path
15
+ from scipy.spatial import distance
16
+ from torchvision import transforms
17
+ from PIL import Image
18
+ from contextlib import asynccontextmanager
19
+
20
+ from fastapi import FastAPI, File, UploadFile, HTTPException
21
+ from fastapi.middleware.cors import CORSMiddleware
22
+ from fastapi.responses import JSONResponse
23
+ import uvicorn
24
+ from huggingface_hub import hf_hub_download
25
+
26
+ # --- Compatibility Patches ---
27
+ if not hasattr(inspect, "getargspec"):
28
+ inspect.getargspec = inspect.getfullargspec
29
+
30
+ for attr, typ in [("int", int), ("float", float), ("complex", complex),
31
+ ("bool", bool), ("object", object), ("str", str), ("unicode", str)]:
32
+ if not hasattr(np, attr):
33
+ setattr(np, attr, typ)
34
+
35
+ # --- Pyrender / OpenGL Mock (Headless Fix) ---
36
+ pyrender_mock = types.ModuleType("pyrender")
37
+ for _attr in ["Scene", "Mesh", "Node", "PerspectiveCamera", "DirectionalLight",
38
+ "PointLight", "SpotLight", "OffscreenRenderer", "RenderFlags",
39
+ "Viewer", "MetallicRoughnessMaterial"]:
40
+ setattr(pyrender_mock, _attr, MagicMock)
41
+ sys.modules["pyrender"] = pyrender_mock
42
+
43
+ for _mod in ["OpenGL", "OpenGL.GL", "OpenGL.GL.framebufferobjects",
44
+ "OpenGL.platform", "OpenGL.error"]:
45
+ if _mod not in sys.modules:
46
+ sys.modules[_mod] = types.ModuleType(_mod)
47
+
48
+ os.environ["PYOPENGL_PLATFORM"] = "osmesa"
49
+
50
+ # --- Hugging Face Model Integration ---
51
+ REPO_ID = "SondosM/api_GP"
52
+
53
+ def get_hf_file(filename):
54
+ return hf_hub_download(repo_id=REPO_ID, filename=filename)
55
+
56
+ # Resolve paths from Hugging Face Repo
57
+ WILOR_REPO_PATH = "./WiLoR"
58
+ WILOR_CKPT = get_hf_file("pretrained_models/pretrained_models/wilor_final.ckpt")
59
+ WILOR_CFG = get_hf_file("pretrained_models/pretrained_models/model_config.yaml")
60
+ DETECTOR_PATH = get_hf_file("pretrained_models/pretrained_models/detector.pt")
61
+ CLASSIFIER_PATH = get_hf_file("classifier.pkl")
62
+
63
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
64
+
65
+ WILOR_TRANSFORM = transforms.Compose([
66
+ transforms.ToTensor(),
67
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
68
+ ])
69
+
70
+ wilor_model = None
71
+ yolo_detector = None
72
+ classifier = None
73
+
74
+ def load_models():
75
+ global wilor_model, yolo_detector, classifier
76
+ sys.path.insert(0, WILOR_REPO_PATH)
77
+ from wilor.models import load_wilor
78
+ from ultralytics import YOLO
79
+
80
+ print(f"Loading WiLoR on {DEVICE}...")
81
+ wilor_model, _ = load_wilor(checkpoint_path=WILOR_CKPT, cfg_path=WILOR_CFG)
82
+ wilor_model.to(DEVICE)
83
+ wilor_model.eval()
84
+
85
+ print(f"Loading YOLO from: {DETECTOR_PATH}")
86
+ yolo_detector = YOLO(DETECTOR_PATH)
87
+
88
+ print("Loading Classifier...")
89
+ classifier = joblib.load(CLASSIFIER_PATH)
90
+ print("Models loaded!")
91
+
92
+ @asynccontextmanager
93
+ async def lifespan(app: FastAPI):
94
+ load_models()
95
+ yield
96
+
97
+ app = FastAPI(title="Arabic Sign Language Interpreter", lifespan=lifespan)
98
+
99
+ app.add_middleware(
100
+ CORSMiddleware,
101
+ allow_origins=["*"],
102
+ allow_methods=["*"],
103
+ allow_headers=["*"],
104
+ )
105
+
106
+ def extract_features(crop_rgb: np.ndarray) -> np.ndarray | None:
107
+ img_input = cv2.resize(crop_rgb, (256, 256))
108
+ img_tensor = WILOR_TRANSFORM(img_input).unsqueeze(0).to(DEVICE)
109
+
110
+ with torch.no_grad():
111
+ output = wilor_model({"img": img_tensor})
112
+
113
+ if "pred_mano_params" not in output or "pred_keypoints_3d" not in output:
114
+ return None
115
+
116
+ mano = output["pred_mano_params"]
117
+ hand_pose = mano["hand_pose"][0].cpu().numpy().flatten()
118
+ global_orient = mano["global_orient"][0].cpu().numpy().flatten()
119
+ theta = np.concatenate([global_orient, hand_pose])
120
+
121
+ joints = output["pred_keypoints_3d"][0].cpu().numpy()
122
+ tips = [4, 8, 12, 16, 20]
123
+ hand_scale = distance.euclidean(joints[0], joints[9]) + 1e-8
124
+
125
+ dist_feats = []
126
+ for i in range(1, 5):
127
+ dist_feats.append(distance.euclidean(joints[tips[0]], joints[tips[i]]) / hand_scale)
128
+ for i in range(1, 4):
129
+ dist_feats.append(distance.euclidean(joints[tips[i]], joints[tips[i+1]]) / hand_scale)
130
+
131
+ return np.concatenate([theta, dist_feats])
132
+
133
+ def get_3d_joints(crop_rgb: np.ndarray) -> np.ndarray:
134
+ img_input = cv2.resize(crop_rgb, (256, 256))
135
+ img_tensor = WILOR_TRANSFORM(img_input).unsqueeze(0).to(DEVICE)
136
+ with torch.no_grad():
137
+ output = wilor_model({"img": img_tensor})
138
+ return output["pred_keypoints_3d"][0].cpu().numpy()
139
+
140
+ def read_image_from_upload(file_bytes: bytes) -> np.ndarray:
141
+ arr = np.frombuffer(file_bytes, np.uint8)
142
+ img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
143
+ if img is None:
144
+ raise HTTPException(status_code=400, detail="Invalid image.")
145
+ return img
146
+
147
+ @app.get("/")
148
+ def root():
149
+ return {"status": "running", "device": DEVICE}
150
+
151
+ @app.post("/predict")
152
+ async def predict(file: UploadFile = File(...)):
153
+ raw = await file.read()
154
+ img_bgr = read_image_from_upload(raw)
155
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
156
+
157
+ results = yolo_detector.predict(img_rgb, conf=0.5, verbose=False, device=DEVICE)
158
+ if not results[0].boxes:
159
+ raise HTTPException(status_code=422, detail="No hand detected.")
160
+
161
+ box = results[0].boxes.xyxy[0].cpu().numpy().astype(int)
162
+ label_id = int(results[0].boxes.cls[0].cpu().item())
163
+ hand_side = "left" if label_id == 0 else "right"
164
+
165
+ x1, y1, x2, y2 = box
166
+ h, w = img_rgb.shape[:2]
167
+ x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w, x2), min(h, y2)
168
+ crop = img_rgb[y1:y2, x1:x2]
169
+
170
+ if crop.size == 0:
171
+ raise HTTPException(status_code=422, detail="Empty crop.")
172
+
173
+ features = extract_features(crop)
174
+ if features is None:
175
+ raise HTTPException(status_code=500, detail="Feature extraction failed.")
176
+
177
+ expected_cols = classifier.feature_names_in_
178
+ final_vector = np.zeros(len(expected_cols))
179
+ limit = min(len(features), len(final_vector))
180
+ final_vector[:limit] = features[:limit]
181
+
182
+ feat_df = pd.DataFrame([final_vector], columns=expected_cols)
183
+ prediction = classifier.predict(feat_df)[0]
184
+ proba = classifier.predict_proba(feat_df)[0]
185
+
186
+ return JSONResponse({
187
+ "prediction": prediction,
188
+ "confidence": round(float(proba.max()), 4),
189
+ "hand_side": hand_side,
190
+ "bbox": [int(x1), int(y1), int(x2), int(y2)],
191
+ })
192
+
193
+ @app.post("/predict_with_skeleton")
194
+ async def predict_with_skeleton(file: UploadFile = File(...)):
195
+ raw = await file.read()
196
+ img_bgr = read_image_from_upload(raw)
197
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
198
+
199
+ results = yolo_detector.predict(img_rgb, conf=0.5, verbose=False, device=DEVICE)
200
+ if not results[0].boxes:
201
+ raise HTTPException(status_code=422, detail="No hand detected.")
202
+
203
+ box = results[0].boxes.xyxy[0].cpu().numpy().astype(int)
204
+ label_id = int(results[0].boxes.cls[0].cpu().item())
205
+ hand_side = "left" if label_id == 0 else "right"
206
+ x1, y1, x2, y2 = box
207
+ h, w = img_rgb.shape[:2]
208
+ x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w, x2), min(h, y2)
209
+ crop = img_rgb[y1:y2, x1:x2]
210
+
211
+ features = extract_features(crop)
212
+ joints = get_3d_joints(crop)
213
+
214
+ expected_cols = classifier.feature_names_in_
215
+ final_vector = np.zeros(len(expected_cols))
216
+ limit = min(len(features), len(final_vector))
217
+ final_vector[:limit] = features[:limit]
218
+
219
+ feat_df = pd.DataFrame([final_vector], columns=expected_cols)
220
+ prediction = classifier.predict(feat_df)[0]
221
+ proba = classifier.predict_proba(feat_df)[0]
222
+
223
+ _, buf = cv2.imencode(".png", cv2.cvtColor(crop, cv2.COLOR_RGB2BGR))
224
+ crop_b64 = base64.b64encode(buf).decode("utf-8")
225
+
226
+ return JSONResponse({
227
+ "prediction": prediction,
228
+ "confidence": round(float(proba.max()), 4),
229
+ "hand_side": hand_side,
230
+ "bbox": [int(x1), int(y1), int(x2), int(y2)],
231
+ "joints_3d": joints.tolist(),
232
+ "crop_b64": crop_b64,
233
+ })
234
+
235
+ if __name__ == "__main__":
236
+ # Hugging Face Spaces port is 7860
237
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)