hysts HF Staff commited on
Commit
8cf6e41
·
1 Parent(s): fc87c68
Files changed (4) hide show
  1. .pre-commit-config.yaml +35 -0
  2. .style.yapf +5 -0
  3. README.md +1 -1
  4. app.py +41 -74
.pre-commit-config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.2.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: double-quote-string-fixer
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ['--fix=lf']
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.4
19
+ hooks:
20
+ - id: docformatter
21
+ args: ['--in-place']
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.12.0
24
+ hooks:
25
+ - id: isort
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v0.991
28
+ hooks:
29
+ - id: mypy
30
+ args: ['--ignore-missing-imports']
31
+ - repo: https://github.com/google/yapf
32
+ rev: v0.32.0
33
+ hooks:
34
+ - id: yapf
35
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 📚
4
  colorFrom: purple
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.0.5
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: purple
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  from __future__ import annotations
4
 
5
- import argparse
6
  import functools
7
  import os
8
  import pathlib
@@ -28,25 +27,10 @@ sys.path.insert(0, 'deep-head-pose/code')
28
  from hopenet import Hopenet
29
  from ibug.face_detection import RetinaFacePredictor
30
 
31
- TITLE = 'natanielruiz/deep-head-pose'
32
  DESCRIPTION = 'This is an unofficial demo for https://github.com/natanielruiz/deep-head-pose.'
33
- ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.hopenet" alt="visitor badge"/></center>'
34
 
35
- TOKEN = os.environ['TOKEN']
36
-
37
-
38
- def parse_args() -> argparse.Namespace:
39
- parser = argparse.ArgumentParser()
40
- parser.add_argument('--device', type=str, default='cpu')
41
- parser.add_argument('--theme', type=str)
42
- parser.add_argument('--live', action='store_true')
43
- parser.add_argument('--share', action='store_true')
44
- parser.add_argument('--port', type=int)
45
- parser.add_argument('--disable-queue',
46
- dest='enable_queue',
47
- action='store_false')
48
- parser.add_argument('--allow-flagging', type=str, default='never')
49
- return parser.parse_args()
50
 
51
 
52
  def load_sample_images() -> list[pathlib.Path]:
@@ -59,7 +43,7 @@ def load_sample_images() -> list[pathlib.Path]:
59
  path = huggingface_hub.hf_hub_download(dataset_repo,
60
  name,
61
  repo_type='dataset',
62
- use_auth_token=TOKEN)
63
  with tarfile.open(path) as f:
64
  f.extractall(image_dir.as_posix())
65
  return sorted(image_dir.rglob('*.jpg'))
@@ -68,7 +52,7 @@ def load_sample_images() -> list[pathlib.Path]:
68
  def load_model(model_name: str, device: torch.device) -> nn.Module:
69
  path = huggingface_hub.hf_hub_download('hysts/Hopenet',
70
  f'models/{model_name}.pkl',
71
- use_auth_token=TOKEN)
72
  state_dict = torch.load(path, map_location='cpu')
73
  model = Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66)
74
  model.load_state_dict(state_dict)
@@ -162,57 +146,40 @@ def run(image: np.ndarray, model_name: str, face_detector: RetinaFacePredictor,
162
  return res[:, :, ::-1]
163
 
164
 
165
- def main():
166
- args = parse_args()
167
- device = torch.device(args.device)
168
-
169
- face_detector = RetinaFacePredictor(
170
- threshold=0.8,
171
- device=device,
172
- model=RetinaFacePredictor.get_model('mobilenet0.25'))
173
-
174
- model_names = [
175
- 'hopenet_alpha1',
176
- 'hopenet_alpha2',
177
- 'hopenet_robust_alpha1',
178
- ]
179
- models = {name: load_model(name, device) for name in model_names}
180
-
181
- transform = create_transform()
182
-
183
- func = functools.partial(run,
184
- face_detector=face_detector,
185
- models=models,
186
- transform=transform,
187
- device=device)
188
- func = functools.update_wrapper(func, run)
189
-
190
- image_paths = load_sample_images()
191
- examples = [[path.as_posix(), model_names[0]] for path in image_paths]
192
-
193
- gr.Interface(
194
- func,
195
- [
196
- gr.inputs.Image(type='numpy', label='Input'),
197
- gr.inputs.Radio(model_names,
198
- type='value',
199
- default=model_names[0],
200
- label='Model'),
201
- ],
202
- gr.outputs.Image(type='numpy', label='Output'),
203
- examples=examples,
204
- title=TITLE,
205
- description=DESCRIPTION,
206
- article=ARTICLE,
207
- theme=args.theme,
208
- allow_flagging=args.allow_flagging,
209
- live=args.live,
210
- ).launch(
211
- enable_queue=args.enable_queue,
212
- server_port=args.port,
213
- share=args.share,
214
- )
215
-
216
-
217
- if __name__ == '__main__':
218
- main()
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import functools
6
  import os
7
  import pathlib
 
27
  from hopenet import Hopenet
28
  from ibug.face_detection import RetinaFacePredictor
29
 
30
+ TITLE = 'Hopenet'
31
  DESCRIPTION = 'This is an unofficial demo for https://github.com/natanielruiz/deep-head-pose.'
 
32
 
33
+ HF_TOKEN = os.getenv('HF_TOKEN')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
 
36
  def load_sample_images() -> list[pathlib.Path]:
 
43
  path = huggingface_hub.hf_hub_download(dataset_repo,
44
  name,
45
  repo_type='dataset',
46
+ use_auth_token=HF_TOKEN)
47
  with tarfile.open(path) as f:
48
  f.extractall(image_dir.as_posix())
49
  return sorted(image_dir.rglob('*.jpg'))
 
52
  def load_model(model_name: str, device: torch.device) -> nn.Module:
53
  path = huggingface_hub.hf_hub_download('hysts/Hopenet',
54
  f'models/{model_name}.pkl',
55
+ use_auth_token=HF_TOKEN)
56
  state_dict = torch.load(path, map_location='cpu')
57
  model = Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66)
58
  model.load_state_dict(state_dict)
 
146
  return res[:, :, ::-1]
147
 
148
 
149
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
150
+ face_detector = RetinaFacePredictor(
151
+ threshold=0.8,
152
+ device=device,
153
+ model=RetinaFacePredictor.get_model('mobilenet0.25'))
154
+
155
+ model_names = [
156
+ 'hopenet_alpha1',
157
+ 'hopenet_alpha2',
158
+ 'hopenet_robust_alpha1',
159
+ ]
160
+ models = {name: load_model(name, device) for name in model_names}
161
+ transform = create_transform()
162
+
163
+ func = functools.partial(run,
164
+ face_detector=face_detector,
165
+ models=models,
166
+ transform=transform,
167
+ device=device)
168
+
169
+ image_paths = load_sample_images()
170
+ examples = [[path.as_posix(), model_names[0]] for path in image_paths]
171
+
172
+ gr.Interface(
173
+ fn=func,
174
+ inputs=[
175
+ gr.Image(type='numpy', label='Input'),
176
+ gr.Radio(model_names,
177
+ type='value',
178
+ default=model_names[0],
179
+ label='Model'),
180
+ ],
181
+ outputs=gr.Image(type='numpy', label='Output'),
182
+ examples=examples,
183
+ title=TITLE,
184
+ description=DESCRIPTION,
185
+ ).launch(show_api=False)