Commit
·
0296317
1
Parent(s):
a3b9c93
updated utils
Browse files
root_gnn_dgl/root_gnn_base/utils.py
CHANGED
|
@@ -8,10 +8,16 @@ import dgl
|
|
| 8 |
import signal
|
| 9 |
|
| 10 |
def buildFromConfig(conf, run_time_args = {}):
|
|
|
|
| 11 |
if 'module' in conf:
|
| 12 |
module = importlib.import_module(conf['module'])
|
| 13 |
cls = getattr(module, conf['class'])
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
else:
|
| 16 |
print('No module specified in config. Returning None.')
|
| 17 |
|
|
@@ -177,21 +183,101 @@ def get_specific_epoch(config, target_epoch, device = None, from_ryan = False):
|
|
| 177 |
checkpoint = torch.load(os.path.join(config['Training_Directory'], f'model_epoch_{last_epoch}.pt'), map_location=device)
|
| 178 |
return last_epoch, checkpoint
|
| 179 |
|
| 180 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
def read_log(config):
|
| 182 |
lines = []
|
| 183 |
with open(config['Training_Directory'] + '/training.log', 'r') as f:
|
| 184 |
lines = f.readlines()
|
| 185 |
-
lines = [
|
| 186 |
-
|
| 187 |
labels = []
|
| 188 |
for field in lines[0].split('|'):
|
| 189 |
labels.append(field.split()[0])
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
for field in line.split('|'):
|
| 193 |
spl = field.split()
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
return log
|
| 196 |
|
| 197 |
#Plot training logs.
|
|
|
|
| 8 |
import signal
|
| 9 |
|
| 10 |
def buildFromConfig(conf, run_time_args = {}):
|
| 11 |
+
device = run_time_args.get('device', 'cpu')
|
| 12 |
if 'module' in conf:
|
| 13 |
module = importlib.import_module(conf['module'])
|
| 14 |
cls = getattr(module, conf['class'])
|
| 15 |
+
args = conf['args'].copy()
|
| 16 |
+
if 'weight' in args and isinstance(args['weight'], list):
|
| 17 |
+
args['weight'] = torch.tensor(args['weight'], dtype=torch.float, device=device)
|
| 18 |
+
# Remove device from run_time_args to not pass it to the class
|
| 19 |
+
run_time_args = {k: v for k, v in run_time_args.items() if k != 'device'}
|
| 20 |
+
return cls(**args, **run_time_args)
|
| 21 |
else:
|
| 22 |
print('No module specified in config. Returning None.')
|
| 23 |
|
|
|
|
| 183 |
checkpoint = torch.load(os.path.join(config['Training_Directory'], f'model_epoch_{last_epoch}.pt'), map_location=device)
|
| 184 |
return last_epoch, checkpoint
|
| 185 |
|
| 186 |
+
#Return the index and checkpoint of the nest epoch.
|
| 187 |
+
def get_best_epoch(config, var='Test_AUC', mode='max', device=None, from_ryan=False):
|
| 188 |
+
# Read the training log
|
| 189 |
+
log = read_log(config)
|
| 190 |
+
|
| 191 |
+
# Ensure the specified variable exists in the log
|
| 192 |
+
if var not in log:
|
| 193 |
+
raise ValueError(f"Variable '{var}' not found in the training log.")
|
| 194 |
+
|
| 195 |
+
# Determine the target epoch based on the mode ('max' or 'min')
|
| 196 |
+
if mode == 'max':
|
| 197 |
+
target_epoch = int(np.argmax(log[var]))
|
| 198 |
+
print(f"Best epoch based on '{var}' (max): {target_epoch} with value: {log[var][target_epoch]}")
|
| 199 |
+
elif mode == 'min':
|
| 200 |
+
target_epoch = int(np.argmin(log[var]))
|
| 201 |
+
print(f"Best epoch based on '{var}' (min): {target_epoch} with value: {log[var][target_epoch]}")
|
| 202 |
+
else:
|
| 203 |
+
raise ValueError(f"Invalid mode '{mode}'. Expected 'max' or 'min'.")
|
| 204 |
+
|
| 205 |
+
# Initialize checkpoint retrieval variables
|
| 206 |
+
last_epoch = -1
|
| 207 |
+
checkpoint = None
|
| 208 |
+
|
| 209 |
+
# Iterate through epochs up to the target epoch to find the corresponding checkpoint
|
| 210 |
+
for ep in range(target_epoch + 1):
|
| 211 |
+
if from_ryan:
|
| 212 |
+
checkpoint_path = os.path.join(
|
| 213 |
+
'/global/cfs/cdirs/atlas/berobert/root_gnn_dgl/',
|
| 214 |
+
config['Training_Directory'],
|
| 215 |
+
f'model_epoch_{ep}.pt'
|
| 216 |
+
)
|
| 217 |
+
else:
|
| 218 |
+
checkpoint_path = os.path.join(
|
| 219 |
+
config['Training_Directory'],
|
| 220 |
+
f'model_epoch_{ep}.pt'
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
if os.path.exists(checkpoint_path):
|
| 224 |
+
last_epoch = ep
|
| 225 |
+
else:
|
| 226 |
+
print(f'Epoch {ep} not found. Stopping at epoch {last_epoch}')
|
| 227 |
+
print('File not found: ', checkpoint_path)
|
| 228 |
+
break
|
| 229 |
+
|
| 230 |
+
# Load the checkpoint for the last valid epoch
|
| 231 |
+
if last_epoch >= 0:
|
| 232 |
+
if from_ryan:
|
| 233 |
+
checkpoint_path = os.path.join(
|
| 234 |
+
'/global/cfs/cdirs/atlas/berobert/root_gnn_dgl/',
|
| 235 |
+
config['Training_Directory'],
|
| 236 |
+
f'model_epoch_{last_epoch}.pt'
|
| 237 |
+
)
|
| 238 |
+
else:
|
| 239 |
+
checkpoint_path = os.path.join(
|
| 240 |
+
config['Training_Directory'],
|
| 241 |
+
f'model_epoch_{last_epoch}.pt'
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 245 |
+
|
| 246 |
+
return last_epoch, checkpoint
|
| 247 |
+
|
| 248 |
def read_log(config):
|
| 249 |
lines = []
|
| 250 |
with open(config['Training_Directory'] + '/training.log', 'r') as f:
|
| 251 |
lines = f.readlines()
|
| 252 |
+
lines = [l for l in lines if 'Epoch' in l]
|
| 253 |
+
|
| 254 |
labels = []
|
| 255 |
for field in lines[0].split('|'):
|
| 256 |
labels.append(field.split()[0])
|
| 257 |
+
|
| 258 |
+
# Initialize log as a dictionary with empty lists
|
| 259 |
+
log = {label: [] for label in labels}
|
| 260 |
+
|
| 261 |
+
for line in lines:
|
| 262 |
+
valid_row = True # Flag to check if the row is valid
|
| 263 |
+
temp_row = {} # Temporary row to store values before adding to log
|
| 264 |
+
|
| 265 |
for field in line.split('|'):
|
| 266 |
spl = field.split()
|
| 267 |
+
try:
|
| 268 |
+
temp_row[spl[0]] = float(spl[1])
|
| 269 |
+
except (ValueError, IndexError):
|
| 270 |
+
valid_row = False # Mark row as invalid if conversion fails
|
| 271 |
+
break
|
| 272 |
+
|
| 273 |
+
if valid_row: # Only add the row if all fields are valid
|
| 274 |
+
for label in labels:
|
| 275 |
+
log[label].append(temp_row.get(label, np.nan)) # Handle missing labels gracefully
|
| 276 |
+
|
| 277 |
+
# Convert lists to numpy arrays for consistency
|
| 278 |
+
for label in labels:
|
| 279 |
+
log[label] = np.array(log[label])
|
| 280 |
+
|
| 281 |
return log
|
| 282 |
|
| 283 |
#Plot training logs.
|