File size: 7,478 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""
App instance for managing UI state and real-time previews
"""

import os
import threading
import time
import tempfile
from typing import List, Any
from PIL import Image


class AppInstance:
    """Main application instance for managing UI state and previews"""

    def __init__(self):
        self.previewer_var = PreviewerVar()
        requested_preview_dir = os.getenv("LD_PREVIEW_DIR") or os.path.join(".", "output", "preview")
        self.preview_dir = requested_preview_dir
        self.preview_lock = threading.Lock()
        self.preview_files = []
        self.preview_images = []  # Store PIL images directly
        self.preview_base64_cache = []  # Cached base64 strings
        self.last_preview_time = 0
        self.current_step = 0
        self.total_steps = 0
        self.progress = ProgressTracker()
        self._interrupt_event = threading.Event()

        # Prefer the configured preview directory, but fall back to a temp
        # location when the working tree is not writable (for example, during
        # constrained test runs or read-only deployments).
        try:
            os.makedirs(self.preview_dir, exist_ok=True)
        except OSError:
            self.preview_dir = os.path.join(tempfile.gettempdir(), "lightdiffusion-preview")
            os.makedirs(self.preview_dir, exist_ok=True)

        # Preview rendering/config options (tunable)
        self.preview_srgb = True  # apply sRGB curve to previews
        self.preview_format = "WEBP"  # 'WEBP' or 'JPEG' or 'PNG'
        self.preview_quality = 90  # quality for lossy formats (0-100)
        self.preview_resample = "LANCZOS"  # resampling preference name
        self.preview_apply_fast_autohdr = False  # lightweight autohdr for previews (disabled by default)

    def update_image(self, images: List[Any], step: int = 0, total_steps: int = 0):
        """Update the gallery with preview images in real-time.
        
        Args:
            images: List of PIL.Image or base64 strings
            step: Current step
            total_steps: Total steps
        """
        with self.preview_lock:
            # Update metadata
            self.current_step = step
            self.total_steps = total_steps
            timestamp = int(time.time() * 1000)
            self.last_preview_time = timestamp
            
            # Store images (or strings) directly to avoid conversion overhead in sampling loop
            self.preview_images = images
            # Invalidate base64 cache
            self.preview_base64_cache = []

    def get_preview_metadata(self):
        """Lightweight check for preview updates"""
        with self.preview_lock:
            return {
                "step": self.current_step,
                "total_steps": self.total_steps,
                "timestamp": self.last_preview_time,
                "has_images": len(self.preview_images) > 0
            }

    def get_latest_previews(self):
        """Get the latest preview images and metadata. Converts to base64 lazily."""
        with self.preview_lock:
            try:
                # Lazy conversion to base64 if not already cached
                if self.preview_images and not self.preview_base64_cache:
                    new_previews = []
                    for img in self.preview_images:
                        if isinstance(img, str) and img.startswith("data:image"):
                            new_previews.append(img)
                        elif hasattr(img, "save"): # PIL Image
                            try:
                                import io
                                import base64
                                buffered = io.BytesIO()
                                fmt = getattr(self, "preview_format", "WEBP")
                                q = getattr(self, "preview_quality", 90)
                                try:
                                    img.save(buffered, format=fmt, quality=q)
                                    mime = f"image/{fmt.lower()}"
                                    img_str = base64.b64encode(buffered.getvalue()).decode()
                                    new_previews.append(f"data:{mime};base64,{img_str}")
                                except Exception:
                                    # Fallback: lossless PNG if format not supported
                                    buffered = io.BytesIO()
                                    img.save(buffered, format="PNG")
                                    img_str = base64.b64encode(buffered.getvalue()).decode()
                                    new_previews.append(f"data:image/png;base64,{img_str}")
                            except Exception:
                                pass
                    self.preview_base64_cache = new_previews

                return {
                    "paths": [], # Deprecated path-based previews
                    "base64": self.preview_base64_cache,
                    "step": self.current_step,
                    "total_steps": self.total_steps,
                    "timestamp": self.last_preview_time
                }
            except Exception as e:
                print(f"Error loading preview images: {e}")
                return {"paths": [], "base64": [], "step": 0, "total_steps": 0, "timestamp": 0}

    def clear_preview_files(self):
        """Clear temporary preview data"""
        with self.preview_lock:
            self.preview_base64 = []
            self.preview_files = []

    def cleanup_all_previews(self):
        """Cleanup all preview files in the directory and clear memory"""
        self.clear_preview_files()
        try:
            if os.path.exists(self.preview_dir):
                for filename in os.listdir(self.preview_dir):
                    if filename.startswith("preview_") and filename.endswith((".png", ".webp")):
                        file_path = os.path.join(self.preview_dir, filename)
                        try:
                            if os.path.exists(file_path):
                                os.remove(file_path)
                        except Exception:
                            pass
        except Exception as e:
            print(f"Error cleaning up preview directory: {e}")

    def cleanup(self):
        """Cleanup resources"""
        self.clear_preview_files()
        self.clear_interrupt()

    @property
    def interrupt_flag(self) -> bool:
        """Return True when an interrupt has been requested"""
        return self._interrupt_event.is_set()

    def request_interrupt(self):
        """Signal sampling loops to stop"""
        self._interrupt_event.set()

    def clear_interrupt(self):
        """Reset interrupt state after a run"""
        self._interrupt_event.clear()


class PreviewerVar:
    """Variable to control preview functionality"""

    def __init__(self):
        self._enabled = True

    def get(self) -> bool:
        """Get preview enabled state"""
        return self._enabled

    def set(self, value: bool):
        """Set preview enabled state"""
        self._enabled = value


class ProgressTracker:
    """Simple progress tracker for sampling"""

    def __init__(self):
        self._progress = 0.0

    def set(self, value: float):
        """Set progress value (0.0 to 1.0)"""
        self._progress = max(0.0, min(1.0, value))

    def get(self) -> float:
        """Get current progress value"""
        return self._progress


# Global app instance
app = AppInstance()