polats commited on
Commit
f77d660
·
1 Parent(s): db6b273

Add Klein ZeroGPU portrait option

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. app.py +96 -8
  3. requirements.txt +1 -0
  4. web/imagen.js +2 -2
  5. web/imagenServer.js +9 -0
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  __pycache__/
2
  *.pyc
3
  .venv/
 
 
1
  __pycache__/
2
  *.pyc
3
  .venv/
4
+ logs/
app.py CHANGED
@@ -20,6 +20,20 @@ import json as _json
20
  import os
21
  import threading
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  import gradio as gr
24
  import uvicorn
25
  from fastapi import FastAPI, Request
@@ -35,6 +49,7 @@ import prompts
35
 
36
  HERE = os.path.dirname(os.path.abspath(__file__))
37
  WEB = os.path.join(HERE, "web")
 
38
 
39
  # The Sprite tab's character picker + controls are built entirely by the shared
40
  # playground (web/playground.js) from /sprites/characters.json — no Python-side
@@ -176,7 +191,7 @@ def diary(unit, traits):
176
  yield header + f"Today I held the line. _(model unavailable: {e})_"
177
 
178
 
179
- with gr.Blocks(title="Tiny Army") as demo:
180
  gr.HTML(SIDEBAR_HTML)
181
  with gr.Tabs():
182
  with gr.Tab("Battle") as battle_tab:
@@ -201,7 +216,7 @@ with gr.Blocks(title="Tiny Army") as demo:
201
  # (footer "Settings" → ?view=settings) by web/settingsPanel.js — not a tab.
202
 
203
  # Mount Gradio on FastAPI so we can also serve the JS module + the sprite assets.
204
- fastapi_app = FastAPI()
205
 
206
 
207
  # Behind HF's custom-domain proxy Gradio emits its theme.css <link> as http://
@@ -413,6 +428,12 @@ _MIN_IMAGE_BYTES = 15_000 # smaller than this = a blank/safety-blocked frame
413
 
414
  _img_pipe = None
415
  _img_lock = threading.Lock()
 
 
 
 
 
 
416
 
417
 
418
  def _load_image_pipe():
@@ -461,6 +482,50 @@ def _local_portrait(prompt, seed=None, width=1024, height=1024, steps=9):
461
  return out.getvalue()
462
 
463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  def _nim_portrait(prompt, provider="flux-schnell", width=1024, height=1024):
465
  import random
466
  p = _NIM_PROVIDERS.get(provider, _NIM_PROVIDERS["flux-schnell"])
@@ -524,10 +589,15 @@ async def portrait(request: Request):
524
  prompt = (body.get("prompt") or "").strip()
525
  seed = body.get("seed")
526
  provider = body.get("provider") or "" # cloud sub-provider hint (e.g. flux-dev)
527
- engine = (body.get("engine") or "").strip().lower() # 'local' | 'cloud' | '' = auto
528
  if not prompt:
529
  return Response("prompt required", status_code=400)
530
  want_local = engine == "local" or (not engine and IMAGE_MODE == "local")
 
 
 
 
 
531
  if want_local: # in-process open weights on your GPU (dev)
532
  if IMAGE_MODE != "local":
533
  return Response("local image mode not enabled (run with TINY_IMAGE_MODE=local)", status_code=503)
@@ -536,6 +606,15 @@ async def portrait(request: Request):
536
  except Exception as e: # noqa: BLE001 — surface a clear setup hint
537
  return Response(f"local image error (pip install 'git+https://github.com/huggingface/diffusers' accelerate?): {e}", status_code=500)
538
  return Response(png, media_type="image/png", headers={"Cache-Control": "no-store"})
 
 
 
 
 
 
 
 
 
539
  # Cloud: prefer NVIDIA NIM (woid's FLUX path), else HF Inference (our HF_TOKEN).
540
  if NIM_KEY:
541
  png, err = await asyncio.to_thread(_nim_portrait, prompt, provider or "flux-schnell")
@@ -633,14 +712,23 @@ async def persona_generate_stream(request: Request):
633
  })
634
 
635
 
636
- app = gr.mount_gradio_app(fastapi_app, demo, path="/", head=HEAD, theme=gr.themes.Soft())
 
637
 
638
 
639
  if __name__ == "__main__":
640
  # The default UI runs the model IN THE BROWSER (wllama). The Python llama.cpp path
641
  # stays as a lazy fallback (only loads if /persona/generate/stream is hit), so we
642
  # don't pre-download it here.
643
- # proxy_headers + trusting forwarded IPs lets Gradio honour X-Forwarded-Proto
644
- # from HF's edge, so it generates https (not http) asset URLs behind the proxy.
645
- uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "7860")),
646
- proxy_headers=True, forwarded_allow_ips="*")
 
 
 
 
 
 
 
 
 
20
  import os
21
  import threading
22
 
23
+ # ZeroGPU requires the spaces shim to be imported before torch. Locally, or on
24
+ # non-ZeroGPU hardware, this falls back to a no-op decorator.
25
+ try:
26
+ import spaces # type: ignore
27
+
28
+ GPU = spaces.GPU
29
+ except Exception: # pragma: no cover
30
+ def GPU(*dargs, **dkwargs): # noqa: N802 - mirror spaces.GPU
31
+ def wrap(fn):
32
+ return fn
33
+ if len(dargs) == 1 and callable(dargs[0]) and not dkwargs:
34
+ return dargs[0]
35
+ return wrap
36
+
37
  import gradio as gr
38
  import uvicorn
39
  from fastapi import FastAPI, Request
 
49
 
50
  HERE = os.path.dirname(os.path.abspath(__file__))
51
  WEB = os.path.join(HERE, "web")
52
+ USE_GRADIO_SERVER = os.environ.get("TINY_GRADIO_SERVER", "").lower() in ("1", "true", "yes")
53
 
54
  # The Sprite tab's character picker + controls are built entirely by the shared
55
  # playground (web/playground.js) from /sprites/characters.json — no Python-side
 
191
  yield header + f"Today I held the line. _(model unavailable: {e})_"
192
 
193
 
194
+ with gr.Blocks(title="Tiny Army") as ui:
195
  gr.HTML(SIDEBAR_HTML)
196
  with gr.Tabs():
197
  with gr.Tab("Battle") as battle_tab:
 
216
  # (footer "Settings" → ?view=settings) by web/settingsPanel.js — not a tab.
217
 
218
  # Mount Gradio on FastAPI so we can also serve the JS module + the sprite assets.
219
+ fastapi_app = gr.Server() if USE_GRADIO_SERVER else FastAPI()
220
 
221
 
222
  # Behind HF's custom-domain proxy Gradio emits its theme.css <link> as http://
 
428
 
429
  _img_pipe = None
430
  _img_lock = threading.Lock()
431
+ _klein_pipe = None
432
+ _klein_lock = threading.Lock()
433
+ _KLEIN_MODEL_ID = os.environ.get("TINY_KLEIN_MODEL", "black-forest-labs/FLUX.2-klein-4B")
434
+ _KLEIN_STEPS = int(os.environ.get("TINY_KLEIN_STEPS", "4"))
435
+ _KLEIN_GUIDANCE = float(os.environ.get("TINY_KLEIN_GUIDANCE", "1.0"))
436
+ _KLEIN_SPACE = os.environ.get("TINY_KLEIN_SPACE", "").strip()
437
 
438
 
439
  def _load_image_pipe():
 
482
  return out.getvalue()
483
 
484
 
485
+ def _load_klein_pipe():
486
+ import torch
487
+ from diffusers import Flux2KleinPipeline
488
+ return Flux2KleinPipeline.from_pretrained(_KLEIN_MODEL_ID, torch_dtype=torch.bfloat16)
489
+
490
+
491
+ def _remote_klein_portrait(prompt, seed=None):
492
+ import os as _os
493
+ from gradio_client import Client
494
+ client = Client(_KLEIN_SPACE, token=HF_TOKEN or None)
495
+ result = client.predict(prompt, int(seed if seed is not None else 42), api_name="/generate")
496
+ path = result[0] if isinstance(result, (tuple, list)) else result
497
+ with open(_os.fspath(path), "rb") as f:
498
+ return f.read()
499
+
500
+
501
+ @GPU(duration=60)
502
+ def _klein_portrait(prompt, seed=None, width=1024, height=1024):
503
+ """FLUX.2 [klein] 4B for ZeroGPU-backed portrait generation."""
504
+ global _klein_pipe
505
+ import io
506
+ import random
507
+ import torch
508
+ with _klein_lock:
509
+ if _klein_pipe is None:
510
+ _klein_pipe = _load_klein_pipe()
511
+ dev = "cuda" if torch.cuda.is_available() else "cpu"
512
+ _klein_pipe.to(dev)
513
+ s = int(seed if seed is not None else random.randint(0, 2_147_483_647))
514
+ img = _klein_pipe(
515
+ prompt=prompt, width=width, height=height,
516
+ num_inference_steps=_KLEIN_STEPS, guidance_scale=_KLEIN_GUIDANCE,
517
+ generator=torch.Generator(device=dev).manual_seed(s),
518
+ ).images[0]
519
+ if dev == "cuda":
520
+ _klein_pipe.to("cpu")
521
+ try:
522
+ torch.cuda.empty_cache()
523
+ except Exception:
524
+ pass
525
+ out = io.BytesIO(); img.save(out, format="PNG")
526
+ return out.getvalue()
527
+
528
+
529
  def _nim_portrait(prompt, provider="flux-schnell", width=1024, height=1024):
530
  import random
531
  p = _NIM_PROVIDERS.get(provider, _NIM_PROVIDERS["flux-schnell"])
 
589
  prompt = (body.get("prompt") or "").strip()
590
  seed = body.get("seed")
591
  provider = body.get("provider") or "" # cloud sub-provider hint (e.g. flux-dev)
592
+ engine = (body.get("engine") or "").strip().lower() # 'local' | 'klein' | 'cloud' | '' = auto
593
  if not prompt:
594
  return Response("prompt required", status_code=400)
595
  want_local = engine == "local" or (not engine and IMAGE_MODE == "local")
596
+ want_klein = (
597
+ engine in ("klein", "zerogpu")
598
+ or provider in ("flux-klein-4b", "klein-4b")
599
+ or (not engine and IMAGE_MODE in ("klein", "zerogpu", "klein-zerogpu"))
600
+ )
601
  if want_local: # in-process open weights on your GPU (dev)
602
  if IMAGE_MODE != "local":
603
  return Response("local image mode not enabled (run with TINY_IMAGE_MODE=local)", status_code=503)
 
606
  except Exception as e: # noqa: BLE001 — surface a clear setup hint
607
  return Response(f"local image error (pip install 'git+https://github.com/huggingface/diffusers' accelerate?): {e}", status_code=500)
608
  return Response(png, media_type="image/png", headers={"Cache-Control": "no-store"})
609
+ if want_klein:
610
+ try:
611
+ if _KLEIN_SPACE:
612
+ png = await asyncio.to_thread(_remote_klein_portrait, prompt, seed)
613
+ else:
614
+ png = await asyncio.to_thread(_klein_portrait, prompt, seed)
615
+ except Exception as e: # noqa: BLE001
616
+ return Response(f"klein image error: {e}", status_code=500)
617
+ return Response(png, media_type="image/png", headers={"Cache-Control": "no-store"})
618
  # Cloud: prefer NVIDIA NIM (woid's FLUX path), else HF Inference (our HF_TOKEN).
619
  if NIM_KEY:
620
  png, err = await asyncio.to_thread(_nim_portrait, prompt, provider or "flux-schnell")
 
712
  })
713
 
714
 
715
+ app = gr.mount_gradio_app(fastapi_app, ui, path="/", head=HEAD, theme=gr.themes.Soft())
716
+ demo = app if USE_GRADIO_SERVER else ui
717
 
718
 
719
  if __name__ == "__main__":
720
  # The default UI runs the model IN THE BROWSER (wllama). The Python llama.cpp path
721
  # stays as a lazy fallback (only loads if /persona/generate/stream is hit), so we
722
  # don't pre-download it here.
723
+ if USE_GRADIO_SERVER:
724
+ app.launch(
725
+ server_name="0.0.0.0",
726
+ server_port=int(os.environ.get("PORT", "7860")),
727
+ head=HEAD,
728
+ theme=gr.themes.Soft(),
729
+ )
730
+ else:
731
+ # proxy_headers + trusting forwarded IPs lets Gradio honour X-Forwarded-Proto
732
+ # from HF's edge, so it generates https (not http) asset URLs behind the proxy.
733
+ uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "7860")),
734
+ proxy_headers=True, forwarded_allow_ips="*")
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  gradio==6.15.2
2
  huggingface_hub
3
  # llama.cpp runtime for the persona + war-diary model. The CPU wheel index ships a
 
1
+ --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu
2
  gradio==6.15.2
3
  huggingface_hub
4
  # llama.cpp runtime for the persona + war-diary model. The CPU wheel index ships a
web/imagen.js CHANGED
@@ -1,10 +1,10 @@
1
  // Image facade — mirrors tts.js. Picks the active portrait engine (local Z-Image on your
2
  // GPU, or cloud FLUX; in-browser SD-Turbo / Janus get added here later) and exposes one
3
  // generatePortrait(). The persona panel + the Settings image bar import only from here.
4
- import { engineLocal as zimagelocal, engineCloud as flux, engineCloudDev as fluxdev, isLocalhost } from '/web/imagenServer.js'
5
  import { engine as bonsai } from '/web/imagenBonsai.js'
6
 
7
- const ENGINES = [zimagelocal, bonsai, flux, fluxdev]
8
  // Default: local Z-Image on localhost (your GPU), cloud FLUX in prod. Persisted across
9
  // refreshes; a saved choice wins if it's still available.
10
  const KEY = 'tinyarmy.imageEngine'
 
1
  // Image facade — mirrors tts.js. Picks the active portrait engine (local Z-Image on your
2
  // GPU, or cloud FLUX; in-browser SD-Turbo / Janus get added here later) and exposes one
3
  // generatePortrait(). The persona panel + the Settings image bar import only from here.
4
+ import { engineLocal as zimagelocal, engineKleinZeroGpu as klein, engineCloud as flux, engineCloudDev as fluxdev, isLocalhost } from '/web/imagenServer.js'
5
  import { engine as bonsai } from '/web/imagenBonsai.js'
6
 
7
+ const ENGINES = [zimagelocal, klein, bonsai, flux, fluxdev]
8
  // Default: local Z-Image on localhost (your GPU), cloud FLUX in prod. Persisted across
9
  // refreshes; a saved choice wins if it's still available.
10
  const KEY = 'tinyarmy.imageEngine'
web/imagenServer.js CHANGED
@@ -29,6 +29,15 @@ export const engineLocal = {
29
  backendLabel: () => '🖥 local model',
30
  }
31
 
 
 
 
 
 
 
 
 
 
32
  // CLOUD: FLUX via the backend proxy (NVIDIA NIM, else HF Inference). `provider` picks the
33
  // NIM sub-model. schnell = fast (4 steps), dev = higher quality (28 steps).
34
  export const engineCloud = {
 
29
  backendLabel: () => '🖥 local model',
30
  }
31
 
32
+ export const engineKleinZeroGpu = {
33
+ ...common,
34
+ id: 'klein-zerogpu',
35
+ label: 'FLUX.2 klein 4B · ZeroGPU',
36
+ available: () => true,
37
+ generate: (prompt, { seed } = {}) => postPortrait({ prompt, seed, engine: 'klein', provider: 'flux-klein-4b' }),
38
+ backendLabel: () => 'ZeroGPU FLUX.2 klein 4B',
39
+ }
40
+
41
  // CLOUD: FLUX via the backend proxy (NVIDIA NIM, else HF Inference). `provider` picks the
42
  // NIM sub-model. schnell = fast (4 steps), dev = higher quality (28 steps).
43
  export const engineCloud = {