ricklon commited on
Commit
2c2efb5
·
1 Parent(s): 2987995

Handle more example payload shapes in workspace loader

Browse files
Files changed (2) hide show
  1. app.py +22 -2
  2. tests/test_example_loader.py +77 -0
app.py CHANGED
@@ -1484,9 +1484,29 @@ def load_image_with_size(file_path, page_num=1, workspace_scale=WORKSPACE_DEFAUL
1484
  def load_example_into_workspace(example_value):
1485
  if example_value is None:
1486
  return None, None, None
1487
- if isinstance(example_value, str):
1488
- img = load_image(example_value, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1489
  return _prepare_workspace_image(img, WORKSPACE_DEFAULT_SCALE)
 
1490
  if isinstance(example_value, Image.Image):
1491
  img = example_value
1492
  else:
 
1484
  def load_example_into_workspace(example_value):
1485
  if example_value is None:
1486
  return None, None, None
1487
+
1488
+ file_path = None
1489
+ if isinstance(example_value, os.PathLike):
1490
+ file_path = os.fspath(example_value)
1491
+ elif isinstance(example_value, str):
1492
+ file_path = example_value
1493
+ elif isinstance(example_value, dict):
1494
+ path_candidate = example_value.get("path") or example_value.get("name")
1495
+ if isinstance(path_candidate, os.PathLike):
1496
+ file_path = os.fspath(path_candidate)
1497
+ elif isinstance(path_candidate, str):
1498
+ file_path = path_candidate
1499
+ elif isinstance(example_value, (list, tuple)) and example_value:
1500
+ first = example_value[0]
1501
+ if isinstance(first, os.PathLike):
1502
+ file_path = os.fspath(first)
1503
+ elif isinstance(first, str):
1504
+ file_path = first
1505
+
1506
+ if file_path:
1507
+ img = load_image(file_path, 1)
1508
  return _prepare_workspace_image(img, WORKSPACE_DEFAULT_SCALE)
1509
+
1510
  if isinstance(example_value, Image.Image):
1511
  img = example_value
1512
  else:
tests/test_example_loader.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import os
3
+ import pathlib
4
+ import types
5
+ import unittest
6
+
7
+
8
+ def _load_example_loader():
9
+ app_path = pathlib.Path(__file__).resolve().parents[1] / "app.py"
10
+ source = app_path.read_text(encoding="utf-8")
11
+ module = ast.parse(source, filename=str(app_path))
12
+
13
+ wanted = {
14
+ "_to_rgba_image",
15
+ "_scale_workspace_image",
16
+ "_prepare_workspace_image",
17
+ "load_example_into_workspace",
18
+ }
19
+ fn_nodes = [n for n in module.body if isinstance(n, ast.FunctionDef) and n.name in wanted]
20
+ fn_nodes.sort(key=lambda n: n.lineno)
21
+
22
+ test_mod = ast.Module(body=fn_nodes, type_ignores=[])
23
+ code = compile(test_mod, filename=str(app_path), mode="exec")
24
+
25
+ class _FakeLoadedImage:
26
+ def __init__(self, width=1068, height=3074):
27
+ self.width = width
28
+ self.height = height
29
+
30
+ class _FakeImageModule:
31
+ class Image: # pragma: no cover - marker type for isinstance checks
32
+ pass
33
+
34
+ @staticmethod
35
+ def open(path):
36
+ return _FakeLoadedImage()
37
+
38
+ fake_np = types.SimpleNamespace(
39
+ ndarray=type("ndarray", (), {}),
40
+ uint8=int,
41
+ stack=lambda *args, **kwargs: None,
42
+ full_like=lambda *args, **kwargs: None,
43
+ concatenate=lambda *args, **kwargs: None,
44
+ )
45
+
46
+ scope = {
47
+ "os": os,
48
+ "np": fake_np,
49
+ "Image": _FakeImageModule,
50
+ "WORKSPACE_DEFAULT_SCALE": 89,
51
+ "load_image": lambda file_path, page_num=1: _FakeLoadedImage(),
52
+ }
53
+ exec(code, scope)
54
+ return scope["load_example_into_workspace"]
55
+
56
+
57
+ class ExampleLoaderTests(unittest.TestCase):
58
+ def test_accepts_common_gradio_payload_shapes(self):
59
+ loader = _load_example_loader()
60
+ sample = pathlib.Path("examples/2022-0922 Section 15 Notes.png")
61
+
62
+ for payload in (
63
+ str(sample),
64
+ sample,
65
+ {"path": str(sample)},
66
+ {"name": str(sample)},
67
+ [str(sample)],
68
+ (str(sample),),
69
+ ):
70
+ display, size, base = loader(payload)
71
+ self.assertIsNotNone(display)
72
+ self.assertIsNotNone(base)
73
+ self.assertEqual((1068, 3074), size)
74
+
75
+
76
+ if __name__ == "__main__":
77
+ unittest.main()