ho22joshua commited on
Commit
0296317
·
1 Parent(s): a3b9c93

updated utils

Browse files
Files changed (1) hide show
  1. root_gnn_dgl/root_gnn_base/utils.py +93 -7
root_gnn_dgl/root_gnn_base/utils.py CHANGED
@@ -8,10 +8,16 @@ import dgl
8
  import signal
9
 
10
  def buildFromConfig(conf, run_time_args = {}):
 
11
  if 'module' in conf:
12
  module = importlib.import_module(conf['module'])
13
  cls = getattr(module, conf['class'])
14
- return cls(**conf['args'], **run_time_args)
 
 
 
 
 
15
  else:
16
  print('No module specified in config. Returning None.')
17
 
@@ -177,21 +183,101 @@ def get_specific_epoch(config, target_epoch, device = None, from_ryan = False):
177
  checkpoint = torch.load(os.path.join(config['Training_Directory'], f'model_epoch_{last_epoch}.pt'), map_location=device)
178
  return last_epoch, checkpoint
179
 
180
- #Convert training logs into dict for plotting.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  def read_log(config):
182
  lines = []
183
  with open(config['Training_Directory'] + '/training.log', 'r') as f:
184
  lines = f.readlines()
185
- lines = [ l for l in lines if 'Epoch' in l ]
186
- nlines = len(lines)
187
  labels = []
188
  for field in lines[0].split('|'):
189
  labels.append(field.split()[0])
190
- log = {label : np.zeros(nlines) for label in labels}
191
- for i, line in enumerate(lines):
 
 
 
 
 
 
192
  for field in line.split('|'):
193
  spl = field.split()
194
- log[spl[0]][i] = float(spl[1])
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  return log
196
 
197
  #Plot training logs.
 
8
  import signal
9
 
10
  def buildFromConfig(conf, run_time_args = {}):
11
+ device = run_time_args.get('device', 'cpu')
12
  if 'module' in conf:
13
  module = importlib.import_module(conf['module'])
14
  cls = getattr(module, conf['class'])
15
+ args = conf['args'].copy()
16
+ if 'weight' in args and isinstance(args['weight'], list):
17
+ args['weight'] = torch.tensor(args['weight'], dtype=torch.float, device=device)
18
+ # Remove device from run_time_args to not pass it to the class
19
+ run_time_args = {k: v for k, v in run_time_args.items() if k != 'device'}
20
+ return cls(**args, **run_time_args)
21
  else:
22
  print('No module specified in config. Returning None.')
23
 
 
183
  checkpoint = torch.load(os.path.join(config['Training_Directory'], f'model_epoch_{last_epoch}.pt'), map_location=device)
184
  return last_epoch, checkpoint
185
 
186
+ #Return the index and checkpoint of the nest epoch.
187
+ def get_best_epoch(config, var='Test_AUC', mode='max', device=None, from_ryan=False):
188
+ # Read the training log
189
+ log = read_log(config)
190
+
191
+ # Ensure the specified variable exists in the log
192
+ if var not in log:
193
+ raise ValueError(f"Variable '{var}' not found in the training log.")
194
+
195
+ # Determine the target epoch based on the mode ('max' or 'min')
196
+ if mode == 'max':
197
+ target_epoch = int(np.argmax(log[var]))
198
+ print(f"Best epoch based on '{var}' (max): {target_epoch} with value: {log[var][target_epoch]}")
199
+ elif mode == 'min':
200
+ target_epoch = int(np.argmin(log[var]))
201
+ print(f"Best epoch based on '{var}' (min): {target_epoch} with value: {log[var][target_epoch]}")
202
+ else:
203
+ raise ValueError(f"Invalid mode '{mode}'. Expected 'max' or 'min'.")
204
+
205
+ # Initialize checkpoint retrieval variables
206
+ last_epoch = -1
207
+ checkpoint = None
208
+
209
+ # Iterate through epochs up to the target epoch to find the corresponding checkpoint
210
+ for ep in range(target_epoch + 1):
211
+ if from_ryan:
212
+ checkpoint_path = os.path.join(
213
+ '/global/cfs/cdirs/atlas/berobert/root_gnn_dgl/',
214
+ config['Training_Directory'],
215
+ f'model_epoch_{ep}.pt'
216
+ )
217
+ else:
218
+ checkpoint_path = os.path.join(
219
+ config['Training_Directory'],
220
+ f'model_epoch_{ep}.pt'
221
+ )
222
+
223
+ if os.path.exists(checkpoint_path):
224
+ last_epoch = ep
225
+ else:
226
+ print(f'Epoch {ep} not found. Stopping at epoch {last_epoch}')
227
+ print('File not found: ', checkpoint_path)
228
+ break
229
+
230
+ # Load the checkpoint for the last valid epoch
231
+ if last_epoch >= 0:
232
+ if from_ryan:
233
+ checkpoint_path = os.path.join(
234
+ '/global/cfs/cdirs/atlas/berobert/root_gnn_dgl/',
235
+ config['Training_Directory'],
236
+ f'model_epoch_{last_epoch}.pt'
237
+ )
238
+ else:
239
+ checkpoint_path = os.path.join(
240
+ config['Training_Directory'],
241
+ f'model_epoch_{last_epoch}.pt'
242
+ )
243
+
244
+ checkpoint = torch.load(checkpoint_path, map_location=device)
245
+
246
+ return last_epoch, checkpoint
247
+
248
  def read_log(config):
249
  lines = []
250
  with open(config['Training_Directory'] + '/training.log', 'r') as f:
251
  lines = f.readlines()
252
+ lines = [l for l in lines if 'Epoch' in l]
253
+
254
  labels = []
255
  for field in lines[0].split('|'):
256
  labels.append(field.split()[0])
257
+
258
+ # Initialize log as a dictionary with empty lists
259
+ log = {label: [] for label in labels}
260
+
261
+ for line in lines:
262
+ valid_row = True # Flag to check if the row is valid
263
+ temp_row = {} # Temporary row to store values before adding to log
264
+
265
  for field in line.split('|'):
266
  spl = field.split()
267
+ try:
268
+ temp_row[spl[0]] = float(spl[1])
269
+ except (ValueError, IndexError):
270
+ valid_row = False # Mark row as invalid if conversion fails
271
+ break
272
+
273
+ if valid_row: # Only add the row if all fields are valid
274
+ for label in labels:
275
+ log[label].append(temp_row.get(label, np.nan)) # Handle missing labels gracefully
276
+
277
+ # Convert lists to numpy arrays for consistency
278
+ for label in labels:
279
+ log[label] = np.array(log[label])
280
+
281
  return log
282
 
283
  #Plot training logs.