chawin.chen commited on
Commit
2cbce78
·
1 Parent(s): c6e4a1e
Files changed (1) hide show
  1. api_routes.py +41 -0
api_routes.py CHANGED
@@ -63,6 +63,47 @@ if DEEPFACE_AVAILABLE:
63
  try:
64
  from deepface import DeepFace
65
  deepface_module = DeepFace
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  logger.info("DeepFace module imported successfully")
67
  except ImportError as e:
68
  logger.error(f"Failed to import DeepFace: {e}")
 
63
  try:
64
  from deepface import DeepFace
65
  deepface_module = DeepFace
66
+ try:
67
+ from deepface.models import FacialRecognition as df_facial_recognition
68
+ DeepFaceModel = getattr(df_facial_recognition, "Model", None)
69
+ _original_forward = df_facial_recognition.FacialRecognition.forward
70
+
71
+ def _patched_forward(self, img):
72
+ """
73
+ 兼容Keras 3 / tf_keras 返回SymbolicTensor的情况,必要时退回predict。
74
+ """
75
+ if DeepFaceModel is None or not isinstance(self.model, DeepFaceModel):
76
+ return _original_forward(self, img)
77
+
78
+ if img.ndim == 3:
79
+ img = np.expand_dims(img, axis=0)
80
+
81
+ if img.ndim == 4 and img.shape[0] == 1:
82
+ try:
83
+ outputs = self.model(img, training=False)
84
+ embeddings = outputs.numpy()
85
+ except Exception:
86
+ # Keras 3 调用 self.model(...) 可能返回SymbolicTensor,退回 predict
87
+ embeddings = self.model.predict(img, verbose=0)
88
+ elif img.ndim == 4 and img.shape[0] > 1:
89
+ embeddings = self.model.predict_on_batch(img)
90
+ else:
91
+ raise ValueError(
92
+ f"Input image must be (1, X, X, 3) shaped but it is {img.shape}"
93
+ )
94
+
95
+ embeddings = np.asarray(embeddings)
96
+ if embeddings.ndim == 0:
97
+ raise ValueError("Embeddings output is empty.")
98
+
99
+ if embeddings.shape[0] == 1:
100
+ return embeddings[0].tolist()
101
+ return embeddings.tolist()
102
+
103
+ df_facial_recognition.FacialRecognition.forward = _patched_forward
104
+ logger.info("Patched DeepFace FacialRecognition.forward for SymbolicTensor compatibility")
105
+ except Exception as patch_exc:
106
+ logger.warning(f"Failed to patch DeepFace forward method: {patch_exc}")
107
  logger.info("DeepFace module imported successfully")
108
  except ImportError as e:
109
  logger.error(f"Failed to import DeepFace: {e}")