LightDiffusion-Next / src /user /app_instance.py
Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""
App instance for managing UI state and real-time previews
"""
import os
import threading
import time
import tempfile
from typing import List, Any
from PIL import Image
class AppInstance:
"""Main application instance for managing UI state and previews"""
def __init__(self):
self.previewer_var = PreviewerVar()
requested_preview_dir = os.getenv("LD_PREVIEW_DIR") or os.path.join(".", "output", "preview")
self.preview_dir = requested_preview_dir
self.preview_lock = threading.Lock()
self.preview_files = []
self.preview_images = [] # Store PIL images directly
self.preview_base64_cache = [] # Cached base64 strings
self.last_preview_time = 0
self.current_step = 0
self.total_steps = 0
self.progress = ProgressTracker()
self._interrupt_event = threading.Event()
# Prefer the configured preview directory, but fall back to a temp
# location when the working tree is not writable (for example, during
# constrained test runs or read-only deployments).
try:
os.makedirs(self.preview_dir, exist_ok=True)
except OSError:
self.preview_dir = os.path.join(tempfile.gettempdir(), "lightdiffusion-preview")
os.makedirs(self.preview_dir, exist_ok=True)
# Preview rendering/config options (tunable)
self.preview_srgb = True # apply sRGB curve to previews
self.preview_format = "WEBP" # 'WEBP' or 'JPEG' or 'PNG'
self.preview_quality = 90 # quality for lossy formats (0-100)
self.preview_resample = "LANCZOS" # resampling preference name
self.preview_apply_fast_autohdr = False # lightweight autohdr for previews (disabled by default)
def update_image(self, images: List[Any], step: int = 0, total_steps: int = 0):
"""Update the gallery with preview images in real-time.
Args:
images: List of PIL.Image or base64 strings
step: Current step
total_steps: Total steps
"""
with self.preview_lock:
# Update metadata
self.current_step = step
self.total_steps = total_steps
timestamp = int(time.time() * 1000)
self.last_preview_time = timestamp
# Store images (or strings) directly to avoid conversion overhead in sampling loop
self.preview_images = images
# Invalidate base64 cache
self.preview_base64_cache = []
def get_preview_metadata(self):
"""Lightweight check for preview updates"""
with self.preview_lock:
return {
"step": self.current_step,
"total_steps": self.total_steps,
"timestamp": self.last_preview_time,
"has_images": len(self.preview_images) > 0
}
def get_latest_previews(self):
"""Get the latest preview images and metadata. Converts to base64 lazily."""
with self.preview_lock:
try:
# Lazy conversion to base64 if not already cached
if self.preview_images and not self.preview_base64_cache:
new_previews = []
for img in self.preview_images:
if isinstance(img, str) and img.startswith("data:image"):
new_previews.append(img)
elif hasattr(img, "save"): # PIL Image
try:
import io
import base64
buffered = io.BytesIO()
fmt = getattr(self, "preview_format", "WEBP")
q = getattr(self, "preview_quality", 90)
try:
img.save(buffered, format=fmt, quality=q)
mime = f"image/{fmt.lower()}"
img_str = base64.b64encode(buffered.getvalue()).decode()
new_previews.append(f"data:{mime};base64,{img_str}")
except Exception:
# Fallback: lossless PNG if format not supported
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
new_previews.append(f"data:image/png;base64,{img_str}")
except Exception:
pass
self.preview_base64_cache = new_previews
return {
"paths": [], # Deprecated path-based previews
"base64": self.preview_base64_cache,
"step": self.current_step,
"total_steps": self.total_steps,
"timestamp": self.last_preview_time
}
except Exception as e:
print(f"Error loading preview images: {e}")
return {"paths": [], "base64": [], "step": 0, "total_steps": 0, "timestamp": 0}
def clear_preview_files(self):
"""Clear temporary preview data"""
with self.preview_lock:
self.preview_base64 = []
self.preview_files = []
def cleanup_all_previews(self):
"""Cleanup all preview files in the directory and clear memory"""
self.clear_preview_files()
try:
if os.path.exists(self.preview_dir):
for filename in os.listdir(self.preview_dir):
if filename.startswith("preview_") and filename.endswith((".png", ".webp")):
file_path = os.path.join(self.preview_dir, filename)
try:
if os.path.exists(file_path):
os.remove(file_path)
except Exception:
pass
except Exception as e:
print(f"Error cleaning up preview directory: {e}")
def cleanup(self):
"""Cleanup resources"""
self.clear_preview_files()
self.clear_interrupt()
@property
def interrupt_flag(self) -> bool:
"""Return True when an interrupt has been requested"""
return self._interrupt_event.is_set()
def request_interrupt(self):
"""Signal sampling loops to stop"""
self._interrupt_event.set()
def clear_interrupt(self):
"""Reset interrupt state after a run"""
self._interrupt_event.clear()
class PreviewerVar:
"""Variable to control preview functionality"""
def __init__(self):
self._enabled = True
def get(self) -> bool:
"""Get preview enabled state"""
return self._enabled
def set(self, value: bool):
"""Set preview enabled state"""
self._enabled = value
class ProgressTracker:
"""Simple progress tracker for sampling"""
def __init__(self):
self._progress = 0.0
def set(self, value: float):
"""Set progress value (0.0 to 1.0)"""
self._progress = max(0.0, min(1.0, value))
def get(self) -> float:
"""Get current progress value"""
return self._progress
# Global app instance
app = AppInstance()