KYM384 commited on
Commit
e0b4e79
·
verified ·
1 Parent(s): da3b582

bfloat16 to float16

Browse files
Files changed (1) hide show
  1. utils.py +218 -218
utils.py CHANGED
@@ -1,142 +1,142 @@
1
- import torch
2
- import torchvision
3
- import numpy as np
4
  import argparse
5
  import copy
6
  import cv2
7
  import os
8
  from contextlib import nullcontext
9
  from huggingface_hub import hf_hub_download
10
-
11
- from facenet_pytorch import MTCNN
12
- from models import MobileGenerator, MobileNetV3MultiTask
13
-
14
-
15
- class Face:
16
- def __init__(self, keypoint: list[tuple[int, int]]):
17
- self.keypoint = keypoint
18
-
19
- e0, e1, n, m0, m1 = keypoint
20
- x_ = e1 - e0
21
- y_ = 0.5 * (e0 + e1) - 0.5 * (m0 + m1)
22
- c = 0.5 * (e0 + e1) - 0.1 * y_
23
- cx, cy = int(c[0]), int(c[1])
24
-
25
- theta = np.arctan2(x_[1], x_[0])
26
-
27
- s = max(4.0 * np.linalg.norm(x_), 3.6 * np.linalg.norm(y_))
28
- s = int(s)
29
-
30
- # bbox: (x, y, w, h)
31
- self.bbox = (cx-s//2, cy-s//2, s, s)
32
- self.theta = theta
33
-
34
- def get_center(self):
35
- return self.bbox[0] + self.bbox[2] // 2, self.bbox[1] + self.bbox[3] // 2
36
-
37
- def get_size(self):
38
- return self.bbox[2]
39
-
40
- def set_attributes(self, age: int, gender: str):
41
- self.age = age
42
- self.gender = gender
43
-
44
- def update(self, keypoint: list[tuple[int, int]]):
45
- self.__init__(keypoint)
46
-
47
- def calc_iou(self, other) -> float:
48
- x1 = max(self.bbox[0], other.bbox[0])
49
- y1 = max(self.bbox[1], other.bbox[1])
50
- x2 = min(self.bbox[0] + self.bbox[2], other.bbox[0] + other.bbox[2])
51
- y2 = min(self.bbox[1] + self.bbox[3], other.bbox[1] + other.bbox[3])
52
-
53
- inter_area = max(0, x2 - x1) * max(0, y2 - y1)
54
- union_area = self.bbox[2] * self.bbox[3] + other.bbox[2] * other.bbox[3] - inter_area
55
-
56
- if union_area == 0:
57
- return 0.0
58
- return inter_area / union_area
59
-
60
-
61
- class FaceSet:
62
- latent_ids = np.load(
63
- hf_hub_download(
64
- repo_id=os.getenv("HF_GEN_REPO_ID"),
65
- filename="latent_ids.npz",
66
- token=os.getenv("HF_HUB_TOKEN")
67
- )
68
- )
69
-
70
- def __init__(self):
71
- self.faces = []
72
- self.nonused_counter = []
73
-
74
- def append(self, face: Face):
75
- self.faces.append(face)
76
- self.nonused_counter.append(0)
77
-
78
- def set_attributes(self, i: int, age: int, gender: str):
79
- self.faces[i].set_attributes(age, gender)
80
- if age[0] == 80 and gender[0] == "M":
81
- age[0] = 70
82
- self.faces[i].latent_id = self.latent_ids[f"{age[0]}_{gender[0]}_jp"]
83
-
84
- def __len__(self) -> int:
85
- # s = sum(c == 0 for c in self.nonused_counter)
86
- # return s
87
- return len(self.faces)
88
-
89
- def __getitem__(self, idx: int) -> Face:
90
- return self.faces[idx]
91
-
92
- def __iter__(self):
93
- # s = sum(c == 0 for c in self.nonused_counter)
94
- # return iter(self.faces[:s])
95
- return iter(self.faces)
96
-
97
- def update(self, other, reset_nonused_threshold: int):
98
- matched_self_indices = []
99
-
100
- for i, other_face in enumerate(other):
101
- max_iou = 0
102
- max_j = -1
103
- for j, self_face in enumerate(self.faces):
104
- iou = other_face.calc_iou(self_face)
105
- if iou > max_iou:
106
- max_iou = iou
107
- max_j = j
108
-
109
- if max_iou > 0.3:
110
- self.faces[max_j].update(other_face.keypoint)
111
- self.nonused_counter[max_j] = 0
112
- matched_self_indices.append(max_j)
113
- else:
114
- self.append(other_face)
115
- matched_self_indices.append(len(self.faces)-1)
116
-
117
- for j in range(len(self.faces)):
118
- if j not in matched_self_indices:
119
- self.nonused_counter[j] += 1
120
-
121
- argsort = np.argsort(self.nonused_counter)[::-1]
122
- self.faces = [self.faces[j] for j in argsort]
123
- self.nonused_counter = [self.nonused_counter[j] for j in argsort]
124
-
125
- self.faces = [face for j, face in enumerate(self.faces) if self.nonused_counter[j] < reset_nonused_threshold]
126
- self.nonused_counter = [count for count in self.nonused_counter if count < reset_nonused_threshold]
127
-
128
-
129
- class FaceCropper:
130
- def __init__(self):
131
- self.size = 256
132
- self.crop_size = 224
133
- self.detector = MTCNN(select_largest=False, keep_all=True, device="cuda" if torch.cuda.is_available() else "cpu")
134
-
135
- mask = np.zeros((self.crop_size, self.crop_size), dtype=np.uint8)
136
- mask[8:-8, 8:-8] = 255
137
- mask = cv2.GaussianBlur(mask, (31, 31), 0)
138
- self.mask = mask
139
-
140
  def detect_keypoints(self, image: np.ndarray) -> FaceSet:
141
  height, width = image.shape[:2]
142
 
@@ -149,90 +149,90 @@ class FaceCropper:
149
  for i in range(len(points)):
150
  left_eye = points[i][0]
151
  right_eye = points[i][1]
152
- nose = points[i][2]
153
- left_mouth = points[i][3]
154
- right_mouth = points[i][4]
155
-
156
- faces_list.append(Face(keypoint=[left_eye, right_eye, nose, left_mouth, right_mouth]))
157
-
158
- return faces_list
159
-
160
- def crop_and_resize(self, image: np.ndarray, face: Face) -> np.ndarray:
161
- cx, cy = face.get_center()
162
- theta = face.theta
163
- s = face.get_size()
164
-
165
- M = cv2.getRotationMatrix2D((cx, cy), np.degrees(theta), self.size / s * 1.14)
166
- M[0, 2] += self.crop_size // 2 - cx
167
- M[1, 2] += self.crop_size // 2 - cy
168
-
169
- cropped = cv2.warpAffine(image, M, (self.crop_size, self.crop_size), flags=cv2.INTER_LINEAR)
170
- return cropped
171
-
172
- def invert_image(self, image: np.ndarray, cropped: np.ndarray, face: Face) -> np.ndarray:
173
- cx, cy = face.get_center()
174
- theta = face.theta
175
- s = face.get_size()
176
-
177
- x0 = max(0, int(np.floor(cx - s)))
178
- y0 = max(0, int(np.floor(cy - s)))
179
- x1 = min(image.shape[1], int(np.ceil(cx + s)))
180
- y1 = min(image.shape[0], int(np.ceil(cy + s)))
181
-
182
- if x0 >= x1 or y0 >= y1:
183
- return image
184
-
185
- cropped_image = image[y0:y1, x0:x1]
186
- cx_local = cx - x0
187
- cy_local = cy - y0
188
-
189
- M = cv2.getRotationMatrix2D((cx_local, cy_local), np.degrees(theta), self.size / s * 1.14)
190
- M[0, 2] += self.crop_size // 2 - cx_local
191
- M[1, 2] += self.crop_size // 2 - cy_local
192
-
193
- M_inv = cv2.invertAffineTransform(M)
194
- inverted = cv2.warpAffine(cropped, M_inv, (x1-x0, y1-y0), flags=cv2.INTER_LINEAR)
195
-
196
- mask = cv2.warpAffine(self.mask, M_inv, (x1-x0, y1-y0))
197
- mask = mask.astype(np.float32)[:, :, None] / 255.0
198
-
199
- blended = cropped_image.astype(np.float32) * (1 - mask) + inverted.astype(np.float32) * mask
200
- result = image.copy()
201
- result[y0:y1, x0:x1] = blended.astype(np.uint8)
202
- return result
203
-
204
-
205
- class FaceSwapper:
206
- def __init__(self, model_path: str, classifier_checkpoint: str):
207
- self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
208
-
209
- self.generator = MobileGenerator(input_nc=3, output_nc=3, latent_dim=512, n_blocks=6)
210
- self.generator.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"), weights_only=False))
211
- self.generator.to(self.device).eval()
212
-
213
- self.classifier = MobileNetV3MultiTask(model_name="mobilenetv3_small_100", num_age_classes=10, num_gender_classes=2)
214
- self.classifier.to(self.device).eval()
215
- self.classifier.load_state_dict(torch.load(classifier_checkpoint, map_location=torch.device("cpu"), weights_only=False)["model_state_dict"])
216
-
217
- self.mean = torch.tensor([0.485, 0.456, 0.406]).reshape(1,3,1,1)
218
- self.std = torch.tensor([0.229, 0.224, 0.225]).reshape(1,3,1,1)
219
-
220
- def np2tensor(self, imgs: np.ndarray) -> torch.Tensor:
221
- if not isinstance(imgs, list):
222
- imgs = [imgs]
223
-
224
- imgs = np.stack(imgs, axis=0)
225
- imgs = torch.from_numpy(imgs.astype(np.float32) / 255).permute(0, 3, 1, 2)
226
- return (imgs - self.mean) / self.std
227
-
228
- def tensor2np(self, imgs: torch.Tensor) -> np.ndarray:
229
- imgs = imgs * self.std + self.mean
230
- imgs = imgs.permute(0, 2, 3, 1).detach().numpy()
231
- imgs = np.clip(imgs, 0, 1)
232
- return (imgs * 255).astype(np.uint8)
233
-
234
  def classify(self, img: np.ndarray) -> list[tuple[int, str]]:
235
- autocast_context = torch.autocast("cuda", torch.bfloat16) if self.device.type == "cuda" else nullcontext()
236
  with torch.no_grad(), autocast_context:
237
  img_tensor = self.np2tensor(img).to(self.device)
238
  ages, genders = self.classifier(img_tensor)
@@ -248,11 +248,11 @@ class FaceSwapper:
248
  return attributes
249
 
250
  def swap(self, img_att: np.ndarray, latent_ids: list[np.ndarray]) -> np.ndarray:
251
- autocast_context = torch.autocast("cuda", torch.bfloat16) if self.device.type == "cuda" else nullcontext()
252
  with torch.no_grad(), autocast_context:
253
  img_att = self.np2tensor(img_att).to(self.device)
254
  latent_ids = torch.from_numpy(np.vstack(latent_ids)).to(self.device)
255
 
256
  output = self.generator(img_att, latent_ids)
257
  return self.tensor2np(output.to("cpu"))
258
-
 
1
+ import torch
2
+ import torchvision
3
+ import numpy as np
4
  import argparse
5
  import copy
6
  import cv2
7
  import os
8
  from contextlib import nullcontext
9
  from huggingface_hub import hf_hub_download
10
+
11
+ from facenet_pytorch import MTCNN
12
+ from models import MobileGenerator, MobileNetV3MultiTask
13
+
14
+
15
+ class Face:
16
+ def __init__(self, keypoint: list[tuple[int, int]]):
17
+ self.keypoint = keypoint
18
+
19
+ e0, e1, n, m0, m1 = keypoint
20
+ x_ = e1 - e0
21
+ y_ = 0.5 * (e0 + e1) - 0.5 * (m0 + m1)
22
+ c = 0.5 * (e0 + e1) - 0.1 * y_
23
+ cx, cy = int(c[0]), int(c[1])
24
+
25
+ theta = np.arctan2(x_[1], x_[0])
26
+
27
+ s = max(4.0 * np.linalg.norm(x_), 3.6 * np.linalg.norm(y_))
28
+ s = int(s)
29
+
30
+ # bbox: (x, y, w, h)
31
+ self.bbox = (cx-s//2, cy-s//2, s, s)
32
+ self.theta = theta
33
+
34
+ def get_center(self):
35
+ return self.bbox[0] + self.bbox[2] // 2, self.bbox[1] + self.bbox[3] // 2
36
+
37
+ def get_size(self):
38
+ return self.bbox[2]
39
+
40
+ def set_attributes(self, age: int, gender: str):
41
+ self.age = age
42
+ self.gender = gender
43
+
44
+ def update(self, keypoint: list[tuple[int, int]]):
45
+ self.__init__(keypoint)
46
+
47
+ def calc_iou(self, other) -> float:
48
+ x1 = max(self.bbox[0], other.bbox[0])
49
+ y1 = max(self.bbox[1], other.bbox[1])
50
+ x2 = min(self.bbox[0] + self.bbox[2], other.bbox[0] + other.bbox[2])
51
+ y2 = min(self.bbox[1] + self.bbox[3], other.bbox[1] + other.bbox[3])
52
+
53
+ inter_area = max(0, x2 - x1) * max(0, y2 - y1)
54
+ union_area = self.bbox[2] * self.bbox[3] + other.bbox[2] * other.bbox[3] - inter_area
55
+
56
+ if union_area == 0:
57
+ return 0.0
58
+ return inter_area / union_area
59
+
60
+
61
+ class FaceSet:
62
+ latent_ids = np.load(
63
+ hf_hub_download(
64
+ repo_id=os.getenv("HF_GEN_REPO_ID"),
65
+ filename="latent_ids.npz",
66
+ token=os.getenv("HF_HUB_TOKEN")
67
+ )
68
+ )
69
+
70
+ def __init__(self):
71
+ self.faces = []
72
+ self.nonused_counter = []
73
+
74
+ def append(self, face: Face):
75
+ self.faces.append(face)
76
+ self.nonused_counter.append(0)
77
+
78
+ def set_attributes(self, i: int, age: int, gender: str):
79
+ self.faces[i].set_attributes(age, gender)
80
+ if age[0] == 80 and gender[0] == "M":
81
+ age[0] = 70
82
+ self.faces[i].latent_id = self.latent_ids[f"{age[0]}_{gender[0]}_jp"]
83
+
84
+ def __len__(self) -> int:
85
+ # s = sum(c == 0 for c in self.nonused_counter)
86
+ # return s
87
+ return len(self.faces)
88
+
89
+ def __getitem__(self, idx: int) -> Face:
90
+ return self.faces[idx]
91
+
92
+ def __iter__(self):
93
+ # s = sum(c == 0 for c in self.nonused_counter)
94
+ # return iter(self.faces[:s])
95
+ return iter(self.faces)
96
+
97
+ def update(self, other, reset_nonused_threshold: int):
98
+ matched_self_indices = []
99
+
100
+ for i, other_face in enumerate(other):
101
+ max_iou = 0
102
+ max_j = -1
103
+ for j, self_face in enumerate(self.faces):
104
+ iou = other_face.calc_iou(self_face)
105
+ if iou > max_iou:
106
+ max_iou = iou
107
+ max_j = j
108
+
109
+ if max_iou > 0.3:
110
+ self.faces[max_j].update(other_face.keypoint)
111
+ self.nonused_counter[max_j] = 0
112
+ matched_self_indices.append(max_j)
113
+ else:
114
+ self.append(other_face)
115
+ matched_self_indices.append(len(self.faces)-1)
116
+
117
+ for j in range(len(self.faces)):
118
+ if j not in matched_self_indices:
119
+ self.nonused_counter[j] += 1
120
+
121
+ argsort = np.argsort(self.nonused_counter)[::-1]
122
+ self.faces = [self.faces[j] for j in argsort]
123
+ self.nonused_counter = [self.nonused_counter[j] for j in argsort]
124
+
125
+ self.faces = [face for j, face in enumerate(self.faces) if self.nonused_counter[j] < reset_nonused_threshold]
126
+ self.nonused_counter = [count for count in self.nonused_counter if count < reset_nonused_threshold]
127
+
128
+
129
+ class FaceCropper:
130
+ def __init__(self):
131
+ self.size = 256
132
+ self.crop_size = 224
133
+ self.detector = MTCNN(select_largest=False, keep_all=True, device="cuda" if torch.cuda.is_available() else "cpu")
134
+
135
+ mask = np.zeros((self.crop_size, self.crop_size), dtype=np.uint8)
136
+ mask[8:-8, 8:-8] = 255
137
+ mask = cv2.GaussianBlur(mask, (31, 31), 0)
138
+ self.mask = mask
139
+
140
  def detect_keypoints(self, image: np.ndarray) -> FaceSet:
141
  height, width = image.shape[:2]
142
 
 
149
  for i in range(len(points)):
150
  left_eye = points[i][0]
151
  right_eye = points[i][1]
152
+ nose = points[i][2]
153
+ left_mouth = points[i][3]
154
+ right_mouth = points[i][4]
155
+
156
+ faces_list.append(Face(keypoint=[left_eye, right_eye, nose, left_mouth, right_mouth]))
157
+
158
+ return faces_list
159
+
160
+ def crop_and_resize(self, image: np.ndarray, face: Face) -> np.ndarray:
161
+ cx, cy = face.get_center()
162
+ theta = face.theta
163
+ s = face.get_size()
164
+
165
+ M = cv2.getRotationMatrix2D((cx, cy), np.degrees(theta), self.size / s * 1.14)
166
+ M[0, 2] += self.crop_size // 2 - cx
167
+ M[1, 2] += self.crop_size // 2 - cy
168
+
169
+ cropped = cv2.warpAffine(image, M, (self.crop_size, self.crop_size), flags=cv2.INTER_LINEAR)
170
+ return cropped
171
+
172
+ def invert_image(self, image: np.ndarray, cropped: np.ndarray, face: Face) -> np.ndarray:
173
+ cx, cy = face.get_center()
174
+ theta = face.theta
175
+ s = face.get_size()
176
+
177
+ x0 = max(0, int(np.floor(cx - s)))
178
+ y0 = max(0, int(np.floor(cy - s)))
179
+ x1 = min(image.shape[1], int(np.ceil(cx + s)))
180
+ y1 = min(image.shape[0], int(np.ceil(cy + s)))
181
+
182
+ if x0 >= x1 or y0 >= y1:
183
+ return image
184
+
185
+ cropped_image = image[y0:y1, x0:x1]
186
+ cx_local = cx - x0
187
+ cy_local = cy - y0
188
+
189
+ M = cv2.getRotationMatrix2D((cx_local, cy_local), np.degrees(theta), self.size / s * 1.14)
190
+ M[0, 2] += self.crop_size // 2 - cx_local
191
+ M[1, 2] += self.crop_size // 2 - cy_local
192
+
193
+ M_inv = cv2.invertAffineTransform(M)
194
+ inverted = cv2.warpAffine(cropped, M_inv, (x1-x0, y1-y0), flags=cv2.INTER_LINEAR)
195
+
196
+ mask = cv2.warpAffine(self.mask, M_inv, (x1-x0, y1-y0))
197
+ mask = mask.astype(np.float32)[:, :, None] / 255.0
198
+
199
+ blended = cropped_image.astype(np.float32) * (1 - mask) + inverted.astype(np.float32) * mask
200
+ result = image.copy()
201
+ result[y0:y1, x0:x1] = blended.astype(np.uint8)
202
+ return result
203
+
204
+
205
+ class FaceSwapper:
206
+ def __init__(self, model_path: str, classifier_checkpoint: str):
207
+ self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
208
+
209
+ self.generator = MobileGenerator(input_nc=3, output_nc=3, latent_dim=512, n_blocks=6)
210
+ self.generator.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"), weights_only=False))
211
+ self.generator.to(self.device).eval()
212
+
213
+ self.classifier = MobileNetV3MultiTask(model_name="mobilenetv3_small_100", num_age_classes=10, num_gender_classes=2)
214
+ self.classifier.to(self.device).eval()
215
+ self.classifier.load_state_dict(torch.load(classifier_checkpoint, map_location=torch.device("cpu"), weights_only=False)["model_state_dict"])
216
+
217
+ self.mean = torch.tensor([0.485, 0.456, 0.406]).reshape(1,3,1,1)
218
+ self.std = torch.tensor([0.229, 0.224, 0.225]).reshape(1,3,1,1)
219
+
220
+ def np2tensor(self, imgs: np.ndarray) -> torch.Tensor:
221
+ if not isinstance(imgs, list):
222
+ imgs = [imgs]
223
+
224
+ imgs = np.stack(imgs, axis=0)
225
+ imgs = torch.from_numpy(imgs.astype(np.float32) / 255).permute(0, 3, 1, 2)
226
+ return (imgs - self.mean) / self.std
227
+
228
+ def tensor2np(self, imgs: torch.Tensor) -> np.ndarray:
229
+ imgs = imgs * self.std + self.mean
230
+ imgs = imgs.permute(0, 2, 3, 1).detach().numpy()
231
+ imgs = np.clip(imgs, 0, 1)
232
+ return (imgs * 255).astype(np.uint8)
233
+
234
  def classify(self, img: np.ndarray) -> list[tuple[int, str]]:
235
+ autocast_context = torch.autocast("cuda", torch.float16) if self.device.type == "cuda" else nullcontext()
236
  with torch.no_grad(), autocast_context:
237
  img_tensor = self.np2tensor(img).to(self.device)
238
  ages, genders = self.classifier(img_tensor)
 
248
  return attributes
249
 
250
  def swap(self, img_att: np.ndarray, latent_ids: list[np.ndarray]) -> np.ndarray:
251
+ autocast_context = torch.autocast("cuda", torch.float16) if self.device.type == "cuda" else nullcontext()
252
  with torch.no_grad(), autocast_context:
253
  img_att = self.np2tensor(img_att).to(self.device)
254
  latent_ids = torch.from_numpy(np.vstack(latent_ids)).to(self.device)
255
 
256
  output = self.generator(img_att, latent_ids)
257
  return self.tensor2np(output.to("cpu"))
258
+