akagtag commited on
Commit
b854243
·
1 Parent(s): 6b11a33

Use Blocks launch entrypoint for Space runtime

Browse files
Files changed (2) hide show
  1. app.py +89 -59
  2. tests/test_zero_gpu_contract.py +10 -6
app.py CHANGED
@@ -1,15 +1,17 @@
1
  from __future__ import annotations
2
 
 
3
  import os
4
  import sys
 
5
  import traceback
 
6
 
 
7
  from fastapi import File, UploadFile
8
  from fastapi.middleware.cors import CORSMiddleware
9
- from fastapi.responses import HTMLResponse, RedirectResponse
10
- from gradio import Server
11
 
12
- from src.api.demo_page import render_demo
13
  from src.api.main import detect_image as api_detect_image
14
  from src.api.main import detect_video as api_detect_video
15
  from src.api.main import health as api_health
@@ -30,79 +32,107 @@ def _normalize_gradio_env() -> None:
30
  os.environ.pop("GRADIO_NODE_SERVER_PORT", None)
31
 
32
 
33
- _normalize_gradio_env()
34
- app = Server()
35
- app.add_middleware(
36
- CORSMiddleware,
37
- allow_origins=["*"],
38
- allow_methods=["*"],
39
- allow_headers=["*"],
40
- )
41
 
 
42
 
43
- @app.api(name="ping", queue=False)
44
- def ping() -> str:
45
- return "ok"
 
 
 
 
 
46
 
47
 
48
- @app.on_event("startup")
49
- async def startup() -> None:
50
- await api_preload()
51
-
52
-
53
- @app.get("/", response_class=HTMLResponse)
54
- async def root() -> HTMLResponse:
55
- return HTMLResponse(render_demo())
56
-
57
-
58
- @app.get("/gradio")
59
- async def gradio_compat_redirect() -> RedirectResponse:
60
- return RedirectResponse(url="/", status_code=307)
61
-
62
-
63
- @app.get("/health")
64
- async def health() -> dict:
65
- return await api_health()
66
-
67
-
68
- @app.get("/api/health")
69
- async def api_health_route() -> dict:
70
- return await health()
71
 
 
 
72
 
73
- @app.get("/health/models")
74
- async def health_models() -> dict[str, object]:
75
- return await api_health_models()
76
 
 
 
77
 
78
- @app.get("/api/health/models")
79
- async def api_health_models_route() -> dict[str, object]:
80
- return await health_models()
81
 
 
 
82
 
83
- @app.post("/detect/image", response_model=DetectionResponse)
84
- async def detect_image(file: UploadFile = File(...)) -> DetectionResponse:
85
- return await api_detect_image(file)
86
 
 
 
87
 
88
- @app.post("/api/detect/image", response_model=DetectionResponse)
89
- async def api_detect_image_route(file: UploadFile = File(...)) -> DetectionResponse:
90
- return await detect_image(file)
91
 
 
 
92
 
93
- @app.post("/detect/video", response_model=DetectionResponse)
94
- async def detect_video(file: UploadFile = File(...)) -> DetectionResponse:
95
- return await api_detect_video(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
98
- @app.post("/api/detect/video", response_model=DetectionResponse)
99
- async def api_detect_video_route(file: UploadFile = File(...)) -> DetectionResponse:
100
- return await detect_video(file)
101
 
102
 
103
  if __name__ == "__main__":
104
  _install_excepthook()
105
- app.launch(
106
- show_error=True,
107
- ssr_mode=False,
108
- )
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ import asyncio
4
  import os
5
  import sys
6
+ import time
7
  import traceback
8
+ from typing import Any
9
 
10
+ import gradio as gr
11
  from fastapi import File, UploadFile
12
  from fastapi.middleware.cors import CORSMiddleware
13
+ from fastapi.responses import RedirectResponse
 
14
 
 
15
  from src.api.main import detect_image as api_detect_image
16
  from src.api.main import detect_video as api_detect_video
17
  from src.api.main import health as api_health
 
32
  os.environ.pop("GRADIO_NODE_SERVER_PORT", None)
33
 
34
 
35
+ def _build_demo() -> gr.Blocks:
36
+ with gr.Blocks(title="GenAI-DeepDetect") as demo:
37
+ gr.Markdown(
38
+ """
39
+ # GenAI-DeepDetect
 
 
 
40
 
41
+ Gradio frontend with API routes attached to the same app for your external frontend.
42
 
43
+ Available API endpoints:
44
+ - `GET /api/health`
45
+ - `GET /api/health/models`
46
+ - `POST /api/detect/image`
47
+ - `POST /api/detect/video`
48
+ """
49
+ )
50
+ return demo
51
 
52
 
53
+ def _attach_api_routes(app: Any) -> None:
54
+ app.add_middleware(
55
+ CORSMiddleware,
56
+ allow_origins=["*"],
57
+ allow_methods=["*"],
58
+ allow_headers=["*"],
59
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ async def health() -> dict:
62
+ return await api_health()
63
 
64
+ async def api_health_route() -> dict:
65
+ return await health()
 
66
 
67
+ async def health_models() -> dict[str, object]:
68
+ return await api_health_models()
69
 
70
+ async def api_health_models_route() -> dict[str, object]:
71
+ return await health_models()
 
72
 
73
+ async def detect_image(file: UploadFile = File(...)) -> DetectionResponse:
74
+ return await api_detect_image(file)
75
 
76
+ async def api_detect_image_route(file: UploadFile = File(...)) -> DetectionResponse:
77
+ return await detect_image(file)
 
78
 
79
+ async def detect_video(file: UploadFile = File(...)) -> DetectionResponse:
80
+ return await api_detect_video(file)
81
 
82
+ async def api_detect_video_route(file: UploadFile = File(...)) -> DetectionResponse:
83
+ return await detect_video(file)
 
84
 
85
+ async def gradio_compat_redirect() -> RedirectResponse:
86
+ return RedirectResponse(url="/", status_code=307)
87
 
88
+ app.add_api_route("/health", health, methods=["GET"])
89
+ app.add_api_route("/api/health", api_health_route, methods=["GET"])
90
+ app.add_api_route("/health/models", health_models, methods=["GET"])
91
+ app.add_api_route("/api/health/models", api_health_models_route, methods=["GET"])
92
+ app.add_api_route(
93
+ "/detect/image",
94
+ detect_image,
95
+ methods=["POST"],
96
+ response_model=DetectionResponse,
97
+ )
98
+ app.add_api_route(
99
+ "/api/detect/image",
100
+ api_detect_image_route,
101
+ methods=["POST"],
102
+ response_model=DetectionResponse,
103
+ )
104
+ app.add_api_route(
105
+ "/detect/video",
106
+ detect_video,
107
+ methods=["POST"],
108
+ response_model=DetectionResponse,
109
+ )
110
+ app.add_api_route(
111
+ "/api/detect/video",
112
+ api_detect_video_route,
113
+ methods=["POST"],
114
+ response_model=DetectionResponse,
115
+ )
116
+ app.add_api_route("/gradio", gradio_compat_redirect, methods=["GET"])
117
 
118
 
119
+ demo = _build_demo().queue()
 
 
120
 
121
 
122
  if __name__ == "__main__":
123
  _install_excepthook()
124
+ _normalize_gradio_env()
125
+ try:
126
+ asyncio.run(api_preload())
127
+ app, _, _ = demo.launch(
128
+ prevent_thread_lock=True,
129
+ show_error=True,
130
+ ssr_mode=False,
131
+ app_kwargs={"docs_url": "/docs", "redoc_url": "/redoc"},
132
+ )
133
+ _attach_api_routes(app)
134
+ while True:
135
+ time.sleep(60)
136
+ except Exception:
137
+ traceback.print_exc()
138
+ sys.exit(1)
tests/test_zero_gpu_contract.py CHANGED
@@ -34,13 +34,17 @@ def test_readme_declares_gradio_space_metadata():
34
  def test_app_mounts_gradio_onto_fastapi():
35
  source = (ROOT / "app.py").read_text(encoding="utf-8")
36
 
37
- assert "from gradio import Server" in source
38
- assert "app = Server()" in source
39
- assert '@app.get("/api/health")' in source
40
- assert '@app.post("/api/detect/image"' in source
41
- assert '@app.post("/api/detect/video"' in source
 
 
 
 
42
  assert "show_error=True" in source
43
- assert "app.launch(" in source
44
  assert "ssr_mode=False" in source
45
  assert 'GRADIO_SSR_MODE"] = "False"' in source
46
  assert "GRADIO_NODE_SERVER_PORT" in source
 
34
  def test_app_mounts_gradio_onto_fastapi():
35
  source = (ROOT / "app.py").read_text(encoding="utf-8")
36
 
37
+ assert "from gradio import Server" not in source
38
+ assert "demo = _build_demo().queue()" in source
39
+ assert "prevent_thread_lock=True" in source
40
+ assert "while True:" in source
41
+ assert 'app_kwargs={"docs_url": "/docs", "redoc_url": "/redoc"}' in source
42
+ assert 'app.add_api_route("/api/health"' in source
43
+ assert "/api/detect/image" in source
44
+ assert "/api/detect/video" in source
45
+ assert "def _attach_api_routes(app: Any) -> None:" in source
46
  assert "show_error=True" in source
47
+ assert "demo.launch(" in source
48
  assert "ssr_mode=False" in source
49
  assert 'GRADIO_SSR_MODE"] = "False"' in source
50
  assert "GRADIO_NODE_SERVER_PORT" in source