Synesthesia / modules /generation_loop.py
Ashiedu's picture
Sync unified workbench
0490201 verified
"""Synesthesia generation loop module.
Bridge between the Streamlit workbench UI and the rolling clip scheduler
worker process. The scheduler runs as an isolated subprocess β€” the UI only
reads JSON state files and writes steering / stop sentinels.
IPC channel
-----------
::
Streamlit UI
β”‚
β”‚ reads: loop_state.json, clips_manifest.json, events.jsonl
β”‚ writes: pending_steering.json, stop.requested
β”‚
β–Ό
Worker Process (ML_Pipeline/workbench_worker.py --spec spec.json)
β”œβ”€β”€ SynesthesiaRuntime (GPU validation + model acquisition)
β”œβ”€β”€ RollingClipScheduler (monitor + generator + metrics threads)
β”œβ”€β”€ CrossfadePlayer (ring buffer audio output)
└── MetricsReporter (writes loop_state.json every second)
"""
from __future__ import annotations
import json
import os
import sys
import time
import uuid
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Optional
# Ensure ML_Pipeline and runtime are importable
_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(_ROOT))
from ML_Pipeline.workbench_core import (
JobSpec,
LoopSessionSpec,
SteeringRequest,
launch_worker,
new_job,
read_events,
read_json,
utcnow,
write_json,
)
# ---------------------------------------------------------------------------
# LoopConfig dataclass
# ---------------------------------------------------------------------------
@dataclass
class LoopConfig:
"""Configuration for a loop generation session."""
clip_duration: float = 20.0
generation_lead: float = 10.0
crossfade_duration: float = 3.0
runtime: str = "torch"
precision: str = "4bit"
active_model: str = "base"
chunk_duration: float = 2.0
guidance_weight: float = 4.0
temperature: float = 1.0
model_dir: str = "Content/MLModels"
top_k: int = 0
def to_session_spec(
self,
output_root: str,
steering_payload: dict[str, Any],
) -> LoopSessionSpec:
return LoopSessionSpec(
model_variant=self.active_model,
quantization_tier=self.precision,
chunk_duration=self.chunk_duration,
total_duration=0.0, # infinite / rolling
no_cpu_fallback=True,
output_root=output_root,
steering_payload=steering_payload,
model_dir=self.model_dir,
guidance_weight=self.guidance_weight,
temperature=self.temperature,
top_k=self.top_k,
)
# ---------------------------------------------------------------------------
# Session management
# ---------------------------------------------------------------------------
def start_loop_session(
loop_config: LoopConfig,
initial_steering: dict[str, Any],
prompt: str,
) -> dict[str, Any]:
"""Start the rolling scheduler as an isolated subprocess worker.
The worker process runs independently. If the Streamlit UI reloads,
the worker continues generating.
Parameters
----------
loop_config : LoopConfig
Timing and model configuration.
initial_steering : dict
Initial steering payload (prompt, mode_preset, etc.).
prompt : str
Natural-language prompt for the first generation.
Returns
-------
dict
``{"job_id": ..., "output_root": ..., "status": "launched"}``
"""
# Merge prompt into steering payload
steering = {**initial_steering, "natural_language_prompt": prompt}
# Create job spec
spec = new_job(
job_type="run_loop_session",
runtime=loop_config.runtime,
payload={
"loop_config": asdict(loop_config),
"steering": steering,
"prompt": prompt,
"scheduler": "rolling", # tell the worker to use rolling scheduler
},
)
output_root = Path(spec.output_root)
# Write initial loop state file
_write_loop_state(output_root, {
"status": "launching",
"job_id": spec.job_id,
"prompt": prompt,
"buffer_depth": 0,
"buffer_state": "priming",
"playing_clip_index": None,
"queued_clip_index": None,
"generating_clip_index": None,
"underrun_count": 0,
"latest_metrics": {},
"started_at": utcnow(),
})
# Launch as isolated subprocess (start_new_session=True)
proc = launch_worker(spec)
return {
"job_id": spec.job_id,
"output_root": str(output_root),
"pid": proc.pid,
"status": "launched",
}
def stop_loop_session(output_root: str) -> None:
"""Signal the worker to stop by writing a sentinel file."""
sentinel = Path(output_root) / "stop.requested"
sentinel.touch()
def loop_job_running(output_root: str) -> bool:
"""Check if the worker process is still alive.
Reads from the PID file or checks process state.
"""
out = Path(output_root)
# Check stop sentinel
if (out / "stop.requested").exists():
return False
# Check worker logs β€” if events.jsonl has a terminal event, it's done
events = read_events(out)
for ev in reversed(events):
if ev.get("status") in ("completed", "failed", "cancelled"):
return False
# Check if loop_state.json was recently updated (within last 10s)
state_path = out / "loop_state.json"
if state_path.exists():
mtime = state_path.stat().st_mtime
if time.time() - mtime < 10.0:
return True
return False
# ---------------------------------------------------------------------------
# State readers
# ---------------------------------------------------------------------------
def load_loop_state(output_root: str) -> dict[str, Any]:
"""Read the current loop state from the IPC state file.
Written by the worker's MetricsReporter every second.
"""
state_path = Path(output_root) / "loop_state.json"
if not state_path.exists():
return _empty_loop_state()
try:
return json.loads(state_path.read_text())
except (json.JSONDecodeError, OSError):
return _empty_loop_state()
def load_clips_manifest(output_root: str) -> dict[str, Any]:
"""Read the clips manifest (list of generated clips and their metadata)."""
manifest_path = Path(output_root) / "clips_manifest.json"
if not manifest_path.exists():
return {"clips": [], "total_clips": 0}
try:
return json.loads(manifest_path.read_text())
except (json.JSONDecodeError, OSError):
return {"clips": [], "total_clips": 0}
def _empty_loop_state() -> dict[str, Any]:
return {
"status": "not_started",
"buffer_depth": 0,
"buffer_state": "priming",
"playing_clip_index": None,
"queued_clip_index": None,
"generating_clip_index": None,
"underrun_count": 0,
"latest_metrics": {},
}
# ---------------------------------------------------------------------------
# Steering
# ---------------------------------------------------------------------------
def next_safe_clip_index(current_index: Optional[int], is_running: bool) -> int:
"""Calculate the next safe clip index for hot-swap steering.
Targets the *next* queued clip (not the currently playing one)
to avoid audible glitches.
"""
if not is_running or current_index is None:
return 0
return current_index + 2 # skip playing + queued, target next-to-generate
def queue_steering_update(
output_root: str,
steering: dict[str, Any],
target_clip_index: int,
) -> dict[str, Any]:
"""Write a pending steering update for the worker to pick up.
The worker polls for ``pending_steering.json`` once per generation cycle.
"""
update = {
"steering": steering,
"target_clip_index": target_clip_index,
"queued_at": utcnow(),
"id": uuid.uuid4().hex[:8],
}
pending_path = Path(output_root) / "pending_steering.json"
# Atomic write via temp file
tmp = pending_path.with_suffix(".tmp")
tmp.write_text(json.dumps(update))
tmp.replace(pending_path)
return update
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _write_loop_state(output_root: Path, state: dict[str, Any]) -> None:
state_path = output_root / "loop_state.json"
tmp = state_path.with_suffix(".tmp")
tmp.write_text(json.dumps(state, default=str))
tmp.replace(state_path)