File size: 14,351 Bytes
89a012c
 
4cdc77a
89a012c
 
4cdc77a
cfbaa51
89a012c
4cdc77a
89a012c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cdc77a
89a012c
 
 
 
 
 
 
4cdc77a
89a012c
 
 
 
4cdc77a
89a012c
4cdc77a
89a012c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cdc77a
89a012c
 
 
4cdc77a
 
89a012c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cdc77a
89a012c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cdc77a
89a012c
 
 
 
 
4cdc77a
89a012c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
import os

from CPR_Module.Common.logging_config import cpr_logger

class GraphPlotter:
    """Class to handle plotting of CPR motion curves and error regions with detailed metrics."""
   
    def __init__(self):
        self.chunks_y_preprocessed = []
        self.chunks_peaks = []
        self.chunks_depth = []
        self.chunks_rate = []
        self.chunks_start_and_end_indices = []
        self.posture_warnings_regions = []
        self.sampling_interval_in_frames = 0
        self.fps = None 
        self._chunks_json_data = []  # Store chunk data for JSON output 

        self.error_symbols = {
            "Right arm bent!": ('o', '#A61D1D'), # circle
            "Left arm bent!": ('s', '#A61D1D'), # square
            "Left hand not on chest!": ('P', '#A61D1D'), # plus
            "Right hand not on chest!": ('*', '#A61D1D'), # star
            "Both hands not on chest!": ('D', '#A61D1D') # diamond
        }

        self.annotation_y_level = None  # Will store our target y-position

    def _assign_graph_data(self, chunks_y_preprocessed, chunks_peaks, chunks_depth, chunks_rate, chunks_start_and_end_indices, posture_warnings_regions, sampling_interval_in_frames, fps):
        """Assign data members for the class"""

        self.chunks_y_preprocessed = chunks_y_preprocessed
        self.chunks_peaks = chunks_peaks
        self.chunks_depth = chunks_depth
        self.chunks_rate = chunks_rate
        self.chunks_start_and_end_indices = chunks_start_and_end_indices
        self.posture_warnings_regions = posture_warnings_regions
        self.sampling_interval_in_frames = sampling_interval_in_frames
        self.fps = fps 

        cpr_logger.info(f"[Graph Plotter] Data members assigned with {len(self.chunks_start_and_end_indices)} chunks and {len(self.posture_warnings_regions)} error regions for a sampling interval of {self.sampling_interval_in_frames} frames and FPS {self.fps}")

    def _plot_single_chunk(self, ax, chunk, idx, prev_last_point, prev_chunk_end):

        (start_frame, end_frame), depth, rate = chunk

        # Convert frames to time
        chunk_frames = np.arange(start_frame, end_frame + 1, self.sampling_interval_in_frames)
        chunk_times = chunk_frames / self.fps  # Convert to seconds
        y_preprocessed = self.chunks_y_preprocessed[idx]
        peaks = self.chunks_peaks[idx]
        
        # Check if chunks are contiguous and need connection (frame-based logic)
        if (prev_chunk_end is not None and 
            start_frame == prev_chunk_end + self.sampling_interval_in_frames and
            prev_last_point is not None):
            
            # Convert connection points to seconds
            connect_start = prev_chunk_end / self.fps
            connect_end = start_frame / self.fps
            connect_times = [connect_start, connect_end]
            
            cpr_logger.info(f"[Graph Plotter] Connecting chunk {idx+1} to previous chunk (time {connect_start:.2f}-{connect_end:.2f}s)")
            ax.plot(connect_times, [prev_last_point['y_preprocessed'], y_preprocessed[0]], 
                color="#2F5597", linewidth=2.5)
        
        # Plot current chunk data
        cpr_logger.info(f"[Graph Plotter] Plotting chunk {idx+1} (time {chunk_times[0]:.2f}-{chunk_times[-1]:.2f}s)")
        smooth_label = "Motion" if idx == 0 else ""
        peaks_label = "Peaks" if idx == 0 else ""

        # Updated motion plot
        ax.plot(chunk_times, y_preprocessed,
                color="#2F5597", linewidth=2.5,
                marker='o', markersize=4,
                markerfacecolor='#2F5597', markeredgecolor='#2F5597',
                label=smooth_label)
        
        # Updated peaks
        if peaks.size > 0:
            ax.plot(chunk_times[peaks], y_preprocessed[peaks],
                    "x", color="#ED7D31", markersize=8,
                    label=peaks_label)

       # Annotate chunk metrics (time-based)
        if (depth is not None and rate is not None) and (depth > 0 and rate > 0):
            mid_time = (start_frame + end_frame) / (2 * self.fps)
            cpr_logger.info(f"[Graph Plotter] Chunk {idx+1} metrics: {depth:.1f}cm depth, {rate:.1f}cpm rate")
            
            # Calculate or use stored annotation y-level
            if self.annotation_y_level is None:
                # For first chunk, calculate midpoint between min and max of y_preprocessed
                y_range = np.max(y_preprocessed) - np.min(y_preprocessed)
                self.annotation_y_level = np.min(y_preprocessed) + y_range * 0.5  # 70% up from bottom
                cpr_logger.info(f"[Graph Plotter] Setting annotation y-level to {self.annotation_y_level:.2f}")
            
            # Updated annotation box using consistent y-level
            ax.annotate(f"Depth: {depth:.1f}cm\nRate: {rate:.1f}cpm",
                    xy=(mid_time, self.annotation_y_level),
                    xytext=(0, 10), textcoords='offset points',
                    ha='center', va='bottom', fontsize=9,
                    bbox=dict(boxstyle='round,pad=0.5', 
                            fc='#F2F2F2', ec='#595959', alpha=0.8))
        
        return {'y_preprocessed': y_preprocessed[-1]}, end_frame

    def _plot_error_regions(self, ax, computed_error_regions):
        """Visualize error regions with adaptive symbol sizing"""

        cpr_logger.info("[Graph Plotter] Rendering error regions:")

        # Size parameters
        target_width_ratio = 0.7 
        legend_size = 80 
        
        legend_handles = []
        y_mid = np.mean(ax.get_ylim())
        
        # Get figure dimensions for size conversion
        fig = ax.figure
        fig_width_points = fig.get_figwidth() * fig.dpi
        x_min, x_max = ax.get_xlim()
        data_range = x_max - x_min
        points_per_second = fig_width_points / data_range

        for idx, (start_sec, end_sec) in enumerate(computed_error_regions):
            region_width = end_sec - start_sec
            region_data = self.posture_warnings_regions[idx]
            warnings = region_data.get('posture_warnings', [])

            # Calculate max allowed width in data units (seconds)
            max_data_width = region_width * target_width_ratio
            
            # Convert legend size to data units
            legend_data_width = (np.sqrt(legend_size) / points_per_second)
            
            # Determine final symbol width (data units)
            symbol_data_width = min(legend_data_width, max_data_width)
            
            # Convert back to points² for matplotlib
            symbol_point_width = symbol_data_width * points_per_second
            symbol_size = symbol_point_width ** 2

            for error in warnings:
                if error in self.error_symbols:
                    marker, color = self.error_symbols[error]
                    
                    ax.scatter(
                        x=(start_sec + end_sec)/2,
                        y=y_mid,
                        s=symbol_size,
                        marker=marker,
                        color=color,
                        alpha=0.7,
                        edgecolors='black',
                        linewidths=0.5,
                        zorder=5
                    )

                    # Create legend entry once
                    if not any(error == h.get_label() for h in legend_handles):
                        legend_handles.append(
                            ax.scatter([], [], 
                                    s=legend_size,
                                    marker=marker,
                                    color=color,
                                    edgecolors='black',
                                    linewidths=0.5,
                                    alpha=0.7,
                                    label=error)
                        )

            # Updated error region fill
            ax.axvspan(start_sec, end_sec, 
                    color='#FCE4D6', alpha=0.3, zorder=1)

        if not ax.get_xlabel():
            ax.set_xlabel("Time (seconds)", fontsize=10)
        if not ax.get_ylabel():
            ax.set_ylabel("Signal Value", fontsize=10)

        return legend_handles

    def plot_motion_curve_for_all_chunks(self, chunks_y_preprocessed, chunks_peaks, chunks_depth, chunks_rate, chunks_start_and_end_indices, posture_warnings_regions, sampling_interval_in_frames, fps, plot_output_path):
        """Plot combined analysis with connected chunks and proper error regions"""
        
        self._assign_graph_data(chunks_y_preprocessed, chunks_peaks, chunks_depth, chunks_rate, chunks_start_and_end_indices, posture_warnings_regions, sampling_interval_in_frames, fps)
        cpr_logger.info("[Graph Plotter] Starting to plot motion curve for all chunks")
        
        # Create figure even if there's only error regions to plot
        plt.figure(figsize=(16, 8))
        ax = plt.gca()
        ax.xaxis.set_major_locator(MultipleLocator(5))

        # Plot CPR chunks if they exist
        if self.chunks_start_and_end_indices:
            sorted_chunks = sorted(zip(self.chunks_start_and_end_indices, 
                                self.chunks_depth, 
                                self.chunks_rate),
                                key=lambda x: x[0][0])
            cpr_logger.info(f"[Graph Plotter] Processing {len(sorted_chunks)} CPR chunks")

            prev_last_point = None
            prev_chunk_end = None

            for idx, chunk in enumerate(sorted_chunks):
                cpr_logger.info(f"[Graph Plotter] Rendering chunk {idx+1}/{len(sorted_chunks)}")
                prev_last_point, prev_chunk_end = self._plot_single_chunk(ax, chunk, idx, prev_last_point, prev_chunk_end)
            
            self._print_analysis_details(sorted_chunks)
        else:
            cpr_logger.info("[Graph Plotter] No chunk data available for plotting")
            # Set reasonable default axis if only plotting errors
            ax.set_ylim(0, 100)  # Example default Y-axis range for position

        # Always plot error regions if they exist
        computed_error_regions = [(er['start_frame']/self.fps, er['end_frame']/self.fps) 
                                for er in self.posture_warnings_regions]
        
        # In the "Configure remaining elements" section (replace existing legend code):
        handles, labels = ax.get_legend_handles_labels()

        # Collect error handles from _plot_error_regions (modified to return them)
        error_legend_handles = self._plot_error_regions(ax, computed_error_regions)

        # Merge both sets of handles/labels
        if error_legend_handles:
            handles += error_legend_handles
            labels += [h.get_label() for h in error_legend_handles]

        # Remove duplicate labels
        unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]]

        # Create single horizontal legend at bottom
        if unique:
           ax.legend(
            *zip(*unique),
            loc='upper center',
            bbox_to_anchor=(0.5, -0.08),
            ncol=len(unique),
            fontsize=8,
            handletextpad=0.3,
            columnspacing=1.5,
            framealpha=0.9,
            borderpad=0.7
        )
        plt.tight_layout(rect=[0, 0.025, 1, 1])

        plt.xlabel("Time (seconds)")
        plt.ylabel("Vertical Position (px)")
        plt.title("Complete CPR Analysis with Metrics", pad=20)  # Added pad parameter
        
        plt.grid(True)
        cpr_logger.info(f"\n[Graph Plotter] Finalizing plot layout")
        
        # Adjust tight_layout with additional padding
        plt.tight_layout(rect=[0, 0.025, 1, 0.95])  # Reduced top from 1 to 0.95 to make space

        if plot_output_path:
            # Ensure directory exists
            os.makedirs(os.path.dirname(plot_output_path), exist_ok=True)
            plt.savefig(plot_output_path, dpi=300, bbox_inches='tight')
            cpr_logger.info(f"[Graph Plotter] Plot saved to {plot_output_path}")
        
        #plt.show()
        cpr_logger.info("[Graph Plotter] Plot display complete")
       
    def _print_analysis_details(self, sorted_chunks):
        """Combined helper for printing chunks and error regions in seconds"""

        cpr_logger.info(f"\n\n=== CPR Chunk Analysis ===")
        display_idx = 0  # Separate counter for displayed indices
        
        # Convert frame numbers to seconds using video FPS
        fps = self.fps  # Get FPS from class instance
        
        for ((start_frame, end_frame), depth, rate) in sorted_chunks:
            # Skip chunks with both values at 0
            if depth == 0 and rate == 0:
                continue
                
            # Convert frames to seconds
            start_sec = start_frame / fps
            end_sec = end_frame / fps
            duration_sec = (end_frame - start_frame + 1) / fps  # +1 to include both endpoints
            
            cpr_logger.info(f"[Graph Plotter] Chunk {display_idx+1}: "
                f"Time {start_sec:.2f}s - {end_sec:.2f}s ({duration_sec:.2f}s), "
                f"Depth: {depth:.1f}cm, Rate: {rate:.1f}cpm")
            
            #& Formatted json to mobile
            chunk_data = {
            "start": round(start_sec, 2),
            "end": round(end_sec, 2),
            "depth": round(depth, 1),
            "rate": round(rate, 1)
            }
            self._chunks_json_data.append(chunk_data)
            
            display_idx += 1

        cpr_logger.info(f"\n\n=== Error Region Analysis ===")
        
        for i, region in enumerate(self.posture_warnings_regions):  # Updated to match actual attribute name
            start_frame = region['start_frame']
            end_frame = region['end_frame']
            errors = region['posture_warnings']
            
            # Convert to seconds
            start_sec = start_frame / fps
            end_sec = end_frame / fps
            error_str = ", ".join(errors) if errors else "No errors detected"
            
            cpr_logger.info(f"[Graph Plotter] Region {i+1}: "
                f"Time {start_sec:.2f}s - {end_sec:.2f}s - {error_str}")
        
        cpr_logger.info(f"\n\n")