File size: 32,122 Bytes
ad5f26a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
r"""

This module exposes a TunableOp interface.



Some operations, such as GEMMs, could be implemented using more than one library

or more than one technique. For example, a GEMM could be implemented for CUDA or

ROCm using either the blas or blasLt libraries. Further, ROCm's rocblas and

hipblaslt libraries allow the user to query for all possible algorithms and then

choose one. How does one know which implementation is the fastest and should be

chosen? That's what TunableOp provides.



Enabling TunableOp and Tuning Separately

========================================



The TunableOp feature is enabled separately from enabling the tuning phase

itself. Enabling TunableOp means that PyTorch will replace any standard

operators with their Tunable implementations. Any call to a TunableOp first

checks whether it has already been tuned for the given operator inputs. If so,

it will immediately call the tuned operation; no further tuning will take place

even when the tuning setting is enabled. Instead if no tuning result is found,

and tuning is enabled, the TunableOp will benchmark every registered

implementation of that operator for the given set of inputs and select the

fastest.



File Input and Output

=====================



The first time any TunableOp is invoked, the internal database of tuned

operations will be prepared by attempting to read the results from the given

file. The default filename is 'tunableop_results.csv'. To support tuning when

multiple GPUs are used across multiple processes, the GPU device ordinal is

automatically inserted into the filename to avoid multiple processes overwriting

the same file.



If tuning is enabled and new tunings are discovered during the course of your

workload, it will also write out to this same filename with all tunings, both

the ones it read in at startup as well as the new ones found at runtime. This

can be used, for example, to build up a tunings file across many workloads by

reusing the same file. The output file is automatically created when the

application terminates. This behavior can be controlled by the C++ and Python

APIs but not the environment variables.



Assuming you specified a filename, you'll end up with a CSV file with contents

like so::



  Validator,PT_VERSION,2.2.0

  Validator,ROCM_VERSION,6.0.0.0-12969-1544e39

  Validator,HIPBLASLT_VERSION,0.6.0-a9c5cc7

  Validator,ROCBLAS_VERSION,4.0.0-72e57364-dirty

  GemmTunableOp_float_NT,nt_25088_4096_64,Gemm_Hipblaslt_1219,1.262

  GemmTunableOp_float_NT,nt_4096_4096_64,Gemm_Rocblas_1216,0.033



Note the "Validator" lines. If you change a library version, or ROCm version, or

PyTorch version, TunableOp will detect this and reject the tunings file because

the prior tunings are likely affected by other software changes.



The remaining lines are the tuned solutions for each TunableOp encountered

during your execution. Each line consists of 4 comma-separated fields: operator

name, operator parameters, solution name, and average execution time. The

execution time is an optional field. The CSV file can be edited, but with

caution. For example, the solution name (field 3) can be changed to "Default"

and it will fall back to the original PyTorch untuned implementation. Or, in the

case of ROCm's hipBLAS or hipBLASLt libraries, if you know the specific solution

index you can override the solution that TunableOp selected by replacing the

value. The operator name and parameters (fields 1 and 2) are internally named

and should not be modified. In the case of GemmTunableOp, field 1 indicates the

datatype and whether the inputs are transposed (T) or not (N) and field 2

indicates the M, N, K input shapes.



There is an option to enable verbose output but it is only recommended for

debugging purposes. This will produce a lot of diagnostic messages but may be

useful to see if TunableOp is being used at all. Otherwise, TunableOp is

completely silent, besides file output, unless there is a warning or error

during its use. The verbose option is only available by setting the environment

variable PYTORCH_TUNABLEOP_VEROBSE=1.



A Note on Tuning Behavior, Warmup, and Cache Effects

====================================================



Tuning an operator consists of iterating through the list or registered

implementations and profiling each one. The profile is established by running a

single implementation in a loop multiple times and taking the average execution

time. There is also an optional warmup phase prior to tuning that can help with

reaching stable power states by the hardware. During tuning of a workload the

various hardware caches will more likely produce hits than when not tuning.

There are options for flushing the instruction cache and rotate the input tensors

which might help produce a more faithful profile of the tuned operator as if the

operator were run within a larger workload instead of in a tight, repetitive loop.



By default, each possible solution for a given operator will be run for either

100 iterations or as many iterations that can be run within 30ms, whichever is

smaller, and its average execution will be calculated. The fastest solution

among all that were successfully profiled will be chosen. A profile might fail

if the given solution doesn't achieve the same accuracy as the default

implementation or if the solution returns an error code.



Current Tunable Operators

=========================



TunableGemm for ROCm

--------------------



Currently only a TunableGemm for ROCm is implemented. Note that CUDA builds of

PyTorch will function correctly when using TunableOp but the only solution

available to CUDA builds is the 'Default' implementation i.e. the original

cuBLAS default, now called through TunableOp. Any call to at::cuda::blas::gemm()

or ::bgemm() will be routed through TunableOp when enabled. Calling gemm() for a

given set of input arguments (transa, transb, m, n, k) will attempt to use the

fastest available implementation across both rocblas and hipblaslt.



Offline Tuning

==============



Motivation

----------

There are several use cases for offline tuning.



One use case involves a workload with a high-memory utilization, where regular tuning might lead to running out of memory.



Another use case is for compute-intensive workloads. In such cases, it is more resource-efficient to collect

the GEMMs for the workload once and then tune repeatedly with different tuning parameters or libraries.



Workflow

--------

There are basically two steps:

1) Set the environment variables to collect the untuned GEMM and this will generate ``tunableop_untuned0.csv``:



.. code-block:: python



   PYTORCH_TUNABLEOP_ENABLED=1

   PYTORCH_TUNABLEOP_TUNING=0

   PYTORCH_TUNABLEOP_RECORD_UNTUNED=1

   ...



2) Run a Python script that reads the ``tunableop_untuned0.csv`` and generates the ``tunableop_results0.csv``, like this:



.. code-block:: python



   import torch.cuda.tunable as tunable

   import os



   os.putenv('PYTORCH_TUNABLEOP_ENABLED', '1')

   os.putenv('PYTORCH_TUNABLEOP_TUNING', '1')

   os.putenv('PYTORCH_TUNABLEOP_RECORD_UNTUNED', '0')

   tunable.tune_gemm_in_file("tunableop_untuned0.csv")





It is also possible to take multiple untuned files and distribute the GEMMs for tuning to multiple GPUs

within a single node. In the first step, the GEMMs are first gathered and duplicate GEMMs are eliminated.

Next, the GEMMs are distributed to different GPUs for tuning. After all GEMMs are tuned, the results from

all the GPUs are then gathered into a single file whose base filename has ``_full0`` appended to it

(for example ``tunableop_results_full0.csv``). Finally, this new file, containing the gathered results, will be

duplicated N times, once for each GPU as convenience to the user will run the workload with the tuned

configuration on N GPUs.



.. code-block:: python



   if __name__ == "__main__":

       num_gpus = 8 # number of GPUs that will be used during the tuning process

       tunable.mgpu_tune_gemm_in_file("tunableop_untuned?.csv", num_gpus)



Note that the usage of the ``mgpu_tune_gemm_in_file`` API is different from its single GPU counterpart

(``tune_gemm_in_file``). The body of the Python script that calls the API must be wrapped in ``main()`` as shown

due to the use of concurrent futures module. The argument to ``mgpu_tune_gemm_in_file`` must contain a wild card

expression (``?`` or ``*``) to generate the list of untuned files containing the GEMMs to be processed. The ``num_gpus``

must between 1 and the total number of GPUs available.



Tuning Context

==============



The behavior of TunableOp is currently manipulated through environment

variables, the C++ interface of at::cuda::tunable::getTuningContext(), or the

torch.cuda.tunable python interfaces. The environment variables take precedence

over any setting you manipulate using the C++ or Python APIs.



Environment Variable Interface

------------------------------

Environment variables are cached the first time they are read. You cannot use the

environment variable interface programmatically since the settings become fixed.

Use the C++ or Python APIs instead.



"""
import concurrent.futures
import glob
import multiprocessing as mp
import os
import shutil
import warnings
from typing import Optional

import torch


__all__ = [
    "enable",
    "is_enabled",
    "tuning_enable",
    "tuning_is_enabled",
    "record_untuned_enable",
    "record_untuned_is_enabled",
    "set_max_tuning_duration",
    "get_max_tuning_duration",
    "set_max_tuning_iterations",
    "get_max_tuning_iterations",
    "set_filename",
    "get_filename",
    "get_results",
    "get_validators",
    "write_file_on_exit",
    "write_file",
    "read_file",
    "tune_gemm_in_file",
    "mgpu_tune_gemm_in_file",
    "set_rotating_buffer_size",
    "get_rotating_buffer_size",
]


def enable(val: bool = True) -> None:
    r"""This is the big on/off switch for all TunableOp implementations."""
    torch._C._cuda_tunableop_enable(val)  # type: ignore[attr-defined]


def is_enabled() -> bool:
    r"""Returns whether the TunableOp feature is enabled."""
    return torch._C._cuda_tunableop_is_enabled()  # type: ignore[attr-defined]


def tuning_enable(val: bool = True) -> None:
    r"""Enable tuning of TunableOp implementations.



    When enabled, if a tuned entry isn't found, run the tuning step and record

    the entry.

    """
    torch._C._cuda_tunableop_tuning_enable(val)  # type: ignore[attr-defined]


def tuning_is_enabled() -> bool:
    r"""Returns whether TunableOp implementations can be tuned."""
    return torch._C._cuda_tunableop_tuning_is_enabled()  # type: ignore[attr-defined]


def record_untuned_enable(val: bool = True) -> None:
    r"""Enable recording untuned of TunableOp perations for offline tuning.



    When enabled, if a tuned entry isn't found, write it to the untuned file.

    """
    torch._C._cuda_record_untuned_enable(val)  # type: ignore[attr-defined]


def record_untuned_is_enabled() -> bool:
    r"""Returns whether TunableOp operations are recorded for offline tuning."""
    return torch._C._cuda_record_untuned_is_enabled()  # type: ignore[attr-defined]


def set_max_tuning_duration(duration: int) -> None:
    r"""Set max time in milliseconds to spend tuning a given solution.



    If both max tuning duration and iterations are set, the smaller of the two

    will be honored. At minimum 1 tuning iteration will always be run.

    """
    torch._C._cuda_tunableop_set_max_tuning_duration(duration)  # type: ignore[attr-defined]


def get_max_tuning_duration() -> int:
    r"""Get max time to spend tuning a given solution."""
    return torch._C._cuda_tunableop_get_max_tuning_duration()  # type: ignore[attr-defined]


def set_max_tuning_iterations(iterations: int) -> None:
    r"""Set max number of iterations to spend tuning a given solution.



    If both max tuning duration and iterations are set, the smaller of the two

    will be honored. At minimum 1 tuning iteration will always be run.

    """
    torch._C._cuda_tunableop_set_max_tuning_iterations(iterations)  # type: ignore[attr-defined]


def get_max_tuning_iterations() -> int:
    r"""Get max iterations to spend tuning a given solution."""
    return torch._C._cuda_tunableop_get_max_tuning_iterations()  # type: ignore[attr-defined]


def set_filename(filename: str, insert_device_ordinal: bool = False) -> None:
    r"""Set the filename to use for input/output of tuning results.



    If :attr:`insert_device_ordinal` is ``True`` then the current device ordinal

    will be added to the given filename automatically. This can be used in a

    1-process-per-gpu cenario to ensure all processes write to a separate file.

    """
    torch._C._cuda_tunableop_set_filename(filename, insert_device_ordinal)  # type: ignore[attr-defined]


def get_filename() -> str:
    r"""Get the results filename."""
    return torch._C._cuda_tunableop_get_filename()  # type: ignore[attr-defined]


def get_results() -> tuple[str, str, str, float]:
    r"""Return all TunableOp results."""
    return torch._C._cuda_tunableop_get_results()  # type: ignore[attr-defined]


def get_validators() -> tuple[str, str]:
    r"""Return the TunableOp validators."""
    return torch._C._cuda_tunableop_get_validators()  # type: ignore[attr-defined]


def write_file_on_exit(val: bool) -> None:
    r"""During Tuning Context destruction, write file to disk.



    This is useful as a final flush of your results to disk if your application

    terminates as result of normal operation or an error. Manual flushing of

    your results can be achieved by manually calling ``write_file()``."""
    torch._C._cuda_tunableop_write_file_on_exit(val)  # type: ignore[attr-defined]


def write_file(filename: Optional[str] = None) -> bool:
    r"""Write results to a CSV file.



    If :attr:`filename` is not given, ``get_filename()`` is called.

    """
    if filename is None:
        filename = get_filename()
    return torch._C._cuda_tunableop_write_file(filename)  # type: ignore[attr-defined]


def read_file(filename: Optional[str] = None) -> bool:
    r"""Read results from a TunableOp CSV file.



    If :attr:`filename` is not given, ``get_filename()`` is called.

    """
    if filename is None:
        filename = get_filename()
    return torch._C._cuda_tunableop_read_file(filename)  # type: ignore[attr-defined]


def set_rotating_buffer_size(buffer_size: int) -> None:
    r"""Set rotating buffer size to this value in MB, if the buffer size is greater than zero.



    If less than zero, query L2 cache size. If equal to zero, means deactivate rotating buffer.

    """
    return torch._C._cuda_tunableop_set_rotating_buffer_size(buffer_size)  # type: ignore[attr-defined]


def get_rotating_buffer_size() -> int:
    r"""Get the rotating buffer size in kilobytes."""
    return torch._C._cuda_tunableop_get_rotating_buffer_size()  # type: ignore[attr-defined]


def tune_gemm_in_file(filename: str) -> None:
    r"""tune GEMM in file."""

    assert is_enabled()
    assert tuning_is_enabled()

    deviceid = torch.cuda.current_device()

    with open(filename) as file:
        for line in file:
            if line.startswith(("Gemm", "ScaledGemm")):
                _process_single_offline_gemm(line, deviceid)


def _gather_unique_untuned_gemm_from_files(filename_pattern: str) -> set[str]:
    r"""Process multiple untuned results file and return a set with duplicates removed."""
    unique_gemm_entries = set()  # set will avoid duplicates

    for file_path in glob.glob(filename_pattern):
        with open(file_path) as file:
            for line in file:
                if line.startswith(("Gemm", "ScaledGemm")):
                    unique_gemm_entries.add(line)

    return unique_gemm_entries


def _gather_tunableop_results() -> None:
    r"""Gather results from multiple tunableop results file and create a single file."""
    gemm_lines = set()
    validator_lines = []

    # Need to allow for the possibility that results filename was
    # set with the Python API instead of with environment variable.
    # Also possible that results filename was not set at all.
    # There are several test cases to check, but ultimately we
    # need a glob-able expression
    results_filename = get_filename()  # Note empty string could be returned here

    if (
        results_filename is not None and results_filename != ""
    ):  # Case were the Python API was used to set the filename
        dot_pos = results_filename.find(".")
        if dot_pos != -1 and dot_pos > 0:
            # Replace the character just to the left of the dot
            filename_pattern = (
                results_filename[: dot_pos - 1] + "?" + results_filename[dot_pos:]
            )
        else:
            filename_pattern = ""  # Needed to make linter happy
    else:  # Case where the environment variable was used to set the filename.
        results_filename_env = os.getenv("PYTORCH_TUNABLEOP_FILENAME")
        if results_filename_env is None or results_filename_env == "":
            filename_pattern = "tunableop_results?.csv"
        elif "%d" in results_filename_env:
            filename_pattern = results_filename_env.replace("%d", "?")
        else:
            filename_pattern = results_filename_env.replace(".", "?.")

    assert "?" in filename_pattern

    FirstFile = False
    matching_files = glob.glob(filename_pattern)
    num_matching_files = len(matching_files)
    for file_path in matching_files:
        with open(file_path) as file:
            for line in file:
                if line.startswith("Validator"):
                    if not (FirstFile):
                        # Only read Validator from first file
                        validator_lines.append(line)
                else:
                    gemm_lines.add(line)

        FirstFile = True

    output_file = filename_pattern.replace("?", "_full0")

    with open(output_file, "w") as out_file:
        for line in validator_lines:
            out_file.write(line)
        for line in gemm_lines:
            out_file.write(line)

    # Create num_matching_copies of the results file
    for i in range(1, num_matching_files):
        duplicate_file = output_file.replace("0", str(i))
        shutil.copy(output_file, duplicate_file)


def _create_matrices(

    m: int,

    n: int,

    k: int,

    lda: int,

    ldb: int,

    ldc: int,

    transA: bool,

    transB: bool,

    dtypeA: torch.dtype,

    deviceid: str,

    dtypeB: Optional[torch.dtype] = None,

    randn: bool = True,

    subMatrix: bool = False,

) -> tuple[torch.Tensor, torch.Tensor]:
    r"""Helper function for _process_single_offline_gemm.

    Creates matrices that are then consumed by one of the Torch GEMM APIs.

    """
    # Fill parameters set for use with ScaledGEMM
    fillA = 0.25
    fillB = 0.75

    if dtypeB is None:
        dtypeB = dtypeA

    if subMatrix:
        # User reference for understanding leading dimension:
        # https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/dgemm.f
        # TO DO: According to lines 108 - 133, there is no lower bound on rowsA,
        # but there is a restriction on rowsB. Using this formula for now as it
        # seems to work for all UTs.
        rowsA = rowsB = max(ldc, k)

        if randn:
            matA = torch.randn(rowsA, lda, dtype=dtypeA, device=deviceid)
            matB = torch.randn(rowsB, ldb, dtype=dtypeA, device=deviceid)
        else:
            matA = torch.full((rowsA, lda), fillA, dtype=dtypeB, device=deviceid)
            matB = torch.full((rowsB, ldb), fillB, dtype=dtypeB, device=deviceid)

        subA = matA[:k, :m].t() if transA else matA[:m, :k]
        subB = matB[:n, :k].t() if transB else matB[:k, :n]
        return subA, subB
    else:
        if randn:
            matA = (
                torch.rand(k, m, dtype=dtypeA, device=deviceid).t()
                if transA
                else torch.rand(m, k, dtype=dtypeA, device=deviceid)
            )
            matB = (
                torch.rand(n, k, dtype=dtypeB, device=deviceid).t()
                if transB
                else torch.rand(k, n, dtype=dtypeB, device=deviceid)
            )
        else:
            matA = (
                torch.full((k, m), fillA, dtype=dtypeA, device=deviceid).t()
                if transA
                else torch.full((m, k), fillA, dtype=dtypeA, device=deviceid)
            )
            matB = (
                torch.full((n, k), fillB, dtype=dtypeB, device=deviceid).t()
                if transB
                else torch.full((k, n), fillB, dtype=dtypeB, device=deviceid)
            )
        return matA, matB


def _create_batch_matrices(

    m: int,

    n: int,

    k: int,

    b: int,

    lda: int,

    ldb: int,

    ldc: int,

    transA: bool,

    transB: bool,

    dtype: torch.dtype,

    deviceid: str,

    subMatrix: bool = False,

) -> tuple[torch.Tensor, torch.Tensor]:
    r"""Helper function for _process_single_offline_gemm.

    Creates batch matrices that are then consumed by one of the Torch GEMM APIs.

    Similar to _create_matrices but for 3D batch matrices.

    """
    if subMatrix:
        # User reference for understanding leading dimension:
        # https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/dgemm.f
        # TO DO: According to lines 108 - 133, there is no lower bound on rowsA,
        # but there is a restriction on rowsB. Using this formula for now as it
        # seems to work for all UTs.
        rowsA = rowsB = max(ldc, k)

        matA = torch.randn(b, rowsA, lda, dtype=dtype, device=deviceid)
        matB = torch.randn(b, rowsB, ldb, dtype=dtype, device=deviceid)

        subA = matA[:b, :k, :m].transpose(1, 2) if transA else matA[:b, :m, :k]
        subB = matB[:b, :n, :k].transpose(1, 2) if transB else matB[:b, :k, :n]
        return subA, subB
    else:
        matA = (
            torch.rand(b, k, m, dtype=dtype, device=deviceid)
            if transA
            else torch.rand(b, m, k, dtype=dtype, device=deviceid)
        )
        matB = (
            torch.rand(b, n, k, dtype=dtype, device=deviceid)
            if transB
            else torch.rand(b, k, n, dtype=dtype, device=deviceid)
        )
        matA = matA.transpose(1, 2) if transA else matA
        matB = matB.transpose(1, 2) if transB else matB
        return matA, matB


def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
    r"""Process a single untuned GEMM."""

    deviceid = "cuda:" + str(gpu_id)

    dtype_dict = {
        "float": torch.float32,
        "tf32": torch.float32,
        "double": torch.float64,
        "BFloat16": torch.bfloat16,
        "Half": torch.half,
        "c10::complex<double>": torch.complex128,
        "c10::complex<float>": torch.complex64,
        "Float8_e4m3fn": torch.float8_e4m3fn,
        "Float8_e5m2": torch.float8_e5m2,
        "Float8_e4m3fnuz": torch.float8_e4m3fnuz,
        "Float8_e5m2fnuz": torch.float8_e5m2fnuz,
    }

    untuned_gemm = untuned_gemm_line.strip().split(",")[:]

    underscore_count = untuned_gemm[0].count("_")

    # Initialize dtype to make linter happy
    dtype = None
    dtypeA = None
    dtypeB = None
    dtypeC = None

    # Extract BLAS parameters
    if underscore_count == 2:
        [op_sig, data_type, layout] = untuned_gemm[0].split("_")
        transB = layout[0] == "T"
        transA = layout[1] == "T"
        dtype = dtype_dict.get(data_type)
        if data_type == "tf32":
            # User must still set HIPBLASLT_ALLOW_TF32=1
            torch.backends.cuda.matmul.allow_tf32 = True
        else:
            torch.backends.cuda.matmul.allow_tf32 = False

    else:  # ScaledGEMM
        count = untuned_gemm[0].count("_")
        assert count in [6, 7]
        untuned_gemm_temp = untuned_gemm[0].split("_")
        # dtypeC = might not be FP8 type, keep track
        # of the the number of underscores
        op_sig = untuned_gemm_temp[0]
        data_typeA = untuned_gemm_temp[1] + "_" + untuned_gemm_temp[2]
        data_typeB = untuned_gemm_temp[3] + "_" + untuned_gemm_temp[4]
        if count == 7:
            data_typeC = untuned_gemm_temp[5] + "_" + untuned_gemm_temp[6]
        else:
            data_typeC = untuned_gemm_temp[5]
        transB = untuned_gemm_temp[count][0] == "T"
        transA = untuned_gemm_temp[count][1] == "T"
        dtypeA = dtype_dict.get(data_typeA)
        dtypeB = dtype_dict.get(data_typeB)
        dtypeC = dtype_dict.get(data_typeC)

    untuned_gemm_temp = untuned_gemm[1].split("_")
    [n, m, k] = [int(g) for g in untuned_gemm_temp[1:4]]
    if op_sig == "GemmStridedBatchedTunableOp":
        assert untuned_gemm_temp[6] == "ld"
        [ldb, lda, ldc] = [int(g) for g in untuned_gemm_temp[7:10]]
    else:
        assert untuned_gemm_temp[4] == "ld"
        [ldb, lda, ldc] = [int(g) for g in untuned_gemm_temp[5:8]]

    # Detect subMatrix case
    if all(item in [n, m, k] for item in [lda, ldb, ldc]):
        subMatrix = False
    else:
        subMatrix = True

    if op_sig == "GemmTunableOp":
        # Warnings for unsupported cases:
        if m == 1 or n == 1 or k == 1:
            if (not transA) and (not transB):
                pass  # case is supported
            elif transA and n == 1:
                pass  # case is supported
            else:
                warnings.warn(
                    "Offline tuning is not supported for this GEMM. Use online tuning instead. "
                    + f"Skipped tuning for: {untuned_gemm[1]}"
                )
                return

        # Resolve linter issue
        if dtype is None or not isinstance(dtype, torch.dtype):
            raise TypeError(f"dtype must be a torch.dtype, but got {dtype}")

        matA, matB = _create_matrices(
            m, n, k, lda, ldb, ldc, transA, transB, dtype, deviceid, subMatrix=subMatrix
        )
        torch.mm(matA, matB)

    elif op_sig == "GemmStridedBatchedTunableOp":
        # Warnings for unsupported cases:
        if m == 1 or n == 1 or k == 1:
            warnings.warn(
                "Offline tuning is not support for this GEMM. Use online tuning instead. "
                + f"Skipped tuning for: {untuned_gemm[1]}"
            )
            return

        [b] = [int(g) for g in untuned_gemm_temp[5:6]]

        # Resolve linter issue
        if dtype is None or not isinstance(dtype, torch.dtype):
            raise TypeError(f"dtype must be a torch.dtype, but got {dtype}")

        matA, matB = _create_batch_matrices(
            m,
            n,
            k,
            b,
            lda,
            ldb,
            ldc,
            transA,
            transB,
            dtype,
            deviceid,
            subMatrix=subMatrix,
        )
        torch.bmm(matA, matB)
    elif op_sig == "ScaledGemmTunableOp":
        # Only combination supported by PyTorch
        assert transB is True
        assert transA is False

        # Resolve linter issue
        if dtypeA is None or not isinstance(dtypeA, torch.dtype):
            raise TypeError(f"dtype must be a torch.dtype, but got {dtypeA}")

        matA, matB = _create_matrices(
            m,
            n,
            k,
            lda,
            ldb,
            ldc,
            transA,
            transB,
            dtypeA,
            deviceid,
            dtypeB=dtypeB,
            randn=False,
            subMatrix=subMatrix,
        )

        assert untuned_gemm_temp[8] == "rw"
        if untuned_gemm_temp[9] == "1":
            rowwise = True
        else:
            rowwise = False
        if rowwise:
            scaleA = (
                torch.ones((1, m), device=deviceid)
                if transA
                else torch.ones((m, 1), device=deviceid)
            )
            scaleB = (
                torch.ones((1, n), device=deviceid)
                if transB
                else torch.ones((n, 1), device=deviceid)
            )
        else:
            scaleA = torch.tensor(0.8, device=deviceid)
            scaleB = torch.tensor(0.9, device=deviceid)

        assert untuned_gemm_temp[10] == "bias"
        if untuned_gemm_temp[11] == "None":  # no bias vector
            torch._scaled_mm(
                matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=dtypeC
            )
        else:  # bias vector present
            fillbias = 0.10
            bias_dtype = dtype_dict.get(untuned_gemm_temp[11])
            bias = (
                torch.full((n,), fillbias, dtype=bias_dtype, device=deviceid)
                if transB
                else torch.full((m,), fillbias, dtype=bias_dtype, device=deviceid)
            )
            torch._scaled_mm(
                matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=dtypeC, bias=bias
            )

    elif op_sig == "GemmAndBiasTunableOp":
        # y = x*A^T + b
        assert transA != transB

        # Resolve linter issue
        if dtype is None or not isinstance(dtype, torch.dtype):
            raise TypeError(f"dtype must be a torch.dtype, but got {dtype}")

        bias = torch.rand(n, dtype=dtype, device=deviceid)

        X, matA = _create_matrices(
            m, n, k, lda, ldb, ldc, transA, transB, dtype, deviceid, subMatrix=subMatrix
        )
        matA = matA.t()
        torch.nn.functional.linear(X, matA, bias)
    else:
        warnings.warn(f"error: unknown op {op_sig}")


def _check_tuning_assertions() -> None:
    r"""Helper function for multi-GPU tuning case. Need to check that TunableOp feature

    is enabled and that tuning is enabled.

    """

    if is_enabled() is False:
        warnings.warn("TunableOp was disabled. Trying to enable now.")
        enable(True)
    assert is_enabled() is True
    assert tuning_is_enabled() is True
    assert record_untuned_is_enabled() is False


def mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None:
    r"""Process one or more files and distribute work over one or more GPUs."""
    unique_gemm_entries = _gather_unique_untuned_gemm_from_files(filename_pattern)

    total_gpus = torch.cuda.device_count()

    assert 1 <= num_gpus <= total_gpus

    mp_context = mp.get_context("spawn")

    futures = []  # empty list to hold futures
    flush_results = []  # empty list to hold futures

    # GEMM are assigned to GPUs in a round robin manner
    h = 0
    with concurrent.futures.ProcessPoolExecutor(
        max_workers=num_gpus,
        mp_context=mp_context,
        initializer=_check_tuning_assertions,
    ) as executor:
        # The workers are a separate process. TunableOp will be
        # enabled in the child processes if PYTORCH_TUNABLEOP_ENABLED=1
        # In the initializer, we also try to enable TunableOP if th
        # environment variable was NOT set.

        for line in unique_gemm_entries:
            future = executor.submit(_process_single_offline_gemm, line, h)
            futures.append(future)
            h = (h + 1) % num_gpus

        for future in concurrent.futures.as_completed(futures):
            future.result()

        for g in range(num_gpus):
            flush_result = executor.submit(write_file)
            flush_results.append(flush_result)

        for flush_result in concurrent.futures.as_completed(flush_results):
            flush_result.result()

    torch.cuda.synchronize()

    _gather_tunableop_results()