File size: 2,827 Bytes
843a502
 
 
 
 
 
aa4aab0
 
 
 
 
843a502
aa4aab0
 
 
 
 
 
843a502
aa4aab0
 
843a502
aa4aab0
843a502
aa4aab0
843a502
 
aa4aab0
 
 
 
 
 
 
 
 
 
 
79484be
843a502
79484be
aa4aab0
79484be
aa4aab0
 
79484be
 
aa4aab0
79484be
aa4aab0
 
 
 
 
 
843a502
aa4aab0
 
 
 
 
 
1ffa54d
aa4aab0
 
 
 
843a502
aa4aab0
 
 
702ad4b
aa4aab0
843a502
aa4aab0
 
 
 
 
 
843a502
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats

class plotting():

    def quick_plot(self, scatter=None, line=None, columns=["X", "Y"], output=None, hline=None):
        """
        Plot the data along with the best fit line and its associated confidence band using
        Matplotlib's object-oriented API to avoid race conditions.
    
        Parameters:
            scatter (list or array): Data to plot as scatter points
            line (list or array): Data to plot as lines
            columns (list): Axes labels (default is ["X","Y"])
            output (str): If provided, save figure to this path. If "marimo", return the figure object
            hline (float): If provided, add a horizontal line at this y-value
    
        Returns:
            None or Figure: Displays a matplotlib plot or returns the figure object if output="marimo"
    
        Example:
            >>> quick_plot(scatter=x_data, line=y_data)
        """
        # Input validation
        if scatter is None and line is None:
            raise ValueError("Either scatter or line should be provided")
    
        # Convert inputs to lists if they aren't already
        if not isinstance(scatter, list):
            scatter = [scatter] if scatter is not None else []
        if not isinstance(line, list):
            line = [line] if line is not None else []
    
        # Create figure and axes objects (OO approach)
        fig, ax = plt.subplots(figsize=(6, 6))
    
        # Plot scatter data
        idx = 0
        for ds in scatter:
            if ds is not None:
                ax.scatter(ds[:, 0], ds[:, 1], label=f"Data {idx}")
                idx += 1
    
        # Plot line data
        for ds in line:
            if ds is not None:
                ax.plot(ds[:, 0], ds[:, 1], color='red', label=f"Data {idx}")
                idx += 1
    
        # Set labels
        ax.set_xlabel(columns[0])
        ax.set_ylabel(columns[1])
    
        # Add horizontal line if specified
        if hline is not None:
            ax.axhline(hline, color='black', linestyle='--')
    
        # Add legend
        ax.legend()
    
        # Add watermark
        ax.text(0.5, 0.5, 'TEMPLATE', transform=ax.transAxes,
                fontsize=40, color='gray', alpha=0.5,
                ha='center', va='center', rotation=30)
    
        # Handle output
        if output is None:
            # Use this instead of plt.show() to avoid blocking behavior
            fig.canvas.draw_idle()
            plt.show(block=False)
        elif output == "marimo":
            return fig
        else:
            # Save the figure to the specified path
            fig.savefig(output)
    
        # Close the figure to free memory if not returning it
        if output != "marimo":
            plt.close(fig)