File size: 10,035 Bytes
3bb804c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""

PC Scanner Node - Automatically scans through all principal components

Creates a contact sheet showing what each PC controls

"""

import numpy as np
from PyQt6 import QtGui
import cv2

import __main__
BaseNode = __main__.BaseNode
QtGui = __main__.QtGui

class PCScannerNode(BaseNode):
    """

    Systematically scans through all PCs to visualize their effects.

    Creates a grid showing: [PC0-, PC0+, PC1-, PC1+, ...]

    """
    NODE_CATEGORY = "AI / Physics"
    NODE_COLOR = QtGui.QColor(220, 180, 100)
    
    def __init__(self, scan_amplitude=2.0, grid_cols=4):
        super().__init__()
        self.node_title = "PC Scanner"
        
        self.inputs = {
            'latent_in': 'spectrum',
            'reconstructed_image': 'image',  # Feedback from iFFT
            'scan_speed': 'signal',  # How fast to scan
            'trigger': 'signal',  # Start scan
            'amplitude': 'signal'  # How much to modify each PC
        }
        self.outputs = {
            'latent_out': 'spectrum',  # Modified latent for current scan
            'contact_sheet': 'image',  # The full grid
            'current_pc': 'signal',  # Which PC we're scanning
            'progress': 'signal'  # 0-1 scan progress
        }
        
        self.scan_amplitude = float(scan_amplitude)
        self.grid_cols = int(grid_cols)
        
        # Scanning state
        self.is_scanning = False
        self.scan_index = 0  # Which PC we're currently scanning
        self.scan_direction = 1  # 1 for +, -1 for -
        self.frame_counter = 0
        self.frames_per_scan = 30  # How many frames to wait per PC
        
        # Storage
        self.base_latent = None
        self.current_latent = None
        self.captured_images = {}  # {(pc_idx, direction): image}
        self.contact_sheet = None
        
        # Dimensions
        self.cell_size = 64
        
    def step(self):
        # Get inputs
        latent_in = self.get_blended_input('latent_in', 'first')
        reconstructed = self.get_blended_input('reconstructed_image', 'mean')
        trigger = self.get_blended_input('trigger', 'sum') or 0.0
        scan_speed = self.get_blended_input('scan_speed', 'sum')
        amplitude_signal = self.get_blended_input('amplitude', 'sum')
        
        if amplitude_signal is not None:
            amplitude = amplitude_signal * 5.0
        else:
            amplitude = self.scan_amplitude
            
        if scan_speed is not None:
            self.frames_per_scan = max(5, int(30 / (scan_speed + 0.1)))
        
        # Store base latent
        if latent_in is not None and self.base_latent is None:
            self.base_latent = latent_in.copy()
            self.current_latent = latent_in.copy()
            
        # Trigger scan
        if trigger > 0.5 and not self.is_scanning:
            self.start_scan()
            
        # Scanning logic
        if self.is_scanning and self.base_latent is not None:
            self.frame_counter += 1
            
            # Capture reconstructed image
            if reconstructed is not None and self.frame_counter > 5:  # Wait a few frames for stabilization
                key = (self.scan_index, self.scan_direction)
                if key not in self.captured_images:
                    # Resize and store
                    img_resized = cv2.resize(reconstructed, (self.cell_size, self.cell_size))
                    self.captured_images[key] = img_resized.copy()
                    
            # Time to move to next scan?
            if self.frame_counter >= self.frames_per_scan:
                self.advance_scan()
                
            # Generate current modified latent
            self.current_latent = self.base_latent.copy()
            if self.scan_index < len(self.base_latent):
                self.current_latent[self.scan_index] += amplitude * self.scan_direction
                
        # Build contact sheet
        if len(self.captured_images) > 0:
            self.build_contact_sheet()
            
    def start_scan(self):
        """Start a new scan"""
        self.is_scanning = True
        self.scan_index = 0
        self.scan_direction = -1  # Start with negative
        self.frame_counter = 0
        self.captured_images = {}
        print("PC Scanner: Starting scan...")
        
    def advance_scan(self):
        """Move to next PC/direction"""
        self.frame_counter = 0
        
        if self.scan_direction == -1:
            # Switch to positive
            self.scan_direction = 1
        else:
            # Move to next PC
            self.scan_direction = -1
            self.scan_index += 1
            
            # Check if scan complete
            if self.scan_index >= len(self.base_latent):
                self.is_scanning = False
                self.current_latent = self.base_latent.copy()
                print(f"PC Scanner: Scan complete! Captured {len(self.captured_images)} images.")
                
    def build_contact_sheet(self):
        """Build the grid visualization"""
        num_pcs = len(self.base_latent) if self.base_latent is not None else 8
        
        # Calculate grid dimensions
        # Each PC gets 2 cells (- and +)
        total_cells = num_pcs * 2
        rows = (total_cells + self.grid_cols - 1) // self.grid_cols
        
        # Create canvas
        canvas_width = self.grid_cols * self.cell_size
        canvas_height = rows * self.cell_size
        canvas = np.zeros((canvas_height, canvas_width, 3), dtype=np.float32)
        
        # Fill grid
        cell_idx = 0
        for pc_idx in range(num_pcs):
            for direction in [-1, 1]:
                key = (pc_idx, direction)
                
                row = cell_idx // self.grid_cols
                col = cell_idx % self.grid_cols
                
                y_start = row * self.cell_size
                x_start = col * self.cell_size
                
                if key in self.captured_images:
                    img = self.captured_images[key]
                    
                    # Ensure correct shape
                    if img.ndim == 2:
                        img = np.stack([img, img, img], axis=-1)
                    elif img.shape[2] == 1:
                        img = np.repeat(img, 3, axis=2)
                        
                    canvas[y_start:y_start+self.cell_size, 
                           x_start:x_start+self.cell_size] = img
                else:
                    # Empty cell - draw placeholder
                    cv2.rectangle(canvas, (x_start, y_start), 
                                (x_start+self.cell_size-1, y_start+self.cell_size-1),
                                (0.2, 0.2, 0.2), 1)
                
                # Label
                label = f"PC{pc_idx}" + ("-" if direction == -1 else "+")
                cv2.putText(canvas, label, (x_start+2, y_start+12),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.3, (1, 1, 1), 1)
                
                # Highlight current scan position
                if self.is_scanning and pc_idx == self.scan_index and direction == self.scan_direction:
                    cv2.rectangle(canvas, (x_start, y_start),
                                (x_start+self.cell_size-1, y_start+self.cell_size-1),
                                (0, 1, 0), 2)
                
                cell_idx += 1
                
        self.contact_sheet = canvas
        
    def get_output(self, port_name):
        if port_name == 'latent_out':
            return self.current_latent
        elif port_name == 'contact_sheet':
            return self.contact_sheet
        elif port_name == 'current_pc':
            return float(self.scan_index) if self.is_scanning else -1.0
        elif port_name == 'progress':
            if self.base_latent is not None and self.is_scanning:
                total = len(self.base_latent) * 2
                current = self.scan_index * 2 + (0 if self.scan_direction == -1 else 1)
                return current / total
            return 0.0
        return None
        
    def get_display_image(self):
        if self.contact_sheet is not None:
            # Display the contact sheet
            img = (np.clip(self.contact_sheet, 0, 1) * 255).astype(np.uint8)
            h, w = img.shape[:2]
            return QtGui.QImage(img.data, w, h, w*3, QtGui.QImage.Format.Format_RGB888)
        else:
            # Show status
            img = np.zeros((256, 256, 3), dtype=np.uint8)
            if self.is_scanning:
                status = f"Scanning PC{self.scan_index}{'-' if self.scan_direction == -1 else '+'}"
                progress = int(self.get_output('progress') * 100)
                cv2.putText(img, status, (10, 128), cv2.FONT_HERSHEY_SIMPLEX, 
                           0.6, (0,255,0), 2)
                cv2.putText(img, f"{progress}%", (10, 160), cv2.FONT_HERSHEY_SIMPLEX,
                           0.5, (255,255,255), 1)
                
                # Progress bar
                bar_width = int(256 * self.get_output('progress'))
                cv2.rectangle(img, (0, 240), (bar_width, 256), (0,255,0), -1)
            else:
                cv2.putText(img, "Ready to scan", (10, 128), cv2.FONT_HERSHEY_SIMPLEX,
                           0.6, (255,255,255), 1)
                cv2.putText(img, "Send trigger signal", (10, 160), cv2.FONT_HERSHEY_SIMPLEX,
                           0.4, (200,200,200), 1)
                
            return QtGui.QImage(img.data, 256, 256, 256*3, QtGui.QImage.Format.Format_RGB888)
            
    def get_config_options(self):
        return [
            ("Scan Amplitude", "scan_amplitude", self.scan_amplitude, None),
            ("Grid Columns", "grid_cols", self.grid_cols, None)
        ]