SondosM commited on
Commit
d18cd55
·
verified ·
1 Parent(s): 9e2d6e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -143
app.py CHANGED
@@ -6,6 +6,7 @@ import os
6
  import types
7
  import shutil
8
  from unittest.mock import MagicMock
 
9
 
10
  import numpy as np
11
  import cv2
@@ -24,7 +25,7 @@ from fastapi.responses import JSONResponse
24
  import uvicorn
25
  from huggingface_hub import hf_hub_download
26
 
27
- # --- Compatibility Patches for Numpy and Inspect ---
28
  if not hasattr(inspect, "getargspec"):
29
  inspect.getargspec = inspect.getfullargspec
30
 
@@ -33,7 +34,7 @@ for attr, typ in [("int", int), ("float", float), ("complex", complex),
33
  if not hasattr(np, attr):
34
  setattr(np, attr, typ)
35
 
36
- # --- Pyrender / OpenGL Mock (Headless Environment Fix) ---
37
  pyrender_mock = types.ModuleType("pyrender")
38
  for _attr in ["Scene", "Mesh", "Node", "PerspectiveCamera", "DirectionalLight",
39
  "PointLight", "SpotLight", "OffscreenRenderer", "RenderFlags",
@@ -52,38 +53,27 @@ os.environ["PYOPENGL_PLATFORM"] = "osmesa"
52
  REPO_ID = "SondosM/api_GP"
53
 
54
  def get_hf_file(filename, is_mano=False):
55
- print(f"Downloading {filename} from {REPO_ID}...")
56
  temp_path = hf_hub_download(repo_id=REPO_ID, filename=filename)
57
-
58
  if is_mano:
59
- # Create local folder structure expected by WiLoR
60
  os.makedirs("./mano_data", exist_ok=True)
61
  target_path = os.path.join("./mano_data", os.path.basename(filename))
62
  if not os.path.exists(target_path):
63
  shutil.copy(temp_path, target_path)
64
- print(f"Copied {filename} to {target_path}")
65
  return target_path
66
-
67
  return temp_path
68
 
69
- # --- Map paths according to your Repo list ---
70
- print("Initializing model file paths...")
71
-
72
- # MANO Files
73
  get_hf_file("mano_data/mano_data/mano_mean_params.npz", is_mano=True)
74
  get_hf_file("mano_data/mano_data/MANO_LEFT.pkl", is_mano=True)
75
  get_hf_file("mano_data/mano_data/MANO_RIGHT.pkl", is_mano=True)
76
 
77
  WILOR_REPO_PATH = "./WiLoR"
78
- # Model weights
79
  WILOR_CKPT = get_hf_file("pretrained_models/pretrained_models/wilor_final.ckpt")
80
  WILOR_CFG = get_hf_file("pretrained_models/pretrained_models/model_config.yaml")
81
  DETECTOR_PATH = get_hf_file("pretrained_models/pretrained_models/detector.pt")
82
- # Classifier
83
  CLASSIFIER_PATH = get_hf_file("classifier.pkl")
84
 
85
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
86
-
87
  WILOR_TRANSFORM = transforms.Compose([
88
  transforms.ToTensor(),
89
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
@@ -99,15 +89,9 @@ def load_models():
99
  from wilor.models import load_wilor
100
  from ultralytics import YOLO
101
 
102
- print(f"Loading WiLoR on {DEVICE}...")
103
  wilor_model, _ = load_wilor(checkpoint_path=WILOR_CKPT, cfg_path=WILOR_CFG)
104
- wilor_model.to(DEVICE)
105
- wilor_model.eval()
106
-
107
- print(f"Loading YOLO detector...")
108
  yolo_detector = YOLO(DETECTOR_PATH)
109
-
110
- print("Loading RandomForest classifier...")
111
  classifier = joblib.load(CLASSIFIER_PATH)
112
  print("✅ All models loaded successfully!")
113
 
@@ -116,143 +100,78 @@ async def lifespan(app: FastAPI):
116
  load_models()
117
  yield
118
 
119
- app = FastAPI(title="Arabic Sign Language Interpreter", lifespan=lifespan)
120
-
121
- app.add_middleware(
122
- CORSMiddleware,
123
- allow_origins=["*"],
124
- allow_methods=["*"],
125
- allow_headers=["*"],
126
- )
127
-
128
- def extract_features(crop_rgb: np.ndarray) -> np.ndarray | None:
129
- img_input = cv2.resize(crop_rgb, (256, 256))
130
- img_tensor = WILOR_TRANSFORM(img_input).unsqueeze(0).to(DEVICE)
131
-
132
- with torch.no_grad():
133
- output = wilor_model({"img": img_tensor})
134
-
135
- if "pred_mano_params" not in output or "pred_keypoints_3d" not in output:
136
- return None
137
-
138
- mano = output["pred_mano_params"]
139
- hand_pose = mano["hand_pose"][0].cpu().numpy().flatten()
140
- global_orient = mano["global_orient"][0].cpu().numpy().flatten()
141
- theta = np.concatenate([global_orient, hand_pose])
142
-
143
- joints = output["pred_keypoints_3d"][0].cpu().numpy()
144
- tips = [4, 8, 12, 16, 20]
145
- hand_scale = distance.euclidean(joints[0], joints[9]) + 1e-8
146
-
147
- dist_feats = []
148
- for i in range(1, 5):
149
- dist_feats.append(distance.euclidean(joints[tips[0]], joints[tips[i]]) / hand_scale)
150
- for i in range(1, 4):
151
- dist_feats.append(distance.euclidean(joints[tips[i]], joints[tips[i+1]]) / hand_scale)
152
-
153
- return np.concatenate([theta, dist_feats])
154
-
155
- def get_3d_joints(crop_rgb: np.ndarray) -> np.ndarray:
156
- img_input = cv2.resize(crop_rgb, (256, 256))
157
- img_tensor = WILOR_TRANSFORM(img_input).unsqueeze(0).to(DEVICE)
158
- with torch.no_grad():
159
- output = wilor_model({"img": img_tensor})
160
- return output["pred_keypoints_3d"][0].cpu().numpy()
161
-
162
- def read_image_from_upload(file_bytes: bytes) -> np.ndarray:
163
- arr = np.frombuffer(file_bytes, np.uint8)
164
- img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
165
- if img is None:
166
- raise HTTPException(status_code=400, detail="Invalid image format.")
167
- return img
168
-
169
- @app.get("/")
170
- def root():
171
- return {"status": "running", "device": DEVICE}
172
-
173
- @app.post("/predict")
174
- async def predict(file: UploadFile = File(...)):
175
- raw = await file.read()
176
- img_bgr = read_image_from_upload(raw)
177
- img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
178
 
 
179
  results = yolo_detector.predict(img_rgb, conf=0.5, verbose=False, device=DEVICE)
180
  if not results[0].boxes:
181
- raise HTTPException(status_code=422, detail="No hand detected.")
182
 
183
  box = results[0].boxes.xyxy[0].cpu().numpy().astype(int)
184
- label_id = int(results[0].boxes.cls[0].cpu().item())
185
- hand_side = "left" if label_id == 0 else "right"
186
-
187
  x1, y1, x2, y2 = box
188
  h, w = img_rgb.shape[:2]
189
- x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w, x2), min(h, y2)
190
- crop = img_rgb[y1:y2, x1:x2]
191
 
192
- if crop.size == 0:
193
- raise HTTPException(status_code=422, detail="Empty hand crop.")
 
 
 
194
 
195
- features = extract_features(crop)
196
- if features is None:
197
- raise HTTPException(status_code=500, detail="Feature extraction failed.")
198
 
199
- expected_cols = classifier.feature_names_in_
200
- final_vector = np.zeros(len(expected_cols))
201
- limit = min(len(features), len(final_vector))
202
- final_vector[:limit] = features[:limit]
203
 
204
- feat_df = pd.DataFrame([final_vector], columns=expected_cols)
 
 
 
 
 
 
 
 
205
  prediction = classifier.predict(feat_df)[0]
206
- proba = classifier.predict_proba(feat_df)[0]
207
 
208
- return JSONResponse({
209
  "prediction": str(prediction),
210
- "confidence": round(float(proba.max()), 4),
211
- "hand_side": hand_side,
212
- "bbox": [int(x1), int(y1), int(x2), int(y2)],
213
- })
214
-
215
- @app.post("/predict_with_skeleton")
216
- async def predict_with_skeleton(file: UploadFile = File(...)):
217
- raw = await file.read()
218
- img_bgr = read_image_from_upload(raw)
219
- img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
220
 
221
- results = yolo_detector.predict(img_rgb, conf=0.5, verbose=False, device=DEVICE)
222
- if not results[0].boxes:
223
- raise HTTPException(status_code=422, detail="No hand detected.")
224
-
225
- box = results[0].boxes.xyxy[0].cpu().numpy().astype(int)
226
- label_id = int(results[0].boxes.cls[0].cpu().item())
227
- hand_side = "left" if label_id == 0 else "right"
228
- x1, y1, x2, y2 = box
229
- h, w = img_rgb.shape[:2]
230
- x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w, x2), min(h, y2)
231
- crop = img_rgb[y1:y2, x1:x2]
232
-
233
- features = extract_features(crop)
234
- joints = get_3d_joints(crop)
235
-
236
- expected_cols = classifier.feature_names_in_
237
- final_vector = np.zeros(len(expected_cols))
238
- limit = min(len(features), len(final_vector))
239
- final_vector[:limit] = features[:limit]
240
-
241
- feat_df = pd.DataFrame([final_vector], columns=expected_cols)
242
- prediction = classifier.predict(feat_df)[0]
243
- proba = classifier.predict_proba(feat_df)[0]
244
-
245
- _, buf = cv2.imencode(".png", cv2.cvtColor(crop, cv2.COLOR_RGB2BGR))
246
- crop_b64 = base64.b64encode(buf).decode("utf-8")
247
 
248
- return JSONResponse({
249
- "prediction": str(prediction),
250
- "confidence": round(float(proba.max()), 4),
251
- "hand_side": hand_side,
252
- "bbox": [int(x1), int(y1), int(x2), int(y2)],
253
- "joints_3d": joints.tolist(),
254
- "crop_b64": crop_b64,
255
- })
256
 
257
  if __name__ == "__main__":
258
- uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
 
6
  import types
7
  import shutil
8
  from unittest.mock import MagicMock
9
+ from typing import List
10
 
11
  import numpy as np
12
  import cv2
 
25
  import uvicorn
26
  from huggingface_hub import hf_hub_download
27
 
28
+ # --- Compatibility Patches ---
29
  if not hasattr(inspect, "getargspec"):
30
  inspect.getargspec = inspect.getfullargspec
31
 
 
34
  if not hasattr(np, attr):
35
  setattr(np, attr, typ)
36
 
37
+ # --- Pyrender / OpenGL Mock (Headless Fix) ---
38
  pyrender_mock = types.ModuleType("pyrender")
39
  for _attr in ["Scene", "Mesh", "Node", "PerspectiveCamera", "DirectionalLight",
40
  "PointLight", "SpotLight", "OffscreenRenderer", "RenderFlags",
 
53
  REPO_ID = "SondosM/api_GP"
54
 
55
  def get_hf_file(filename, is_mano=False):
 
56
  temp_path = hf_hub_download(repo_id=REPO_ID, filename=filename)
 
57
  if is_mano:
 
58
  os.makedirs("./mano_data", exist_ok=True)
59
  target_path = os.path.join("./mano_data", os.path.basename(filename))
60
  if not os.path.exists(target_path):
61
  shutil.copy(temp_path, target_path)
 
62
  return target_path
 
63
  return temp_path
64
 
65
+ # Resolve paths
 
 
 
66
  get_hf_file("mano_data/mano_data/mano_mean_params.npz", is_mano=True)
67
  get_hf_file("mano_data/mano_data/MANO_LEFT.pkl", is_mano=True)
68
  get_hf_file("mano_data/mano_data/MANO_RIGHT.pkl", is_mano=True)
69
 
70
  WILOR_REPO_PATH = "./WiLoR"
 
71
  WILOR_CKPT = get_hf_file("pretrained_models/pretrained_models/wilor_final.ckpt")
72
  WILOR_CFG = get_hf_file("pretrained_models/pretrained_models/model_config.yaml")
73
  DETECTOR_PATH = get_hf_file("pretrained_models/pretrained_models/detector.pt")
 
74
  CLASSIFIER_PATH = get_hf_file("classifier.pkl")
75
 
76
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
77
  WILOR_TRANSFORM = transforms.Compose([
78
  transforms.ToTensor(),
79
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
89
  from wilor.models import load_wilor
90
  from ultralytics import YOLO
91
 
 
92
  wilor_model, _ = load_wilor(checkpoint_path=WILOR_CKPT, cfg_path=WILOR_CFG)
93
+ wilor_model.to(DEVICE).eval()
 
 
 
94
  yolo_detector = YOLO(DETECTOR_PATH)
 
 
95
  classifier = joblib.load(CLASSIFIER_PATH)
96
  print("✅ All models loaded successfully!")
97
 
 
100
  load_models()
101
  yield
102
 
103
+ app = FastAPI(title="Arabic Sign Language Batch API", lifespan=lifespan)
104
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ def process_single_image(img_rgb):
107
  results = yolo_detector.predict(img_rgb, conf=0.5, verbose=False, device=DEVICE)
108
  if not results[0].boxes:
109
+ return None, "No hand detected"
110
 
111
  box = results[0].boxes.xyxy[0].cpu().numpy().astype(int)
 
 
 
112
  x1, y1, x2, y2 = box
113
  h, w = img_rgb.shape[:2]
114
+ crop = img_rgb[max(0, y1):min(h, y2), max(0, x1):min(w, x2)]
 
115
 
116
+ img_input = cv2.resize(crop, (256, 256))
117
+ img_tensor = WILOR_TRANSFORM(img_input).unsqueeze(0).to(DEVICE)
118
+
119
+ with torch.no_grad():
120
+ output = wilor_model({"img": img_tensor})
121
 
122
+ if "pred_mano_params" not in output:
123
+ return None, "Feature extraction failed"
 
124
 
125
+ mano = output["pred_mano_params"]
126
+ theta = np.concatenate([mano["global_orient"][0].cpu().numpy().flatten(),
127
+ mano["hand_pose"][0].cpu().numpy().flatten()])
 
128
 
129
+ joints = output["pred_keypoints_3d"][0].cpu().numpy()
130
+ hand_scale = distance.euclidean(joints[0], joints[9]) + 1e-8
131
+ tips = [4, 8, 12, 16, 20]
132
+ dist_feats = [distance.euclidean(joints[tips[0]], joints[tips[i]])/hand_scale for i in range(1,5)]
133
+ dist_feats += [distance.euclidean(joints[tips[i]], joints[tips[i+1]])/hand_scale for i in range(1,4)]
134
+
135
+ features = np.concatenate([theta, dist_feats])
136
+
137
+ feat_df = pd.DataFrame([features], columns=classifier.feature_names_in_)
138
  prediction = classifier.predict(feat_df)[0]
139
+ confidence = float(classifier.predict_proba(feat_df)[0].max())
140
 
141
+ return {
142
  "prediction": str(prediction),
143
+ "confidence": round(confidence, 4),
144
+ "bbox": [int(x1), int(y1), int(x2), int(y2)]
145
+ }, None
 
 
 
 
 
 
 
146
 
147
+ @app.post("/predict")
148
+ async def predict(files: List[UploadFile] = File(...)):
149
+ final_results = []
150
+ for file in files:
151
+ try:
152
+ raw = await file.read()
153
+ arr = np.frombuffer(raw, np.uint8)
154
+ img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
155
+ if img is None:
156
+ final_results.append({"filename": file.filename, "error": "Invalid image format"})
157
+ continue
158
+
159
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
160
+ res, err = process_single_image(img_rgb)
161
+
162
+ if err:
163
+ final_results.append({"filename": file.filename, "error": err})
164
+ else:
165
+ res["filename"] = file.filename
166
+ final_results.append(res)
167
+ except Exception as e:
168
+ final_results.append({"filename": file.filename, "error": str(e)})
169
+
170
+ return JSONResponse({"results": final_results})
 
 
171
 
172
+ @app.get("/")
173
+ def root():
174
+ return {"status": "running", "batch_mode": True}
 
 
 
 
 
175
 
176
  if __name__ == "__main__":
177
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)