|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
from matplotlib.ticker import MultipleLocator |
|
|
import os |
|
|
|
|
|
from CPR_Module.Common.logging_config import cpr_logger |
|
|
|
|
|
class GraphPlotter: |
|
|
"""Class to plot graphs for various metrics""" |
|
|
|
|
|
def __init__(self): |
|
|
self.chunks_y_preprocessed = [] |
|
|
self.chunks_peaks = [] |
|
|
self.chunks_depth = [] |
|
|
self.chunks_rate = [] |
|
|
self.chunks_start_and_end_indices = [] |
|
|
self.posture_warnings_regions = [] |
|
|
self.sampling_interval_in_frames = 0 |
|
|
self.fps = None |
|
|
|
|
|
self.error_symbols = { |
|
|
"Right arm bent!": ('o', '#A61D1D'), |
|
|
"Left arm bent!": ('s', '#A61D1D'), |
|
|
"Left hand not on chest!": ('P', '#A61D1D'), |
|
|
"Right hand not on chest!": ('*', '#A61D1D'), |
|
|
"Both hands not on chest!": ('D', '#A61D1D') |
|
|
} |
|
|
|
|
|
self.annotation_y_level = None |
|
|
|
|
|
def _assign_graph_data(self, chunks_y_preprocessed, chunks_peaks, chunks_depth, chunks_rate, chunks_start_and_end_indices, posture_warnings_regions, sampling_interval_in_frames, fps): |
|
|
"""Assign data members for the class""" |
|
|
self.chunks_y_preprocessed = chunks_y_preprocessed |
|
|
self.chunks_peaks = chunks_peaks |
|
|
self.chunks_depth = chunks_depth |
|
|
self.chunks_rate = chunks_rate |
|
|
self.chunks_start_and_end_indices = chunks_start_and_end_indices |
|
|
self.posture_warnings_regions = posture_warnings_regions |
|
|
self.sampling_interval_in_frames = sampling_interval_in_frames |
|
|
self.fps = fps |
|
|
|
|
|
cpr_logger.info(f"[Graph Plotter] Data members assigned with {len(self.chunks_start_and_end_indices)} chunks and {len(self.posture_warnings_regions)} error regions for a sampling interval of {self.sampling_interval_in_frames} frames and FPS {self.fps}") |
|
|
|
|
|
def _plot_single_chunk(self, ax, chunk, idx, prev_last_point, prev_chunk_end): |
|
|
(start_frame, end_frame), depth, rate = chunk |
|
|
|
|
|
chunk_frames = np.arange(start_frame, end_frame + 1, self.sampling_interval_in_frames) |
|
|
chunk_times = chunk_frames / self.fps |
|
|
y_preprocessed = self.chunks_y_preprocessed[idx] |
|
|
peaks = self.chunks_peaks[idx] |
|
|
|
|
|
|
|
|
if (prev_chunk_end is not None and |
|
|
start_frame == prev_chunk_end + self.sampling_interval_in_frames and |
|
|
prev_last_point is not None): |
|
|
|
|
|
|
|
|
connect_start = prev_chunk_end / self.fps |
|
|
connect_end = start_frame / self.fps |
|
|
connect_times = [connect_start, connect_end] |
|
|
|
|
|
cpr_logger.info(f"[Graph Plotter] Connecting chunk {idx+1} to previous chunk (time {connect_start:.2f}-{connect_end:.2f}s)") |
|
|
ax.plot(connect_times, [prev_last_point['y_preprocessed'], y_preprocessed[0]], |
|
|
color="#2F5597", linewidth=2.5) |
|
|
|
|
|
|
|
|
cpr_logger.info(f"[Graph Plotter] Plotting chunk {idx+1} (time {chunk_times[0]:.2f}-{chunk_times[-1]:.2f}s)") |
|
|
smooth_label = "Motion" if idx == 0 else "" |
|
|
peaks_label = "Peaks" if idx == 0 else "" |
|
|
|
|
|
|
|
|
ax.plot(chunk_times, y_preprocessed, |
|
|
color="#2F5597", linewidth=2.5, |
|
|
marker='o', markersize=4, |
|
|
markerfacecolor='#2F5597', markeredgecolor='#2F5597', |
|
|
label=smooth_label) |
|
|
|
|
|
|
|
|
if peaks.size > 0: |
|
|
ax.plot(chunk_times[peaks], y_preprocessed[peaks], |
|
|
"x", color="#ED7D31", markersize=8, |
|
|
label=peaks_label) |
|
|
|
|
|
|
|
|
if (depth is not None and rate is not None) and (depth > 0 and rate > 0): |
|
|
mid_time = (start_frame + end_frame) / (2 * self.fps) |
|
|
cpr_logger.info(f"[Graph Plotter] Chunk {idx+1} metrics: {depth:.1f}cm depth, {rate:.1f}cpm rate") |
|
|
|
|
|
|
|
|
if self.annotation_y_level is None: |
|
|
|
|
|
y_range = np.max(y_preprocessed) - np.min(y_preprocessed) |
|
|
self.annotation_y_level = np.min(y_preprocessed) + y_range * 0.5 |
|
|
cpr_logger.info(f"[Graph Plotter] Setting annotation y-level to {self.annotation_y_level:.2f}") |
|
|
|
|
|
|
|
|
ax.annotate(f"Depth: {depth:.1f}cm\nRate: {rate:.1f}cpm", |
|
|
xy=(mid_time, self.annotation_y_level), |
|
|
xytext=(0, 10), textcoords='offset points', |
|
|
ha='center', va='bottom', fontsize=9, |
|
|
bbox=dict(boxstyle='round,pad=0.5', |
|
|
fc='#F2F2F2', ec='#595959', alpha=0.8)) |
|
|
|
|
|
return {'y_preprocessed': y_preprocessed[-1]}, end_frame |
|
|
|
|
|
def _plot_error_regions(self, ax, computed_error_regions): |
|
|
"""Visualize error regions with adaptive symbol sizing""" |
|
|
cpr_logger.info("[Graph Plotter] Rendering error regions:") |
|
|
|
|
|
|
|
|
target_width_ratio = 0.7 |
|
|
legend_size = 80 |
|
|
|
|
|
legend_handles = [] |
|
|
y_mid = np.mean(ax.get_ylim()) |
|
|
|
|
|
|
|
|
fig = ax.figure |
|
|
fig_width_points = fig.get_figwidth() * fig.dpi |
|
|
x_min, x_max = ax.get_xlim() |
|
|
data_range = x_max - x_min |
|
|
points_per_second = fig_width_points / data_range |
|
|
|
|
|
for idx, (start_sec, end_sec) in enumerate(computed_error_regions): |
|
|
region_width = end_sec - start_sec |
|
|
region_data = self.posture_warnings_regions[idx] |
|
|
warnings = region_data.get('posture_warnings', []) |
|
|
|
|
|
|
|
|
max_data_width = region_width * target_width_ratio |
|
|
|
|
|
|
|
|
legend_data_width = (np.sqrt(legend_size) / points_per_second) |
|
|
|
|
|
|
|
|
symbol_data_width = min(legend_data_width, max_data_width) |
|
|
|
|
|
|
|
|
symbol_point_width = symbol_data_width * points_per_second |
|
|
symbol_size = symbol_point_width ** 2 |
|
|
|
|
|
for error in warnings: |
|
|
if error in self.error_symbols: |
|
|
marker, color = self.error_symbols[error] |
|
|
|
|
|
ax.scatter( |
|
|
x=(start_sec + end_sec)/2, |
|
|
y=y_mid, |
|
|
s=symbol_size, |
|
|
marker=marker, |
|
|
color=color, |
|
|
alpha=0.7, |
|
|
edgecolors='black', |
|
|
linewidths=0.5, |
|
|
zorder=5 |
|
|
) |
|
|
|
|
|
|
|
|
if not any(error == h.get_label() for h in legend_handles): |
|
|
legend_handles.append( |
|
|
ax.scatter([], [], |
|
|
s=legend_size, |
|
|
marker=marker, |
|
|
color=color, |
|
|
edgecolors='black', |
|
|
linewidths=0.5, |
|
|
alpha=0.7, |
|
|
label=error) |
|
|
) |
|
|
|
|
|
|
|
|
ax.axvspan(start_sec, end_sec, |
|
|
color='#FCE4D6', alpha=0.3, zorder=1) |
|
|
|
|
|
if not ax.get_xlabel(): |
|
|
ax.set_xlabel("Time (seconds)", fontsize=10) |
|
|
if not ax.get_ylabel(): |
|
|
ax.set_ylabel("Signal Value", fontsize=10) |
|
|
|
|
|
return legend_handles |
|
|
|
|
|
def plot_motion_curve_for_all_chunks(self, chunks_y_preprocessed, chunks_peaks, chunks_depth, chunks_rate, chunks_start_and_end_indices, posture_warnings_regions, sampling_interval_in_frames, fps, plot_output_path): |
|
|
"""Plot combined analysis with connected chunks and proper error regions""" |
|
|
|
|
|
self._assign_graph_data(chunks_y_preprocessed, chunks_peaks, chunks_depth, chunks_rate, chunks_start_and_end_indices, posture_warnings_regions, sampling_interval_in_frames, fps) |
|
|
cpr_logger.info("[Graph Plotter] Starting to plot motion curve for all chunks") |
|
|
|
|
|
|
|
|
plt.figure(figsize=(16, 8)) |
|
|
ax = plt.gca() |
|
|
ax.xaxis.set_major_locator(MultipleLocator(5)) |
|
|
|
|
|
|
|
|
if self.chunks_start_and_end_indices: |
|
|
sorted_chunks = sorted(zip(self.chunks_start_and_end_indices, |
|
|
self.chunks_depth, |
|
|
self.chunks_rate), |
|
|
key=lambda x: x[0][0]) |
|
|
cpr_logger.info(f"[Graph Plotter] Processing {len(sorted_chunks)} CPR chunks") |
|
|
|
|
|
prev_last_point = None |
|
|
prev_chunk_end = None |
|
|
|
|
|
for idx, chunk in enumerate(sorted_chunks): |
|
|
cpr_logger.info(f"[Graph Plotter] Rendering chunk {idx+1}/{len(sorted_chunks)}") |
|
|
prev_last_point, prev_chunk_end = self._plot_single_chunk(ax, chunk, idx, prev_last_point, prev_chunk_end) |
|
|
|
|
|
self._print_analysis_details(sorted_chunks) |
|
|
else: |
|
|
cpr_logger.info("[Graph Plotter] No chunk data available for plotting") |
|
|
|
|
|
ax.set_ylim(0, 100) |
|
|
|
|
|
|
|
|
computed_error_regions = [(er['start_frame']/self.fps, er['end_frame']/self.fps) |
|
|
for er in self.posture_warnings_regions] |
|
|
|
|
|
|
|
|
handles, labels = ax.get_legend_handles_labels() |
|
|
|
|
|
|
|
|
error_legend_handles = self._plot_error_regions(ax, computed_error_regions) |
|
|
|
|
|
|
|
|
if error_legend_handles: |
|
|
handles += error_legend_handles |
|
|
labels += [h.get_label() for h in error_legend_handles] |
|
|
|
|
|
|
|
|
unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]] |
|
|
|
|
|
|
|
|
if unique: |
|
|
ax.legend( |
|
|
*zip(*unique), |
|
|
loc='upper center', |
|
|
bbox_to_anchor=(0.5, -0.08), |
|
|
ncol=len(unique), |
|
|
fontsize=8, |
|
|
handletextpad=0.3, |
|
|
columnspacing=1.5, |
|
|
framealpha=0.9, |
|
|
borderpad=0.7 |
|
|
) |
|
|
plt.tight_layout(rect=[0, 0.025, 1, 1]) |
|
|
|
|
|
plt.xlabel("Time (seconds)") |
|
|
plt.ylabel("Vertical Position (px)") |
|
|
plt.title("Complete CPR Analysis with Metrics", pad=20) |
|
|
|
|
|
plt.grid(True) |
|
|
cpr_logger.info(f"\n[Graph Plotter] Finalizing plot layout") |
|
|
|
|
|
|
|
|
plt.tight_layout(rect=[0, 0.025, 1, 0.95]) |
|
|
|
|
|
if plot_output_path: |
|
|
|
|
|
os.makedirs(os.path.dirname(plot_output_path), exist_ok=True) |
|
|
plt.savefig(plot_output_path, dpi=300, bbox_inches='tight') |
|
|
cpr_logger.info(f"[Graph Plotter] Plot saved to {plot_output_path}") |
|
|
|
|
|
plt.show() |
|
|
cpr_logger.info("[Graph Plotter] Plot display complete") |
|
|
|
|
|
def _print_analysis_details(self, sorted_chunks): |
|
|
"""Combined helper for printing chunks and error regions in seconds""" |
|
|
cpr_logger.info(f"\n\n=== CPR Chunk Analysis ===") |
|
|
display_idx = 0 |
|
|
|
|
|
|
|
|
fps = self.fps |
|
|
|
|
|
for ((start_frame, end_frame), depth, rate) in sorted_chunks: |
|
|
|
|
|
if depth == 0 and rate == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
start_sec = start_frame / fps |
|
|
end_sec = end_frame / fps |
|
|
duration_sec = (end_frame - start_frame + 1) / fps |
|
|
|
|
|
cpr_logger.info(f"[Graph Plotter] Chunk {display_idx+1}: " |
|
|
f"Time {start_sec:.2f}s - {end_sec:.2f}s ({duration_sec:.2f}s), " |
|
|
f"Depth: {depth:.1f}cm, Rate: {rate:.1f}cpm") |
|
|
|
|
|
display_idx += 1 |
|
|
|
|
|
cpr_logger.info(f"\n\n=== Error Region Analysis ===") |
|
|
|
|
|
for i, region in enumerate(self.posture_warnings_regions): |
|
|
start_frame = region['start_frame'] |
|
|
end_frame = region['end_frame'] |
|
|
errors = region['posture_warnings'] |
|
|
|
|
|
|
|
|
start_sec = start_frame / fps |
|
|
end_sec = end_frame / fps |
|
|
error_str = ", ".join(errors) if errors else "No errors detected" |
|
|
|
|
|
cpr_logger.info(f"[Graph Plotter] Region {i+1}: " |
|
|
f"Time {start_sec:.2f}s - {end_sec:.2f}s - {error_str}") |
|
|
|
|
|
cpr_logger.info(f"\n\n") |
|
|
|