File size: 1,717 Bytes
cb099ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import matplotlib.pyplot as plt
import numpy as np
import os

def plot_sigma_sequence(sigs, stopping_index, log_filename, save_directory="modules/sd_simple_kes/image_generation_data", show_plot=False):
        """

        Plot the sigma sequence and mark the early stopping point.



        Parameters:

        - sigs: The sigma tensor or numpy array (can be truncated if stopping early).

        - stopping_index: The step index where early stopping was triggered.

        - log_filename: The filename of the generation log (used to match the graph name).

        - save_directory: The folder where the plot should be saved.

        - show_plot: Set to True to display the plot interactively.

        """

        # Extract base name to match log filename
        base_filename = os.path.splitext(os.path.basename(log_filename))[0]
        graph_filename = f"{base_filename}_sigma_plot.png"
        graph_path = os.path.join(save_directory, graph_filename)

        # Prepare sigma sequence for plotting
        sigs_np = sigs.cpu().numpy() if hasattr(sigs, 'cpu') else np.array(sigs)
        x = np.arange(len(sigs_np))

        # Plotting
        plt.figure(figsize=(10, 6))
        plt.plot(x, sigs_np, label='Sigma Sequence', marker='o')
        plt.axvline(x=stopping_index, color='red', linestyle='--', label=f'Stopping Point: {stopping_index}')
        plt.xlabel('Step Index')
        plt.ylabel('Sigma Value')
        plt.title('Sigma Sequence with Early Stopping Point')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(graph_path)

        if show_plot:
            plt.show()

        plt.close()
        return graph_path