File size: 3,376 Bytes
c8b42eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
Utility functions for UniCeption Encoders.
"""

import functools

import numpy as np
import torch


def profile_encoder(num_warmup=3, num_runs=20, autocast_precision="float16", use_compile=False, dynamic=True):
    def decorator(func):
        @functools.wraps(func)
        def wrapper(self, *args, **kwargs):
            device = "cuda"
            autocast_dtype = getattr(torch, autocast_precision)

            # Compile the model if requested
            if use_compile:
                compiled_func = torch.compile(func, dynamic=dynamic, mode="max-autotune")
            else:
                compiled_func = func

            with torch.autocast("cuda", dtype=autocast_dtype):
                # Warm-up runs
                for _ in range(num_warmup):
                    output = compiled_func(self, *args, **kwargs)
                    if isinstance(output, torch.Tensor):
                        output.sum().backward()
                    else:
                        output.features.sum().backward()
                    torch.cuda.synchronize()

                # Clear memory cache
                torch.cuda.empty_cache()

                # Lists to store results
                forward_times, backward_times, memory_usages = [], [], []

                for _ in range(num_runs):
                    start_event = torch.cuda.Event(enable_timing=True)
                    end_event = torch.cuda.Event(enable_timing=True)

                    torch.cuda.reset_peak_memory_stats()
                    memory_before = torch.cuda.max_memory_allocated(device)

                    # Forward pass
                    start_event.record()
                    output = compiled_func(self, *args, **kwargs)
                    end_event.record()
                    torch.cuda.synchronize()
                    forward_times.append(start_event.elapsed_time(end_event))

                    # Backward pass
                    start_event.record()
                    if isinstance(output, torch.Tensor):
                        output.sum().backward()
                    else:
                        output.features.sum().backward()
                    end_event.record()
                    torch.cuda.synchronize()
                    backward_times.append(start_event.elapsed_time(end_event))

                    memory_after = torch.cuda.max_memory_allocated(device)
                    memory_usages.append((memory_after - memory_before) / 1e6)  # Convert to MB

            # Compute mean and standard deviation
            fwd_mean, fwd_std = np.mean(forward_times), np.std(forward_times)
            bwd_mean, bwd_std = np.mean(backward_times), np.std(backward_times)
            mem_mean, mem_std = np.mean(memory_usages), np.std(memory_usages)

            compile_status = (
                "with torch.compile (dynamic=True)"
                if use_compile and dynamic
                else "with torch.compile (dynamic=False)" if use_compile else "without torch.compile"
            )
            print(f"Profiling results {compile_status}:")
            print(f"Forward Pass Time: {fwd_mean:.2f} ± {fwd_std:.2f} ms")
            print(f"Backward Pass Time: {bwd_mean:.2f} ± {bwd_std:.2f} ms")
            print(f"Peak GPU Memory Usage: {mem_mean:.2f} ± {mem_std:.2f} MB")

            return output

        return wrapper

    return decorator