cheenchan commited on
Commit
de0fca3
·
1 Parent(s): 7c7d36f

Lazy-load face detector and embedder

Browse files
frame_extraction/src/frame_extraction/face.py CHANGED
@@ -1,8 +1,7 @@
1
  from __future__ import annotations
2
 
3
- from dataclasses import dataclass
4
- from pathlib import Path
5
- from typing import Iterable, List, Tuple
6
 
7
  import numpy as np
8
  import torch
@@ -10,18 +9,22 @@ from facenet_pytorch import InceptionResnetV1, MTCNN
10
  from PIL import Image
11
 
12
 
13
- @dataclass(slots=True)
14
  class FaceDetector:
15
- device: str = "cuda" if torch.cuda.is_available() else "cpu"
16
  min_face_size: int = 60
17
-
18
- def __post_init__(self) -> None:
19
- self.model = MTCNN(
20
- keep_all=True,
21
- device=self.device,
22
- min_face_size=self.min_face_size,
23
- post_process=False,
24
- )
 
 
 
 
25
 
26
  def detect(self, image: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
27
  pil = Image.fromarray(cv2_to_rgb(image))
@@ -31,13 +34,17 @@ class FaceDetector:
31
  return boxes.astype(np.float32), probs.astype(np.float32)
32
 
33
 
34
- @dataclass(slots=True)
35
  class FaceEmbedder:
36
- device: str = "cuda" if torch.cuda.is_available() else "cpu"
37
  batch_size: int = 16
 
38
 
39
- def __post_init__(self) -> None:
40
- self.model = InceptionResnetV1(pretrained="vggface2").eval().to(self.device)
 
 
 
41
 
42
  @torch.no_grad()
43
  def embed(self, crops: Iterable[Image.Image]) -> np.ndarray:
 
1
  from __future__ import annotations
2
 
3
+ from dataclasses import dataclass, field
4
+ from typing import Iterable, List
 
5
 
6
  import numpy as np
7
  import torch
 
9
  from PIL import Image
10
 
11
 
12
+ @dataclass
13
  class FaceDetector:
14
+ device: str = field(default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu")
15
  min_face_size: int = 60
16
+ _model: MTCNN | None = field(init=False, default=None, repr=False)
17
+
18
+ @property
19
+ def model(self) -> MTCNN:
20
+ if self._model is None:
21
+ self._model = MTCNN(
22
+ keep_all=True,
23
+ device=self.device,
24
+ min_face_size=self.min_face_size,
25
+ post_process=False,
26
+ )
27
+ return self._model
28
 
29
  def detect(self, image: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
30
  pil = Image.fromarray(cv2_to_rgb(image))
 
34
  return boxes.astype(np.float32), probs.astype(np.float32)
35
 
36
 
37
+ @dataclass
38
  class FaceEmbedder:
39
+ device: str = field(default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu")
40
  batch_size: int = 16
41
+ _model: InceptionResnetV1 | None = field(init=False, default=None, repr=False)
42
 
43
+ @property
44
+ def model(self) -> InceptionResnetV1:
45
+ if self._model is None:
46
+ self._model = InceptionResnetV1(pretrained="vggface2").eval().to(self.device)
47
+ return self._model
48
 
49
  @torch.no_grad()
50
  def embed(self, crops: Iterable[Image.Image]) -> np.ndarray: