celik-muhammed commited on
Commit
4ee3014
Β·
verified Β·
1 Parent(s): 385004a

Upload 3 files

Browse files
Files changed (1) hide show
  1. app.py +398 -93
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # scikit-plots/ai-model Β· app.py v2.0.0
2
  #
3
  # PURPOSE
4
  # ───────
@@ -25,23 +25,38 @@
25
  # + Long-term maintainability
26
  # + Minimal hidden behavior
27
  #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # CRITICAL: ZEROGPU ARCHITECTURE REQUIREMENT
29
  # ──────────────────────────────────────────
30
  # On HuggingFace Spaces with sdk: gradio, ZeroGPU hooks attach to the
31
  # Gradio server lifecycle. Gradio MUST be the ASGI root application.
32
  #
33
- # CORRECT (v2.0.0):
34
  # Gradio (gr.Blocks) is the ASGI root.
35
  # REST routes are registered on Gradio's internal FastAPI instance
36
  # via gradio.routes.App.create_app(demo).
37
  # @spaces.GPU is active on _generate.
38
  #
39
- # WRONG (v1.x.x):
40
  # FastAPI was the ASGI root; Gradio was a child via gr.mount_gradio_app.
41
  # @spaces.GPU was commented out.
42
  # ZeroGPU hooks never activated.
43
  #
44
- # CRITICAL: ZERO-GPU MODEL PATTERN
45
  # ────────────────────────────────
46
  # CORRECT:
47
  # * Tokenizer loaded on CPU at import time.
@@ -49,14 +64,16 @@
49
  # * Model moved to GPU ONLY inside @spaces.GPU scope.
50
  # * Model returned to CPU in finally block after inference.
51
  # * torch.cuda.empty_cache() called after every request.
 
52
  #
53
  # WRONG:
54
  # * pipeline(... device_map="auto") at module level.
55
  # * model.to("cuda") outside @spaces.GPU scope.
56
  # * Holding GPU between requests.
57
  # * Blocking asyncio event loop with synchronous inference.
 
58
  #
59
- # ASSEMBLY DIAGRAM (v2.0.0)
60
  # ─────────────────────────
61
  #
62
  # HuggingFace Spaces
@@ -77,6 +94,49 @@
77
  # MAX_BODY_BYTES
78
  # Maximum accepted request size.
79
  #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # SPDX-License-Identifier: BSD-3-Clause
81
  # Authors: The scikit-plots developers
82
 
@@ -98,6 +158,10 @@ Developer note
98
  HuggingFace Spaces exports the ``app`` variable. It must be the
99
  Gradio-rooted ASGI application returned by ``App.create_app``.
100
 
 
 
 
 
101
  User note
102
  The Gradio UI at ``/`` is for manual testing only.
103
  Production traffic routes through the proxy Space.
@@ -109,9 +173,10 @@ import asyncio
109
  import json
110
  import logging
111
  import os
 
112
  import time
113
  import uuid
114
- from typing import Final
115
 
116
  import gradio as gr # type: ignore[]
117
  import spaces # type: ignore[] # ZeroGPU β€” must be imported before torch
@@ -231,6 +296,10 @@ _MAX_NEW_TOKENS_FLOOR: Final[int] = 1
231
  _MAX_NEW_TOKENS_CEIL: Final[int] = 4096
232
  _MAX_NEW_TOKENS_DEFAULT: Final[int] = 512
233
 
 
 
 
 
234
  DEFAULT_MAX_BODY_BYTES: Final[int] = 10 * 1024 * 1024
235
 
236
  MAX_BODY_BYTES: Final[int] = _safe_int(
@@ -295,15 +364,12 @@ def _clamp_max_tokens(
295
  parsed = int(value)
296
  except (TypeError, ValueError) as exc:
297
  raise ValueError(
298
- f"max_tokens must be integer, got {value!r}"
299
  ) from exc
300
 
301
  return max(
302
  _MAX_NEW_TOKENS_FLOOR,
303
- min(
304
- parsed,
305
- _MAX_NEW_TOKENS_CEIL,
306
- ),
307
  )
308
 
309
 
@@ -369,9 +435,119 @@ def _validate_messages(
369
  return validated
370
 
371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  logger.info("Validation helpers initialized successfully.")
373
 
374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  # ─────────────────────────────────────────────────────────────────────────────
376
  # Model loading
377
  # ─────────────────────────────────────────────────────────────────────────────
@@ -411,6 +587,8 @@ logger.info("Model loaded on CPU successfully.")
411
  # * GPU exists only inside @spaces.GPU scope.
412
  # * Model moved CPU β†’ GPU at entry; GPU β†’ CPU in finally.
413
  # * VRAM fully released after every request.
 
 
414
  # * This function is called from both:
415
  # - Gradio event handlers (direct sync call via _gradio_respond)
416
  # - FastAPI route handlers (via asyncio.to_thread in _generate_async)
@@ -420,6 +598,8 @@ logger.info("Model loaded on CPU successfully.")
420
  def _generate(
421
  messages: list[dict[str, str]],
422
  max_new_tokens: int = _MAX_NEW_TOKENS_DEFAULT,
 
 
423
  ) -> str:
424
  """
425
  Run generation using ZeroGPU.
@@ -432,6 +612,14 @@ def _generate(
432
  max_new_tokens : int, default=512
433
  Maximum generated tokens.
434
 
 
 
 
 
 
 
 
 
435
  Returns
436
  -------
437
  str
@@ -443,24 +631,36 @@ def _generate(
443
  On invalid inputs or missing chat template.
444
 
445
  RuntimeError
446
- On inference failure.
447
 
448
  Notes
449
  -----
450
  Developer note
451
  GPU is acquired automatically by ``@spaces.GPU``.
452
 
453
- Model is explicitly moved CPU β†’ GPU β†’ CPU during each request
454
- to avoid persistent VRAM ownership between requests.
 
 
 
 
 
 
 
 
 
 
 
455
 
456
  ``finally`` block ensures CPU return and cache clear even if
457
- inference raises. The inner ``try`` around ``_model.cpu()``
458
- guarantees ``torch.cuda.empty_cache()`` runs regardless of
459
- whether the CPU move itself fails.
 
460
 
461
- This function is intentionally synchronous. Async routes call
462
  it via ``_generate_async`` which wraps it with
463
- ``asyncio.to_thread``. Gradio event handlers call it directly
464
  because Gradio dispatches handlers in its own thread pool,
465
  outside the asyncio event loop.
466
 
@@ -469,7 +669,6 @@ def _generate(
469
  Use ``_generate_async`` from FastAPI routes.
470
  """
471
  validated_messages = _validate_messages(messages)
472
-
473
  max_new_tokens = _clamp_max_tokens(max_new_tokens)
474
 
475
  if not getattr(_tokenizer, "chat_template", None):
@@ -481,69 +680,99 @@ def _generate(
481
  logger.info(
482
  "GPU generation starting | "
483
  "messages=%d | "
484
- "max_new_tokens=%d",
 
 
485
  len(validated_messages),
486
  max_new_tokens,
 
 
487
  )
488
 
489
- try:
490
- logger.info("Moving model to GPU...")
 
491
 
492
- _model.cuda()
493
 
494
- input_ids = _tokenizer.apply_chat_template(
495
- validated_messages,
496
- add_generation_prompt=True,
497
- return_tensors="pt",
498
- )
 
499
 
500
- input_ids = input_ids.cuda()
501
 
502
- logger.info("Generation started.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
504
- with torch.no_grad():
505
- output_ids = _model.generate(
506
- input_ids,
507
- max_new_tokens=max_new_tokens,
508
- do_sample=True,
509
- temperature=0.7,
510
- pad_token_id=_tokenizer.eos_token_id,
511
  )
512
 
513
- new_token_ids = output_ids[0][input_ids.shape[-1]:]
514
-
515
- decoded = _tokenizer.decode(
516
- new_token_ids,
517
- skip_special_tokens=True,
518
- )
519
 
520
- logger.info("Generation completed successfully.")
 
 
 
 
521
 
522
- return decoded
523
 
524
- except ValueError:
525
- raise
526
 
527
- except Exception as exc:
528
- logger.exception("Inference failure.")
529
 
530
- raise RuntimeError(
531
- f"Inference failed: {exc}"
532
- ) from exc
533
 
534
- finally:
535
- logger.info(
536
- "Returning model to CPU "
537
- "and clearing CUDA cache..."
538
- )
539
 
540
- try:
541
- _model.cpu()
542
  finally:
543
- if torch.cuda.is_available():
544
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
545
 
546
- logger.info("GPU resources released.")
547
 
548
 
549
  # ─────────────────────────────────────────────────────────────────────────────
@@ -564,6 +793,8 @@ def _generate(
564
  async def _generate_async(
565
  messages: list[dict[str, str]],
566
  max_new_tokens: int,
 
 
567
  ) -> str:
568
  """
569
  Async wrapper for GPU generation.
@@ -576,6 +807,12 @@ async def _generate_async(
576
  max_new_tokens : int
577
  Generation token limit.
578
 
 
 
 
 
 
 
579
  Returns
580
  -------
581
  str
@@ -595,6 +832,8 @@ async def _generate_async(
595
  _generate,
596
  messages,
597
  max_new_tokens,
 
 
598
  )
599
 
600
 
@@ -651,7 +890,7 @@ async def _read_bounded_body(
651
 
652
  def _parse_request_body(
653
  raw: bytes,
654
- ) -> dict:
655
  """
656
  Decode and parse a UTF-8 JSON request body.
657
 
@@ -662,7 +901,7 @@ def _parse_request_body(
662
 
663
  Returns
664
  -------
665
- dict
666
  Parsed JSON payload.
667
 
668
  Raises
@@ -767,7 +1006,7 @@ def _build_completion_response(
767
  model_id: str,
768
  prompt_tokens: int,
769
  completion_tokens: int,
770
- ) -> dict:
771
  """
772
  Build an OpenAI-compatible chat completion response payload.
773
 
@@ -787,7 +1026,7 @@ def _build_completion_response(
787
 
788
  Returns
789
  -------
790
- dict
791
  OpenAI-compatible ``chat.completion`` object.
792
 
793
  Notes
@@ -800,6 +1039,9 @@ def _build_completion_response(
800
  current ``_generate`` implementation does not expose partial
801
  stop conditions. Extend this if streaming or early stopping
802
  is added.
 
 
 
803
 
804
  User note
805
  The returned dict is compatible with OpenAI Python SDK
@@ -810,11 +1052,21 @@ def _build_completion_response(
810
  .. [1] OpenAI API reference: Chat completions object
811
  https://platform.openai.com/docs/api-reference/chat/object
812
  """
 
 
 
 
 
 
 
 
 
813
  return {
814
  "id": f"chatcmpl-{uuid.uuid4().hex}",
815
  "object": "chat.completion",
816
  "created": int(time.time()),
817
  "model": model_id,
 
818
  "choices": [
819
  {
820
  "index": 0,
@@ -995,6 +1247,8 @@ def _gradio_respond(
995
  message: str,
996
  history: list,
997
  max_new_tokens: int,
 
 
998
  ) -> str:
999
  """
1000
  Gradio ``ChatInterface`` event handler.
@@ -1011,6 +1265,12 @@ def _gradio_respond(
1011
  max_new_tokens : int
1012
  Maximum tokens to generate, sourced from the UI slider.
1013
 
 
 
 
 
 
 
1014
  Returns
1015
  -------
1016
  str
@@ -1022,7 +1282,8 @@ def _gradio_respond(
1022
  If ``message`` is empty after stripping.
1023
 
1024
  RuntimeError
1025
- Propagated from ``_generate`` on inference failure.
 
1026
 
1027
  Notes
1028
  -----
@@ -1056,24 +1317,30 @@ def _gradio_respond(
1056
  logger.info(
1057
  "Gradio inference | "
1058
  "history_turns=%d | "
1059
- "max_new_tokens=%d",
 
 
1060
  len(messages) - 1,
1061
  max_new_tokens,
 
 
1062
  )
1063
 
1064
  return _generate(
1065
  messages,
1066
  max_new_tokens,
 
 
1067
  )
1068
 
1069
 
1070
  # ─────────────────────────────────────────────────────────────────────────────
1071
  # Gradio UI
1072
  # ─────────────────────────────────────────────────────────────────────────────
1073
- # v2.0.0: Gradio is the ASGI ROOT β€” not a child sub-app mounted on FastAPI.
1074
  # This is required for ZeroGPU to activate on HuggingFace Spaces.
1075
  #
1076
- # The Gradio UI is now served at / (root) instead of /ui.
1077
  # Custom REST routes are added to Gradio's internal FastAPI instance below.
1078
 
1079
  _UI_WARNING = """\
@@ -1103,10 +1370,26 @@ with _gradio_ui:
1103
  step=1,
1104
  label="max_tokens",
1105
  info=(
1106
- f"Range: {_MAX_NEW_TOKENS_FLOOR}-{_MAX_NEW_TOKENS_CEIL}. "
1107
  f"Default: {_MAX_NEW_TOKENS_DEFAULT}."
1108
  ),
1109
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1110
  ],
1111
  additional_inputs_accordion="Generation settings",
1112
  )
@@ -1120,14 +1403,8 @@ logger.info(
1120
  # ─────────────────────────────────────────────────────────────────────────────
1121
  # App assembly β€” HuggingFace Space export
1122
  # ─────────────────────────────────────────────────────────────────────────────
1123
- # v2.0.0 ARCHITECTURE CHANGE:
1124
- #
1125
- # v1.x.x (BROKEN):
1126
- # _api = FastAPI() ← FastAPI was ASGI root
1127
- # app = gr.mount_gradio_app(_api, ui, "/ui") ← Gradio was child
1128
- # ZeroGPU never activated β€” wrong root.
1129
  #
1130
- # v2.0.0 (CORRECT):
1131
  # app = _GradioApp.create_app(_gradio_ui) ← Gradio is ASGI root
1132
  # @app.get/post(...) ← routes on Gradio's FastAPI
1133
  # ZeroGPU activates correctly.
@@ -1157,6 +1434,7 @@ app.add_middleware(
1157
  allow_headers=[
1158
  "Content-Type",
1159
  ],
 
1160
  )
1161
 
1162
  logger.info(
@@ -1195,7 +1473,7 @@ async def health() -> JSONResponse:
1195
  Examples
1196
  --------
1197
  >>> # curl http://localhost:7860/health
1198
- ... # {"status": "ok", "model": "...", "version": "2.0.0"}
1199
  """
1200
  logger.info("GET /health")
1201
 
@@ -1203,7 +1481,7 @@ async def health() -> JSONResponse:
1203
  content={
1204
  "status": "ok",
1205
  "model": MODEL_ID,
1206
- "version": "2.0.0",
1207
  },
1208
  status_code=200,
1209
  )
@@ -1229,15 +1507,15 @@ async def chat_completions( # noqa: PLR0911
1229
  Returns
1230
  -------
1231
  JSONResponse
1232
- HTTP 200 with an OpenAI-compatible completion payload.
1233
 
1234
- Raises
1235
- ------
1236
- JSONResponse
1237
  HTTP 413 if the body exceeds ``MAX_BODY_BYTES``.
1238
- HTTP 400 if the body is not valid UTF-8 JSON.
1239
- HTTP 400 if ``messages`` or ``max_tokens`` fail validation.
1240
- HTTP 500 on inference failure.
 
 
 
1241
 
1242
  Notes
1243
  -----
@@ -1246,9 +1524,9 @@ async def chat_completions( # noqa: PLR0911
1246
 
1247
  1. Read and bound-check raw body bytes (413 guard).
1248
  2. Decode and parse JSON (400 guard).
1249
- 3. Extract ``messages`` and ``max_tokens`` fields.
1250
- 4. Validate with ``_validate_messages`` and
1251
- ``_clamp_max_tokens`` (400 guard).
1252
  5. Count prompt tokens on CPU (no GPU needed).
1253
  6. Dispatch to ``_generate_async`` which offloads to
1254
  ``@spaces.GPU`` via ``asyncio.to_thread``.
@@ -1261,6 +1539,10 @@ async def chat_completions( # noqa: PLR0911
1261
  * ``RuntimeError`` β†’ 500 (wrapped inference failure from ``_generate``)
1262
  * ``Exception`` β†’ 500 (unexpected catch-all, never leaks internals)
1263
 
 
 
 
 
1264
  User note
1265
  Compatible with the OpenAI Python SDK:
1266
 
@@ -1274,6 +1556,8 @@ async def chat_completions( # noqa: PLR0911
1274
  response = client.chat.completions.create(
1275
  model="any",
1276
  messages=[{"role": "user", "content": "Hello"}],
 
 
1277
  )
1278
  """
1279
  request_id = uuid.uuid4().hex
@@ -1324,12 +1608,25 @@ async def chat_completions( # noqa: PLR0911
1324
  "max_tokens",
1325
  _MAX_NEW_TOKENS_DEFAULT,
1326
  )
 
 
 
 
 
 
 
 
 
 
 
1327
 
1328
  # ── 4. Input validation ───────────────────────────────────────────────────
1329
 
1330
  try:
1331
  messages = _validate_messages(messages_raw)
1332
  max_new_tokens = _clamp_max_tokens(max_tokens_raw)
 
 
1333
  except ValueError as exc:
1334
  logger.warning(
1335
  "Validation error | request_id=%s | error=%s",
@@ -1346,11 +1643,17 @@ async def chat_completions( # noqa: PLR0911
1346
  logger.info(
1347
  "Dispatching inference | "
1348
  "request_id=%s | "
 
1349
  "messages=%d | "
1350
- "max_new_tokens=%d",
 
 
1351
  request_id,
 
1352
  len(messages),
1353
  max_new_tokens,
 
 
1354
  )
1355
 
1356
  # ── 5. Prompt token count (CPU, pre-dispatch) ────────────────────────────���
@@ -1363,6 +1666,8 @@ async def chat_completions( # noqa: PLR0911
1363
  content = await _generate_async(
1364
  messages,
1365
  max_new_tokens,
 
 
1366
  )
1367
 
1368
  except ValueError as exc:
@@ -1446,7 +1751,7 @@ logger.info(
1446
 
1447
  logger.info(
1448
  "scikit-plots ai-model Space initialized successfully.\n"
1449
- " version : 2.0.0\n"
1450
  " model : %s\n"
1451
  " CORS : %s\n"
1452
  " max_body : %s bytes\n"
 
1
+ # scikit-plots/ai-model Β· app.py v2.1.0
2
  #
3
  # PURPOSE
4
  # ───────
 
25
  # + Long-term maintainability
26
  # + Minimal hidden behavior
27
  #
28
+ # CRITICAL: SINGLE-WORKER REQUIREMENT
29
+ # ────────────────────────────────────
30
+ # This Space MUST run with a single uvicorn worker.
31
+ # The model (7B params, bfloat16) consumes ~14 GB of RAM on CPU.
32
+ # The ZeroGPU hard RAM limit is 16 GB.
33
+ #
34
+ # Two workers Γ— 14 GB = 28 GB β†’ OOM β†’ the OS kills the second process
35
+ # with a clean exit code (0), which HuggingFace reports as "runtime error".
36
+ #
37
+ # HuggingFace Spaces with sdk: gradio default to a single worker, which
38
+ # is the correct configuration. If you observe two initialization sequences
39
+ # in the container log, verify that no external launcher is adding workers
40
+ # (e.g., GRADIO_NUM_WORKERS, uvicorn --workers) and file a HuggingFace
41
+ # support ticket if the double-start persists.
42
+ #
43
  # CRITICAL: ZEROGPU ARCHITECTURE REQUIREMENT
44
  # ──────────────────────────────────────────
45
  # On HuggingFace Spaces with sdk: gradio, ZeroGPU hooks attach to the
46
  # Gradio server lifecycle. Gradio MUST be the ASGI root application.
47
  #
48
+ # CORRECT (v2.x):
49
  # Gradio (gr.Blocks) is the ASGI root.
50
  # REST routes are registered on Gradio's internal FastAPI instance
51
  # via gradio.routes.App.create_app(demo).
52
  # @spaces.GPU is active on _generate.
53
  #
54
+ # WRONG (v1.x):
55
  # FastAPI was the ASGI root; Gradio was a child via gr.mount_gradio_app.
56
  # @spaces.GPU was commented out.
57
  # ZeroGPU hooks never activated.
58
  #
59
+ # CRITICAL: ZEROGPU MODEL PATTERN
60
  # ────────────────────────────────
61
  # CORRECT:
62
  # * Tokenizer loaded on CPU at import time.
 
64
  # * Model moved to GPU ONLY inside @spaces.GPU scope.
65
  # * Model returned to CPU in finally block after inference.
66
  # * torch.cuda.empty_cache() called after every request.
67
+ # * _MODEL_LOCK serialises all model device transitions.
68
  #
69
  # WRONG:
70
  # * pipeline(... device_map="auto") at module level.
71
  # * model.to("cuda") outside @spaces.GPU scope.
72
  # * Holding GPU between requests.
73
  # * Blocking asyncio event loop with synchronous inference.
74
+ # * Concurrent model.cuda() / model.cpu() without a lock.
75
  #
76
+ # ASSEMBLY DIAGRAM (v2.1.0)
77
  # ─────────────────────────
78
  #
79
  # HuggingFace Spaces
 
94
  # MAX_BODY_BYTES
95
  # Maximum accepted request size.
96
  #
97
+ # CHANGES v2.0.0 β†’ v2.1.0
98
+ # ─────────────────────────
99
+ # [CRITICAL] Add _MODEL_LOCK (threading.Lock) to serialise all model
100
+ # device transitions (cuda/cpu) across concurrent inference
101
+ # calls. Without this, concurrent @spaces.GPU activations
102
+ # can corrupt model device state.
103
+ #
104
+ # [CRITICAL] Explicit GPU tensor cleanup (del input_ids, output_ids,
105
+ # new_token_ids) in the success path before _model.cpu() and
106
+ # torch.cuda.empty_cache(). Ensures VRAM is fully released
107
+ # before the ZeroGPU scope exits.
108
+ #
109
+ # [HIGH] Add `except RuntimeError: raise` to _generate exception
110
+ # chain so that RuntimeErrors (including the empty-response
111
+ # guard below) are not accidentally double-wrapped.
112
+ #
113
+ # [HIGH] Guard against empty model output: raise RuntimeError if
114
+ # the decoded string is empty after skip_special_tokens.
115
+ #
116
+ # [MEDIUM] temperature and top_p are now configurable from the
117
+ # request body (REST) and from sliders (Gradio UI).
118
+ # Defaults: temperature=0.7, top_p=1.0.
119
+ # temperature=0.0 β†’ greedy decoding (do_sample=False).
120
+ #
121
+ # [MEDIUM] Log the requested model field from the request body for
122
+ # proxy-routing diagnostics.
123
+ #
124
+ # [MEDIUM] Fix chat_completions docstring: JSONResponse error cases
125
+ # moved from the incorrect Raises section to Notes, because
126
+ # they are returned values, not raised exceptions.
127
+ #
128
+ # [LOW] _parse_request_body and _build_completion_response now
129
+ # carry precise dict[str, Any] return type annotations.
130
+ #
131
+ # [LOW] system_fingerprint field added to completion response for
132
+ # improved OpenAI SDK compatibility.
133
+ #
134
+ # [LOW] Explicit allow_credentials=False in CORS middleware.
135
+ #
136
+ # [DOC] Prominent single-worker warning added to module header
137
+ # (see above) explaining the double-startup / exit-0 OOM
138
+ # failure mode observed in the container log.
139
+ #
140
  # SPDX-License-Identifier: BSD-3-Clause
141
  # Authors: The scikit-plots developers
142
 
 
158
  HuggingFace Spaces exports the ``app`` variable. It must be the
159
  Gradio-rooted ASGI application returned by ``App.create_app``.
160
 
161
+ ``_MODEL_LOCK`` serialises all calls to ``_model.cuda()`` and
162
+ ``_model.cpu()``. A single ``_model`` object must not have its
163
+ device changed by two threads simultaneously.
164
+
165
  User note
166
  The Gradio UI at ``/`` is for manual testing only.
167
  Production traffic routes through the proxy Space.
 
173
  import json
174
  import logging
175
  import os
176
+ import threading
177
  import time
178
  import uuid
179
+ from typing import Any, Final
180
 
181
  import gradio as gr # type: ignore[]
182
  import spaces # type: ignore[] # ZeroGPU β€” must be imported before torch
 
296
  _MAX_NEW_TOKENS_CEIL: Final[int] = 4096
297
  _MAX_NEW_TOKENS_DEFAULT: Final[int] = 512
298
 
299
+ # Generation defaults β€” match OpenAI API defaults where applicable.
300
+ _DEFAULT_TEMPERATURE: Final[float] = 0.7
301
+ _DEFAULT_TOP_P: Final[float] = 1.0
302
+
303
  DEFAULT_MAX_BODY_BYTES: Final[int] = 10 * 1024 * 1024
304
 
305
  MAX_BODY_BYTES: Final[int] = _safe_int(
 
364
  parsed = int(value)
365
  except (TypeError, ValueError) as exc:
366
  raise ValueError(
367
+ f"max_tokens must be an integer, got {value!r}"
368
  ) from exc
369
 
370
  return max(
371
  _MAX_NEW_TOKENS_FLOOR,
372
+ min(parsed, _MAX_NEW_TOKENS_CEIL),
 
 
 
373
  )
374
 
375
 
 
435
  return validated
436
 
437
 
438
+ def _validate_temperature(
439
+ value: object,
440
+ ) -> float:
441
+ """
442
+ Validate and return a generation temperature value.
443
+
444
+ Parameters
445
+ ----------
446
+ value : object
447
+ Candidate temperature.
448
+
449
+ Returns
450
+ -------
451
+ float
452
+ Validated temperature in [0.0, 2.0].
453
+
454
+ Raises
455
+ ------
456
+ ValueError
457
+ If conversion fails or value is out of range.
458
+
459
+ Notes
460
+ -----
461
+ Developer note
462
+ ``temperature=0.0`` selects greedy decoding (``do_sample=False``).
463
+ The upper bound 2.0 matches the OpenAI API specification.
464
+
465
+ References
466
+ ----------
467
+ .. [1] OpenAI API reference: temperature parameter
468
+ https://platform.openai.com/docs/api-reference/chat/create#temperature
469
+ """
470
+ try:
471
+ parsed = float(value)
472
+ except (TypeError, ValueError) as exc:
473
+ raise ValueError(
474
+ f"temperature must be a number, got {value!r}"
475
+ ) from exc
476
+
477
+ if not (0.0 <= parsed <= 2.0):
478
+ raise ValueError(
479
+ f"temperature must be in [0.0, 2.0], got {parsed!r}"
480
+ )
481
+
482
+ return parsed
483
+
484
+
485
+ def _validate_top_p(
486
+ value: object,
487
+ ) -> float:
488
+ """
489
+ Validate and return a nucleus-sampling top_p value.
490
+
491
+ Parameters
492
+ ----------
493
+ value : object
494
+ Candidate top_p.
495
+
496
+ Returns
497
+ -------
498
+ float
499
+ Validated top_p in (0.0, 1.0].
500
+
501
+ Raises
502
+ ------
503
+ ValueError
504
+ If conversion fails or value is out of range.
505
+
506
+ Notes
507
+ -----
508
+ Developer note
509
+ ``top_p=1.0`` effectively disables nucleus sampling.
510
+ OpenAI recommends altering temperature or top_p but not both.
511
+
512
+ References
513
+ ----------
514
+ .. [1] OpenAI API reference: top_p parameter
515
+ https://platform.openai.com/docs/api-reference/chat/create#top_p
516
+ """
517
+ try:
518
+ parsed = float(value)
519
+ except (TypeError, ValueError) as exc:
520
+ raise ValueError(
521
+ f"top_p must be a number, got {value!r}"
522
+ ) from exc
523
+
524
+ if not (0.0 < parsed <= 1.0):
525
+ raise ValueError(
526
+ f"top_p must be in (0.0, 1.0], got {parsed!r}"
527
+ )
528
+
529
+ return parsed
530
+
531
+
532
  logger.info("Validation helpers initialized successfully.")
533
 
534
 
535
+ # ─────────────────────────────────────────────────────────────────────────────
536
+ # Model lock
537
+ # ─────────────────────────────────────────────────────────────────────────────
538
+ # Serialises all _model.cuda() / _model.cpu() transitions.
539
+ #
540
+ # A single _model object must not be moved to different devices by two
541
+ # threads simultaneously. @spaces.GPU does not prevent concurrent calls
542
+ # by itself (the Gradio queue or multiple in-flight async requests can
543
+ # dispatch _generate from multiple threads at the same time).
544
+ #
545
+ # Holding _MODEL_LOCK for the duration of the entire inference (cuda β†’
546
+ # generate β†’ cpu) is correct and safe: we are single-model, single-GPU.
547
+
548
+ _MODEL_LOCK: Final[threading.Lock] = threading.Lock()
549
+
550
+
551
  # ─────────────────────────────────────────────────────────────────────────────
552
  # Model loading
553
  # ─────────────────────────────────────────────────────────────────────────────
 
587
  # * GPU exists only inside @spaces.GPU scope.
588
  # * Model moved CPU β†’ GPU at entry; GPU β†’ CPU in finally.
589
  # * VRAM fully released after every request.
590
+ # * _MODEL_LOCK held for the full duration of the inference to prevent
591
+ # concurrent device transitions on the shared _model object.
592
  # * This function is called from both:
593
  # - Gradio event handlers (direct sync call via _gradio_respond)
594
  # - FastAPI route handlers (via asyncio.to_thread in _generate_async)
 
598
  def _generate(
599
  messages: list[dict[str, str]],
600
  max_new_tokens: int = _MAX_NEW_TOKENS_DEFAULT,
601
+ temperature: float = _DEFAULT_TEMPERATURE,
602
+ top_p: float = _DEFAULT_TOP_P,
603
  ) -> str:
604
  """
605
  Run generation using ZeroGPU.
 
612
  max_new_tokens : int, default=512
613
  Maximum generated tokens.
614
 
615
+ temperature : float, default=0.7
616
+ Sampling temperature in [0.0, 2.0].
617
+ ``0.0`` selects greedy decoding (do_sample=False).
618
+
619
+ top_p : float, default=1.0
620
+ Nucleus sampling cutoff in (0.0, 1.0].
621
+ ``1.0`` disables nucleus sampling.
622
+
623
  Returns
624
  -------
625
  str
 
631
  On invalid inputs or missing chat template.
632
 
633
  RuntimeError
634
+ On inference failure or empty model output.
635
 
636
  Notes
637
  -----
638
  Developer note
639
  GPU is acquired automatically by ``@spaces.GPU``.
640
 
641
+ ``_MODEL_LOCK`` is held for the entire inference duration
642
+ (cuda β†’ generate β†’ cpu) to prevent concurrent threads from
643
+ issuing conflicting device transitions on the shared ``_model``
644
+ object. ZeroGPU + Gradio queue can dispatch this function from
645
+ multiple threads simultaneously; the lock serialises them.
646
+
647
+ GPU tensors (``input_ids``, ``output_ids``, ``new_token_ids``)
648
+ are explicitly deleted in the success path before ``_model.cpu()``
649
+ and ``torch.cuda.empty_cache()``. This ensures VRAM is fully
650
+ reclaimed before the ``@spaces.GPU`` scope exits. On the
651
+ exception path, any tensors that were assigned before the error
652
+ remain alive until the function exits (acceptable: ZeroGPU
653
+ releases all GPU memory at ``@spaces.GPU`` scope exit).
654
 
655
  ``finally`` block ensures CPU return and cache clear even if
656
+ inference raises. The inner ``try/except`` around ``_model.cpu()``
657
+ logs and absorbs a potential CPU-move failure so that the
658
+ original inference exception is not masked; it still calls
659
+ ``torch.cuda.empty_cache()`` via its own nested ``finally``.
660
 
661
+ This function is intentionally synchronous. Async routes call
662
  it via ``_generate_async`` which wraps it with
663
+ ``asyncio.to_thread``. Gradio event handlers call it directly
664
  because Gradio dispatches handlers in its own thread pool,
665
  outside the asyncio event loop.
666
 
 
669
  Use ``_generate_async`` from FastAPI routes.
670
  """
671
  validated_messages = _validate_messages(messages)
 
672
  max_new_tokens = _clamp_max_tokens(max_new_tokens)
673
 
674
  if not getattr(_tokenizer, "chat_template", None):
 
680
  logger.info(
681
  "GPU generation starting | "
682
  "messages=%d | "
683
+ "max_new_tokens=%d | "
684
+ "temperature=%.2f | "
685
+ "top_p=%.2f",
686
  len(validated_messages),
687
  max_new_tokens,
688
+ temperature,
689
+ top_p,
690
  )
691
 
692
+ with _MODEL_LOCK:
693
+ try:
694
+ logger.info("Moving model to GPU...")
695
 
696
+ _model.cuda()
697
 
698
+ input_ids = _tokenizer.apply_chat_template(
699
+ validated_messages,
700
+ add_generation_prompt=True,
701
+ return_tensors="pt",
702
+ )
703
+ input_ids = input_ids.cuda()
704
 
705
+ logger.info("Generation started.")
706
 
707
+ # Build generation kwargs.
708
+ # temperature=0.0 β†’ greedy (do_sample=False, no temperature/top_p).
709
+ # temperature>0.0 β†’ sampling; top_p applied only when < 1.0.
710
+ generate_kwargs: dict[str, Any] = {
711
+ "max_new_tokens": max_new_tokens,
712
+ "pad_token_id": _tokenizer.eos_token_id,
713
+ }
714
+ if temperature > 0.0:
715
+ generate_kwargs["do_sample"] = True
716
+ generate_kwargs["temperature"] = temperature
717
+ if top_p < 1.0:
718
+ generate_kwargs["top_p"] = top_p
719
+
720
+ with torch.no_grad():
721
+ output_ids = _model.generate(
722
+ input_ids,
723
+ **generate_kwargs,
724
+ )
725
 
726
+ new_token_ids = output_ids[0][input_ids.shape[-1]:]
727
+ decoded = _tokenizer.decode(
728
+ new_token_ids,
729
+ skip_special_tokens=True,
 
 
 
730
  )
731
 
732
+ # Release GPU tensors before CPU move and cache clear.
733
+ # new_token_ids is a view of output_ids; deleting both here
734
+ # drops all references, freeing the underlying CUDA storage.
735
+ del input_ids, output_ids, new_token_ids
 
 
736
 
737
+ if not decoded.strip():
738
+ raise RuntimeError(
739
+ "Model returned an empty response. "
740
+ "Retry or reduce prompt length."
741
+ )
742
 
743
+ logger.info("Generation completed successfully.")
744
 
745
+ return decoded
 
746
 
747
+ except ValueError:
748
+ raise
749
 
750
+ except RuntimeError:
751
+ raise
 
752
 
753
+ except Exception as exc:
754
+ logger.exception("Inference failure.")
755
+ raise RuntimeError(
756
+ f"Inference failed: {exc}"
757
+ ) from exc
758
 
 
 
759
  finally:
760
+ logger.info(
761
+ "Returning model to CPU "
762
+ "and clearing CUDA cache..."
763
+ )
764
+ try:
765
+ _model.cpu()
766
+ except Exception: # noqa: BLE001
767
+ logger.exception(
768
+ "Failed to move model back to CPU. "
769
+ "VRAM may not be fully released."
770
+ )
771
+ finally:
772
+ if torch.cuda.is_available():
773
+ torch.cuda.empty_cache()
774
 
775
+ logger.info("GPU resources released.")
776
 
777
 
778
  # ─────────────────────────────────────────────────────────────────────────────
 
793
  async def _generate_async(
794
  messages: list[dict[str, str]],
795
  max_new_tokens: int,
796
+ temperature: float = _DEFAULT_TEMPERATURE,
797
+ top_p: float = _DEFAULT_TOP_P,
798
  ) -> str:
799
  """
800
  Async wrapper for GPU generation.
 
807
  max_new_tokens : int
808
  Generation token limit.
809
 
810
+ temperature : float, default=0.7
811
+ Sampling temperature forwarded to ``_generate``.
812
+
813
+ top_p : float, default=1.0
814
+ Nucleus sampling cutoff forwarded to ``_generate``.
815
+
816
  Returns
817
  -------
818
  str
 
832
  _generate,
833
  messages,
834
  max_new_tokens,
835
+ temperature,
836
+ top_p,
837
  )
838
 
839
 
 
890
 
891
  def _parse_request_body(
892
  raw: bytes,
893
+ ) -> dict[str, Any]:
894
  """
895
  Decode and parse a UTF-8 JSON request body.
896
 
 
901
 
902
  Returns
903
  -------
904
+ dict[str, Any]
905
  Parsed JSON payload.
906
 
907
  Raises
 
1006
  model_id: str,
1007
  prompt_tokens: int,
1008
  completion_tokens: int,
1009
+ ) -> dict[str, Any]:
1010
  """
1011
  Build an OpenAI-compatible chat completion response payload.
1012
 
 
1026
 
1027
  Returns
1028
  -------
1029
+ dict[str, Any]
1030
  OpenAI-compatible ``chat.completion`` object.
1031
 
1032
  Notes
 
1039
  current ``_generate`` implementation does not expose partial
1040
  stop conditions. Extend this if streaming or early stopping
1041
  is added.
1042
+ ``system_fingerprint`` is derived from the model ID slug to
1043
+ satisfy OpenAI SDK response parsing without exposing internal
1044
+ infrastructure details.
1045
 
1046
  User note
1047
  The returned dict is compatible with OpenAI Python SDK
 
1052
  .. [1] OpenAI API reference: Chat completions object
1053
  https://platform.openai.com/docs/api-reference/chat/object
1054
  """
1055
+ # Derive a deterministic, URL-safe fingerprint from the model ID.
1056
+ _model_slug = (
1057
+ model_id
1058
+ .lower()
1059
+ .replace("/", "-")
1060
+ .replace(".", "-")
1061
+ .replace("_", "-")
1062
+ )
1063
+
1064
  return {
1065
  "id": f"chatcmpl-{uuid.uuid4().hex}",
1066
  "object": "chat.completion",
1067
  "created": int(time.time()),
1068
  "model": model_id,
1069
+ "system_fingerprint": f"fp-{_model_slug}",
1070
  "choices": [
1071
  {
1072
  "index": 0,
 
1247
  message: str,
1248
  history: list,
1249
  max_new_tokens: int,
1250
+ temperature: float,
1251
+ top_p: float,
1252
  ) -> str:
1253
  """
1254
  Gradio ``ChatInterface`` event handler.
 
1265
  max_new_tokens : int
1266
  Maximum tokens to generate, sourced from the UI slider.
1267
 
1268
+ temperature : float
1269
+ Sampling temperature sourced from the UI slider.
1270
+
1271
+ top_p : float
1272
+ Nucleus sampling cutoff sourced from the UI slider.
1273
+
1274
  Returns
1275
  -------
1276
  str
 
1282
  If ``message`` is empty after stripping.
1283
 
1284
  RuntimeError
1285
+ Propagated from ``_generate`` on inference failure or empty
1286
+ model output.
1287
 
1288
  Notes
1289
  -----
 
1317
  logger.info(
1318
  "Gradio inference | "
1319
  "history_turns=%d | "
1320
+ "max_new_tokens=%d | "
1321
+ "temperature=%.2f | "
1322
+ "top_p=%.2f",
1323
  len(messages) - 1,
1324
  max_new_tokens,
1325
+ temperature,
1326
+ top_p,
1327
  )
1328
 
1329
  return _generate(
1330
  messages,
1331
  max_new_tokens,
1332
+ temperature,
1333
+ top_p,
1334
  )
1335
 
1336
 
1337
  # ─────────────────────────────────────────────────────────────────────────────
1338
  # Gradio UI
1339
  # ─────────────────────────────────────────────────────────────────────────────
1340
+ # v2.x: Gradio is the ASGI ROOT β€” not a child sub-app mounted on FastAPI.
1341
  # This is required for ZeroGPU to activate on HuggingFace Spaces.
1342
  #
1343
+ # The Gradio UI is served at / (root).
1344
  # Custom REST routes are added to Gradio's internal FastAPI instance below.
1345
 
1346
  _UI_WARNING = """\
 
1370
  step=1,
1371
  label="max_tokens",
1372
  info=(
1373
+ f"Range: {_MAX_NEW_TOKENS_FLOOR}–{_MAX_NEW_TOKENS_CEIL}. "
1374
  f"Default: {_MAX_NEW_TOKENS_DEFAULT}."
1375
  ),
1376
  ),
1377
+ gr.Slider(
1378
+ minimum=0.0,
1379
+ maximum=2.0,
1380
+ value=_DEFAULT_TEMPERATURE,
1381
+ step=0.05,
1382
+ label="temperature",
1383
+ info="0.0 = greedy, 0.7 = default, 2.0 = very random.",
1384
+ ),
1385
+ gr.Slider(
1386
+ minimum=0.01,
1387
+ maximum=1.0,
1388
+ value=_DEFAULT_TOP_P,
1389
+ step=0.01,
1390
+ label="top_p",
1391
+ info="Nucleus sampling cutoff. 1.0 = disabled.",
1392
+ ),
1393
  ],
1394
  additional_inputs_accordion="Generation settings",
1395
  )
 
1403
  # ─────────────────────────────────────────────────────────────────────────────
1404
  # App assembly β€” HuggingFace Space export
1405
  # ─────────────────────────────────────────────────────────────────────────────
1406
+ # v2.x ARCHITECTURE:
 
 
 
 
 
1407
  #
 
1408
  # app = _GradioApp.create_app(_gradio_ui) ← Gradio is ASGI root
1409
  # @app.get/post(...) ← routes on Gradio's FastAPI
1410
  # ZeroGPU activates correctly.
 
1434
  allow_headers=[
1435
  "Content-Type",
1436
  ],
1437
+ allow_credentials=False, # This Space does not use credential-bearing requests.
1438
  )
1439
 
1440
  logger.info(
 
1473
  Examples
1474
  --------
1475
  >>> # curl http://localhost:7860/health
1476
+ ... # {"status": "ok", "model": "...", "version": "2.1.0"}
1477
  """
1478
  logger.info("GET /health")
1479
 
 
1481
  content={
1482
  "status": "ok",
1483
  "model": MODEL_ID,
1484
+ "version": "2.1.0",
1485
  },
1486
  status_code=200,
1487
  )
 
1507
  Returns
1508
  -------
1509
  JSONResponse
1510
+ HTTP 200 with an OpenAI-compatible completion payload on success.
1511
 
 
 
 
1512
  HTTP 413 if the body exceeds ``MAX_BODY_BYTES``.
1513
+
1514
+ HTTP 400 if the body is not valid UTF-8 JSON, or if
1515
+ ``messages``, ``max_tokens``, ``temperature``, or ``top_p``
1516
+ fail validation.
1517
+
1518
+ HTTP 500 on inference failure or unexpected server error.
1519
 
1520
  Notes
1521
  -----
 
1524
 
1525
  1. Read and bound-check raw body bytes (413 guard).
1526
  2. Decode and parse JSON (400 guard).
1527
+ 3. Extract ``messages``, ``max_tokens``, ``temperature``,
1528
+ ``top_p``, and ``model`` fields.
1529
+ 4. Validate with field-specific validators (400 guard).
1530
  5. Count prompt tokens on CPU (no GPU needed).
1531
  6. Dispatch to ``_generate_async`` which offloads to
1532
  ``@spaces.GPU`` via ``asyncio.to_thread``.
 
1539
  * ``RuntimeError`` β†’ 500 (wrapped inference failure from ``_generate``)
1540
  * ``Exception`` β†’ 500 (unexpected catch-all, never leaks internals)
1541
 
1542
+ The requested ``model`` field is logged for proxy-routing
1543
+ diagnostics but does not affect which model is used; this Space
1544
+ always serves ``MODEL_ID``.
1545
+
1546
  User note
1547
  Compatible with the OpenAI Python SDK:
1548
 
 
1556
  response = client.chat.completions.create(
1557
  model="any",
1558
  messages=[{"role": "user", "content": "Hello"}],
1559
+ temperature=0.7,
1560
+ top_p=1.0,
1561
  )
1562
  """
1563
  request_id = uuid.uuid4().hex
 
1608
  "max_tokens",
1609
  _MAX_NEW_TOKENS_DEFAULT,
1610
  )
1611
+ temperature_raw: object = payload.get(
1612
+ "temperature",
1613
+ _DEFAULT_TEMPERATURE,
1614
+ )
1615
+ top_p_raw: object = payload.get(
1616
+ "top_p",
1617
+ _DEFAULT_TOP_P,
1618
+ )
1619
+ # Log requested model for proxy-routing diagnostics only.
1620
+ # This Space always serves MODEL_ID regardless of the field value.
1621
+ model_requested: object = payload.get("model", MODEL_ID)
1622
 
1623
  # ── 4. Input validation ───────────────────────────────────────────────────
1624
 
1625
  try:
1626
  messages = _validate_messages(messages_raw)
1627
  max_new_tokens = _clamp_max_tokens(max_tokens_raw)
1628
+ temperature = _validate_temperature(temperature_raw)
1629
+ top_p = _validate_top_p(top_p_raw)
1630
  except ValueError as exc:
1631
  logger.warning(
1632
  "Validation error | request_id=%s | error=%s",
 
1643
  logger.info(
1644
  "Dispatching inference | "
1645
  "request_id=%s | "
1646
+ "model_requested=%s | "
1647
  "messages=%d | "
1648
+ "max_new_tokens=%d | "
1649
+ "temperature=%.2f | "
1650
+ "top_p=%.2f",
1651
  request_id,
1652
+ model_requested,
1653
  len(messages),
1654
  max_new_tokens,
1655
+ temperature,
1656
+ top_p,
1657
  )
1658
 
1659
  # ── 5. Prompt token count (CPU, pre-dispatch) ────────────────────────────���
 
1666
  content = await _generate_async(
1667
  messages,
1668
  max_new_tokens,
1669
+ temperature,
1670
+ top_p,
1671
  )
1672
 
1673
  except ValueError as exc:
 
1751
 
1752
  logger.info(
1753
  "scikit-plots ai-model Space initialized successfully.\n"
1754
+ " version : 2.1.0\n"
1755
  " model : %s\n"
1756
  " CORS : %s\n"
1757
  " max_body : %s bytes\n"