Luigi commited on
Commit
6452205
·
1 Parent(s): 22f2362

improve model uploader ui

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -8,7 +8,6 @@ from PIL import Image
8
  spec = importlib.util.find_spec('mmdet')
9
  if spec and spec.origin:
10
  src = open(spec.origin, encoding='utf-8').read()
11
- # strip out the mmcv_minimum_version…assert… block up to __all__
12
  patched = re.sub(r'(?ms)^[ \t]*mmcv_minimum_version.*?^__all__', '__all__', src)
13
  m = importlib.util.module_from_spec(spec)
14
  m.__loader__ = spec.loader
@@ -28,16 +27,18 @@ def load_inferencer(checkpoint_path=None, device=None):
28
 
29
  # ——— Gradio prediction function ———
30
  @spaces.GPU()
31
- def predict(image: Image.Image, checkpoint: str = None):
32
  # save upload to temp file
33
  inp_path = "/tmp/upload.jpg"
34
  image.save(inp_path)
35
 
 
 
 
36
  vis_dir = "/tmp/vis"
37
  os.makedirs(vis_dir, exist_ok=True)
38
 
39
- inferencer = load_inferencer(checkpoint_path=checkpoint, device=None)
40
- # run inference & visualization
41
  for result in inferencer(
42
  inputs=inp_path,
43
  bbox_thr=0.1,
@@ -48,7 +49,6 @@ def predict(image: Image.Image, checkpoint: str = None):
48
  ):
49
  pass
50
 
51
- # return the first visualization
52
  out_files = sorted(os.listdir(vis_dir))
53
  if out_files:
54
  return Image.open(os.path.join(vis_dir, out_files[0]))
@@ -59,11 +59,11 @@ demo = gr.Interface(
59
  fn=predict,
60
  inputs=[
61
  gr.Image(type="pil", label="Upload Image"),
62
- gr.Textbox(label="RTMO PyTorch Checkpoint Path (optional)")
63
  ],
64
  outputs=gr.Image(type="pil", label="Annotated Image"),
65
  title="RTMO Pose Demo",
66
- description="Upload an image, optionally supply a RTMO .pth checkpoint, and see 2D pose annotation.",
67
  )
68
 
69
  def main():
 
8
  spec = importlib.util.find_spec('mmdet')
9
  if spec and spec.origin:
10
  src = open(spec.origin, encoding='utf-8').read()
 
11
  patched = re.sub(r'(?ms)^[ \t]*mmcv_minimum_version.*?^__all__', '__all__', src)
12
  m = importlib.util.module_from_spec(spec)
13
  m.__loader__ = spec.loader
 
27
 
28
  # ——— Gradio prediction function ———
29
  @spaces.GPU()
30
+ def predict(image: Image.Image, checkpoint):
31
  # save upload to temp file
32
  inp_path = "/tmp/upload.jpg"
33
  image.save(inp_path)
34
 
35
+ # determine checkpoint path if user uploaded
36
+ ckpt_path = checkpoint.name if checkpoint else None
37
+
38
  vis_dir = "/tmp/vis"
39
  os.makedirs(vis_dir, exist_ok=True)
40
 
41
+ inferencer = load_inferencer(checkpoint_path=ckpt_path, device=None)
 
42
  for result in inferencer(
43
  inputs=inp_path,
44
  bbox_thr=0.1,
 
49
  ):
50
  pass
51
 
 
52
  out_files = sorted(os.listdir(vis_dir))
53
  if out_files:
54
  return Image.open(os.path.join(vis_dir, out_files[0]))
 
59
  fn=predict,
60
  inputs=[
61
  gr.Image(type="pil", label="Upload Image"),
62
+ gr.File(file_types=['.pth'], label="Upload RTMO .pth Checkpoint (optional)")
63
  ],
64
  outputs=gr.Image(type="pil", label="Annotated Image"),
65
  title="RTMO Pose Demo",
66
+ description="Upload an image and (optionally) a RTMO .pth checkpoint to get 2D pose annotation.",
67
  )
68
 
69
  def main():