schrum2 commited on
Commit
9756b46
·
verified ·
1 Parent(s): 9ceb35a

don't need

Browse files
Files changed (1) hide show
  1. plotter.py +0 -173
plotter.py DELETED
@@ -1,173 +0,0 @@
1
- # Track changes in loss and learning rate during execution
2
- import argparse
3
- import matplotlib
4
- import matplotlib.pyplot as plt
5
- import os
6
- import time
7
- import json
8
- import tempfile
9
- import shutil
10
- from pathlib import Path
11
-
12
-
13
- def parse_args():
14
- parser = argparse.ArgumentParser(description="Train a text-conditional diffusion model for tile-based level generation")
15
-
16
- # Dataset args
17
- parser.add_argument("--log_file", type=str, default=None, help="The the filepath of the file to get the data from")
18
- parser.add_argument("--left_key", type=str, default=None, help="The key for the left y-axis")
19
- parser.add_argument("--right_key", type=str, default=None, help="The key for the right y-axis")
20
- parser.add_argument("--left_label", type=str, default=None, help="The label for the left y-axis")
21
- parser.add_argument("--right_label", type=str, default=None, help="The label for the right y-axis")
22
- parser.add_argument("--output_png", type=str, default="output.png", help="The output png file")
23
- parser.add_argument("--update_interval", type=int, default=1.0, help="The update inteval in epochs")
24
- parser.add_argument("--start_point", type=int, default=None, help="The start point for the plot")
25
-
26
- return parser.parse_args()
27
-
28
-
29
- def main():
30
- args = parse_args()
31
-
32
- log_file = args.log_file
33
- left_key = args.left_key
34
- right_key = args.right_key
35
- left_label = args.left_label
36
- right_label = args.right_label
37
- output_png = args.output_png
38
- update_interval = args.update_interval
39
- start_point = args.start_point
40
-
41
- general_update_plot(log_file, left_key, right_key, left_label, right_label, output_png, update_interval=update_interval, startPoint=start_point)
42
-
43
-
44
- def general_update_plot(log_file, left_key, right_key, left_label, right_label, output_png, update_interval=1.0, startPoint=None):
45
- log_dir = os.path.dirname(log_file)
46
-
47
- # Create figure here and ensure it's closed
48
- fig = plt.figure(figsize=(10, 6))
49
- ax = fig.add_subplot(111)
50
-
51
- try:
52
- if os.path.exists(log_file):
53
- with open(log_file, 'r') as f:
54
- data = [json.loads(line) for line in f if line.strip()]
55
-
56
- if not data:
57
- return
58
-
59
- if startPoint is not None:
60
- data = [entry for entry in data if entry.get('epoch', 0) >= startPoint]
61
-
62
- if not data:
63
- return
64
-
65
- epochs = [entry.get('epoch', 0) for entry in data]
66
- left = [entry.get(left_key, 0) for entry in data]
67
-
68
- # For right axis (e.g., lr), only include points where right_key exists
69
- right_points = [(entry.get('epoch', 0), entry.get(right_key))
70
- for entry in data if right_key in entry]
71
- if right_points:
72
- right_epochs, right_values = zip(*right_points)
73
- else:
74
- right_epochs, right_values = [], []
75
-
76
- # Clear axis
77
- ax.clear()
78
-
79
- # Plot both metrics on the same axis
80
- ax.plot(epochs, left, 'b-', label=left_label)
81
- if right_epochs:
82
- ax.plot(right_epochs, right_values, 'r-', label=right_label)
83
-
84
- ax.set_xlabel('Epoch')
85
- ax.set_ylabel(left_label) # "Loss" as y-axis label
86
- ax.set_title('Training Progress')
87
- ax.legend(loc='upper left')
88
- #Limit x-axis to startPoint if provided
89
- if startPoint is not None:
90
- ax.set_xlim(left=startPoint)
91
- fig.tight_layout()
92
-
93
- # Use the stored base directory instead of getting it from log_file
94
- if os.path.isabs(output_png) or os.path.dirname(output_png):
95
- output_path = output_png
96
- else:
97
- output_path = os.path.join(log_dir, output_png)
98
-
99
- save_figure_safely(fig, output_path)
100
- finally:
101
- plt.close(fig) # Ensure figure is closed even if an error occurs
102
-
103
- def save_figure_safely(fig, output_path):
104
- """Save figure to a temporary file first, then move it to the final location"""
105
- output_path = str(Path(output_path)) # Convert to string path
106
-
107
- # Create temporary file with .png extension
108
- with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
109
- tmp_path = tmp_file.name
110
-
111
- try:
112
- # Save to temporary file
113
- fig.savefig(tmp_path)
114
-
115
- # Create output directory if it doesn't exist
116
- os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
117
-
118
- # Try to move the file to final destination
119
- # If move fails, try to copy and then delete
120
- try:
121
- shutil.move(tmp_path, output_path)
122
- except OSError:
123
- shutil.copy2(tmp_path, output_path)
124
- os.unlink(tmp_path)
125
- except Exception as e:
126
- # Clean up temporary file if anything goes wrong
127
- if os.path.exists(tmp_path):
128
- os.unlink(tmp_path)
129
- raise e
130
-
131
- class Plotter:
132
- def __init__(self, log_file, update_interval=1.0, left_key='loss', right_key='lr',
133
- left_label='Loss', right_label='Learning Rate', output_png='training_progress.png'):
134
- self.log_dir = os.path.dirname(log_file)
135
- self.log_file = log_file
136
- self.update_interval = update_interval
137
- self.running = True
138
- self.output_png = output_png
139
- self.left_key = left_key
140
- self.right_key = right_key
141
- self.left_label = left_label
142
- self.right_label = right_label
143
-
144
- matplotlib.use('Agg')
145
-
146
- def __enter__(self):
147
- return self
148
-
149
- def __exit__(self, exc_type, exc_val, exc_tb):
150
- self.stop_plotting()
151
-
152
- def __del__(self):
153
- self.stop_plotting()
154
-
155
- def update_plot(self):
156
- general_update_plot(self.log_file, self.left_key, self.right_key,
157
- self.left_label, self.right_label, self.output_png,
158
- update_interval=self.update_interval)
159
-
160
- def start_plotting(self):
161
- print("Starting plotting in background")
162
- while self.running:
163
- self.update_plot()
164
- time.sleep(self.update_interval)
165
-
166
- def stop_plotting(self):
167
- if hasattr(self, 'running'): # Check if already stopped
168
- self.running = False
169
- self.update_plot()
170
- print("Plotting stopped")
171
-
172
- if __name__ == "__main__":
173
- main()