fix: synchronize the complete application runtime

#5
by thangvip - opened
src/compliment_forest/config.py CHANGED
@@ -5,20 +5,32 @@ from pathlib import Path
5
  from typing import Literal
6
  from urllib.parse import urlparse
7
 
8
- from pydantic import BaseModel, ConfigDict, Field, model_validator
 
 
9
 
10
 
11
  class AppConfig(BaseModel):
12
  model_config = ConfigDict(extra="forbid")
13
 
14
- text_backend: Literal["demo", "llama_cpp"] = "demo"
15
- image_backend: Literal["demo", "flux"] = "demo"
 
 
 
 
16
  llama_base_url: str = "http://127.0.0.1:8080"
17
  llama_model: str = "compliment-forest-minicpm5-1b"
18
  flux_model_id: str = "black-forest-labs/FLUX.1-dev"
19
  flux_lora_id: str = "build-small-hackathon/compliment-forest-flux-lora"
 
 
 
 
 
20
  local_files_only: bool = False
21
  default_seed: int = Field(default=3407, ge=0, le=2_147_483_647)
 
22
  trace_path: Path | None = None
23
 
24
  @model_validator(mode="after")
@@ -27,14 +39,54 @@ class AppConfig(BaseModel):
27
  hostname = urlparse(self.llama_base_url).hostname
28
  if hostname not in {"127.0.0.1", "localhost", "::1"}:
29
  raise ValueError("llama.cpp model server must be local")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  return self
31
 
32
  @classmethod
33
  def from_env(cls) -> AppConfig:
34
  trace_path = os.getenv("CF_TRACE_PATH")
 
 
 
 
 
 
 
 
35
  return cls(
36
- text_backend=os.getenv("CF_TEXT_BACKEND", "demo"),
37
- image_backend=os.getenv("CF_IMAGE_BACKEND", "demo"),
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  llama_base_url=os.getenv("CF_LLAMA_BASE_URL", "http://127.0.0.1:8080"),
39
  llama_model=os.getenv(
40
  "CF_LLAMA_MODEL",
@@ -48,7 +100,13 @@ class AppConfig(BaseModel):
48
  "CF_FLUX_LORA_ID",
49
  "build-small-hackathon/compliment-forest-flux-lora",
50
  ),
 
 
 
 
 
51
  local_files_only=os.getenv("CF_LOCAL_FILES_ONLY", "0") == "1",
52
  default_seed=int(os.getenv("CF_DEFAULT_SEED", "3407")),
 
53
  trace_path=Path(trace_path) if trace_path else None,
54
  )
 
5
  from typing import Literal
6
  from urllib.parse import urlparse
7
 
8
+ from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
9
+
10
+ from .schema import ForestStyle
11
 
12
 
13
  class AppConfig(BaseModel):
14
  model_config = ConfigDict(extra="forbid")
15
 
16
+ text_backend: Literal["demo", "hf_inference", "llama_cpp", "transformers", "modal"] = "demo"
17
+ image_backend: Literal["demo", "flux", "hf_inference", "modal", "zerogpu"] = "demo"
18
+ music_backend: Literal["none", "modal"] = "none"
19
+ hf_text_model: str = "openbmb/MiniCPM4.1-8B"
20
+ transformers_text_model: str = "openbmb/MiniCPM4.1-8B"
21
+ hf_image_model: str = "black-forest-labs/FLUX.1-schnell"
22
  llama_base_url: str = "http://127.0.0.1:8080"
23
  llama_model: str = "compliment-forest-minicpm5-1b"
24
  flux_model_id: str = "black-forest-labs/FLUX.1-dev"
25
  flux_lora_id: str = "build-small-hackathon/compliment-forest-flux-lora"
26
+ modal_text_endpoint: str | None = None
27
+ modal_image_endpoint: str | None = None
28
+ modal_music_endpoint: str | None = None
29
+ modal_signing_key: SecretStr | None = None
30
+ upstream_space_url: str | None = None
31
  local_files_only: bool = False
32
  default_seed: int = Field(default=3407, ge=0, le=2_147_483_647)
33
+ default_style: ForestStyle = "surprise"
34
  trace_path: Path | None = None
35
 
36
  @model_validator(mode="after")
 
39
  hostname = urlparse(self.llama_base_url).hostname
40
  if hostname not in {"127.0.0.1", "localhost", "::1"}:
41
  raise ValueError("llama.cpp model server must be local")
42
+ if self.text_backend == "modal":
43
+ if not self.modal_text_endpoint or not self.modal_signing_key:
44
+ raise ValueError("modal text backend requires endpoint credentials")
45
+ if urlparse(self.modal_text_endpoint).scheme != "https":
46
+ raise ValueError("modal text endpoint must use HTTPS")
47
+ if self.image_backend == "modal":
48
+ if not self.modal_image_endpoint or not self.modal_signing_key:
49
+ raise ValueError("modal image backend requires endpoint credentials")
50
+ if urlparse(self.modal_image_endpoint).scheme != "https":
51
+ raise ValueError("modal image endpoint must use HTTPS")
52
+ if self.music_backend == "modal":
53
+ if not self.modal_music_endpoint or not self.modal_signing_key:
54
+ raise ValueError("modal music backend requires endpoint credentials")
55
+ if urlparse(self.modal_music_endpoint).scheme != "https":
56
+ raise ValueError("modal music endpoint must use HTTPS")
57
+ if self.upstream_space_url:
58
+ parsed_upstream = urlparse(self.upstream_space_url)
59
+ if parsed_upstream.scheme != "https" or not parsed_upstream.netloc:
60
+ raise ValueError("upstream Space URL must use HTTPS")
61
  return self
62
 
63
  @classmethod
64
  def from_env(cls) -> AppConfig:
65
  trace_path = os.getenv("CF_TRACE_PATH")
66
+ hosted_space = bool(os.getenv("SPACE_ID"))
67
+ submission_upstream = (
68
+ "https://thangvip-compliment-forest.hf.space"
69
+ if os.getenv("SPACE_ID") == "build-small-hackathon/compliment-forest"
70
+ else None
71
+ )
72
+ default_text_backend = "transformers" if hosted_space else "demo"
73
+ default_image_backend = "zerogpu" if hosted_space else "demo"
74
  return cls(
75
+ text_backend=os.getenv("CF_TEXT_BACKEND", default_text_backend),
76
+ image_backend=os.getenv("CF_IMAGE_BACKEND", default_image_backend),
77
+ music_backend=os.getenv("CF_MUSIC_BACKEND", "none"),
78
+ hf_text_model=os.getenv(
79
+ "CF_HF_TEXT_MODEL",
80
+ "openbmb/MiniCPM4.1-8B",
81
+ ),
82
+ transformers_text_model=os.getenv(
83
+ "CF_TRANSFORMERS_TEXT_MODEL",
84
+ "openbmb/MiniCPM4.1-8B",
85
+ ),
86
+ hf_image_model=os.getenv(
87
+ "CF_HF_IMAGE_MODEL",
88
+ "black-forest-labs/FLUX.1-schnell",
89
+ ),
90
  llama_base_url=os.getenv("CF_LLAMA_BASE_URL", "http://127.0.0.1:8080"),
91
  llama_model=os.getenv(
92
  "CF_LLAMA_MODEL",
 
100
  "CF_FLUX_LORA_ID",
101
  "build-small-hackathon/compliment-forest-flux-lora",
102
  ),
103
+ modal_text_endpoint=os.getenv("CF_MODAL_TEXT_ENDPOINT"),
104
+ modal_image_endpoint=os.getenv("CF_MODAL_IMAGE_ENDPOINT"),
105
+ modal_music_endpoint=os.getenv("CF_MODAL_MUSIC_ENDPOINT"),
106
+ modal_signing_key=(os.getenv("CF_MODAL_SIGNING_KEY") or os.getenv("HF_TOKEN")),
107
+ upstream_space_url=os.getenv("CF_UPSTREAM_SPACE_URL") or submission_upstream,
108
  local_files_only=os.getenv("CF_LOCAL_FILES_ONLY", "0") == "1",
109
  default_seed=int(os.getenv("CF_DEFAULT_SEED", "3407")),
110
+ default_style=os.getenv("CF_DEFAULT_STYLE", "surprise"),
111
  trace_path=Path(trace_path) if trace_path else None,
112
  )
src/compliment_forest/data_builder.py CHANGED
@@ -169,7 +169,21 @@ def validate_synthetic_example(example: dict[str, Any]) -> dict[str, Any] | None
169
 
170
  def build_sft_record(example: dict[str, Any]) -> dict[str, Any]:
171
  user_content = json.dumps(
172
- {"name": example["name"], "situation": example["situation"]},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  ensure_ascii=False,
174
  )
175
  assistant_content = json.dumps(example["forest"], ensure_ascii=False)
@@ -223,11 +237,14 @@ def template_forest(name: str, situation: str, variant: int) -> dict[str, Any]:
223
  CREATURES[: variant % len(CREATURES)]
224
  )
225
  selected = rotated[:5]
 
226
  clearings = []
227
  for clearing_index, (creature, strength, spell) in enumerate(selected):
228
  line_template = LINE_TEMPLATES[(variant + clearing_index) % len(LINE_TEMPLATES)]
229
  clearings.append(
230
  {
 
 
231
  "creature": creature,
232
  "strength": strength,
233
  "line": line_template.format(
@@ -289,6 +306,8 @@ def forest_batch_json_schema() -> dict[str, Any]:
289
  "type": "object",
290
  "additionalProperties": False,
291
  "required": [
 
 
292
  "creature",
293
  "strength",
294
  "line",
@@ -297,6 +316,8 @@ def forest_batch_json_schema() -> dict[str, Any]:
297
  "image_prompt",
298
  ],
299
  "properties": {
 
 
300
  "creature": {"type": "string"},
301
  "strength": {"type": "string"},
302
  "line": {"type": "string"},
@@ -368,7 +389,9 @@ class CohereForestGenerator:
368
  "clearings. Every line must repeat at least one concrete noun or phrase from "
369
  "its situation. Acknowledge difficulty without diagnosis, guarantees, hollow "
370
  "praise, or toxic positivity. Spells begin with 'I' and use at most 12 words. "
371
- "Image prompts describe one creature only and contain no style words."
 
 
372
  ),
373
  "requests": list(requests),
374
  "voice_hints": list(source_hints)[:8],
 
169
 
170
  def build_sft_record(example: dict[str, Any]) -> dict[str, Any]:
171
  user_content = json.dumps(
172
+ {
173
+ "name": example["name"],
174
+ "situation": example["situation"],
175
+ "validated_fact_plan": {
176
+ "faithful_summary": example["situation"],
177
+ "fact_anchors": [
178
+ {
179
+ "source_phrase": example["situation"],
180
+ "meaning": example["situation"],
181
+ }
182
+ ],
183
+ "central_uncertainty": "What will happen next",
184
+ "desired_direction": "Move with clarity and care",
185
+ },
186
+ },
187
  ensure_ascii=False,
188
  )
189
  assistant_content = json.dumps(example["forest"], ensure_ascii=False)
 
237
  CREATURES[: variant % len(CREATURES)]
238
  )
239
  selected = rotated[:5]
240
+ roles = ("arrive", "steady", "widen", "step", "carry")
241
  clearings = []
242
  for clearing_index, (creature, strength, spell) in enumerate(selected):
243
  line_template = LINE_TEMPLATES[(variant + clearing_index) % len(LINE_TEMPLATES)]
244
  clearings.append(
245
  {
246
+ "arc_role": roles[clearing_index],
247
+ "source_phrase": situation,
248
  "creature": creature,
249
  "strength": strength,
250
  "line": line_template.format(
 
306
  "type": "object",
307
  "additionalProperties": False,
308
  "required": [
309
+ "arc_role",
310
+ "source_phrase",
311
  "creature",
312
  "strength",
313
  "line",
 
316
  "image_prompt",
317
  ],
318
  "properties": {
319
+ "arc_role": {"type": "string"},
320
+ "source_phrase": {"type": "string"},
321
  "creature": {"type": "string"},
322
  "strength": {"type": "string"},
323
  "line": {"type": "string"},
 
389
  "clearings. Every line must repeat at least one concrete noun or phrase from "
390
  "its situation. Acknowledge difficulty without diagnosis, guarantees, hollow "
391
  "praise, or toxic positivity. Spells begin with 'I' and use at most 12 words. "
392
+ "Use arrive, steady, widen, step, and optional carry in order. Each "
393
+ "source_phrase must copy exact text from the situation. Image prompts "
394
+ "describe one coherent scene and contain no style words or text."
395
  ),
396
  "requests": list(requests),
397
  "voice_hints": list(source_hints)[:8],
src/compliment_forest/schema.py CHANGED
@@ -4,15 +4,39 @@ from typing import Literal
4
 
5
  from pydantic import BaseModel, ConfigDict, Field, field_validator
6
 
 
 
 
 
 
 
 
 
 
7
 
8
  class StrictModel(BaseModel):
9
  model_config = ConfigDict(extra="forbid", str_strip_whitespace=True)
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class Clearing(StrictModel):
13
- creature: str = Field(min_length=3, max_length=80)
 
 
 
 
14
  strength: str = Field(min_length=3, max_length=100)
15
- line: str = Field(min_length=12, max_length=360)
16
  reflection: str = Field(min_length=12, max_length=260)
17
  spell: str = Field(min_length=3, max_length=80)
18
  image_prompt: str = Field(min_length=8, max_length=300)
@@ -27,6 +51,25 @@ class Clearing(StrictModel):
27
  return value
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  class ForestDraft(StrictModel):
31
  forest_title: str = Field(min_length=3, max_length=120)
32
  proposed_strengths: list[str] = Field(min_length=3, max_length=6)
@@ -73,7 +116,14 @@ class GuardResult(StrictModel):
73
 
74
 
75
  class StreamEvent(StrictModel):
76
- type: Literal["status", "support", "forest", "clearing", "complete", "error"]
 
 
 
 
 
 
 
 
77
  message: str = ""
78
  data: dict[str, object] = Field(default_factory=dict)
79
-
 
4
 
5
  from pydantic import BaseModel, ConfigDict, Field, field_validator
6
 
7
+ ForestStyle = Literal[
8
+ "surprise",
9
+ "watercolor",
10
+ "paper_cut",
11
+ "moonlit_gouache",
12
+ "botanical_ink",
13
+ ]
14
+ ArcRole = Literal["arrive", "steady", "widen", "step", "carry"]
15
+
16
 
17
  class StrictModel(BaseModel):
18
  model_config = ConfigDict(extra="forbid", str_strip_whitespace=True)
19
 
20
 
21
+ class FactAnchor(StrictModel):
22
+ source_phrase: str = Field(min_length=1, max_length=240)
23
+ meaning: str = Field(min_length=3, max_length=300)
24
+
25
+
26
+ class SituationPlan(StrictModel):
27
+ faithful_summary: str = Field(min_length=12, max_length=500)
28
+ fact_anchors: list[FactAnchor] = Field(min_length=1, max_length=4)
29
+ central_uncertainty: str = Field(min_length=3, max_length=300)
30
+ desired_direction: str = Field(min_length=3, max_length=300)
31
+
32
+
33
  class Clearing(StrictModel):
34
+ arc_role: ArcRole
35
+ source_phrase: str = Field(min_length=1, max_length=240)
36
+ scene_title: str = Field(min_length=3, max_length=80)
37
+ scene_intro: str = Field(min_length=12, max_length=240)
38
+ narration: str = Field(min_length=80, max_length=720)
39
  strength: str = Field(min_length=3, max_length=100)
 
40
  reflection: str = Field(min_length=12, max_length=260)
41
  spell: str = Field(min_length=3, max_length=80)
42
  image_prompt: str = Field(min_length=8, max_length=300)
 
51
  return value
52
 
53
 
54
+ class IntakeTurn(StrictModel):
55
+ question: str = Field(min_length=4, max_length=240)
56
+ answer: str = Field(min_length=1, max_length=240)
57
+
58
+
59
+ class IntakeQuestion(StrictModel):
60
+ question: str = Field(min_length=4, max_length=240)
61
+ options: list[str] = Field(min_length=3, max_length=4)
62
+ rationale: str = Field(default="", max_length=2000)
63
+
64
+ @field_validator("options")
65
+ @classmethod
66
+ def validate_unique_options(cls, values: list[str]) -> list[str]:
67
+ normalized = {value.casefold() for value in values}
68
+ if len(normalized) != len(values):
69
+ raise ValueError("options must be unique")
70
+ return values
71
+
72
+
73
  class ForestDraft(StrictModel):
74
  forest_title: str = Field(min_length=3, max_length=120)
75
  proposed_strengths: list[str] = Field(min_length=3, max_length=6)
 
116
 
117
 
118
  class StreamEvent(StrictModel):
119
+ type: Literal[
120
+ "status",
121
+ "support",
122
+ "forest",
123
+ "clearing",
124
+ "soundscape",
125
+ "complete",
126
+ "error",
127
+ ]
128
  message: str = ""
129
  data: dict[str, object] = Field(default_factory=dict)
 
src/compliment_forest/server.py CHANGED
@@ -1,18 +1,35 @@
1
  from __future__ import annotations
2
 
 
3
  from pathlib import Path
4
  from typing import Any
5
 
6
  import gradio as gr
 
 
7
  from fastapi.responses import FileResponse
8
  from fastapi.staticfiles import StaticFiles
9
  from pydantic import BaseModel, ConfigDict, Field
10
  from starlette.responses import StreamingResponse
11
 
12
- from .backends.image import DemoImageBackend, FluxImageBackend
13
- from .backends.text import DemoTextBackend, LlamaCppTextBackend
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from .config import AppConfig
15
- from .orchestrator import ForestOrchestrator
 
16
  from .trace import TraceRecorder
17
 
18
 
@@ -22,18 +39,67 @@ class ForestRequest(BaseModel):
22
  name: str = Field(min_length=1, max_length=80)
23
  situation: str = Field(min_length=1, max_length=1200)
24
  seed: int | None = Field(default=None, ge=0, le=2_147_483_647)
 
 
25
 
26
 
27
- def build_orchestrator(config: AppConfig) -> ForestOrchestrator:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  if config.text_backend == "llama_cpp":
29
  text_backend = LlamaCppTextBackend(
30
  base_url=config.llama_base_url,
31
  model=config.llama_model,
32
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  else:
34
  text_backend = DemoTextBackend()
35
 
36
- if config.image_backend == "flux":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  image_backend = FluxImageBackend(
38
  model_id=config.flux_model_id,
39
  lora_id=config.flux_lora_id,
@@ -41,10 +107,21 @@ def build_orchestrator(config: AppConfig) -> ForestOrchestrator:
41
  )
42
  else:
43
  image_backend = DemoImageBackend()
 
 
 
 
 
 
 
 
 
 
44
  trace_recorder = TraceRecorder(config.trace_path) if config.trace_path else None
45
  return ForestOrchestrator(
46
  text_backend=text_backend,
47
  image_backend=image_backend,
 
48
  trace_recorder=trace_recorder,
49
  )
50
 
@@ -54,9 +131,24 @@ def create_app(
54
  config: AppConfig | None = None,
55
  orchestrator: Any | None = None,
56
  frontend_dir: str | Path | None = None,
 
 
 
57
  ) -> gr.Server:
58
  runtime = config or AppConfig.from_env()
59
- forest = orchestrator or build_orchestrator(runtime)
 
 
 
 
 
 
 
 
 
 
 
 
60
  frontend = (
61
  Path(frontend_dir)
62
  if frontend_dir is not None
@@ -69,17 +161,31 @@ def create_app(
69
  redoc_url=None,
70
  )
71
 
 
 
 
 
 
 
72
  @app.get("/")
73
  def index() -> FileResponse:
74
- return FileResponse(frontend / "index.html")
75
 
76
  @app.get("/styles.css")
77
  def styles() -> FileResponse:
78
- return FileResponse(frontend / "styles.css", media_type="text/css")
 
 
 
 
79
 
80
  @app.get("/app.js")
81
  def javascript() -> FileResponse:
82
- return FileResponse(frontend / "app.js", media_type="text/javascript")
 
 
 
 
83
 
84
  assets = frontend / "assets"
85
  if assets.exists():
@@ -87,20 +193,146 @@ def create_app(
87
 
88
  @app.get("/health")
89
  def health() -> dict[str, object]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  return {
91
  "status": "ok",
92
  "text_backend": runtime.text_backend,
 
93
  "image_backend": runtime.image_backend,
94
- "off_grid": True,
95
- "model_parameter_budget_billions": 18,
 
 
 
 
96
  }
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  @app.post("/api/forest")
99
  def generate_forest(request: ForestRequest) -> StreamingResponse:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def stream():
 
101
  seed = request.seed if request.seed is not None else runtime.default_seed
102
- for event in forest.generate(request.name, request.situation, seed):
103
- yield event.model_dump_json() + "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  return StreamingResponse(stream(), media_type="application/x-ndjson")
106
 
 
1
  from __future__ import annotations
2
 
3
+ from collections.abc import Callable
4
  from pathlib import Path
5
  from typing import Any
6
 
7
  import gradio as gr
8
+ import httpx
9
+ from fastapi import HTTPException
10
  from fastapi.responses import FileResponse
11
  from fastapi.staticfiles import StaticFiles
12
  from pydantic import BaseModel, ConfigDict, Field
13
  from starlette.responses import StreamingResponse
14
 
15
+ from .backends.image import (
16
+ DemoImageBackend,
17
+ FluxImageBackend,
18
+ HfInferenceImageBackend,
19
+ ModalImageBackend,
20
+ ZeroGpuImageBackend,
21
+ )
22
+ from .backends.music import ModalMusicBackend, NoMusicBackend
23
+ from .backends.text import (
24
+ DemoTextBackend,
25
+ HfInferenceTextBackend,
26
+ LlamaCppTextBackend,
27
+ ModalTextBackend,
28
+ TransformersTextBackend,
29
+ )
30
  from .config import AppConfig
31
+ from .orchestrator import ForestOrchestrator, build_guided_situation
32
+ from .schema import ForestStyle, IntakeQuestion, IntakeTurn, StreamEvent
33
  from .trace import TraceRecorder
34
 
35
 
 
39
  name: str = Field(min_length=1, max_length=80)
40
  situation: str = Field(min_length=1, max_length=1200)
41
  seed: int | None = Field(default=None, ge=0, le=2_147_483_647)
42
+ style: ForestStyle | None = None
43
+ intake: list[IntakeTurn] = Field(default_factory=list, max_length=5)
44
 
45
 
46
+ class IntakeNextRequest(BaseModel):
47
+ model_config = ConfigDict(extra="forbid", str_strip_whitespace=True)
48
+
49
+ name: str = Field(min_length=1, max_length=80)
50
+ situation: str = Field(min_length=1, max_length=1200)
51
+ history: list[IntakeTurn] = Field(default_factory=list, max_length=5)
52
+ seed: int | None = Field(default=None, ge=0, le=2_147_483_647)
53
+
54
+
55
+ def build_orchestrator(
56
+ config: AppConfig,
57
+ *,
58
+ gpu_image_generator: Callable[[str, int, str], str] | None = None,
59
+ gpu_text_generator: Callable[[list[dict[str, str]], dict[str, object]], str] | None = None,
60
+ ) -> ForestOrchestrator:
61
  if config.text_backend == "llama_cpp":
62
  text_backend = LlamaCppTextBackend(
63
  base_url=config.llama_base_url,
64
  model=config.llama_model,
65
  )
66
+ elif config.text_backend == "hf_inference":
67
+ text_backend = HfInferenceTextBackend(model=config.hf_text_model)
68
+ elif config.text_backend == "transformers":
69
+ if gpu_text_generator is None:
70
+ raise ValueError("transformers text backend requires a GPU text generator")
71
+ text_backend = TransformersTextBackend(
72
+ model=config.transformers_text_model,
73
+ generator=gpu_text_generator,
74
+ )
75
+ elif config.text_backend == "modal":
76
+ assert config.modal_text_endpoint is not None
77
+ assert config.modal_signing_key is not None
78
+ text_backend = ModalTextBackend(
79
+ endpoint=config.modal_text_endpoint,
80
+ signing_key=config.modal_signing_key.get_secret_value(),
81
+ )
82
  else:
83
  text_backend = DemoTextBackend()
84
 
85
+ if config.image_backend == "modal":
86
+ assert config.modal_image_endpoint is not None
87
+ assert config.modal_signing_key is not None
88
+ image_backend = ModalImageBackend(
89
+ endpoint=config.modal_image_endpoint,
90
+ signing_key=config.modal_signing_key.get_secret_value(),
91
+ fallback=HfInferenceImageBackend(model=config.hf_image_model),
92
+ )
93
+ elif config.image_backend == "zerogpu":
94
+ if gpu_image_generator is None:
95
+ raise ValueError("zerogpu image backend requires a GPU image generator")
96
+ image_backend = ZeroGpuImageBackend(
97
+ gpu_image_generator,
98
+ fallback=HfInferenceImageBackend(model=config.hf_image_model),
99
+ )
100
+ elif config.image_backend == "hf_inference":
101
+ image_backend = HfInferenceImageBackend(model=config.hf_image_model)
102
+ elif config.image_backend == "flux":
103
  image_backend = FluxImageBackend(
104
  model_id=config.flux_model_id,
105
  lora_id=config.flux_lora_id,
 
107
  )
108
  else:
109
  image_backend = DemoImageBackend()
110
+
111
+ if config.music_backend == "modal":
112
+ assert config.modal_music_endpoint is not None
113
+ assert config.modal_signing_key is not None
114
+ music_backend = ModalMusicBackend(
115
+ endpoint=config.modal_music_endpoint,
116
+ signing_key=config.modal_signing_key.get_secret_value(),
117
+ )
118
+ else:
119
+ music_backend = NoMusicBackend()
120
  trace_recorder = TraceRecorder(config.trace_path) if config.trace_path else None
121
  return ForestOrchestrator(
122
  text_backend=text_backend,
123
  image_backend=image_backend,
124
+ music_backend=music_backend,
125
  trace_recorder=trace_recorder,
126
  )
127
 
 
131
  config: AppConfig | None = None,
132
  orchestrator: Any | None = None,
133
  frontend_dir: str | Path | None = None,
134
+ gpu_image_generator: Callable[[str, int, str], str] | None = None,
135
+ gpu_text_generator: Callable[[list[dict[str, str]], dict[str, object]], str] | None = None,
136
+ upstream_client: httpx.Client | None = None,
137
  ) -> gr.Server:
138
  runtime = config or AppConfig.from_env()
139
+ forest = None
140
+ if runtime.upstream_space_url is None:
141
+ forest = orchestrator or build_orchestrator(
142
+ runtime,
143
+ gpu_image_generator=gpu_image_generator,
144
+ gpu_text_generator=gpu_text_generator,
145
+ )
146
+ proxy = upstream_client
147
+ if runtime.upstream_space_url and proxy is None:
148
+ proxy = httpx.Client(
149
+ timeout=httpx.Timeout(600, connect=30),
150
+ follow_redirects=True,
151
+ )
152
  frontend = (
153
  Path(frontend_dir)
154
  if frontend_dir is not None
 
161
  redoc_url=None,
162
  )
163
 
164
+ # Browsers will heuristically cache static files for hours when no
165
+ # Cache-Control header is present, and HF Spaces does not set one for
166
+ # FastAPI-served files. Force revalidation so each Space rebuild is
167
+ # immediately visible without a cache wipe on the user's side.
168
+ _NO_CACHE = {"Cache-Control": "no-cache, must-revalidate"}
169
+
170
  @app.get("/")
171
  def index() -> FileResponse:
172
+ return FileResponse(frontend / "index.html", headers=_NO_CACHE)
173
 
174
  @app.get("/styles.css")
175
  def styles() -> FileResponse:
176
+ return FileResponse(
177
+ frontend / "styles.css",
178
+ media_type="text/css",
179
+ headers=_NO_CACHE,
180
+ )
181
 
182
  @app.get("/app.js")
183
  def javascript() -> FileResponse:
184
+ return FileResponse(
185
+ frontend / "app.js",
186
+ media_type="text/javascript",
187
+ headers=_NO_CACHE,
188
+ )
189
 
190
  assets = frontend / "assets"
191
  if assets.exists():
 
193
 
194
  @app.get("/health")
195
  def health() -> dict[str, object]:
196
+ if runtime.upstream_space_url:
197
+ return {
198
+ "status": "ok",
199
+ "runtime_mode": "upstream_proxy",
200
+ "upstream_space_url": runtime.upstream_space_url,
201
+ "off_grid": False,
202
+ "fresh_images": True,
203
+ "default_style": runtime.default_style,
204
+ "model_parameter_budget_billions": 25,
205
+ "phase1_model_parameter_budget_billions": 18,
206
+ }
207
+ hosted = bool(
208
+ {"hf_inference", "modal", "zerogpu", "transformers"}
209
+ & {runtime.text_backend, runtime.image_backend}
210
+ )
211
+ runtime_text_model = {
212
+ "demo": "demo",
213
+ "hf_inference": runtime.hf_text_model,
214
+ "llama_cpp": runtime.llama_model,
215
+ "transformers": runtime.transformers_text_model,
216
+ "modal": "openbmb/MiniCPM4.1-8B (Modal)",
217
+ }[runtime.text_backend]
218
+ phase1_budget = (
219
+ 18 if runtime.text_backend == "llama_cpp" and runtime.image_backend == "flux" else None
220
+ )
221
+ active_budget = phase1_budget
222
+ uses_minicpm = (
223
+ runtime.text_backend == "modal"
224
+ or (
225
+ runtime.text_backend == "transformers"
226
+ and runtime.transformers_text_model.endswith("MiniCPM4.1-8B")
227
+ )
228
+ or (
229
+ runtime.text_backend == "hf_inference"
230
+ and runtime.hf_text_model.endswith("MiniCPM4.1-8B")
231
+ )
232
+ )
233
+ if uses_minicpm:
234
+ active_budget = 25
235
  return {
236
  "status": "ok",
237
  "text_backend": runtime.text_backend,
238
+ "runtime_text_model": runtime_text_model,
239
  "image_backend": runtime.image_backend,
240
+ "music_backend": runtime.music_backend,
241
+ "off_grid": not hosted,
242
+ "fresh_images": runtime.image_backend != "demo",
243
+ "default_style": runtime.default_style,
244
+ "model_parameter_budget_billions": active_budget,
245
+ "phase1_model_parameter_budget_billions": 18,
246
  }
247
 
248
+ @app.post("/api/intake/next")
249
+ def next_intake(request: IntakeNextRequest) -> IntakeQuestion:
250
+ if runtime.upstream_space_url:
251
+ assert proxy is not None
252
+ try:
253
+ response = proxy.post(
254
+ f"{runtime.upstream_space_url}/api/intake/next",
255
+ json=request.model_dump(mode="json"),
256
+ )
257
+ response.raise_for_status()
258
+ return IntakeQuestion.model_validate(response.json())
259
+ except (httpx.HTTPError, ValueError) as error:
260
+ raise HTTPException(
261
+ status_code=502,
262
+ detail=f"The forest could not reach its generation service: {error}",
263
+ ) from error
264
+
265
+ from .safety import guard_input
266
+
267
+ assert forest is not None
268
+ guard = guard_input(request.name, request.situation)
269
+ if not guard.allowed:
270
+ raise HTTPException(status_code=400, detail=guard.message)
271
+ if len(request.history) >= 5:
272
+ raise HTTPException(status_code=400, detail="intake already complete")
273
+ seed = (request.seed if request.seed is not None else runtime.default_seed) + len(
274
+ request.history
275
+ )
276
+ try:
277
+ return forest.next_intake_question(
278
+ request.name,
279
+ request.situation,
280
+ request.history,
281
+ seed=seed,
282
+ )
283
+ except ValueError as error:
284
+ raise HTTPException(
285
+ status_code=502,
286
+ detail=f"The forest could not produce a question: {error}",
287
+ ) from error
288
+
289
  @app.post("/api/forest")
290
  def generate_forest(request: ForestRequest) -> StreamingResponse:
291
+ if runtime.upstream_space_url:
292
+
293
+ def proxy_stream():
294
+ assert proxy is not None
295
+ try:
296
+ with proxy.stream(
297
+ "POST",
298
+ f"{runtime.upstream_space_url}/api/forest",
299
+ json=request.model_dump(mode="json"),
300
+ ) as response:
301
+ response.raise_for_status()
302
+ yield from response.iter_bytes()
303
+ except httpx.HTTPError as error:
304
+ yield (
305
+ StreamEvent(
306
+ type="error",
307
+ message=(
308
+ "The forest could not reach its generation service: "
309
+ f"{error}"
310
+ ),
311
+ ).model_dump_json()
312
+ + "\n"
313
+ )
314
+
315
+ return StreamingResponse(proxy_stream(), media_type="application/x-ndjson")
316
+
317
  def stream():
318
+ assert forest is not None
319
  seed = request.seed if request.seed is not None else runtime.default_seed
320
+ style = request.style or runtime.default_style
321
+ model_situation = build_guided_situation(request.situation, request.intake)
322
+ try:
323
+ for event in forest.generate(
324
+ request.name,
325
+ request.situation,
326
+ seed,
327
+ style,
328
+ model_situation=model_situation,
329
+ ):
330
+ yield event.model_dump_json() + "\n"
331
+ except Exception as error:
332
+ yield StreamEvent(
333
+ type="error",
334
+ message=f"The forest could not grow: {error}",
335
+ ).model_dump_json() + "\n"
336
 
337
  return StreamingResponse(stream(), media_type="application/x-ndjson")
338
 
src/compliment_forest/style_data.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Literal
5
+
6
+ from .backends.image import STYLE_PROFILES, compose_flux_prompt
7
+
8
+ SceneCategory = Literal["animal", "human", "object", "environment"]
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class ForestScene:
13
+ slug: str
14
+ category: SceneCategory
15
+ prompt: str
16
+
17
+
18
+ FOREST_SCENES = (
19
+ ForestScene(
20
+ "fox-threshold",
21
+ "animal",
22
+ "a gentle red fox pausing at the edge of a fern-lined path",
23
+ ),
24
+ ForestScene(
25
+ "listening-owl",
26
+ "animal",
27
+ "a round barn owl listening from a low mossy branch",
28
+ ),
29
+ ForestScene(
30
+ "steady-deer",
31
+ "animal",
32
+ "a young deer standing calmly between silver birch trees",
33
+ ),
34
+ ForestScene(
35
+ "brave-snail",
36
+ "animal",
37
+ "a tiny snail crossing a dew-covered fern at dawn",
38
+ ),
39
+ ForestScene(
40
+ "singing-wren",
41
+ "animal",
42
+ "a small wren singing beside loose woodland flowers",
43
+ ),
44
+ ForestScene(
45
+ "river-otter",
46
+ "animal",
47
+ "a river otter holding one smooth stone beside quiet reeds",
48
+ ),
49
+ ForestScene(
50
+ "thoughtful-badger",
51
+ "animal",
52
+ "a thoughtful badger beside a lantern-shaped mushroom",
53
+ ),
54
+ ForestScene(
55
+ "patient-hare",
56
+ "animal",
57
+ "a patient brown hare resting beneath arching grasses",
58
+ ),
59
+ ForestScene(
60
+ "moonlit-moth",
61
+ "animal",
62
+ "a luna moth hovering near moonlit foxgloves",
63
+ ),
64
+ ForestScene(
65
+ "walking-turtle",
66
+ "animal",
67
+ "a small woodland turtle moving between clover and stones",
68
+ ),
69
+ ForestScene(
70
+ "person-open-window",
71
+ "human",
72
+ "an adult seen from behind opening a window to pale morning light",
73
+ ),
74
+ ForestScene(
75
+ "person-blank-notebook",
76
+ "human",
77
+ "an adult seated at a wooden desk with an open blank notebook",
78
+ ),
79
+ ForestScene(
80
+ "person-forked-path",
81
+ "human",
82
+ "a small human figure viewed from behind at a gentle fork in a path",
83
+ ),
84
+ ForestScene(
85
+ "person-train-platform",
86
+ "human",
87
+ "a quiet adult figure waiting on a misty train platform with one bag",
88
+ ),
89
+ ForestScene(
90
+ "person-moving-box",
91
+ "human",
92
+ "an adult carrying one moving box toward a sunlit doorway",
93
+ ),
94
+ ForestScene(
95
+ "person-footbridge",
96
+ "human",
97
+ "a side-view figure taking one step across a narrow wooden footbridge",
98
+ ),
99
+ ForestScene(
100
+ "person-doorway",
101
+ "human",
102
+ "a calm adult silhouette standing in an open doorway between two rooms",
103
+ ),
104
+ ForestScene(
105
+ "person-seedling",
106
+ "human",
107
+ "hands gently watering a small seedling on a windowsill",
108
+ ),
109
+ ForestScene(
110
+ "person-rain-shelter",
111
+ "human",
112
+ "an adult seen from the side resting on a bench beneath a rain shelter",
113
+ ),
114
+ ForestScene(
115
+ "person-dawn-hill",
116
+ "human",
117
+ "a distant human figure standing on a low hillside at dawn",
118
+ ),
119
+ ForestScene(
120
+ "lantern-crossroads",
121
+ "object",
122
+ "a small glowing lantern placed where two woodland paths meet",
123
+ ),
124
+ ForestScene(
125
+ "map-compass",
126
+ "object",
127
+ "an unfolded map and simple compass resting on a wooden table",
128
+ ),
129
+ ForestScene(
130
+ "open-notebook",
131
+ "object",
132
+ "an open blank notebook beside a pencil and one pressed leaf",
133
+ ),
134
+ ForestScene(
135
+ "stepping-stones",
136
+ "object",
137
+ "four smooth stepping stones crossing a narrow stream",
138
+ ),
139
+ ForestScene(
140
+ "warm-cup",
141
+ "object",
142
+ "a warm ceramic cup sending a thin curl of steam into morning light",
143
+ ),
144
+ ForestScene(
145
+ "woven-thread",
146
+ "object",
147
+ "loose green and gold threads gradually woven into one calm pattern",
148
+ ),
149
+ ForestScene(
150
+ "key-and-door",
151
+ "object",
152
+ "a simple brass key resting beside a small unopened wooden door",
153
+ ),
154
+ ForestScene(
155
+ "paper-boat",
156
+ "object",
157
+ "a single paper boat floating on still water beneath willow reflections",
158
+ ),
159
+ ForestScene(
160
+ "balanced-stones",
161
+ "object",
162
+ "three imperfect river stones balanced beside soft grasses",
163
+ ),
164
+ ForestScene(
165
+ "empty-chair-light",
166
+ "object",
167
+ "an empty wooden chair in a quiet patch of warm window light",
168
+ ),
169
+ ForestScene(
170
+ "winding-path",
171
+ "environment",
172
+ "a winding path disappearing gently through tall ferns and morning mist",
173
+ ),
174
+ ForestScene(
175
+ "river-crossing",
176
+ "environment",
177
+ "a shallow river crossing with stones visible beneath clear water",
178
+ ),
179
+ ForestScene(
180
+ "room-at-dawn",
181
+ "environment",
182
+ "a quiet room at dawn with curtains moving beside an open window",
183
+ ),
184
+ ForestScene(
185
+ "city-garden",
186
+ "environment",
187
+ "a small green garden between quiet city buildings after rain",
188
+ ),
189
+ ForestScene(
190
+ "misty-platform",
191
+ "environment",
192
+ "an empty train platform fading softly into early morning mist",
193
+ ),
194
+ ForestScene(
195
+ "clearing-after-rain",
196
+ "environment",
197
+ "a forest clearing just after rain with one bright opening in the clouds",
198
+ ),
199
+ ForestScene(
200
+ "hillside-trail",
201
+ "environment",
202
+ "a gradual hillside trail curving toward a pale open horizon",
203
+ ),
204
+ ForestScene(
205
+ "staircase-light",
206
+ "environment",
207
+ "a simple staircase with warm light falling across the next three steps",
208
+ ),
209
+ ForestScene(
210
+ "canopy-opening",
211
+ "environment",
212
+ "a dark green canopy opening into a circle of soft sky",
213
+ ),
214
+ ForestScene(
215
+ "shoreline-horizon",
216
+ "environment",
217
+ "a calm shoreline where fading clouds meet a wide quiet horizon",
218
+ ),
219
+ )
220
+
221
+ # Compatibility alias for callers that used the v1 name.
222
+ FOREST_SUBJECTS = FOREST_SCENES
223
+
224
+ TRAINED_STYLE_IDS = (
225
+ "watercolor",
226
+ "paper_cut",
227
+ "moonlit_gouache",
228
+ "botanical_ink",
229
+ )
230
+
231
+
232
+ def build_style_records(
233
+ *,
234
+ samples_per_style: int = 40,
235
+ base_seed: int = 9000,
236
+ ) -> list[dict[str, str | int]]:
237
+ if not 1 <= samples_per_style <= len(FOREST_SCENES):
238
+ raise ValueError(f"samples_per_style must be between 1 and {len(FOREST_SCENES)}")
239
+
240
+ records: list[dict[str, str | int]] = []
241
+ for style_offset, style in enumerate(TRAINED_STYLE_IDS):
242
+ profile = STYLE_PROFILES[style]
243
+ for scene_index, scene in enumerate(FOREST_SCENES[:samples_per_style]):
244
+ seed = base_seed + style_offset * 1000 + scene_index
245
+ prompt = compose_flux_prompt(
246
+ scene.prompt,
247
+ style=style, # type: ignore[arg-type]
248
+ seed=seed,
249
+ )
250
+ records.append(
251
+ {
252
+ "style": style,
253
+ "trigger": profile.trigger,
254
+ "category": scene.category,
255
+ "subject": scene.slug,
256
+ "seed": seed,
257
+ "prompt": prompt,
258
+ "text": (
259
+ f"{profile.trigger}, {scene.prompt}, "
260
+ f"{profile.label.lower()} storybook scene"
261
+ ),
262
+ "file_name": (
263
+ f"{scene_index:03d}-{scene.category}-{scene.slug}.png"
264
+ ),
265
+ }
266
+ )
267
+ return records
src/compliment_forest/training.py CHANGED
@@ -1,6 +1,6 @@
1
  from __future__ import annotations
2
 
3
- from typing import Any, Protocol
4
 
5
  from pydantic import BaseModel, ConfigDict, Field
6
 
@@ -89,6 +89,82 @@ class FluxTrainingConfig(BaseModel):
89
  )
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def format_training_example(
93
  example: dict[str, Any],
94
  tokenizer: ChatTemplateTokenizer,
 
1
  from __future__ import annotations
2
 
3
+ from typing import Any, Literal, Protocol
4
 
5
  from pydantic import BaseModel, ConfigDict, Field
6
 
 
89
  )
90
 
91
 
92
+ TrainedForestStyle = Literal[
93
+ "watercolor",
94
+ "paper_cut",
95
+ "moonlit_gouache",
96
+ "botanical_ink",
97
+ ]
98
+
99
+ _STYLE_TRAINING = {
100
+ "watercolor": {
101
+ "trigger": "cmprst_watercolor",
102
+ "repo_suffix": "watercolor",
103
+ "validation": "a gentle fox pausing beside ferns at dawn",
104
+ },
105
+ "paper_cut": {
106
+ "trigger": "cmprst_papercut",
107
+ "repo_suffix": "paper-cut",
108
+ "validation": "a thoughtful badger beside layered woodland leaves",
109
+ },
110
+ "moonlit_gouache": {
111
+ "trigger": "cmprst_moonlit",
112
+ "repo_suffix": "moonlit-gouache",
113
+ "validation": "a small owl resting in a moonlit pine clearing",
114
+ },
115
+ "botanical_ink": {
116
+ "trigger": "cmprst_inkwash",
117
+ "repo_suffix": "botanical-ink",
118
+ "validation": "a patient hare beneath sparse woodland flowers",
119
+ },
120
+ }
121
+
122
+
123
+ class FluxStyleTrainingConfig(BaseModel):
124
+ model_config = ConfigDict(extra="forbid", frozen=True)
125
+
126
+ style: TrainedForestStyle
127
+ base_model: str = "black-forest-labs/FLUX.1-schnell"
128
+ dataset_id: str = "thangvip/compliment-forest-multistyle-v2"
129
+ dataset_config_name: str
130
+ model_id: str
131
+ trigger_token: str
132
+ output_dir: str
133
+ resolution: int = 512
134
+ max_train_steps: int = 300
135
+ train_batch_size: int = 1
136
+ gradient_accumulation_steps: int = 1
137
+ learning_rate: float = 1e-4
138
+ rank: int = 16
139
+ lora_alpha: int = 16
140
+ repeats: int = 3
141
+ seed: int = 3407
142
+ guidance_scale: float = 0
143
+ validation_prompt: str
144
+
145
+ @classmethod
146
+ def for_style(
147
+ cls,
148
+ style: TrainedForestStyle,
149
+ *,
150
+ smoke: bool = False,
151
+ ) -> FluxStyleTrainingConfig:
152
+ spec = _STYLE_TRAINING[style]
153
+ config = cls(
154
+ style=style,
155
+ dataset_config_name=style,
156
+ model_id=(
157
+ f"thangvip/compliment-forest-{spec['repo_suffix']}-flux-lora-v2"
158
+ ),
159
+ trigger_token=spec["trigger"],
160
+ output_dir=f"/training/compliment-forest-{spec['repo_suffix']}-flux",
161
+ validation_prompt=f"{spec['trigger']}, {spec['validation']}",
162
+ )
163
+ if smoke:
164
+ return config.model_copy(update={"max_train_steps": 2, "repeats": 1})
165
+ return config
166
+
167
+
168
  def format_training_example(
169
  example: dict[str, Any],
170
  tokenizer: ChatTemplateTokenizer,