prodia3456 commited on
Commit
b666f4c
·
verified ·
1 Parent(s): a99e46e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -11
app.py CHANGED
@@ -41,28 +41,46 @@ class Predictor:
41
  print("No face found")
42
  return None
43
 
 
44
  def predict(self, input_image, swap_image):
45
  """Run a single prediction on the model"""
 
 
 
 
46
  try:
47
- frame = cv2.imread(input_image.name)
48
- face = self.get_face(frame)
49
- source_face = self.get_face(cv2.imread(swap_image.name))
50
- try:
51
- print(frame.shape, face.shape, source_face.shape)
52
- except:
53
- print("printing shapes failed.")
54
- result = self.face_swapper.get(frame, face, source_face, paste_back=True)
55
-
 
 
 
 
 
 
 
 
 
56
  _, _, result = self.face_enhancer.enhance(
57
  result,
58
  paste_back=True
59
  )
 
 
60
  out_path = tempfile.mkdtemp() + f"/{str(int(time.time()))}.jpg"
61
  cv2.imwrite(out_path, result)
62
  return out_path
 
63
  except Exception as e:
64
- print(f"{e}")
65
- return None
 
66
 
67
 
68
  # Instantiate the Predictor class
 
41
  print("No face found")
42
  return None
43
 
44
+ # 这是修正后的predict函数,请用它替换掉您源代码中的旧版本
45
  def predict(self, input_image, swap_image):
46
  """Run a single prediction on the model"""
47
+ # 【新增】检查输入是否为空
48
+ if input_image is None or swap_image is None:
49
+ raise gr.Error("请确保同时上传了目标图片和源图片。") # 使用Gradio的方式优雅地报错
50
+
51
  try:
52
+ # 读取图片
53
+ target_img = cv2.imread(input_image.name)
54
+ swap_img = cv2.imread(swap_image.name)
55
+
56
+ # 分析人脸
57
+ target_face = self.get_face(target_img)
58
+ source_face = self.get_face(swap_img)
59
+
60
+ # 【关键修正】在执行换脸前,检查是否成功找到了人脸
61
+ if target_face is None:
62
+ raise gr.Error("在目标图片中未能检测到人脸,请更换图片后重试。")
63
+ if source_face is None:
64
+ raise gr.Error("在源图片中未能检测到人脸,请更换图片后重试。")
65
+
66
+ # 如果人脸都找到了,才执行核心换脸操作
67
+ result = self.face_swapper.get(target_img, target_face, source_face, paste_back=True)
68
+
69
+ # 增强画质
70
  _, _, result = self.face_enhancer.enhance(
71
  result,
72
  paste_back=True
73
  )
74
+
75
+ # 保存并返回结果
76
  out_path = tempfile.mkdtemp() + f"/{str(int(time.time()))}.jpg"
77
  cv2.imwrite(out_path, result)
78
  return out_path
79
+
80
  except Exception as e:
81
+ # 如果发生其他未知错误,也通过Gradio报错
82
+ print(f"An unexpected error occurred: {e}")
83
+ raise gr.Error(f"发生未知错误: {e}")
84
 
85
 
86
  # Instantiate the Predictor class