File size: 4,591 Bytes
67e9774 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import io
import json
import logging
import os
import tempfile
from typing import IO
import torch
from torch._inductor import config
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
from torch.export.pt2_archive._package import (
AOTI_FILES,
AOTICompiledModel,
load_pt2,
package_pt2,
)
from torch.types import FileLike
log = logging.getLogger(__name__)
def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str:
def get_aoti_file_with_suffix(suffix: str) -> str:
for file in aoti_files:
if file.endswith(suffix):
return file
raise RuntimeError(f"Unable to find file with suffix {suffix}")
# Compile all the files into a .so
cpp_file = os.path.join(aoti_dir, get_aoti_file_with_suffix(".cpp"))
consts_o = os.path.join(aoti_dir, get_aoti_file_with_suffix(".o"))
file_name = os.path.splitext(cpp_file)[0]
# Parse compile flags and build the .o file
with open(file_name + "_compile_flags.json") as f:
compile_flags = json.load(f)
compile_options = BuildOptionsBase(
**compile_flags, use_relative_path=config.is_fbcode()
)
object_builder = CppBuilder(
name=file_name,
sources=cpp_file,
BuildOption=compile_options,
)
output_o = object_builder.get_target_file_path()
object_builder.build()
# Parse linker flags and build the .so file
with open(file_name + "_linker_flags.json") as f:
linker_flags = json.load(f)
linker_options = BuildOptionsBase(
**linker_flags, use_relative_path=config.is_fbcode()
)
so_builder = CppBuilder(
name=os.path.split(so_path)[-1],
sources=[output_o, consts_o],
BuildOption=linker_options,
output_dir=so_path,
)
output_so = so_builder.get_target_file_path()
so_builder.build()
# mmapped weights
serialized_weights_filename = file_name + "_serialized_weights.bin"
if serialized_weights_filename in aoti_files:
with open(serialized_weights_filename, "rb") as f_weights:
serialized_weights = f_weights.read()
with open(output_so, "a+b") as f_so:
so_size = f_so.tell()
# Page align the weights
f_so.write(b" " * (16384 - so_size % 16384))
f_so.write(serialized_weights)
return output_so
def package_aoti(
archive_file: FileLike,
aoti_files: AOTI_FILES,
) -> FileLike:
"""
Saves the AOTInductor generated files to the PT2Archive format.
Args:
archive_file: The file name to save the package to.
aoti_files: This can either be a singular path to a directory containing
the AOTInductor files, or a dictionary mapping the model name to the
path to its AOTInductor generated files.
"""
return package_pt2(
archive_file,
aoti_files=aoti_files,
)
def load_package(
path: FileLike,
model_name: str = "model",
run_single_threaded: bool = False,
num_runners: int = 1,
device_index: int = -1,
) -> AOTICompiledModel: # type: ignore[type-arg]
try:
pt2_contents = load_pt2(
path,
run_single_threaded=run_single_threaded,
num_runners=num_runners,
device_index=device_index,
)
if model_name not in pt2_contents.aoti_runners:
raise RuntimeError(f"Model {model_name} not found in package")
return pt2_contents.aoti_runners[model_name]
except RuntimeError:
log.warning("Loading outdated pt2 file. Please regenerate your package.")
if isinstance(path, (io.IOBase, IO)):
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
# TODO(angelayi): We shouldn't need to do this -- miniz should
# handle reading the buffer. This is just a temporary workaround
path.seek(0)
f.write(path.read())
log.debug("Writing buffer to tmp file located at %s.", f.name)
loader = torch._C._aoti.AOTIModelPackageLoader(
f.name, model_name, run_single_threaded, num_runners, device_index
)
return AOTICompiledModel(loader)
path = os.fspath(path) # AOTIModelPackageLoader expects (str, str)
loader = torch._C._aoti.AOTIModelPackageLoader(
path, model_name, run_single_threaded, num_runners, device_index
)
return AOTICompiledModel(loader)
|