infinity1096
initial commit
c8b42eb
"""
Utility functions for UniCeption Encoders.
"""
import functools
import numpy as np
import torch
def profile_encoder(num_warmup=3, num_runs=20, autocast_precision="float16", use_compile=False, dynamic=True):
def decorator(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
device = "cuda"
autocast_dtype = getattr(torch, autocast_precision)
# Compile the model if requested
if use_compile:
compiled_func = torch.compile(func, dynamic=dynamic, mode="max-autotune")
else:
compiled_func = func
with torch.autocast("cuda", dtype=autocast_dtype):
# Warm-up runs
for _ in range(num_warmup):
output = compiled_func(self, *args, **kwargs)
if isinstance(output, torch.Tensor):
output.sum().backward()
else:
output.features.sum().backward()
torch.cuda.synchronize()
# Clear memory cache
torch.cuda.empty_cache()
# Lists to store results
forward_times, backward_times, memory_usages = [], [], []
for _ in range(num_runs):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.reset_peak_memory_stats()
memory_before = torch.cuda.max_memory_allocated(device)
# Forward pass
start_event.record()
output = compiled_func(self, *args, **kwargs)
end_event.record()
torch.cuda.synchronize()
forward_times.append(start_event.elapsed_time(end_event))
# Backward pass
start_event.record()
if isinstance(output, torch.Tensor):
output.sum().backward()
else:
output.features.sum().backward()
end_event.record()
torch.cuda.synchronize()
backward_times.append(start_event.elapsed_time(end_event))
memory_after = torch.cuda.max_memory_allocated(device)
memory_usages.append((memory_after - memory_before) / 1e6) # Convert to MB
# Compute mean and standard deviation
fwd_mean, fwd_std = np.mean(forward_times), np.std(forward_times)
bwd_mean, bwd_std = np.mean(backward_times), np.std(backward_times)
mem_mean, mem_std = np.mean(memory_usages), np.std(memory_usages)
compile_status = (
"with torch.compile (dynamic=True)"
if use_compile and dynamic
else "with torch.compile (dynamic=False)" if use_compile else "without torch.compile"
)
print(f"Profiling results {compile_status}:")
print(f"Forward Pass Time: {fwd_mean:.2f} ± {fwd_std:.2f} ms")
print(f"Backward Pass Time: {bwd_mean:.2f} ± {bwd_std:.2f} ms")
print(f"Peak GPU Memory Usage: {mem_mean:.2f} ± {mem_std:.2f} MB")
return output
return wrapper
return decorator