File size: 3,795 Bytes
2357a36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from __future__ import annotations

import importlib
import inspect
import sys
import types
import unittest
from pathlib import Path
from unittest.mock import patch


ROOT_DIR = Path(__file__).resolve().parents[1]
SRC_DIR = ROOT_DIR / "src"
sys.path.insert(0, str(ROOT_DIR))
sys.path.insert(0, str(SRC_DIR))


def _decorator_factory(_path: str | None = None, **_kwargs: object):
    def decorator(function):
        return function

    return decorator


class _ServerStub:
    def __init__(self, **_kwargs: object) -> None:
        pass

    def get(self, *_args: object, **_kwargs: object):
        return _decorator_factory()

    def post(self, *_args: object, **_kwargs: object):
        return _decorator_factory()

    def mount(self, *_args: object, **_kwargs: object) -> None:
        pass

    def launch(self, **_kwargs: object) -> None:
        pass


class _ObjectStub:
    def __init__(self, *_args: object, **_kwargs: object) -> None:
        pass


def _default_marker(*_args: object, default: object = None, **_kwargs: object) -> object:
    return default


def _app_import_stubs() -> dict[str, types.ModuleType]:
    gradio = types.ModuleType("gradio")
    gradio.Server = _ServerStub

    fastapi = types.ModuleType("fastapi")
    fastapi.File = _default_marker
    fastapi.Form = _default_marker
    fastapi.UploadFile = _ObjectStub
    fastapi.HTTPException = type(
        "HTTPException",
        (Exception,),
        {"__init__": _ObjectStub.__init__},
    )

    responses = types.ModuleType("fastapi.responses")
    responses.FileResponse = _ObjectStub
    responses.HTMLResponse = _ObjectStub
    responses.StreamingResponse = _ObjectStub

    staticfiles = types.ModuleType("fastapi.staticfiles")
    staticfiles.StaticFiles = _ObjectStub

    exercise_catalog = types.ModuleType("pozify.exercise_catalog")
    exercise_catalog.USER_SELECTABLE_EXERCISES = ["squat"]

    pipeline = types.ModuleType("pozify.pipeline")

    def run_pipeline(**_kwargs: object) -> dict[str, object]:
        return {"source": "pipeline"}

    pipeline.run_pipeline = run_pipeline

    return {
        "gradio": gradio,
        "fastapi": fastapi,
        "fastapi.responses": responses,
        "fastapi.staticfiles": staticfiles,
        "pozify.exercise_catalog": exercise_catalog,
        "pozify.pipeline": pipeline,
    }


def _import_app_module():
    sys.modules.pop("app", None)
    return importlib.import_module("app")


class AppZeroGpuProgressTests(unittest.TestCase):
    def tearDown(self) -> None:
        sys.modules.pop("app", None)

    def test_analysis_pipeline_is_not_wrapped_at_api_layer(self) -> None:
        with patch.dict(sys.modules, _app_import_stubs()):
            app = _import_app_module()

        signature = inspect.signature(app._run_analysis_pipeline)
        self.assertIn("progress", signature.parameters)
        self.assertEqual(app._run_analysis_pipeline.__name__, "_run_analysis_pipeline")

    def test_analysis_pipeline_forwards_progress_callback_inside_api_process(self) -> None:
        with patch.dict(sys.modules, _app_import_stubs()):
            app = _import_app_module()

        progress_events: list[dict[str, object]] = []
        progress_callback = progress_events.append

        def local_pipeline(**kwargs: object) -> dict[str, object]:
            self.assertIs(kwargs["progress"], progress_callback)
            return {"source": "local"}

        with patch.object(app, "run_pipeline", side_effect=local_pipeline):
            result = app._run_analysis_pipeline(
                "video.mp4",
                {"goal": "beginner_practice"},
                False,
                progress_callback,
            )

        self.assertEqual(result, {"source": "local"})


if __name__ == "__main__":
    unittest.main()