| import argparse |
| import mmap |
| import os |
| import re |
| import subprocess |
| from torch.utils.cpp_extension import CUDA_HOME |
|
|
|
|
| def run_cuobjdump(file_path): |
| command = [f'{CUDA_HOME}/bin/cuobjdump', '-sass', file_path] |
| result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) |
| assert result.returncode == 0 |
| return result.stdout |
|
|
|
|
| def extract_ffma(sass): |
| lines = sass.splitlines() |
| collected = [] |
| current = [] |
|
|
| arch_name, func_name = 'N/A', 'N/A' |
| skip_next_line = False |
| for line in lines: |
| if 'code for' in line: |
| arch_name = line.lstrip().lstrip('code for ').rstrip() |
| elif 'Function :' in line: |
| func_name = line.lstrip().lstrip('Function :').rstrip() |
| elif 'FFMA' in line: |
| current.append(line) |
| skip_next_line = True |
| elif skip_next_line: |
| current.append(line) |
| skip_next_line = False |
| else: |
| if len(current) >= 16: |
| assert len(current) % 2 == 0 |
| collected.append((f'{arch_name}::{func_name}', current)) |
| current = [] |
|
|
| if os.getenv('DG_PRINT_REG_REUSE', None): |
| print(f'Found {len(collected)} FFMA segments') |
| return collected |
|
|
|
|
| def extract_hex_from_line(line): |
| match = re.search(r'/\*\s*(0x[0-9a-fA-F]+)\s*\*/', line) |
| assert match |
| return int(match.group(1), 16) |
|
|
|
|
| def validate(m, offset, le_bytes, num_lines): |
| assert len(le_bytes) == num_lines // 2 |
| assert m[offset:offset + 16] == le_bytes[0] |
| for i in range(1, num_lines // 2): |
| if m[offset + i * 16:offset + i * 16 + 16] != le_bytes[i]: |
| return False |
| return True |
|
|
|
|
| def parse_registers(line): |
| line = re.sub(r'/\*.*?\*/', '', line) |
| line = line.replace(';', '') |
| tokens = line.strip().split(',') |
| registers = [] |
| for token in tokens: |
| token = token.strip() |
| words = token.split() |
| for word in words: |
| if word.startswith('R'): |
| reg = word.split('.')[0] |
| registers.append(reg) |
| return registers |
|
|
|
|
| def modify_segment(m, name, ffma_lines): |
| num_lines = len(ffma_lines) |
| assert num_lines % 2 == 0 |
|
|
| le_bytes, new_le_bytes = [], [] |
| reused_list = [] |
| dst_reg_set = set() |
| last_reused, last_dst_reg = False, '' |
| num_changed = 0 |
| for i in range(num_lines // 2): |
| dst_reg = parse_registers(ffma_lines[i * 2])[-2] |
| low_line, high_line = ffma_lines[i * 2], ffma_lines[i * 2 + 1] |
| low_hex, high_hex = extract_hex_from_line(low_line), extract_hex_from_line(high_line) |
| le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little')) |
| reused = (high_hex & 0x0800000000000000) != 0 |
| if reused: |
| is_first_occurred = dst_reg not in dst_reg_set |
| if is_first_occurred or (last_reused and dst_reg == last_dst_reg): |
| |
| assert high_hex & 0x0800200000000000, f'{hex(high_hex)}' |
| high_hex ^= 0x0800200000000000 |
| reused = False |
| num_changed += 1 |
| else: |
| reused_list.append(i) |
| dst_reg_set.add(dst_reg) |
| new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little')) |
| last_reused, last_dst_reg = reused, dst_reg |
| if os.getenv('DG_PRINT_REG_REUSE', None): |
| print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}') |
|
|
| |
| offsets = [] |
| offset = m.find(le_bytes[0]) |
| while offset != -1: |
| offsets.append(offset) |
| offset = m.find(le_bytes[0], offset + 1) |
| offsets = list(filter(lambda x: validate(m, x, le_bytes, num_lines), offsets)) |
|
|
| |
| for offset in offsets: |
| for i in range(num_lines // 2): |
| m[offset + i * 16:offset + i * 16 + 16] = new_le_bytes[i] |
|
|
|
|
| def process(path): |
| if os.getenv('DG_PRINT_REG_REUSE', None): |
| print(f'Processing {path}') |
| output = run_cuobjdump(path) |
| segments = extract_ffma(output) |
| with open(path, 'r+b') as f: |
| mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_WRITE) |
| for segment in segments: |
| modify_segment(mm, *segment) |
| mm.close() |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser(description='Interleave FFMA reg reuse') |
| parser.add_argument('--so', help='Path to the SO file') |
| args = parser.parse_args() |
|
|
| process(args.so) |
|
|