diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/_core_metadata.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/_core_metadata.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a30a30d0eba9d67b566763318ca16c24ca104a9e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/_core_metadata.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/_entry_points.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/_entry_points.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ddbab587e245e1cdb4e8a84952906055e59a893 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/_entry_points.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/_importlib.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/_importlib.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..585bd3781a24e414edf36969ef81dbd2feb83785 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/_importlib.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/_reqs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/_reqs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1db693db6e9d7b1649f3a102977ed60edd3dcdca Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/_reqs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/build_meta.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/build_meta.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d3938dfc5137d75902feb61c6d1774865ea0c8c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/build_meta.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/discovery.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/discovery.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7acec24c8b166c85cdd0ad4f02260c7f9824a7f3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/discovery.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/installer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/installer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f4d6347378c1765c0eacb97703cb25c24e44aba Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/installer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/logging.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/logging.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..563c565cf7a983f434949655b4a1c59a31939a3c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/logging.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/wheel.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/wheel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1be8356aeb2a1bf39da7513c3c87a4e594f49e7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/wheel.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/windows_support.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/windows_support.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9720a70e4712ba0099677c532d9cad8492eafcf3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/__pycache__/windows_support.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/_requirestxt.py b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/_requirestxt.py new file mode 100644 index 0000000000000000000000000000000000000000..1f1967e7aa3501aef2880c24ea284bfbb1d3d291 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/_requirestxt.py @@ -0,0 +1,131 @@ +"""Helper code used to generate ``requires.txt`` files in the egg-info directory. + +The ``requires.txt`` file has an specific format: + - Environment markers need to be part of the section headers and + should not be part of the requirement spec itself. + +See https://setuptools.pypa.io/en/latest/deprecated/python_eggs.html#requires-txt +""" + +from __future__ import annotations + +import io +from collections import defaultdict +from itertools import filterfalse +from typing import Dict, Mapping, TypeVar + +from .. import _reqs +from ..extern.jaraco.text import yield_lines +from ..extern.packaging.requirements import Requirement + + +# dict can work as an ordered set +_T = TypeVar("_T") +_Ordered = Dict[_T, None] +_ordered = dict +_StrOrIter = _reqs._StrOrIter + + +def _prepare( + install_requires: _StrOrIter, extras_require: Mapping[str, _StrOrIter] +) -> tuple[list[str], dict[str, list[str]]]: + """Given values for ``install_requires`` and ``extras_require`` + create modified versions in a way that can be written in ``requires.txt`` + """ + extras = _convert_extras_requirements(extras_require) + return _move_install_requirements_markers(install_requires, extras) + + +def _convert_extras_requirements( + extras_require: Mapping[str, _StrOrIter], +) -> Mapping[str, _Ordered[Requirement]]: + """ + Convert requirements in `extras_require` of the form + `"extra": ["barbazquux; {marker}"]` to + `"extra:{marker}": ["barbazquux"]`. + """ + output: Mapping[str, _Ordered[Requirement]] = defaultdict(dict) + for section, v in extras_require.items(): + # Do not strip empty sections. + output[section] + for r in _reqs.parse(v): + output[section + _suffix_for(r)].setdefault(r) + + return output + + +def _move_install_requirements_markers( + install_requires: _StrOrIter, extras_require: Mapping[str, _Ordered[Requirement]] +) -> tuple[list[str], dict[str, list[str]]]: + """ + The ``requires.txt`` file has an specific format: + - Environment markers need to be part of the section headers and + should not be part of the requirement spec itself. + + Move requirements in ``install_requires`` that are using environment + markers ``extras_require``. + """ + + # divide the install_requires into two sets, simple ones still + # handled by install_requires and more complex ones handled by extras_require. + + inst_reqs = list(_reqs.parse(install_requires)) + simple_reqs = filter(_no_marker, inst_reqs) + complex_reqs = filterfalse(_no_marker, inst_reqs) + simple_install_requires = list(map(str, simple_reqs)) + + for r in complex_reqs: + extras_require[':' + str(r.marker)].setdefault(r) + + expanded_extras = dict( + # list(dict.fromkeys(...)) ensures a list of unique strings + (k, list(dict.fromkeys(str(r) for r in map(_clean_req, v)))) + for k, v in extras_require.items() + ) + + return simple_install_requires, expanded_extras + + +def _suffix_for(req): + """Return the 'extras_require' suffix for a given requirement.""" + return ':' + str(req.marker) if req.marker else '' + + +def _clean_req(req): + """Given a Requirement, remove environment markers and return it""" + r = Requirement(str(req)) # create a copy before modifying + r.marker = None + return r + + +def _no_marker(req): + return not req.marker + + +def _write_requirements(stream, reqs): + lines = yield_lines(reqs or ()) + + def append_cr(line): + return line + '\n' + + lines = map(append_cr, lines) + stream.writelines(lines) + + +def write_requirements(cmd, basename, filename): + dist = cmd.distribution + data = io.StringIO() + install_requires, extras_require = _prepare( + dist.install_requires or (), dist.extras_require or {} + ) + _write_requirements(data, install_requires) + for extra in sorted(extras_require): + data.write('\n[{extra}]\n'.format(**vars())) + _write_requirements(data, extras_require[extra]) + cmd.write_or_delete_file("requirements", filename, data.getvalue()) + + +def write_setup_requirements(cmd, basename, filename): + data = io.StringIO() + _write_requirements(data, cmd.distribution.setup_requires) + cmd.write_or_delete_file("setup-requirements", filename, data.getvalue()) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/bdist_egg.py b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/bdist_egg.py new file mode 100644 index 0000000000000000000000000000000000000000..559f7d6032a2167645f2f9ec1d2204b7c3324888 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/bdist_egg.py @@ -0,0 +1,464 @@ +"""setuptools.command.bdist_egg + +Build .egg distributions""" + +from distutils.dir_util import remove_tree, mkpath +from distutils import log +from types import CodeType +import sys +import os +import re +import textwrap +import marshal + +from setuptools.extension import Library +from setuptools import Command +from .._path import ensure_directory + +from sysconfig import get_path, get_python_version + + +def _get_purelib(): + return get_path("purelib") + + +def strip_module(filename): + if '.' in filename: + filename = os.path.splitext(filename)[0] + if filename.endswith('module'): + filename = filename[:-6] + return filename + + +def sorted_walk(dir): + """Do os.walk in a reproducible way, + independent of indeterministic filesystem readdir order + """ + for base, dirs, files in os.walk(dir): + dirs.sort() + files.sort() + yield base, dirs, files + + +def write_stub(resource, pyfile): + _stub_template = textwrap.dedent( + """ + def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, importlib.util + __file__ = pkg_resources.resource_filename(__name__, %r) + __loader__ = None; del __bootstrap__, __loader__ + spec = importlib.util.spec_from_file_location(__name__,__file__) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + __bootstrap__() + """ + ).lstrip() + with open(pyfile, 'w', encoding="utf-8") as f: + f.write(_stub_template % resource) + + +class bdist_egg(Command): + description = "create an \"egg\" distribution" + + user_options = [ + ('bdist-dir=', 'b', "temporary directory for creating the distribution"), + ( + 'plat-name=', + 'p', + "platform name to embed in generated filenames " + "(by default uses `pkg_resources.get_build_platform()`)", + ), + ('exclude-source-files', None, "remove all .py files from the generated egg"), + ( + 'keep-temp', + 'k', + "keep the pseudo-installation tree around after " + "creating the distribution archive", + ), + ('dist-dir=', 'd', "directory to put final built distributions in"), + ('skip-build', None, "skip rebuilding everything (for testing/debugging)"), + ] + + boolean_options = ['keep-temp', 'skip-build', 'exclude-source-files'] + + def initialize_options(self): + self.bdist_dir = None + self.plat_name = None + self.keep_temp = False + self.dist_dir = None + self.skip_build = False + self.egg_output = None + self.exclude_source_files = None + + def finalize_options(self): + ei_cmd = self.ei_cmd = self.get_finalized_command("egg_info") + self.egg_info = ei_cmd.egg_info + + if self.bdist_dir is None: + bdist_base = self.get_finalized_command('bdist').bdist_base + self.bdist_dir = os.path.join(bdist_base, 'egg') + + if self.plat_name is None: + from pkg_resources import get_build_platform + + self.plat_name = get_build_platform() + + self.set_undefined_options('bdist', ('dist_dir', 'dist_dir')) + + if self.egg_output is None: + # Compute filename of the output egg + basename = ei_cmd._get_egg_basename( + py_version=get_python_version(), + platform=self.distribution.has_ext_modules() and self.plat_name, + ) + + self.egg_output = os.path.join(self.dist_dir, basename + '.egg') + + def do_install_data(self): + # Hack for packages that install data to install's --install-lib + self.get_finalized_command('install').install_lib = self.bdist_dir + + site_packages = os.path.normcase(os.path.realpath(_get_purelib())) + old, self.distribution.data_files = self.distribution.data_files, [] + + for item in old: + if isinstance(item, tuple) and len(item) == 2: + if os.path.isabs(item[0]): + realpath = os.path.realpath(item[0]) + normalized = os.path.normcase(realpath) + if normalized == site_packages or normalized.startswith( + site_packages + os.sep + ): + item = realpath[len(site_packages) + 1 :], item[1] + # XXX else: raise ??? + self.distribution.data_files.append(item) + + try: + log.info("installing package data to %s", self.bdist_dir) + self.call_command('install_data', force=False, root=None) + finally: + self.distribution.data_files = old + + def get_outputs(self): + return [self.egg_output] + + def call_command(self, cmdname, **kw): + """Invoke reinitialized command `cmdname` with keyword args""" + for dirname in INSTALL_DIRECTORY_ATTRS: + kw.setdefault(dirname, self.bdist_dir) + kw.setdefault('skip_build', self.skip_build) + kw.setdefault('dry_run', self.dry_run) + cmd = self.reinitialize_command(cmdname, **kw) + self.run_command(cmdname) + return cmd + + def run(self): # noqa: C901 # is too complex (14) # FIXME + # Generate metadata first + self.run_command("egg_info") + # We run install_lib before install_data, because some data hacks + # pull their data path from the install_lib command. + log.info("installing library code to %s", self.bdist_dir) + instcmd = self.get_finalized_command('install') + old_root = instcmd.root + instcmd.root = None + if self.distribution.has_c_libraries() and not self.skip_build: + self.run_command('build_clib') + cmd = self.call_command('install_lib', warn_dir=False) + instcmd.root = old_root + + all_outputs, ext_outputs = self.get_ext_outputs() + self.stubs = [] + to_compile = [] + for p, ext_name in enumerate(ext_outputs): + filename, ext = os.path.splitext(ext_name) + pyfile = os.path.join(self.bdist_dir, strip_module(filename) + '.py') + self.stubs.append(pyfile) + log.info("creating stub loader for %s", ext_name) + if not self.dry_run: + write_stub(os.path.basename(ext_name), pyfile) + to_compile.append(pyfile) + ext_outputs[p] = ext_name.replace(os.sep, '/') + + if to_compile: + cmd.byte_compile(to_compile) + if self.distribution.data_files: + self.do_install_data() + + # Make the EGG-INFO directory + archive_root = self.bdist_dir + egg_info = os.path.join(archive_root, 'EGG-INFO') + self.mkpath(egg_info) + if self.distribution.scripts: + script_dir = os.path.join(egg_info, 'scripts') + log.info("installing scripts to %s", script_dir) + self.call_command('install_scripts', install_dir=script_dir, no_ep=True) + + self.copy_metadata_to(egg_info) + native_libs = os.path.join(egg_info, "native_libs.txt") + if all_outputs: + log.info("writing %s", native_libs) + if not self.dry_run: + ensure_directory(native_libs) + with open(native_libs, 'wt', encoding="utf-8") as libs_file: + libs_file.write('\n'.join(all_outputs)) + libs_file.write('\n') + elif os.path.isfile(native_libs): + log.info("removing %s", native_libs) + if not self.dry_run: + os.unlink(native_libs) + + write_safety_flag(os.path.join(archive_root, 'EGG-INFO'), self.zip_safe()) + + if os.path.exists(os.path.join(self.egg_info, 'depends.txt')): + log.warn( + "WARNING: 'depends.txt' will not be used by setuptools 0.6!\n" + "Use the install_requires/extras_require setup() args instead." + ) + + if self.exclude_source_files: + self.zap_pyfiles() + + # Make the archive + make_zipfile( + self.egg_output, + archive_root, + verbose=self.verbose, + dry_run=self.dry_run, + mode=self.gen_header(), + ) + if not self.keep_temp: + remove_tree(self.bdist_dir, dry_run=self.dry_run) + + # Add to 'Distribution.dist_files' so that the "upload" command works + getattr(self.distribution, 'dist_files', []).append(( + 'bdist_egg', + get_python_version(), + self.egg_output, + )) + + def zap_pyfiles(self): + log.info("Removing .py files from temporary directory") + for base, dirs, files in walk_egg(self.bdist_dir): + for name in files: + path = os.path.join(base, name) + + if name.endswith('.py'): + log.debug("Deleting %s", path) + os.unlink(path) + + if base.endswith('__pycache__'): + path_old = path + + pattern = r'(?P.+)\.(?P[^.]+)\.pyc' + m = re.match(pattern, name) + path_new = os.path.join(base, os.pardir, m.group('name') + '.pyc') + log.info("Renaming file from [%s] to [%s]" % (path_old, path_new)) + try: + os.remove(path_new) + except OSError: + pass + os.rename(path_old, path_new) + + def zip_safe(self): + safe = getattr(self.distribution, 'zip_safe', None) + if safe is not None: + return safe + log.warn("zip_safe flag not set; analyzing archive contents...") + return analyze_egg(self.bdist_dir, self.stubs) + + def gen_header(self): + return 'w' + + def copy_metadata_to(self, target_dir): + "Copy metadata (egg info) to the target_dir" + # normalize the path (so that a forward-slash in egg_info will + # match using startswith below) + norm_egg_info = os.path.normpath(self.egg_info) + prefix = os.path.join(norm_egg_info, '') + for path in self.ei_cmd.filelist.files: + if path.startswith(prefix): + target = os.path.join(target_dir, path[len(prefix) :]) + ensure_directory(target) + self.copy_file(path, target) + + def get_ext_outputs(self): + """Get a list of relative paths to C extensions in the output distro""" + + all_outputs = [] + ext_outputs = [] + + paths = {self.bdist_dir: ''} + for base, dirs, files in sorted_walk(self.bdist_dir): + all_outputs.extend( + paths[base] + filename + for filename in files + if os.path.splitext(filename)[1].lower() in NATIVE_EXTENSIONS + ) + for filename in dirs: + paths[os.path.join(base, filename)] = paths[base] + filename + '/' + + if self.distribution.has_ext_modules(): + build_cmd = self.get_finalized_command('build_ext') + for ext in build_cmd.extensions: + if isinstance(ext, Library): + continue + fullname = build_cmd.get_ext_fullname(ext.name) + filename = build_cmd.get_ext_filename(fullname) + if not os.path.basename(filename).startswith('dl-'): + if os.path.exists(os.path.join(self.bdist_dir, filename)): + ext_outputs.append(filename) + + return all_outputs, ext_outputs + + +NATIVE_EXTENSIONS = dict.fromkeys('.dll .so .dylib .pyd'.split()) + + +def walk_egg(egg_dir): + """Walk an unpacked egg's contents, skipping the metadata directory""" + walker = sorted_walk(egg_dir) + base, dirs, files = next(walker) + if 'EGG-INFO' in dirs: + dirs.remove('EGG-INFO') + yield base, dirs, files + yield from walker + + +def analyze_egg(egg_dir, stubs): + # check for existing flag in EGG-INFO + for flag, fn in safety_flags.items(): + if os.path.exists(os.path.join(egg_dir, 'EGG-INFO', fn)): + return flag + if not can_scan(): + return False + safe = True + for base, dirs, files in walk_egg(egg_dir): + for name in files: + if name.endswith('.py') or name.endswith('.pyw'): + continue + elif name.endswith('.pyc') or name.endswith('.pyo'): + # always scan, even if we already know we're not safe + safe = scan_module(egg_dir, base, name, stubs) and safe + return safe + + +def write_safety_flag(egg_dir, safe): + # Write or remove zip safety flag file(s) + for flag, fn in safety_flags.items(): + fn = os.path.join(egg_dir, fn) + if os.path.exists(fn): + if safe is None or bool(safe) != flag: + os.unlink(fn) + elif safe is not None and bool(safe) == flag: + with open(fn, 'wt', encoding="utf-8") as f: + f.write('\n') + + +safety_flags = { + True: 'zip-safe', + False: 'not-zip-safe', +} + + +def scan_module(egg_dir, base, name, stubs): + """Check whether module possibly uses unsafe-for-zipfile stuff""" + + filename = os.path.join(base, name) + if filename[:-1] in stubs: + return True # Extension module + pkg = base[len(egg_dir) + 1 :].replace(os.sep, '.') + module = pkg + (pkg and '.' or '') + os.path.splitext(name)[0] + skip = 16 # skip magic & reserved? & date & file size + f = open(filename, 'rb') + f.read(skip) + code = marshal.load(f) + f.close() + safe = True + symbols = dict.fromkeys(iter_symbols(code)) + for bad in ['__file__', '__path__']: + if bad in symbols: + log.warn("%s: module references %s", module, bad) + safe = False + if 'inspect' in symbols: + for bad in [ + 'getsource', + 'getabsfile', + 'getfile', + 'getsourcefile', + 'getsourcelines', + 'findsource', + 'getcomments', + 'getframeinfo', + 'getinnerframes', + 'getouterframes', + 'stack', + 'trace', + ]: + if bad in symbols: + log.warn("%s: module MAY be using inspect.%s", module, bad) + safe = False + return safe + + +def iter_symbols(code): + """Yield names and strings used by `code` and its nested code objects""" + yield from code.co_names + for const in code.co_consts: + if isinstance(const, str): + yield const + elif isinstance(const, CodeType): + yield from iter_symbols(const) + + +def can_scan(): + if not sys.platform.startswith('java') and sys.platform != 'cli': + # CPython, PyPy, etc. + return True + log.warn("Unable to analyze compiled code on this platform.") + log.warn( + "Please ask the author to include a 'zip_safe'" + " setting (either True or False) in the package's setup.py" + ) + return False + + +# Attribute names of options for commands that might need to be convinced to +# install to the egg build directory + +INSTALL_DIRECTORY_ATTRS = ['install_lib', 'install_dir', 'install_data', 'install_base'] + + +def make_zipfile( + zip_filename, base_dir, verbose=False, dry_run=False, compress=True, mode='w' +): + """Create a zip file from all the files under 'base_dir'. The output + zip file will be named 'base_dir' + ".zip". Uses either the "zipfile" + Python module (if available) or the InfoZIP "zip" utility (if installed + and found on the default search path). If neither tool is available, + raises DistutilsExecError. Returns the name of the output zip file. + """ + import zipfile + + mkpath(os.path.dirname(zip_filename), dry_run=dry_run) + log.info("creating '%s' and adding '%s' to it", zip_filename, base_dir) + + def visit(z, dirname, names): + for name in names: + path = os.path.normpath(os.path.join(dirname, name)) + if os.path.isfile(path): + p = path[len(base_dir) + 1 :] + if not dry_run: + z.write(path, p) + log.debug("adding '%s'", p) + + compression = zipfile.ZIP_DEFLATED if compress else zipfile.ZIP_STORED + if not dry_run: + z = zipfile.ZipFile(zip_filename, mode, compression=compression) + for dirname, dirs, files in sorted_walk(base_dir): + visit(z, dirname, files) + z.close() + else: + for dirname, dirs, files in sorted_walk(base_dir): + visit(None, dirname, files) + return zip_filename diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/bdist_rpm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/bdist_rpm.py new file mode 100644 index 0000000000000000000000000000000000000000..70ed6b6097fbe5de359539e9c5ba62801076e093 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/bdist_rpm.py @@ -0,0 +1,39 @@ +import distutils.command.bdist_rpm as orig + +from ..warnings import SetuptoolsDeprecationWarning + + +class bdist_rpm(orig.bdist_rpm): + """ + Override the default bdist_rpm behavior to do the following: + + 1. Run egg_info to ensure the name and version are properly calculated. + 2. Always run 'install' using --single-version-externally-managed to + disable eggs in RPM distributions. + """ + + def run(self): + SetuptoolsDeprecationWarning.emit( + "Deprecated command", + """ + bdist_rpm is deprecated and will be removed in a future version. + Use bdist_wheel (wheel packages) instead. + """, + see_url="https://github.com/pypa/setuptools/issues/1988", + due_date=(2023, 10, 30), # Deprecation introduced in 22 Oct 2021. + ) + + # ensure distro name is up-to-date + self.run_command('egg_info') + + orig.bdist_rpm.run(self) + + def _make_spec_file(self): + spec = orig.bdist_rpm._make_spec_file(self) + return [ + line.replace( + "setup.py install ", + "setup.py install --single-version-externally-managed ", + ).replace("%setup", "%setup -n %{name}-%{unmangled_version}") + for line in spec + ] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/bdist_wheel.py b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/bdist_wheel.py new file mode 100644 index 0000000000000000000000000000000000000000..d8cdd4e4060f68027ca6d93721813bcce343c86e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/bdist_wheel.py @@ -0,0 +1,597 @@ +""" +Create a wheel (.whl) distribution. + +A wheel is a built archive format. +""" + +from __future__ import annotations + +import os +import re +import shutil +import stat +import struct +import sys +import sysconfig +import warnings +from email.generator import BytesGenerator, Generator +from email.policy import EmailPolicy +from distutils import log +from glob import iglob +from shutil import rmtree +from typing import TYPE_CHECKING, Callable, Iterable, Literal, Sequence, cast +from zipfile import ZIP_DEFLATED, ZIP_STORED + +from .. import Command, __version__ +from ..extern.wheel.metadata import pkginfo_to_metadata +from ..extern.packaging import tags +from ..extern.packaging import version as _packaging_version +from ..extern.wheel.wheelfile import WheelFile + +if TYPE_CHECKING: + import types + + +def safe_name(name: str) -> str: + """Convert an arbitrary string to a standard distribution name + Any runs of non-alphanumeric/. characters are replaced with a single '-'. + """ + return re.sub("[^A-Za-z0-9.]+", "-", name) + + +def safe_version(version: str) -> str: + """ + Convert an arbitrary string to a standard version string + """ + try: + # normalize the version + return str(_packaging_version.Version(version)) + except _packaging_version.InvalidVersion: + version = version.replace(" ", ".") + return re.sub("[^A-Za-z0-9.]+", "-", version) + + +setuptools_major_version = int(__version__.split(".")[0]) + +PY_LIMITED_API_PATTERN = r"cp3\d" + + +def _is_32bit_interpreter() -> bool: + return struct.calcsize("P") == 4 + + +def python_tag() -> str: + return f"py{sys.version_info[0]}" + + +def get_platform(archive_root: str | None) -> str: + """Return our platform name 'win32', 'linux_x86_64'""" + result = sysconfig.get_platform() + if result.startswith("macosx") and archive_root is not None: + from ..extern.wheel.macosx_libfile import calculate_macosx_platform_tag + + result = calculate_macosx_platform_tag(archive_root, result) + elif _is_32bit_interpreter(): + if result == "linux-x86_64": + # pip pull request #3497 + result = "linux-i686" + elif result == "linux-aarch64": + # packaging pull request #234 + # TODO armv8l, packaging pull request #690 => this did not land + # in pip/packaging yet + result = "linux-armv7l" + + return result.replace("-", "_") + + +def get_flag( + var: str, fallback: bool, expected: bool = True, warn: bool = True +) -> bool: + """Use a fallback value for determining SOABI flags if the needed config + var is unset or unavailable.""" + val = sysconfig.get_config_var(var) + if val is None: + if warn: + warnings.warn( + f"Config variable '{var}' is unset, Python ABI tag may be incorrect", + RuntimeWarning, + stacklevel=2, + ) + return fallback + return val == expected + + +def get_abi_tag() -> str | None: + """Return the ABI tag based on SOABI (if available) or emulate SOABI (PyPy2).""" + soabi: str = sysconfig.get_config_var("SOABI") + impl = tags.interpreter_name() + if not soabi and impl in ("cp", "pp") and hasattr(sys, "maxunicode"): + d = "" + m = "" + u = "" + if get_flag("Py_DEBUG", hasattr(sys, "gettotalrefcount"), warn=(impl == "cp")): + d = "d" + + if get_flag( + "WITH_PYMALLOC", + impl == "cp", + warn=(impl == "cp" and sys.version_info < (3, 8)), + ) and sys.version_info < (3, 8): + m = "m" + + abi = f"{impl}{tags.interpreter_version()}{d}{m}{u}" + elif soabi and impl == "cp" and soabi.startswith("cpython"): + # non-Windows + abi = "cp" + soabi.split("-")[1] + elif soabi and impl == "cp" and soabi.startswith("cp"): + # Windows + abi = soabi.split("-")[0] + elif soabi and impl == "pp": + # we want something like pypy36-pp73 + abi = "-".join(soabi.split("-")[:2]) + abi = abi.replace(".", "_").replace("-", "_") + elif soabi and impl == "graalpy": + abi = "-".join(soabi.split("-")[:3]) + abi = abi.replace(".", "_").replace("-", "_") + elif soabi: + abi = soabi.replace(".", "_").replace("-", "_") + else: + abi = None + + return abi + + +def safer_name(name: str) -> str: + return safe_name(name).replace("-", "_") + + +def safer_version(version: str) -> str: + return safe_version(version).replace("-", "_") + + +def remove_readonly( + func: Callable[..., object], + path: str, + excinfo: tuple[type[Exception], Exception, types.TracebackType], +) -> None: + remove_readonly_exc(func, path, excinfo[1]) + + +def remove_readonly_exc(func: Callable[..., object], path: str, exc: Exception) -> None: + os.chmod(path, stat.S_IWRITE) + func(path) + + +class bdist_wheel(Command): + description = "create a wheel distribution" + + supported_compressions = { + "stored": ZIP_STORED, + "deflated": ZIP_DEFLATED, + } + + user_options = [ + ("bdist-dir=", "b", "temporary directory for creating the distribution"), + ( + "plat-name=", + "p", + "platform name to embed in generated filenames " + f"[default: {get_platform(None)}]", + ), + ( + "keep-temp", + "k", + "keep the pseudo-installation tree around after " + "creating the distribution archive", + ), + ("dist-dir=", "d", "directory to put final built distributions in"), + ("skip-build", None, "skip rebuilding everything (for testing/debugging)"), + ( + "relative", + None, + "build the archive using relative paths [default: false]", + ), + ( + "owner=", + "u", + "Owner name used when creating a tar file [default: current user]", + ), + ( + "group=", + "g", + "Group name used when creating a tar file [default: current group]", + ), + ("universal", None, "make a universal wheel [default: false]"), + ( + "compression=", + None, + "zipfile compression (one of: {}) [default: 'deflated']".format( + ", ".join(supported_compressions) + ), + ), + ( + "python-tag=", + None, + f"Python implementation compatibility tag [default: '{python_tag()}']", + ), + ( + "build-number=", + None, + "Build number for this particular version. " + "As specified in PEP-0427, this must start with a digit. " + "[default: None]", + ), + ( + "py-limited-api=", + None, + "Python tag (cp32|cp33|cpNN) for abi3 wheel tag [default: false]", + ), + ] + + boolean_options = ["keep-temp", "skip-build", "relative", "universal"] + + def initialize_options(self) -> None: + self.bdist_dir: str | None = None + self.data_dir = None + self.plat_name: str | None = None + self.plat_tag = None + self.format = "zip" + self.keep_temp = False + self.dist_dir: str | None = None + self.egginfo_dir = None + self.root_is_pure: bool | None = None + self.skip_build = None + self.relative = False + self.owner = None + self.group = None + self.universal: bool = False + self.compression: str | int = "deflated" + self.python_tag: str = python_tag() + self.build_number: str | None = None + self.py_limited_api: str | Literal[False] = False + self.plat_name_supplied = False + + def finalize_options(self): + if self.bdist_dir is None: + bdist_base = self.get_finalized_command("bdist").bdist_base + self.bdist_dir = os.path.join(bdist_base, "wheel") + + egg_info = self.distribution.get_command_obj("egg_info") + egg_info.ensure_finalized() # needed for correct `wheel_dist_name` + + self.data_dir = self.wheel_dist_name + ".data" + self.plat_name_supplied = self.plat_name is not None + + try: + self.compression = self.supported_compressions[self.compression] + except KeyError: + raise ValueError(f"Unsupported compression: {self.compression}") from None + + need_options = ("dist_dir", "plat_name", "skip_build") + + self.set_undefined_options("bdist", *zip(need_options, need_options)) + + self.root_is_pure = not ( + self.distribution.has_ext_modules() or self.distribution.has_c_libraries() + ) + + if self.py_limited_api and not re.match( + PY_LIMITED_API_PATTERN, self.py_limited_api + ): + raise ValueError(f"py-limited-api must match '{PY_LIMITED_API_PATTERN}'") + + # Support legacy [wheel] section for setting universal + wheel = self.distribution.get_option_dict("wheel") + if "universal" in wheel: + # please don't define this in your global configs + log.warn("The [wheel] section is deprecated. Use [bdist_wheel] instead.") + val = wheel["universal"][1].strip() + if val.lower() in ("1", "true", "yes"): + self.universal = True + + if self.build_number is not None and not self.build_number[:1].isdigit(): + raise ValueError("Build tag (build-number) must start with a digit.") + + @property + def wheel_dist_name(self): + """Return distribution full name with - replaced with _""" + components = ( + safer_name(self.distribution.get_name()), + safer_version(self.distribution.get_version()), + ) + if self.build_number: + components += (self.build_number,) + return "-".join(components) + + def get_tag(self) -> tuple[str, str, str]: + # bdist sets self.plat_name if unset, we should only use it for purepy + # wheels if the user supplied it. + if self.plat_name_supplied: + plat_name = cast(str, self.plat_name) + elif self.root_is_pure: + plat_name = "any" + else: + # macosx contains system version in platform name so need special handle + if self.plat_name and not self.plat_name.startswith("macosx"): + plat_name = self.plat_name + else: + # on macosx always limit the platform name to comply with any + # c-extension modules in bdist_dir, since the user can specify + # a higher MACOSX_DEPLOYMENT_TARGET via tools like CMake + + # on other platforms, and on macosx if there are no c-extension + # modules, use the default platform name. + plat_name = get_platform(self.bdist_dir) + + if _is_32bit_interpreter(): + if plat_name in ("linux-x86_64", "linux_x86_64"): + plat_name = "linux_i686" + if plat_name in ("linux-aarch64", "linux_aarch64"): + # TODO armv8l, packaging pull request #690 => this did not land + # in pip/packaging yet + plat_name = "linux_armv7l" + + plat_name = ( + plat_name.lower().replace("-", "_").replace(".", "_").replace(" ", "_") + ) + + if self.root_is_pure: + if self.universal: + impl = "py2.py3" + else: + impl = self.python_tag + tag = (impl, "none", plat_name) + else: + impl_name = tags.interpreter_name() + impl_ver = tags.interpreter_version() + impl = impl_name + impl_ver + # We don't work on CPython 3.1, 3.0. + if self.py_limited_api and (impl_name + impl_ver).startswith("cp3"): + impl = self.py_limited_api + abi_tag = "abi3" + else: + abi_tag = str(get_abi_tag()).lower() + tag = (impl, abi_tag, plat_name) + # issue gh-374: allow overriding plat_name + supported_tags = [ + (t.interpreter, t.abi, plat_name) for t in tags.sys_tags() + ] + assert ( + tag in supported_tags + ), f"would build wheel with unsupported tag {tag}" + return tag + + def run(self): + build_scripts = self.reinitialize_command("build_scripts") + build_scripts.executable = "python" + build_scripts.force = True + + build_ext = self.reinitialize_command("build_ext") + build_ext.inplace = False + + if not self.skip_build: + self.run_command("build") + + install = self.reinitialize_command("install", reinit_subcommands=True) + install.root = self.bdist_dir + install.compile = False + install.skip_build = self.skip_build + install.warn_dir = False + + # A wheel without setuptools scripts is more cross-platform. + # Use the (undocumented) `no_ep` option to setuptools' + # install_scripts command to avoid creating entry point scripts. + install_scripts = self.reinitialize_command("install_scripts") + install_scripts.no_ep = True + + # Use a custom scheme for the archive, because we have to decide + # at installation time which scheme to use. + for key in ("headers", "scripts", "data", "purelib", "platlib"): + setattr(install, "install_" + key, os.path.join(self.data_dir, key)) + + basedir_observed = "" + + if os.name == "nt": + # win32 barfs if any of these are ''; could be '.'? + # (distutils.command.install:change_roots bug) + basedir_observed = os.path.normpath(os.path.join(self.data_dir, "..")) + self.install_libbase = self.install_lib = basedir_observed + + setattr( + install, + "install_purelib" if self.root_is_pure else "install_platlib", + basedir_observed, + ) + + log.info(f"installing to {self.bdist_dir}") + + self.run_command("install") + + impl_tag, abi_tag, plat_tag = self.get_tag() + archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + if not self.relative: + archive_root = self.bdist_dir + else: + archive_root = os.path.join( + self.bdist_dir, self._ensure_relative(install.install_base) + ) + + self.set_undefined_options("install_egg_info", ("target", "egginfo_dir")) + distinfo_dirname = ( + f"{safer_name(self.distribution.get_name())}-" + f"{safer_version(self.distribution.get_version())}.dist-info" + ) + distinfo_dir = os.path.join(self.bdist_dir, distinfo_dirname) + self.egg2dist(self.egginfo_dir, distinfo_dir) + + self.write_wheelfile(distinfo_dir) + + # Make the archive + if not os.path.exists(self.dist_dir): + os.makedirs(self.dist_dir) + + wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") + with WheelFile(wheel_path, "w", self.compression) as wf: + wf.write_files(archive_root) + + # Add to 'Distribution.dist_files' so that the "upload" command works + getattr(self.distribution, "dist_files", []).append(( + "bdist_wheel", + "{}.{}".format(*sys.version_info[:2]), # like 3.7 + wheel_path, + )) + + if not self.keep_temp: + log.info(f"removing {self.bdist_dir}") + if not self.dry_run: + if sys.version_info < (3, 12): + rmtree(self.bdist_dir, onerror=remove_readonly) + else: + rmtree(self.bdist_dir, onexc=remove_readonly_exc) + + def write_wheelfile( + self, wheelfile_base: str, generator: str = f"setuptools ({__version__})" + ): + from email.message import Message + + msg = Message() + msg["Wheel-Version"] = "1.0" # of the spec + msg["Generator"] = generator + msg["Root-Is-Purelib"] = str(self.root_is_pure).lower() + if self.build_number is not None: + msg["Build"] = self.build_number + + # Doesn't work for bdist_wininst + impl_tag, abi_tag, plat_tag = self.get_tag() + for impl in impl_tag.split("."): + for abi in abi_tag.split("."): + for plat in plat_tag.split("."): + msg["Tag"] = "-".join((impl, abi, plat)) + + wheelfile_path = os.path.join(wheelfile_base, "WHEEL") + log.info(f"creating {wheelfile_path}") + with open(wheelfile_path, "wb") as f: + BytesGenerator(f, maxheaderlen=0).flatten(msg) + + def _ensure_relative(self, path: str) -> str: + # copied from dir_util, deleted + drive, path = os.path.splitdrive(path) + if path[0:1] == os.sep: + path = drive + path[1:] + return path + + @property + def license_paths(self) -> Iterable[str]: + if setuptools_major_version >= 57: + # Setuptools has resolved any patterns to actual file names + return self.distribution.metadata.license_files or () + + files: set[str] = set() + metadata = self.distribution.get_option_dict("metadata") + if setuptools_major_version >= 42: + # Setuptools recognizes the license_files option but does not do globbing + patterns = cast(Sequence[str], self.distribution.metadata.license_files) + else: + # Prior to those, wheel is entirely responsible for handling license files + if "license_files" in metadata: + patterns = metadata["license_files"][1].split() + else: + patterns = () + + if "license_file" in metadata: + warnings.warn( + 'The "license_file" option is deprecated. Use "license_files" instead.', + DeprecationWarning, + stacklevel=2, + ) + files.add(metadata["license_file"][1]) + + if not files and not patterns and not isinstance(patterns, list): + patterns = ("LICEN[CS]E*", "COPYING*", "NOTICE*", "AUTHORS*") + + for pattern in patterns: + for path in iglob(pattern): + if path.endswith("~"): + log.debug( + f'ignoring license file "{path}" as it looks like a backup' + ) + continue + + if path not in files and os.path.isfile(path): + log.info( + f'adding license file "{path}" (matched pattern "{pattern}")' + ) + files.add(path) + + return files + + def egg2dist(self, egginfo_path: str, distinfo_path: str): + """Convert an .egg-info directory into a .dist-info directory""" + + def adios(p: str) -> None: + """Appropriately delete directory, file or link.""" + if os.path.exists(p) and not os.path.islink(p) and os.path.isdir(p): + shutil.rmtree(p) + elif os.path.exists(p): + os.unlink(p) + + adios(distinfo_path) + + if not os.path.exists(egginfo_path): + # There is no egg-info. This is probably because the egg-info + # file/directory is not named matching the distribution name used + # to name the archive file. Check for this case and report + # accordingly. + import glob + + pat = os.path.join(os.path.dirname(egginfo_path), "*.egg-info") + possible = glob.glob(pat) + err = f"Egg metadata expected at {egginfo_path} but not found" + if possible: + alt = os.path.basename(possible[0]) + err += f" ({alt} found - possible misnamed archive file?)" + + raise ValueError(err) + + if os.path.isfile(egginfo_path): + # .egg-info is a single file + pkg_info = pkginfo_to_metadata(egginfo_path, egginfo_path) + os.mkdir(distinfo_path) + else: + # .egg-info is a directory + pkginfo_path = os.path.join(egginfo_path, "PKG-INFO") + pkg_info = pkginfo_to_metadata(egginfo_path, pkginfo_path) + + # ignore common egg metadata that is useless to wheel + shutil.copytree( + egginfo_path, + distinfo_path, + ignore=lambda x, y: { + "PKG-INFO", + "requires.txt", + "SOURCES.txt", + "not-zip-safe", + }, + ) + + # delete dependency_links if it is only whitespace + dependency_links_path = os.path.join(distinfo_path, "dependency_links.txt") + with open(dependency_links_path, encoding="utf-8") as dependency_links_file: + dependency_links = dependency_links_file.read().strip() + if not dependency_links: + adios(dependency_links_path) + + pkg_info_path = os.path.join(distinfo_path, "METADATA") + serialization_policy = EmailPolicy( + utf8=True, + mangle_from_=False, + max_line_length=0, + ) + with open(pkg_info_path, "w", encoding="utf-8") as out: + Generator(out, policy=serialization_policy).flatten(pkg_info) + + for license_path in self.license_paths: + filename = os.path.basename(license_path) + shutil.copy(license_path, os.path.join(distinfo_path, filename)) + + adios(egginfo_path) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/build_py.py b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/build_py.py new file mode 100644 index 0000000000000000000000000000000000000000..ab49874635fed0e07ff552fe4ccf8d0edb991223 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/build_py.py @@ -0,0 +1,401 @@ +from __future__ import annotations + +from functools import partial +from glob import glob +from distutils.util import convert_path +import distutils.command.build_py as orig +import os +import fnmatch +import textwrap +import distutils.errors +import itertools +import stat +from pathlib import Path +from typing import Iterable, Iterator + +from ..extern.more_itertools import unique_everseen +from ..warnings import SetuptoolsDeprecationWarning + + +_IMPLICIT_DATA_FILES = ('*.pyi', 'py.typed') + + +def make_writable(target): + os.chmod(target, os.stat(target).st_mode | stat.S_IWRITE) + + +class build_py(orig.build_py): + """Enhanced 'build_py' command that includes data files with packages + + The data files are specified via a 'package_data' argument to 'setup()'. + See 'setuptools.dist.Distribution' for more details. + + Also, this version of the 'build_py' command allows you to specify both + 'py_modules' and 'packages' in the same setup operation. + """ + + editable_mode: bool = False + existing_egg_info_dir: str | None = None #: Private API, internal use only. + + def finalize_options(self): + orig.build_py.finalize_options(self) + self.package_data = self.distribution.package_data + self.exclude_package_data = self.distribution.exclude_package_data or {} + if 'data_files' in self.__dict__: + del self.__dict__['data_files'] + self.__updated_files = [] + + def copy_file( + self, + infile, + outfile, + preserve_mode=True, + preserve_times=True, + link=None, + level=1, + ): + # Overwrite base class to allow using links + if link: + infile = str(Path(infile).resolve()) + outfile = str(Path(outfile).resolve()) + return super().copy_file( + infile, outfile, preserve_mode, preserve_times, link, level + ) + + def run(self): + """Build modules, packages, and copy data files to build directory""" + if not (self.py_modules or self.packages) or self.editable_mode: + return + + if self.py_modules: + self.build_modules() + + if self.packages: + self.build_packages() + self.build_package_data() + + # Only compile actual .py files, using our base class' idea of what our + # output files are. + self.byte_compile(orig.build_py.get_outputs(self, include_bytecode=False)) + + def __getattr__(self, attr): + "lazily compute data files" + if attr == 'data_files': + self.data_files = self._get_data_files() + return self.data_files + return orig.build_py.__getattr__(self, attr) + + def build_module(self, module, module_file, package): + outfile, copied = orig.build_py.build_module(self, module, module_file, package) + if copied: + self.__updated_files.append(outfile) + return outfile, copied + + def _get_data_files(self): + """Generate list of '(package,src_dir,build_dir,filenames)' tuples""" + self.analyze_manifest() + return list(map(self._get_pkg_data_files, self.packages or ())) + + def get_data_files_without_manifest(self): + """ + Generate list of ``(package,src_dir,build_dir,filenames)`` tuples, + but without triggering any attempt to analyze or build the manifest. + """ + # Prevent eventual errors from unset `manifest_files` + # (that would otherwise be set by `analyze_manifest`) + self.__dict__.setdefault('manifest_files', {}) + return list(map(self._get_pkg_data_files, self.packages or ())) + + def _get_pkg_data_files(self, package): + # Locate package source directory + src_dir = self.get_package_dir(package) + + # Compute package build directory + build_dir = os.path.join(*([self.build_lib] + package.split('.'))) + + # Strip directory from globbed filenames + filenames = [ + os.path.relpath(file, src_dir) + for file in self.find_data_files(package, src_dir) + ] + return package, src_dir, build_dir, filenames + + def find_data_files(self, package, src_dir): + """Return filenames for package's data files in 'src_dir'""" + patterns = self._get_platform_patterns( + self.package_data, + package, + src_dir, + extra_patterns=_IMPLICIT_DATA_FILES, + ) + globs_expanded = map(partial(glob, recursive=True), patterns) + # flatten the expanded globs into an iterable of matches + globs_matches = itertools.chain.from_iterable(globs_expanded) + glob_files = filter(os.path.isfile, globs_matches) + files = itertools.chain( + self.manifest_files.get(package, []), + glob_files, + ) + return self.exclude_data_files(package, src_dir, files) + + def get_outputs(self, include_bytecode=True) -> list[str]: + """See :class:`setuptools.commands.build.SubCommand`""" + if self.editable_mode: + return list(self.get_output_mapping().keys()) + return super().get_outputs(include_bytecode) + + def get_output_mapping(self) -> dict[str, str]: + """See :class:`setuptools.commands.build.SubCommand`""" + mapping = itertools.chain( + self._get_package_data_output_mapping(), + self._get_module_mapping(), + ) + return dict(sorted(mapping, key=lambda x: x[0])) + + def _get_module_mapping(self) -> Iterator[tuple[str, str]]: + """Iterate over all modules producing (dest, src) pairs.""" + for package, module, module_file in self.find_all_modules(): + package = package.split('.') + filename = self.get_module_outfile(self.build_lib, package, module) + yield (filename, module_file) + + def _get_package_data_output_mapping(self) -> Iterator[tuple[str, str]]: + """Iterate over package data producing (dest, src) pairs.""" + for package, src_dir, build_dir, filenames in self.data_files: + for filename in filenames: + target = os.path.join(build_dir, filename) + srcfile = os.path.join(src_dir, filename) + yield (target, srcfile) + + def build_package_data(self): + """Copy data files into build directory""" + for target, srcfile in self._get_package_data_output_mapping(): + self.mkpath(os.path.dirname(target)) + _outf, _copied = self.copy_file(srcfile, target) + make_writable(target) + + def analyze_manifest(self): + self.manifest_files = mf = {} + if not self.distribution.include_package_data: + return + src_dirs = {} + for package in self.packages or (): + # Locate package source directory + src_dirs[assert_relative(self.get_package_dir(package))] = package + + if ( + getattr(self, 'existing_egg_info_dir', None) + and Path(self.existing_egg_info_dir, "SOURCES.txt").exists() + ): + egg_info_dir = self.existing_egg_info_dir + manifest = Path(egg_info_dir, "SOURCES.txt") + files = manifest.read_text(encoding="utf-8").splitlines() + else: + self.run_command('egg_info') + ei_cmd = self.get_finalized_command('egg_info') + egg_info_dir = ei_cmd.egg_info + files = ei_cmd.filelist.files + + check = _IncludePackageDataAbuse() + for path in self._filter_build_files(files, egg_info_dir): + d, f = os.path.split(assert_relative(path)) + prev = None + oldf = f + while d and d != prev and d not in src_dirs: + prev = d + d, df = os.path.split(d) + f = os.path.join(df, f) + if d in src_dirs: + if f == oldf: + if check.is_module(f): + continue # it's a module, not data + else: + importable = check.importable_subpackage(src_dirs[d], f) + if importable: + check.warn(importable) + mf.setdefault(src_dirs[d], []).append(path) + + def _filter_build_files(self, files: Iterable[str], egg_info: str) -> Iterator[str]: + """ + ``build_meta`` may try to create egg_info outside of the project directory, + and this can be problematic for certain plugins (reported in issue #3500). + + Extensions might also include between their sources files created on the + ``build_lib`` and ``build_temp`` directories. + + This function should filter this case of invalid files out. + """ + build = self.get_finalized_command("build") + build_dirs = (egg_info, self.build_lib, build.build_temp, build.build_base) + norm_dirs = [os.path.normpath(p) for p in build_dirs if p] + + for file in files: + norm_path = os.path.normpath(file) + if not os.path.isabs(file) or all(d not in norm_path for d in norm_dirs): + yield file + + def get_data_files(self): + pass # Lazily compute data files in _get_data_files() function. + + def check_package(self, package, package_dir): + """Check namespace packages' __init__ for declare_namespace""" + try: + return self.packages_checked[package] + except KeyError: + pass + + init_py = orig.build_py.check_package(self, package, package_dir) + self.packages_checked[package] = init_py + + if not init_py or not self.distribution.namespace_packages: + return init_py + + for pkg in self.distribution.namespace_packages: + if pkg == package or pkg.startswith(package + '.'): + break + else: + return init_py + + with open(init_py, 'rb') as f: + contents = f.read() + if b'declare_namespace' not in contents: + raise distutils.errors.DistutilsError( + "Namespace package problem: %s is a namespace package, but " + "its\n__init__.py does not call declare_namespace()! Please " + 'fix it.\n(See the setuptools manual under ' + '"Namespace Packages" for details.)\n"' % (package,) + ) + return init_py + + def initialize_options(self): + self.packages_checked = {} + orig.build_py.initialize_options(self) + self.editable_mode = False + self.existing_egg_info_dir = None + + def get_package_dir(self, package): + res = orig.build_py.get_package_dir(self, package) + if self.distribution.src_root is not None: + return os.path.join(self.distribution.src_root, res) + return res + + def exclude_data_files(self, package, src_dir, files): + """Filter filenames for package's data files in 'src_dir'""" + files = list(files) + patterns = self._get_platform_patterns( + self.exclude_package_data, + package, + src_dir, + ) + match_groups = (fnmatch.filter(files, pattern) for pattern in patterns) + # flatten the groups of matches into an iterable of matches + matches = itertools.chain.from_iterable(match_groups) + bad = set(matches) + keepers = (fn for fn in files if fn not in bad) + # ditch dupes + return list(unique_everseen(keepers)) + + @staticmethod + def _get_platform_patterns(spec, package, src_dir, extra_patterns=()): + """ + yield platform-specific path patterns (suitable for glob + or fn_match) from a glob-based spec (such as + self.package_data or self.exclude_package_data) + matching package in src_dir. + """ + raw_patterns = itertools.chain( + extra_patterns, + spec.get('', []), + spec.get(package, []), + ) + return ( + # Each pattern has to be converted to a platform-specific path + os.path.join(src_dir, convert_path(pattern)) + for pattern in raw_patterns + ) + + +def assert_relative(path): + if not os.path.isabs(path): + return path + from distutils.errors import DistutilsSetupError + + msg = ( + textwrap.dedent( + """ + Error: setup script specifies an absolute path: + + %s + + setup() arguments must *always* be /-separated paths relative to the + setup.py directory, *never* absolute paths. + """ + ).lstrip() + % path + ) + raise DistutilsSetupError(msg) + + +class _IncludePackageDataAbuse: + """Inform users that package or module is included as 'data file'""" + + class _Warning(SetuptoolsDeprecationWarning): + _SUMMARY = """ + Package {importable!r} is absent from the `packages` configuration. + """ + + _DETAILS = """ + ############################ + # Package would be ignored # + ############################ + Python recognizes {importable!r} as an importable package[^1], + but it is absent from setuptools' `packages` configuration. + + This leads to an ambiguous overall configuration. If you want to distribute this + package, please make sure that {importable!r} is explicitly added + to the `packages` configuration field. + + Alternatively, you can also rely on setuptools' discovery methods + (for example by using `find_namespace_packages(...)`/`find_namespace:` + instead of `find_packages(...)`/`find:`). + + You can read more about "package discovery" on setuptools documentation page: + + - https://setuptools.pypa.io/en/latest/userguide/package_discovery.html + + If you don't want {importable!r} to be distributed and are + already explicitly excluding {importable!r} via + `find_namespace_packages(...)/find_namespace` or `find_packages(...)/find`, + you can try to use `exclude_package_data`, or `include-package-data=False` in + combination with a more fine grained `package-data` configuration. + + You can read more about "package data files" on setuptools documentation page: + + - https://setuptools.pypa.io/en/latest/userguide/datafiles.html + + + [^1]: For Python, any directory (with suitable naming) can be imported, + even if it does not contain any `.py` files. + On the other hand, currently there is no concept of package data + directory, all directories are treated like packages. + """ + # _DUE_DATE: still not defined as this is particularly controversial. + # Warning initially introduced in May 2022. See issue #3340 for discussion. + + def __init__(self): + self._already_warned = set() + + def is_module(self, file): + return file.endswith(".py") and file[: -len(".py")].isidentifier() + + def importable_subpackage(self, parent, file): + pkg = Path(file).parent + parts = list(itertools.takewhile(str.isidentifier, pkg.parts)) + if parts: + return ".".join([parent, *parts]) + return None + + def warn(self, importable): + if importable not in self._already_warned: + self._Warning.emit(importable=importable) + self._already_warned.add(importable) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/develop.py b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/develop.py new file mode 100644 index 0000000000000000000000000000000000000000..55f24f396cc6dd4822abd8cbf10ccc6b22ca618b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/develop.py @@ -0,0 +1,196 @@ +from distutils.util import convert_path +from distutils import log +from distutils.errors import DistutilsOptionError +import os +import glob + +from setuptools.command.easy_install import easy_install +from setuptools import _normalization +from setuptools import _path +from setuptools import namespaces +import setuptools + +from ..unicode_utils import _read_utf8_with_fallback + + +class develop(namespaces.DevelopInstaller, easy_install): + """Set up package for development""" + + description = "install package in 'development mode'" + + user_options = easy_install.user_options + [ + ("uninstall", "u", "Uninstall this source package"), + ("egg-path=", None, "Set the path to be used in the .egg-link file"), + ] + + boolean_options = easy_install.boolean_options + ['uninstall'] + + command_consumes_arguments = False # override base + + def run(self): + if self.uninstall: + self.multi_version = True + self.uninstall_link() + self.uninstall_namespaces() + else: + self.install_for_development() + self.warn_deprecated_options() + + def initialize_options(self): + self.uninstall = None + self.egg_path = None + easy_install.initialize_options(self) + self.setup_path = None + self.always_copy_from = '.' # always copy eggs installed in curdir + + def finalize_options(self): + import pkg_resources + + ei = self.get_finalized_command("egg_info") + self.args = [ei.egg_name] + + easy_install.finalize_options(self) + self.expand_basedirs() + self.expand_dirs() + # pick up setup-dir .egg files only: no .egg-info + self.package_index.scan(glob.glob('*.egg')) + + egg_link_fn = ( + _normalization.filename_component_broken(ei.egg_name) + '.egg-link' + ) + self.egg_link = os.path.join(self.install_dir, egg_link_fn) + self.egg_base = ei.egg_base + if self.egg_path is None: + self.egg_path = os.path.abspath(ei.egg_base) + + target = _path.normpath(self.egg_base) + egg_path = _path.normpath(os.path.join(self.install_dir, self.egg_path)) + if egg_path != target: + raise DistutilsOptionError( + "--egg-path must be a relative path from the install" + " directory to " + target + ) + + # Make a distribution for the package's source + self.dist = pkg_resources.Distribution( + target, + pkg_resources.PathMetadata(target, os.path.abspath(ei.egg_info)), + project_name=ei.egg_name, + ) + + self.setup_path = self._resolve_setup_path( + self.egg_base, + self.install_dir, + self.egg_path, + ) + + @staticmethod + def _resolve_setup_path(egg_base, install_dir, egg_path): + """ + Generate a path from egg_base back to '.' where the + setup script resides and ensure that path points to the + setup path from $install_dir/$egg_path. + """ + path_to_setup = egg_base.replace(os.sep, '/').rstrip('/') + if path_to_setup != os.curdir: + path_to_setup = '../' * (path_to_setup.count('/') + 1) + resolved = _path.normpath(os.path.join(install_dir, egg_path, path_to_setup)) + curdir = _path.normpath(os.curdir) + if resolved != curdir: + raise DistutilsOptionError( + "Can't get a consistent path to setup script from" + " installation directory", + resolved, + curdir, + ) + return path_to_setup + + def install_for_development(self): + self.run_command('egg_info') + + # Build extensions in-place + self.reinitialize_command('build_ext', inplace=True) + self.run_command('build_ext') + + if setuptools.bootstrap_install_from: + self.easy_install(setuptools.bootstrap_install_from) + setuptools.bootstrap_install_from = None + + self.install_namespaces() + + # create an .egg-link in the installation dir, pointing to our egg + log.info("Creating %s (link to %s)", self.egg_link, self.egg_base) + if not self.dry_run: + with open(self.egg_link, "w", encoding="utf-8") as f: + f.write(self.egg_path + "\n" + self.setup_path) + # postprocess the installed distro, fixing up .pth, installing scripts, + # and handling requirements + self.process_distribution(None, self.dist, not self.no_deps) + + def uninstall_link(self): + if os.path.exists(self.egg_link): + log.info("Removing %s (link to %s)", self.egg_link, self.egg_base) + + contents = [ + line.rstrip() + for line in _read_utf8_with_fallback(self.egg_link).splitlines() + ] + + if contents not in ([self.egg_path], [self.egg_path, self.setup_path]): + log.warn("Link points to %s: uninstall aborted", contents) + return + if not self.dry_run: + os.unlink(self.egg_link) + if not self.dry_run: + self.update_pth(self.dist) # remove any .pth link to us + if self.distribution.scripts: + # XXX should also check for entry point scripts! + log.warn("Note: you must uninstall or replace scripts manually!") + + def install_egg_scripts(self, dist): + if dist is not self.dist: + # Installing a dependency, so fall back to normal behavior + return easy_install.install_egg_scripts(self, dist) + + # create wrapper scripts in the script dir, pointing to dist.scripts + + # new-style... + self.install_wrapper_scripts(dist) + + # ...and old-style + for script_name in self.distribution.scripts or []: + script_path = os.path.abspath(convert_path(script_name)) + script_name = os.path.basename(script_path) + script_text = _read_utf8_with_fallback(script_path) + self.install_script(dist, script_name, script_text, script_path) + + return None + + def install_wrapper_scripts(self, dist): + dist = VersionlessRequirement(dist) + return easy_install.install_wrapper_scripts(self, dist) + + +class VersionlessRequirement: + """ + Adapt a pkg_resources.Distribution to simply return the project + name as the 'requirement' so that scripts will work across + multiple versions. + + >>> from pkg_resources import Distribution + >>> dist = Distribution(project_name='foo', version='1.0') + >>> str(dist.as_requirement()) + 'foo==1.0' + >>> adapted_dist = VersionlessRequirement(dist) + >>> str(adapted_dist.as_requirement()) + 'foo' + """ + + def __init__(self, dist): + self.__dist = dist + + def __getattr__(self, name): + return getattr(self.__dist, name) + + def as_requirement(self): + return self.project_name diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/dist_info.py b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/dist_info.py new file mode 100644 index 0000000000000000000000000000000000000000..2adc1c46f33a02f813d5dff648ae60bb30f1071f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/dist_info.py @@ -0,0 +1,106 @@ +""" +Create a dist_info directory +As defined in the wheel specification +""" + +import os +import shutil +from contextlib import contextmanager +from distutils import log +from distutils.core import Command +from pathlib import Path +from typing import cast + +from .. import _normalization +from .egg_info import egg_info as egg_info_cls + + +class dist_info(Command): + """ + This command is private and reserved for internal use of setuptools, + users should rely on ``setuptools.build_meta`` APIs. + """ + + description = "DO NOT CALL DIRECTLY, INTERNAL ONLY: create .dist-info directory" + + user_options = [ + ( + 'output-dir=', + 'o', + "directory inside of which the .dist-info will be" + "created [default: top of the source tree]", + ), + ('tag-date', 'd', "Add date stamp (e.g. 20050528) to version number"), + ('tag-build=', 'b', "Specify explicit tag to add to version number"), + ('no-date', 'D', "Don't include date stamp [default]"), + ('keep-egg-info', None, "*TRANSITIONAL* will be removed in the future"), + ] + + boolean_options = ['tag-date', 'keep-egg-info'] + negative_opt = {'no-date': 'tag-date'} + + def initialize_options(self): + self.output_dir = None + self.name = None + self.dist_info_dir = None + self.tag_date = None + self.tag_build = None + self.keep_egg_info = False + + def finalize_options(self): + dist = self.distribution + project_dir = dist.src_root or os.curdir + self.output_dir = Path(self.output_dir or project_dir) + + egg_info = cast(egg_info_cls, self.reinitialize_command("egg_info")) + egg_info.egg_base = str(self.output_dir) + + if self.tag_date: + egg_info.tag_date = self.tag_date + else: + self.tag_date = egg_info.tag_date + + if self.tag_build: + egg_info.tag_build = self.tag_build + else: + self.tag_build = egg_info.tag_build + + egg_info.finalize_options() + self.egg_info = egg_info + + name = _normalization.safer_name(dist.get_name()) + version = _normalization.safer_best_effort_version(dist.get_version()) + self.name = f"{name}-{version}" + self.dist_info_dir = os.path.join(self.output_dir, f"{self.name}.dist-info") + + @contextmanager + def _maybe_bkp_dir(self, dir_path: str, requires_bkp: bool): + if requires_bkp: + bkp_name = f"{dir_path}.__bkp__" + _rm(bkp_name, ignore_errors=True) + shutil.copytree(dir_path, bkp_name, dirs_exist_ok=True, symlinks=True) + try: + yield + finally: + _rm(dir_path, ignore_errors=True) + shutil.move(bkp_name, dir_path) + else: + yield + + def run(self): + self.output_dir.mkdir(parents=True, exist_ok=True) + self.egg_info.run() + egg_info_dir = self.egg_info.egg_info + assert os.path.isdir(egg_info_dir), ".egg-info dir should have been created" + + log.info("creating '{}'".format(os.path.abspath(self.dist_info_dir))) + bdist_wheel = self.get_finalized_command('bdist_wheel') + + # TODO: if bdist_wheel if merged into setuptools, just add "keep_egg_info" there + with self._maybe_bkp_dir(egg_info_dir, self.keep_egg_info): + bdist_wheel.egg2dist(egg_info_dir, self.dist_info_dir) + + +def _rm(dir_name, **opts): + if os.path.isdir(dir_name): + shutil.rmtree(dir_name, **opts) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/editable_wheel.py b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/editable_wheel.py new file mode 100644 index 0000000000000000000000000000000000000000..ae31bb4c79be4eb60e4a618bb31207a9066185d4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/editable_wheel.py @@ -0,0 +1,918 @@ +""" +Create a wheel that, when installed, will make the source package 'editable' +(add it to the interpreter's path, including metadata) per PEP 660. Replaces +'setup.py develop'. + +.. note:: + One of the mechanisms briefly mentioned in PEP 660 to implement editable installs is + to create a separated directory inside ``build`` and use a .pth file to point to that + directory. In the context of this file such directory is referred as + *auxiliary build directory* or ``auxiliary_dir``. +""" + +from __future__ import annotations + +import logging +import io +import os +import shutil +import traceback +from contextlib import suppress +from enum import Enum +from inspect import cleandoc +from itertools import chain, starmap +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import ( + TYPE_CHECKING, + Iterable, + Iterator, + Mapping, + Protocol, + TypeVar, + cast, +) + +from .. import ( + Command, + _normalization, + _path, + errors, + namespaces, +) +from .._path import StrPath +from ..compat import py39 +from ..discovery import find_package_path +from ..dist import Distribution +from ..warnings import ( + InformationOnly, + SetuptoolsDeprecationWarning, + SetuptoolsWarning, +) +from .build import build as build_cls +from .build_py import build_py as build_py_cls +from .dist_info import dist_info as dist_info_cls +from .egg_info import egg_info as egg_info_cls +from .install import install as install_cls +from .install_scripts import install_scripts as install_scripts_cls + +if TYPE_CHECKING: + from .._vendor.wheel.wheelfile import WheelFile + +_P = TypeVar("_P", bound=StrPath) +_logger = logging.getLogger(__name__) + + +class _EditableMode(Enum): + """ + Possible editable installation modes: + `lenient` (new files automatically added to the package - DEFAULT); + `strict` (requires a new installation when files are added/removed); or + `compat` (attempts to emulate `python setup.py develop` - DEPRECATED). + """ + + STRICT = "strict" + LENIENT = "lenient" + COMPAT = "compat" # TODO: Remove `compat` after Dec/2022. + + @classmethod + def convert(cls, mode: str | None) -> _EditableMode: + if not mode: + return _EditableMode.LENIENT # default + + _mode = mode.upper() + if _mode not in _EditableMode.__members__: + raise errors.OptionError(f"Invalid editable mode: {mode!r}. Try: 'strict'.") + + if _mode == "COMPAT": + SetuptoolsDeprecationWarning.emit( + "Compat editable installs", + """ + The 'compat' editable mode is transitional and will be removed + in future versions of `setuptools`. + Please adapt your code accordingly to use either the 'strict' or the + 'lenient' modes. + """, + see_docs="userguide/development_mode.html", + # TODO: define due_date + # There is a series of shortcomings with the available editable install + # methods, and they are very controversial. This is something that still + # needs work. + # Moreover, `pip` is still hiding this warning, so users are not aware. + ) + + return _EditableMode[_mode] + + +_STRICT_WARNING = """ +New or renamed files may not be automatically picked up without a new installation. +""" + +_LENIENT_WARNING = """ +Options like `package-data`, `include/exclude-package-data` or +`packages.find.exclude/include` may have no effect. +""" + + +class editable_wheel(Command): + """Build 'editable' wheel for development. + This command is private and reserved for internal use of setuptools, + users should rely on ``setuptools.build_meta`` APIs. + """ + + description = "DO NOT CALL DIRECTLY, INTERNAL ONLY: create PEP 660 editable wheel" + + user_options = [ + ("dist-dir=", "d", "directory to put final built distributions in"), + ("dist-info-dir=", "I", "path to a pre-build .dist-info directory"), + ("mode=", None, cleandoc(_EditableMode.__doc__ or "")), + ] + + def initialize_options(self): + self.dist_dir = None + self.dist_info_dir = None + self.project_dir = None + self.mode = None + + def finalize_options(self): + dist = self.distribution + self.project_dir = dist.src_root or os.curdir + self.package_dir = dist.package_dir or {} + self.dist_dir = Path(self.dist_dir or os.path.join(self.project_dir, "dist")) + + def run(self): + try: + self.dist_dir.mkdir(exist_ok=True) + self._ensure_dist_info() + + # Add missing dist_info files + self.reinitialize_command("bdist_wheel") + bdist_wheel = self.get_finalized_command("bdist_wheel") + bdist_wheel.write_wheelfile(self.dist_info_dir) + + self._create_wheel_file(bdist_wheel) + except Exception: + traceback.print_exc() + project = self.distribution.name or self.distribution.get_name() + _DebuggingTips.emit(project=project) + raise + + def _ensure_dist_info(self): + if self.dist_info_dir is None: + dist_info = cast(dist_info_cls, self.reinitialize_command("dist_info")) + dist_info.output_dir = self.dist_dir + dist_info.ensure_finalized() + dist_info.run() + self.dist_info_dir = dist_info.dist_info_dir + else: + assert str(self.dist_info_dir).endswith(".dist-info") + assert Path(self.dist_info_dir, "METADATA").exists() + + def _install_namespaces(self, installation_dir, pth_prefix): + # XXX: Only required to support the deprecated namespace practice + dist = self.distribution + if not dist.namespace_packages: + return + + src_root = Path(self.project_dir, self.package_dir.get("", ".")).resolve() + installer = _NamespaceInstaller(dist, installation_dir, pth_prefix, src_root) + installer.install_namespaces() + + def _find_egg_info_dir(self) -> str | None: + parent_dir = Path(self.dist_info_dir).parent if self.dist_info_dir else Path() + candidates = map(str, parent_dir.glob("*.egg-info")) + return next(candidates, None) + + def _configure_build( + self, name: str, unpacked_wheel: StrPath, build_lib: StrPath, tmp_dir: StrPath + ): + """Configure commands to behave in the following ways: + + - Build commands can write to ``build_lib`` if they really want to... + (but this folder is expected to be ignored and modules are expected to live + in the project directory...) + - Binary extensions should be built in-place (editable_mode = True) + - Data/header/script files are not part of the "editable" specification + so they are written directly to the unpacked_wheel directory. + """ + # Non-editable files (data, headers, scripts) are written directly to the + # unpacked_wheel + + dist = self.distribution + wheel = str(unpacked_wheel) + build_lib = str(build_lib) + data = str(Path(unpacked_wheel, f"{name}.data", "data")) + headers = str(Path(unpacked_wheel, f"{name}.data", "headers")) + scripts = str(Path(unpacked_wheel, f"{name}.data", "scripts")) + + # egg-info may be generated again to create a manifest (used for package data) + egg_info = cast( + egg_info_cls, dist.reinitialize_command("egg_info", reinit_subcommands=True) + ) + egg_info.egg_base = str(tmp_dir) + egg_info.ignore_egg_info_in_manifest = True + + build = cast( + build_cls, dist.reinitialize_command("build", reinit_subcommands=True) + ) + install = cast( + install_cls, dist.reinitialize_command("install", reinit_subcommands=True) + ) + + build.build_platlib = build.build_purelib = build.build_lib = build_lib + install.install_purelib = install.install_platlib = install.install_lib = wheel + install.install_scripts = build.build_scripts = scripts + install.install_headers = headers + install.install_data = data + + install_scripts = cast( + install_scripts_cls, dist.get_command_obj("install_scripts") + ) + install_scripts.no_ep = True + + build.build_temp = str(tmp_dir) + + build_py = cast(build_py_cls, dist.get_command_obj("build_py")) + build_py.compile = False + build_py.existing_egg_info_dir = self._find_egg_info_dir() + + self._set_editable_mode() + + build.ensure_finalized() + install.ensure_finalized() + + def _set_editable_mode(self): + """Set the ``editable_mode`` flag in the build sub-commands""" + dist = self.distribution + build = dist.get_command_obj("build") + # TODO: Update typeshed distutils stubs to overload non-None return type by default + for cmd_name in build.get_sub_commands(): + cmd = dist.get_command_obj(cmd_name) + if hasattr(cmd, "editable_mode"): + cmd.editable_mode = True + elif hasattr(cmd, "inplace"): + cmd.inplace = True # backward compatibility with distutils + + def _collect_build_outputs(self) -> tuple[list[str], dict[str, str]]: + files: list[str] = [] + mapping: dict[str, str] = {} + build = self.get_finalized_command("build") + + for cmd_name in build.get_sub_commands(): + cmd = self.get_finalized_command(cmd_name) + if hasattr(cmd, "get_outputs"): + files.extend(cmd.get_outputs() or []) + if hasattr(cmd, "get_output_mapping"): + mapping.update(cmd.get_output_mapping() or {}) + + return files, mapping + + def _run_build_commands( + self, + dist_name: str, + unpacked_wheel: StrPath, + build_lib: StrPath, + tmp_dir: StrPath, + ) -> tuple[list[str], dict[str, str]]: + self._configure_build(dist_name, unpacked_wheel, build_lib, tmp_dir) + self._run_build_subcommands() + files, mapping = self._collect_build_outputs() + self._run_install("headers") + self._run_install("scripts") + self._run_install("data") + return files, mapping + + def _run_build_subcommands(self) -> None: + """ + Issue #3501 indicates that some plugins/customizations might rely on: + + 1. ``build_py`` not running + 2. ``build_py`` always copying files to ``build_lib`` + + However both these assumptions may be false in editable_wheel. + This method implements a temporary workaround to support the ecosystem + while the implementations catch up. + """ + # TODO: Once plugins/customisations had the chance to catch up, replace + # `self._run_build_subcommands()` with `self.run_command("build")`. + # Also remove _safely_run, TestCustomBuildPy. Suggested date: Aug/2023. + build = self.get_finalized_command("build") + for name in build.get_sub_commands(): + cmd = self.get_finalized_command(name) + if name == "build_py" and type(cmd) != build_py_cls: + self._safely_run(name) + else: + self.run_command(name) + + def _safely_run(self, cmd_name: str): + try: + return self.run_command(cmd_name) + except Exception: + SetuptoolsDeprecationWarning.emit( + "Customization incompatible with editable install", + f""" + {traceback.format_exc()} + + If you are seeing this warning it is very likely that a setuptools + plugin or customization overrides the `{cmd_name}` command, without + taking into consideration how editable installs run build steps + starting from setuptools v64.0.0. + + Plugin authors and developers relying on custom build steps are + encouraged to update their `{cmd_name}` implementation considering the + information about editable installs in + https://setuptools.pypa.io/en/latest/userguide/extension.html. + + For the time being `setuptools` will silence this error and ignore + the faulty command, but this behaviour will change in future versions. + """, + # TODO: define due_date + # There is a series of shortcomings with the available editable install + # methods, and they are very controversial. This is something that still + # needs work. + ) + + def _create_wheel_file(self, bdist_wheel): + from ..extern.wheel.wheelfile import WheelFile + + dist_info = self.get_finalized_command("dist_info") + dist_name = dist_info.name + tag = "-".join(bdist_wheel.get_tag()) + build_tag = "0.editable" # According to PEP 427 needs to start with digit + archive_name = f"{dist_name}-{build_tag}-{tag}.whl" + wheel_path = Path(self.dist_dir, archive_name) + if wheel_path.exists(): + wheel_path.unlink() + + unpacked_wheel = TemporaryDirectory(suffix=archive_name) + build_lib = TemporaryDirectory(suffix=".build-lib") + build_tmp = TemporaryDirectory(suffix=".build-temp") + + with unpacked_wheel as unpacked, build_lib as lib, build_tmp as tmp: + unpacked_dist_info = Path(unpacked, Path(self.dist_info_dir).name) + shutil.copytree(self.dist_info_dir, unpacked_dist_info) + self._install_namespaces(unpacked, dist_name) + files, mapping = self._run_build_commands(dist_name, unpacked, lib, tmp) + strategy = self._select_strategy(dist_name, tag, lib) + with strategy, WheelFile(wheel_path, "w") as wheel_obj: + strategy(wheel_obj, files, mapping) + wheel_obj.write_files(unpacked) + + return wheel_path + + def _run_install(self, category: str): + has_category = getattr(self.distribution, f"has_{category}", None) + if has_category and has_category(): + _logger.info(f"Installing {category} as non editable") + self.run_command(f"install_{category}") + + def _select_strategy( + self, + name: str, + tag: str, + build_lib: StrPath, + ) -> EditableStrategy: + """Decides which strategy to use to implement an editable installation.""" + build_name = f"__editable__.{name}-{tag}" + project_dir = Path(self.project_dir) + mode = _EditableMode.convert(self.mode) + + if mode is _EditableMode.STRICT: + auxiliary_dir = _empty_dir(Path(self.project_dir, "build", build_name)) + return _LinkTree(self.distribution, name, auxiliary_dir, build_lib) + + packages = _find_packages(self.distribution) + has_simple_layout = _simple_layout(packages, self.package_dir, project_dir) + is_compat_mode = mode is _EditableMode.COMPAT + if set(self.package_dir) == {""} and has_simple_layout or is_compat_mode: + # src-layout(ish) is relatively safe for a simple pth file + src_dir = self.package_dir.get("", ".") + return _StaticPth(self.distribution, name, [Path(project_dir, src_dir)]) + + # Use a MetaPathFinder to avoid adding accidental top-level packages/modules + return _TopLevelFinder(self.distribution, name) + + +class EditableStrategy(Protocol): + def __call__(self, wheel: WheelFile, files: list[str], mapping: dict[str, str]): ... + + def __enter__(self): ... + + def __exit__(self, _exc_type, _exc_value, _traceback): ... + + +class _StaticPth: + def __init__(self, dist: Distribution, name: str, path_entries: list[Path]): + self.dist = dist + self.name = name + self.path_entries = path_entries + + def __call__(self, wheel: WheelFile, files: list[str], mapping: dict[str, str]): + entries = "\n".join(str(p.resolve()) for p in self.path_entries) + contents = _encode_pth(f"{entries}\n") + wheel.writestr(f"__editable__.{self.name}.pth", contents) + + def __enter__(self): + msg = f""" + Editable install will be performed using .pth file to extend `sys.path` with: + {list(map(os.fspath, self.path_entries))!r} + """ + _logger.warning(msg + _LENIENT_WARNING) + return self + + def __exit__(self, _exc_type, _exc_value, _traceback): ... + + +class _LinkTree(_StaticPth): + """ + Creates a ``.pth`` file that points to a link tree in the ``auxiliary_dir``. + + This strategy will only link files (not dirs), so it can be implemented in + any OS, even if that means using hardlinks instead of symlinks. + + By collocating ``auxiliary_dir`` and the original source code, limitations + with hardlinks should be avoided. + """ + + def __init__( + self, + dist: Distribution, + name: str, + auxiliary_dir: StrPath, + build_lib: StrPath, + ): + self.auxiliary_dir = Path(auxiliary_dir) + self.build_lib = Path(build_lib).resolve() + # TODO: Update typeshed distutils stubs to overload non-None return type by default + self._file = dist.get_command_obj("build_py").copy_file # type: ignore[union-attr] + super().__init__(dist, name, [self.auxiliary_dir]) + + def __call__(self, wheel: WheelFile, files: list[str], mapping: dict[str, str]): + self._create_links(files, mapping) + super().__call__(wheel, files, mapping) + + def _normalize_output(self, file: str) -> str | None: + # Files relative to build_lib will be normalized to None + with suppress(ValueError): + path = Path(file).resolve().relative_to(self.build_lib) + return str(path).replace(os.sep, '/') + return None + + def _create_file(self, relative_output: str, src_file: str, link=None): + dest = self.auxiliary_dir / relative_output + if not dest.parent.is_dir(): + dest.parent.mkdir(parents=True) + # TODO: Update typeshed distutils stubs so distutils.cmd.Command.copy_file, accepts PathLike + # same with methods used by copy_file + self._file(src_file, dest, link=link) # type: ignore[arg-type] + + def _create_links(self, outputs, output_mapping): + self.auxiliary_dir.mkdir(parents=True, exist_ok=True) + link_type = "sym" if _can_symlink_files(self.auxiliary_dir) else "hard" + mappings = {self._normalize_output(k): v for k, v in output_mapping.items()} + mappings.pop(None, None) # remove files that are not relative to build_lib + + for output in outputs: + relative = self._normalize_output(output) + if relative and relative not in mappings: + self._create_file(relative, output) + + for relative, src in mappings.items(): + self._create_file(relative, src, link=link_type) + + def __enter__(self): + msg = "Strict editable install will be performed using a link tree.\n" + _logger.warning(msg + _STRICT_WARNING) + return self + + def __exit__(self, _exc_type, _exc_value, _traceback): + msg = f"""\n + Strict editable installation performed using the auxiliary directory: + {self.auxiliary_dir} + + Please be careful to not remove this directory, otherwise you might not be able + to import/use your package. + """ + InformationOnly.emit("Editable installation.", msg) + + +class _TopLevelFinder: + def __init__(self, dist: Distribution, name: str): + self.dist = dist + self.name = name + + def template_vars(self) -> tuple[str, str, dict[str, str], dict[str, list[str]]]: + src_root = self.dist.src_root or os.curdir + top_level = chain(_find_packages(self.dist), _find_top_level_modules(self.dist)) + package_dir = self.dist.package_dir or {} + roots = _find_package_roots(top_level, package_dir, src_root) + + namespaces_: dict[str, list[str]] = dict( + chain( + _find_namespaces(self.dist.packages or [], roots), + ((ns, []) for ns in _find_virtual_namespaces(roots)), + ) + ) + + legacy_namespaces = { + pkg: find_package_path(pkg, roots, self.dist.src_root or "") + for pkg in self.dist.namespace_packages or [] + } + + mapping = {**roots, **legacy_namespaces} + # ^-- We need to explicitly add the legacy_namespaces to the mapping to be + # able to import their modules even if another package sharing the same + # namespace is installed in a conventional (non-editable) way. + + name = f"__editable__.{self.name}.finder" + finder = _normalization.safe_identifier(name) + return finder, name, mapping, namespaces_ + + def get_implementation(self) -> Iterator[tuple[str, bytes]]: + finder, name, mapping, namespaces_ = self.template_vars() + + content = bytes(_finder_template(name, mapping, namespaces_), "utf-8") + yield (f"{finder}.py", content) + + content = _encode_pth(f"import {finder}; {finder}.install()") + yield (f"__editable__.{self.name}.pth", content) + + def __call__(self, wheel: WheelFile, files: list[str], mapping: dict[str, str]): + for file, content in self.get_implementation(): + wheel.writestr(file, content) + + def __enter__(self): + msg = "Editable install will be performed using a meta path finder.\n" + _logger.warning(msg + _LENIENT_WARNING) + return self + + def __exit__(self, _exc_type, _exc_value, _traceback): + msg = """\n + Please be careful with folders in your working directory with the same + name as your package as they may take precedence during imports. + """ + InformationOnly.emit("Editable installation.", msg) + + +def _encode_pth(content: str) -> bytes: + """.pth files are always read with 'locale' encoding, the recommendation + from the cpython core developers is to write them as ``open(path, "w")`` + and ignore warnings (see python/cpython#77102, pypa/setuptools#3937). + This function tries to simulate this behaviour without having to create an + actual file, in a way that supports a range of active Python versions. + (There seems to be some variety in the way different version of Python handle + ``encoding=None``, not all of them use ``locale.getpreferredencoding(False)`` + or ``locale.getencoding()``). + """ + with io.BytesIO() as buffer: + wrapper = io.TextIOWrapper(buffer, encoding=py39.LOCALE_ENCODING) + wrapper.write(content) + wrapper.flush() + buffer.seek(0) + return buffer.read() + + +def _can_symlink_files(base_dir: Path) -> bool: + with TemporaryDirectory(dir=str(base_dir.resolve())) as tmp: + path1, path2 = Path(tmp, "file1.txt"), Path(tmp, "file2.txt") + path1.write_text("file1", encoding="utf-8") + with suppress(AttributeError, NotImplementedError, OSError): + os.symlink(path1, path2) + if path2.is_symlink() and path2.read_text(encoding="utf-8") == "file1": + return True + + try: + os.link(path1, path2) # Ensure hard links can be created + except Exception as ex: + msg = ( + "File system does not seem to support either symlinks or hard links. " + "Strict editable installs require one of them to be supported." + ) + raise LinksNotSupported(msg) from ex + return False + + +def _simple_layout( + packages: Iterable[str], package_dir: dict[str, str], project_dir: StrPath +) -> bool: + """Return ``True`` if: + - all packages are contained by the same parent directory, **and** + - all packages become importable if the parent directory is added to ``sys.path``. + + >>> _simple_layout(['a'], {"": "src"}, "/tmp/myproj") + True + >>> _simple_layout(['a', 'a.b'], {"": "src"}, "/tmp/myproj") + True + >>> _simple_layout(['a', 'a.b'], {}, "/tmp/myproj") + True + >>> _simple_layout(['a', 'a.a1', 'a.a1.a2', 'b'], {"": "src"}, "/tmp/myproj") + True + >>> _simple_layout(['a', 'a.a1', 'a.a1.a2', 'b'], {"a": "a", "b": "b"}, ".") + True + >>> _simple_layout(['a', 'a.a1', 'a.a1.a2', 'b'], {"a": "_a", "b": "_b"}, ".") + False + >>> _simple_layout(['a', 'a.a1', 'a.a1.a2', 'b'], {"a": "_a"}, "/tmp/myproj") + False + >>> _simple_layout(['a', 'a.a1', 'a.a1.a2', 'b'], {"a.a1.a2": "_a2"}, ".") + False + >>> _simple_layout(['a', 'a.b'], {"": "src", "a.b": "_ab"}, "/tmp/myproj") + False + >>> # Special cases, no packages yet: + >>> _simple_layout([], {"": "src"}, "/tmp/myproj") + True + >>> _simple_layout([], {"a": "_a", "": "src"}, "/tmp/myproj") + False + """ + layout = {pkg: find_package_path(pkg, package_dir, project_dir) for pkg in packages} + if not layout: + return set(package_dir) in ({}, {""}) + parent = os.path.commonpath(starmap(_parent_path, layout.items())) + return all( + _path.same_path(Path(parent, *key.split('.')), value) + for key, value in layout.items() + ) + + +def _parent_path(pkg, pkg_path): + """Infer the parent path containing a package, that if added to ``sys.path`` would + allow importing that package. + When ``pkg`` is directly mapped into a directory with a different name, return its + own path. + >>> _parent_path("a", "src/a") + 'src' + >>> _parent_path("b", "src/c") + 'src/c' + """ + parent = pkg_path[: -len(pkg)] if pkg_path.endswith(pkg) else pkg_path + return parent.rstrip("/" + os.sep) + + +def _find_packages(dist: Distribution) -> Iterator[str]: + yield from iter(dist.packages or []) + + py_modules = dist.py_modules or [] + nested_modules = [mod for mod in py_modules if "." in mod] + if dist.ext_package: + yield dist.ext_package + else: + ext_modules = dist.ext_modules or [] + nested_modules += [x.name for x in ext_modules if "." in x.name] + + for module in nested_modules: + package, _, _ = module.rpartition(".") + yield package + + +def _find_top_level_modules(dist: Distribution) -> Iterator[str]: + py_modules = dist.py_modules or [] + yield from (mod for mod in py_modules if "." not in mod) + + if not dist.ext_package: + ext_modules = dist.ext_modules or [] + yield from (x.name for x in ext_modules if "." not in x.name) + + +def _find_package_roots( + packages: Iterable[str], + package_dir: Mapping[str, str], + src_root: StrPath, +) -> dict[str, str]: + pkg_roots: dict[str, str] = { + pkg: _absolute_root(find_package_path(pkg, package_dir, src_root)) + for pkg in sorted(packages) + } + + return _remove_nested(pkg_roots) + + +def _absolute_root(path: StrPath) -> str: + """Works for packages and top-level modules""" + path_ = Path(path) + parent = path_.parent + + if path_.exists(): + return str(path_.resolve()) + else: + return str(parent.resolve() / path_.name) + + +def _find_virtual_namespaces(pkg_roots: dict[str, str]) -> Iterator[str]: + """By carefully designing ``package_dir``, it is possible to implement the logical + structure of PEP 420 in a package without the corresponding directories. + + Moreover a parent package can be purposefully/accidentally skipped in the discovery + phase (e.g. ``find_packages(include=["mypkg.*"])``, when ``mypkg.foo`` is included + by ``mypkg`` itself is not). + We consider this case to also be a virtual namespace (ignoring the original + directory) to emulate a non-editable installation. + + This function will try to find these kinds of namespaces. + """ + for pkg in pkg_roots: + if "." not in pkg: + continue + parts = pkg.split(".") + for i in range(len(parts) - 1, 0, -1): + partial_name = ".".join(parts[:i]) + path = Path(find_package_path(partial_name, pkg_roots, "")) + if not path.exists() or partial_name not in pkg_roots: + # partial_name not in pkg_roots ==> purposefully/accidentally skipped + yield partial_name + + +def _find_namespaces( + packages: list[str], pkg_roots: dict[str, str] +) -> Iterator[tuple[str, list[str]]]: + for pkg in packages: + path = find_package_path(pkg, pkg_roots, "") + if Path(path).exists() and not Path(path, "__init__.py").exists(): + yield (pkg, [path]) + + +def _remove_nested(pkg_roots: dict[str, str]) -> dict[str, str]: + output = dict(pkg_roots.copy()) + + for pkg, path in reversed(list(pkg_roots.items())): + if any( + pkg != other and _is_nested(pkg, path, other, other_path) + for other, other_path in pkg_roots.items() + ): + output.pop(pkg) + + return output + + +def _is_nested(pkg: str, pkg_path: str, parent: str, parent_path: str) -> bool: + """ + Return ``True`` if ``pkg`` is nested inside ``parent`` both logically and in the + file system. + >>> _is_nested("a.b", "path/a/b", "a", "path/a") + True + >>> _is_nested("a.b", "path/a/b", "a", "otherpath/a") + False + >>> _is_nested("a.b", "path/a/b", "c", "path/c") + False + >>> _is_nested("a.a", "path/a/a", "a", "path/a") + True + >>> _is_nested("b.a", "path/b/a", "a", "path/a") + False + """ + norm_pkg_path = _path.normpath(pkg_path) + rest = pkg.replace(parent, "", 1).strip(".").split(".") + return pkg.startswith(parent) and norm_pkg_path == _path.normpath( + Path(parent_path, *rest) + ) + + +def _empty_dir(dir_: _P) -> _P: + """Create a directory ensured to be empty. Existing files may be removed.""" + shutil.rmtree(dir_, ignore_errors=True) + os.makedirs(dir_) + return dir_ + + +class _NamespaceInstaller(namespaces.Installer): + def __init__(self, distribution, installation_dir, editable_name, src_root): + self.distribution = distribution + self.src_root = src_root + self.installation_dir = installation_dir + self.editable_name = editable_name + self.outputs = [] + self.dry_run = False + + def _get_nspkg_file(self): + """Installation target.""" + return os.path.join(self.installation_dir, self.editable_name + self.nspkg_ext) + + def _get_root(self): + """Where the modules/packages should be loaded from.""" + return repr(str(self.src_root)) + + +_FINDER_TEMPLATE = """\ +from __future__ import annotations +import sys +from importlib.machinery import ModuleSpec, PathFinder +from importlib.machinery import all_suffixes as module_suffixes +from importlib.util import spec_from_file_location +from itertools import chain +from pathlib import Path + +MAPPING: dict[str, str] = {mapping!r} +NAMESPACES: dict[str, list[str]] = {namespaces!r} +PATH_PLACEHOLDER = {name!r} + ".__path_hook__" + + +class _EditableFinder: # MetaPathFinder + @classmethod + def find_spec(cls, fullname: str, _path=None, _target=None) -> ModuleSpec | None: + # Top-level packages and modules (we know these exist in the FS) + if fullname in MAPPING: + pkg_path = MAPPING[fullname] + return cls._find_spec(fullname, Path(pkg_path)) + + # Handle immediate children modules (required for namespaces to work) + # To avoid problems with case sensitivity in the file system we delegate + # to the importlib.machinery implementation. + parent, _, child = fullname.rpartition(".") + if parent and parent in MAPPING: + return PathFinder.find_spec(fullname, path=[MAPPING[parent]]) + + # Other levels of nesting should be handled automatically by importlib + # using the parent path. + return None + + @classmethod + def _find_spec(cls, fullname: str, candidate_path: Path) -> ModuleSpec | None: + init = candidate_path / "__init__.py" + candidates = (candidate_path.with_suffix(x) for x in module_suffixes()) + for candidate in chain([init], candidates): + if candidate.exists(): + return spec_from_file_location(fullname, candidate) + return None + + +class _EditableNamespaceFinder: # PathEntryFinder + @classmethod + def _path_hook(cls, path) -> type[_EditableNamespaceFinder]: + if path == PATH_PLACEHOLDER: + return cls + raise ImportError + + @classmethod + def _paths(cls, fullname: str) -> list[str]: + paths = NAMESPACES[fullname] + if not paths and fullname in MAPPING: + paths = [MAPPING[fullname]] + # Always add placeholder, for 2 reasons: + # 1. __path__ cannot be empty for the spec to be considered namespace. + # 2. In the case of nested namespaces, we need to force + # import machinery to query _EditableNamespaceFinder again. + return [*paths, PATH_PLACEHOLDER] + + @classmethod + def find_spec(cls, fullname: str, _target=None) -> ModuleSpec | None: + if fullname in NAMESPACES: + spec = ModuleSpec(fullname, None, is_package=True) + spec.submodule_search_locations = cls._paths(fullname) + return spec + return None + + @classmethod + def find_module(cls, _fullname) -> None: + return None + + +def install(): + if not any(finder == _EditableFinder for finder in sys.meta_path): + sys.meta_path.append(_EditableFinder) + + if not NAMESPACES: + return + + if not any(hook == _EditableNamespaceFinder._path_hook for hook in sys.path_hooks): + # PathEntryFinder is needed to create NamespaceSpec without private APIS + sys.path_hooks.append(_EditableNamespaceFinder._path_hook) + if PATH_PLACEHOLDER not in sys.path: + sys.path.append(PATH_PLACEHOLDER) # Used just to trigger the path hook +""" + + +def _finder_template( + name: str, mapping: Mapping[str, str], namespaces: dict[str, list[str]] +) -> str: + """Create a string containing the code for the``MetaPathFinder`` and + ``PathEntryFinder``. + """ + mapping = dict(sorted(mapping.items(), key=lambda p: p[0])) + return _FINDER_TEMPLATE.format(name=name, mapping=mapping, namespaces=namespaces) + + +class LinksNotSupported(errors.FileError): + """File system does not seem to support either symlinks or hard links.""" + + +class _DebuggingTips(SetuptoolsWarning): + _SUMMARY = "Problem in editable installation." + _DETAILS = """ + An error happened while installing `{project}` in editable mode. + + The following steps are recommended to help debug this problem: + + - Try to install the project normally, without using the editable mode. + Does the error still persist? + (If it does, try fixing the problem before attempting the editable mode). + - If you are using binary extensions, make sure you have all OS-level + dependencies installed (e.g. compilers, toolchains, binary libraries, ...). + - Try the latest version of setuptools (maybe the error was already fixed). + - If you (or your project dependencies) are using any setuptools extension + or customization, make sure they support the editable mode. + + After following the steps above, if the problem still persists and + you think this is related to how setuptools handles editable installations, + please submit a reproducible example + (see https://stackoverflow.com/help/minimal-reproducible-example) to: + + https://github.com/pypa/setuptools/issues + """ + _SEE_DOCS = "userguide/development_mode.html" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/egg_info.py b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/egg_info.py new file mode 100644 index 0000000000000000000000000000000000000000..2f2030334132ba31daf397fff0a9f9652a9f4bfb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/egg_info.py @@ -0,0 +1,737 @@ +"""setuptools.command.egg_info + +Create a distribution's .egg-info directory and contents""" + +from distutils.filelist import FileList as _FileList +from distutils.errors import DistutilsInternalError +from distutils.util import convert_path +from distutils import log +import distutils.errors +import distutils.filelist +import functools +import os +import re +import sys +import time +import collections + +from .._importlib import metadata +from .. import _entry_points, _normalization +from . import _requirestxt + +from setuptools import Command +from setuptools.command.sdist import sdist +from setuptools.command.sdist import walk_revctrl +from setuptools.command.setopt import edit_config +from setuptools.command import bdist_egg +import setuptools.unicode_utils as unicode_utils +from setuptools.glob import glob + +from setuptools.extern import packaging +from ..warnings import SetuptoolsDeprecationWarning + + +PY_MAJOR = '{}.{}'.format(*sys.version_info) + + +def translate_pattern(glob): # noqa: C901 # is too complex (14) # FIXME + """ + Translate a file path glob like '*.txt' in to a regular expression. + This differs from fnmatch.translate which allows wildcards to match + directory separators. It also knows about '**/' which matches any number of + directories. + """ + pat = '' + + # This will split on '/' within [character classes]. This is deliberate. + chunks = glob.split(os.path.sep) + + sep = re.escape(os.sep) + valid_char = '[^%s]' % (sep,) + + for c, chunk in enumerate(chunks): + last_chunk = c == len(chunks) - 1 + + # Chunks that are a literal ** are globstars. They match anything. + if chunk == '**': + if last_chunk: + # Match anything if this is the last component + pat += '.*' + else: + # Match '(name/)*' + pat += '(?:%s+%s)*' % (valid_char, sep) + continue # Break here as the whole path component has been handled + + # Find any special characters in the remainder + i = 0 + chunk_len = len(chunk) + while i < chunk_len: + char = chunk[i] + if char == '*': + # Match any number of name characters + pat += valid_char + '*' + elif char == '?': + # Match a name character + pat += valid_char + elif char == '[': + # Character class + inner_i = i + 1 + # Skip initial !/] chars + if inner_i < chunk_len and chunk[inner_i] == '!': + inner_i = inner_i + 1 + if inner_i < chunk_len and chunk[inner_i] == ']': + inner_i = inner_i + 1 + + # Loop till the closing ] is found + while inner_i < chunk_len and chunk[inner_i] != ']': + inner_i = inner_i + 1 + + if inner_i >= chunk_len: + # Got to the end of the string without finding a closing ] + # Do not treat this as a matching group, but as a literal [ + pat += re.escape(char) + else: + # Grab the insides of the [brackets] + inner = chunk[i + 1 : inner_i] + char_class = '' + + # Class negation + if inner[0] == '!': + char_class = '^' + inner = inner[1:] + + char_class += re.escape(inner) + pat += '[%s]' % (char_class,) + + # Skip to the end ] + i = inner_i + else: + pat += re.escape(char) + i += 1 + + # Join each chunk with the dir separator + if not last_chunk: + pat += sep + + pat += r'\Z' + return re.compile(pat, flags=re.MULTILINE | re.DOTALL) + + +class InfoCommon: + tag_build = None + tag_date = None + + @property + def name(self): + return _normalization.safe_name(self.distribution.get_name()) + + def tagged_version(self): + tagged = self._maybe_tag(self.distribution.get_version()) + return _normalization.safe_version(tagged) + + def _maybe_tag(self, version): + """ + egg_info may be called more than once for a distribution, + in which case the version string already contains all tags. + """ + return ( + version + if self.vtags and self._already_tagged(version) + else version + self.vtags + ) + + def _already_tagged(self, version: str) -> bool: + # Depending on their format, tags may change with version normalization. + # So in addition the regular tags, we have to search for the normalized ones. + return version.endswith(self.vtags) or version.endswith(self._safe_tags()) + + def _safe_tags(self) -> str: + # To implement this we can rely on `safe_version` pretending to be version 0 + # followed by tags. Then we simply discard the starting 0 (fake version number) + try: + return _normalization.safe_version(f"0{self.vtags}")[1:] + except packaging.version.InvalidVersion: + return _normalization.safe_name(self.vtags.replace(' ', '.')) + + def tags(self) -> str: + version = '' + if self.tag_build: + version += self.tag_build + if self.tag_date: + version += time.strftime("%Y%m%d") + return version + + vtags = property(tags) + + +class egg_info(InfoCommon, Command): + description = "create a distribution's .egg-info directory" + + user_options = [ + ( + 'egg-base=', + 'e', + "directory containing .egg-info directories" + " [default: top of the source tree]", + ), + ('tag-date', 'd', "Add date stamp (e.g. 20050528) to version number"), + ('tag-build=', 'b', "Specify explicit tag to add to version number"), + ('no-date', 'D', "Don't include date stamp [default]"), + ] + + boolean_options = ['tag-date'] + negative_opt = { + 'no-date': 'tag-date', + } + + def initialize_options(self): + self.egg_base = None + self.egg_name = None + self.egg_info = None + self.egg_version = None + self.ignore_egg_info_in_manifest = False + + #################################### + # allow the 'tag_svn_revision' to be detected and + # set, supporting sdists built on older Setuptools. + @property + def tag_svn_revision(self): + pass + + @tag_svn_revision.setter + def tag_svn_revision(self, value): + pass + + #################################### + + def save_version_info(self, filename): + """ + Materialize the value of date into the + build tag. Install build keys in a deterministic order + to avoid arbitrary reordering on subsequent builds. + """ + egg_info = collections.OrderedDict() + # follow the order these keys would have been added + # when PYTHONHASHSEED=0 + egg_info['tag_build'] = self.tags() + egg_info['tag_date'] = 0 + edit_config(filename, dict(egg_info=egg_info)) + + def finalize_options(self): + # Note: we need to capture the current value returned + # by `self.tagged_version()`, so we can later update + # `self.distribution.metadata.version` without + # repercussions. + self.egg_name = self.name + self.egg_version = self.tagged_version() + parsed_version = packaging.version.Version(self.egg_version) + + try: + is_version = isinstance(parsed_version, packaging.version.Version) + spec = "%s==%s" if is_version else "%s===%s" + packaging.requirements.Requirement(spec % (self.egg_name, self.egg_version)) + except ValueError as e: + raise distutils.errors.DistutilsOptionError( + "Invalid distribution name or version syntax: %s-%s" + % (self.egg_name, self.egg_version) + ) from e + + if self.egg_base is None: + dirs = self.distribution.package_dir + self.egg_base = (dirs or {}).get('', os.curdir) + + self.ensure_dirname('egg_base') + self.egg_info = _normalization.filename_component(self.egg_name) + '.egg-info' + if self.egg_base != os.curdir: + self.egg_info = os.path.join(self.egg_base, self.egg_info) + + # Set package version for the benefit of dumber commands + # (e.g. sdist, bdist_wininst, etc.) + # + self.distribution.metadata.version = self.egg_version + + # If we bootstrapped around the lack of a PKG-INFO, as might be the + # case in a fresh checkout, make sure that any special tags get added + # to the version info + # + pd = self.distribution._patched_dist + key = getattr(pd, "key", None) or getattr(pd, "name", None) + if pd is not None and key == self.egg_name.lower(): + pd._version = self.egg_version + pd._parsed_version = packaging.version.Version(self.egg_version) + self.distribution._patched_dist = None + + def _get_egg_basename(self, py_version=PY_MAJOR, platform=None): + """Compute filename of the output egg. Private API.""" + return _egg_basename(self.egg_name, self.egg_version, py_version, platform) + + def write_or_delete_file(self, what, filename, data, force=False): + """Write `data` to `filename` or delete if empty + + If `data` is non-empty, this routine is the same as ``write_file()``. + If `data` is empty but not ``None``, this is the same as calling + ``delete_file(filename)`. If `data` is ``None``, then this is a no-op + unless `filename` exists, in which case a warning is issued about the + orphaned file (if `force` is false), or deleted (if `force` is true). + """ + if data: + self.write_file(what, filename, data) + elif os.path.exists(filename): + if data is None and not force: + log.warn("%s not set in setup(), but %s exists", what, filename) + return + else: + self.delete_file(filename) + + def write_file(self, what, filename, data): + """Write `data` to `filename` (if not a dry run) after announcing it + + `what` is used in a log message to identify what is being written + to the file. + """ + log.info("writing %s to %s", what, filename) + data = data.encode("utf-8") + if not self.dry_run: + f = open(filename, 'wb') + f.write(data) + f.close() + + def delete_file(self, filename): + """Delete `filename` (if not a dry run) after announcing it""" + log.info("deleting %s", filename) + if not self.dry_run: + os.unlink(filename) + + def run(self): + self.mkpath(self.egg_info) + try: + os.utime(self.egg_info, None) + except OSError as e: + msg = f"Cannot update time stamp of directory '{self.egg_info}'" + raise distutils.errors.DistutilsFileError(msg) from e + for ep in metadata.entry_points(group='egg_info.writers'): + writer = ep.load() + writer(self, ep.name, os.path.join(self.egg_info, ep.name)) + + # Get rid of native_libs.txt if it was put there by older bdist_egg + nl = os.path.join(self.egg_info, "native_libs.txt") + if os.path.exists(nl): + self.delete_file(nl) + + self.find_sources() + + def find_sources(self): + """Generate SOURCES.txt manifest file""" + manifest_filename = os.path.join(self.egg_info, "SOURCES.txt") + mm = manifest_maker(self.distribution) + mm.ignore_egg_info_dir = self.ignore_egg_info_in_manifest + mm.manifest = manifest_filename + mm.run() + self.filelist = mm.filelist + + +class FileList(_FileList): + # Implementations of the various MANIFEST.in commands + + def __init__(self, warn=None, debug_print=None, ignore_egg_info_dir=False): + super().__init__(warn, debug_print) + self.ignore_egg_info_dir = ignore_egg_info_dir + + def process_template_line(self, line): + # Parse the line: split it up, make sure the right number of words + # is there, and return the relevant words. 'action' is always + # defined: it's the first word of the line. Which of the other + # three are defined depends on the action; it'll be either + # patterns, (dir and patterns), or (dir_pattern). + (action, patterns, dir, dir_pattern) = self._parse_template_line(line) + + action_map = { + 'include': self.include, + 'exclude': self.exclude, + 'global-include': self.global_include, + 'global-exclude': self.global_exclude, + 'recursive-include': functools.partial( + self.recursive_include, + dir, + ), + 'recursive-exclude': functools.partial( + self.recursive_exclude, + dir, + ), + 'graft': self.graft, + 'prune': self.prune, + } + log_map = { + 'include': "warning: no files found matching '%s'", + 'exclude': ("warning: no previously-included files found matching '%s'"), + 'global-include': ( + "warning: no files found matching '%s' anywhere in distribution" + ), + 'global-exclude': ( + "warning: no previously-included files matching " + "'%s' found anywhere in distribution" + ), + 'recursive-include': ( + "warning: no files found matching '%s' under directory '%s'" + ), + 'recursive-exclude': ( + "warning: no previously-included files matching " + "'%s' found under directory '%s'" + ), + 'graft': "warning: no directories found matching '%s'", + 'prune': "no previously-included directories found matching '%s'", + } + + try: + process_action = action_map[action] + except KeyError: + msg = f"Invalid MANIFEST.in: unknown action {action!r} in {line!r}" + raise DistutilsInternalError(msg) from None + + # OK, now we know that the action is valid and we have the + # right number of words on the line for that action -- so we + # can proceed with minimal error-checking. + + action_is_recursive = action.startswith('recursive-') + if action in {'graft', 'prune'}: + patterns = [dir_pattern] + extra_log_args = (dir,) if action_is_recursive else () + log_tmpl = log_map[action] + + self.debug_print( + ' '.join( + [action] + ([dir] if action_is_recursive else []) + patterns, + ) + ) + for pattern in patterns: + if not process_action(pattern): + log.warn(log_tmpl, pattern, *extra_log_args) + + def _remove_files(self, predicate): + """ + Remove all files from the file list that match the predicate. + Return True if any matching files were removed + """ + found = False + for i in range(len(self.files) - 1, -1, -1): + if predicate(self.files[i]): + self.debug_print(" removing " + self.files[i]) + del self.files[i] + found = True + return found + + def include(self, pattern): + """Include files that match 'pattern'.""" + found = [f for f in glob(pattern) if not os.path.isdir(f)] + self.extend(found) + return bool(found) + + def exclude(self, pattern): + """Exclude files that match 'pattern'.""" + match = translate_pattern(pattern) + return self._remove_files(match.match) + + def recursive_include(self, dir, pattern): + """ + Include all files anywhere in 'dir/' that match the pattern. + """ + full_pattern = os.path.join(dir, '**', pattern) + found = [f for f in glob(full_pattern, recursive=True) if not os.path.isdir(f)] + self.extend(found) + return bool(found) + + def recursive_exclude(self, dir, pattern): + """ + Exclude any file anywhere in 'dir/' that match the pattern. + """ + match = translate_pattern(os.path.join(dir, '**', pattern)) + return self._remove_files(match.match) + + def graft(self, dir): + """Include all files from 'dir/'.""" + found = [ + item + for match_dir in glob(dir) + for item in distutils.filelist.findall(match_dir) + ] + self.extend(found) + return bool(found) + + def prune(self, dir): + """Filter out files from 'dir/'.""" + match = translate_pattern(os.path.join(dir, '**')) + return self._remove_files(match.match) + + def global_include(self, pattern): + """ + Include all files anywhere in the current directory that match the + pattern. This is very inefficient on large file trees. + """ + if self.allfiles is None: + self.findall() + match = translate_pattern(os.path.join('**', pattern)) + found = [f for f in self.allfiles if match.match(f)] + self.extend(found) + return bool(found) + + def global_exclude(self, pattern): + """ + Exclude all files anywhere that match the pattern. + """ + match = translate_pattern(os.path.join('**', pattern)) + return self._remove_files(match.match) + + def append(self, item): + if item.endswith('\r'): # Fix older sdists built on Windows + item = item[:-1] + path = convert_path(item) + + if self._safe_path(path): + self.files.append(path) + + def extend(self, paths): + self.files.extend(filter(self._safe_path, paths)) + + def _repair(self): + """ + Replace self.files with only safe paths + + Because some owners of FileList manipulate the underlying + ``files`` attribute directly, this method must be called to + repair those paths. + """ + self.files = list(filter(self._safe_path, self.files)) + + def _safe_path(self, path): + enc_warn = "'%s' not %s encodable -- skipping" + + # To avoid accidental trans-codings errors, first to unicode + u_path = unicode_utils.filesys_decode(path) + if u_path is None: + log.warn("'%s' in unexpected encoding -- skipping" % path) + return False + + # Must ensure utf-8 encodability + utf8_path = unicode_utils.try_encode(u_path, "utf-8") + if utf8_path is None: + log.warn(enc_warn, path, 'utf-8') + return False + + try: + # ignore egg-info paths + is_egg_info = ".egg-info" in u_path or b".egg-info" in utf8_path + if self.ignore_egg_info_dir and is_egg_info: + return False + # accept is either way checks out + if os.path.exists(u_path) or os.path.exists(utf8_path): + return True + # this will catch any encode errors decoding u_path + except UnicodeEncodeError: + log.warn(enc_warn, path, sys.getfilesystemencoding()) + + +class manifest_maker(sdist): + template = "MANIFEST.in" + + def initialize_options(self): + self.use_defaults = True + self.prune = True + self.manifest_only = True + self.force_manifest = True + self.ignore_egg_info_dir = False + + def finalize_options(self): + pass + + def run(self): + self.filelist = FileList(ignore_egg_info_dir=self.ignore_egg_info_dir) + if not os.path.exists(self.manifest): + self.write_manifest() # it must exist so it'll get in the list + self.add_defaults() + if os.path.exists(self.template): + self.read_template() + self.add_license_files() + self._add_referenced_files() + self.prune_file_list() + self.filelist.sort() + self.filelist.remove_duplicates() + self.write_manifest() + + def _manifest_normalize(self, path): + path = unicode_utils.filesys_decode(path) + return path.replace(os.sep, '/') + + def write_manifest(self): + """ + Write the file list in 'self.filelist' to the manifest file + named by 'self.manifest'. + """ + self.filelist._repair() + + # Now _repairs should encodability, but not unicode + files = [self._manifest_normalize(f) for f in self.filelist.files] + msg = "writing manifest file '%s'" % self.manifest + self.execute(write_file, (self.manifest, files), msg) + + def warn(self, msg): + if not self._should_suppress_warning(msg): + sdist.warn(self, msg) + + @staticmethod + def _should_suppress_warning(msg): + """ + suppress missing-file warnings from sdist + """ + return re.match(r"standard file .*not found", msg) + + def add_defaults(self): + sdist.add_defaults(self) + self.filelist.append(self.template) + self.filelist.append(self.manifest) + rcfiles = list(walk_revctrl()) + if rcfiles: + self.filelist.extend(rcfiles) + elif os.path.exists(self.manifest): + self.read_manifest() + + if os.path.exists("setup.py"): + # setup.py should be included by default, even if it's not + # the script called to create the sdist + self.filelist.append("setup.py") + + ei_cmd = self.get_finalized_command('egg_info') + self.filelist.graft(ei_cmd.egg_info) + + def add_license_files(self): + license_files = self.distribution.metadata.license_files or [] + for lf in license_files: + log.info("adding license file '%s'", lf) + self.filelist.extend(license_files) + + def _add_referenced_files(self): + """Add files referenced by the config (e.g. `file:` directive) to filelist""" + referenced = getattr(self.distribution, '_referenced_files', []) + # ^-- fallback if dist comes from distutils or is a custom class + for rf in referenced: + log.debug("adding file referenced by config '%s'", rf) + self.filelist.extend(referenced) + + def prune_file_list(self): + build = self.get_finalized_command('build') + base_dir = self.distribution.get_fullname() + self.filelist.prune(build.build_base) + self.filelist.prune(base_dir) + sep = re.escape(os.sep) + self.filelist.exclude_pattern( + r'(^|' + sep + r')(RCS|CVS|\.svn)' + sep, is_regex=True + ) + + def _safe_data_files(self, build_py): + """ + The parent class implementation of this method + (``sdist``) will try to include data files, which + might cause recursion problems when + ``include_package_data=True``. + + Therefore, avoid triggering any attempt of + analyzing/building the manifest again. + """ + if hasattr(build_py, 'get_data_files_without_manifest'): + return build_py.get_data_files_without_manifest() + + SetuptoolsDeprecationWarning.emit( + "`build_py` command does not inherit from setuptools' `build_py`.", + """ + Custom 'build_py' does not implement 'get_data_files_without_manifest'. + Please extend command classes from setuptools instead of distutils. + """, + see_url="https://peps.python.org/pep-0632/", + # due_date not defined yet, old projects might still do it? + ) + return build_py.get_data_files() + + +def write_file(filename, contents): + """Create a file with the specified name and write 'contents' (a + sequence of strings without line terminators) to it. + """ + contents = "\n".join(contents) + + # assuming the contents has been vetted for utf-8 encoding + contents = contents.encode("utf-8") + + with open(filename, "wb") as f: # always write POSIX-style manifest + f.write(contents) + + +def write_pkg_info(cmd, basename, filename): + log.info("writing %s", filename) + if not cmd.dry_run: + metadata = cmd.distribution.metadata + metadata.version, oldver = cmd.egg_version, metadata.version + metadata.name, oldname = cmd.egg_name, metadata.name + + try: + # write unescaped data to PKG-INFO, so older pkg_resources + # can still parse it + metadata.write_pkg_info(cmd.egg_info) + finally: + metadata.name, metadata.version = oldname, oldver + + safe = getattr(cmd.distribution, 'zip_safe', None) + + bdist_egg.write_safety_flag(cmd.egg_info, safe) + + +def warn_depends_obsolete(cmd, basename, filename): + """ + Unused: left to avoid errors when updating (from source) from <= 67.8. + Old installations have a .dist-info directory with the entry-point + ``depends.txt = setuptools.command.egg_info:warn_depends_obsolete``. + This may trigger errors when running the first egg_info in build_meta. + TODO: Remove this function in a version sufficiently > 68. + """ + + +# Export API used in entry_points +write_requirements = _requirestxt.write_requirements +write_setup_requirements = _requirestxt.write_setup_requirements + + +def write_toplevel_names(cmd, basename, filename): + pkgs = dict.fromkeys([ + k.split('.', 1)[0] for k in cmd.distribution.iter_distribution_names() + ]) + cmd.write_file("top-level names", filename, '\n'.join(sorted(pkgs)) + '\n') + + +def overwrite_arg(cmd, basename, filename): + write_arg(cmd, basename, filename, True) + + +def write_arg(cmd, basename, filename, force=False): + argname = os.path.splitext(basename)[0] + value = getattr(cmd.distribution, argname, None) + if value is not None: + value = '\n'.join(value) + '\n' + cmd.write_or_delete_file(argname, filename, value, force) + + +def write_entries(cmd, basename, filename): + eps = _entry_points.load(cmd.distribution.entry_points) + defn = _entry_points.render(eps) + cmd.write_or_delete_file('entry points', filename, defn, True) + + +def _egg_basename(egg_name, egg_version, py_version=None, platform=None): + """Compute filename of the output egg. Private API.""" + name = _normalization.filename_component(egg_name) + version = _normalization.filename_component(egg_version) + egg = f"{name}-{version}-py{py_version or PY_MAJOR}" + if platform: + egg += f"-{platform}" + return egg + + +class EggInfoDeprecationWarning(SetuptoolsDeprecationWarning): + """Deprecated behavior warning for EggInfo, bypassing suppression.""" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/install_lib.py b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/install_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..5e74be247e5ad51474131160517a865facd52f34 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/install_lib.py @@ -0,0 +1,126 @@ +import os +import sys +from itertools import product, starmap +import distutils.command.install_lib as orig +from .._path import StrPath + + +class install_lib(orig.install_lib): + """Don't add compiled flags to filenames of non-Python files""" + + def run(self): + self.build() + outfiles = self.install() + if outfiles is not None: + # always compile, in case we have any extension stubs to deal with + self.byte_compile(outfiles) + + def get_exclusions(self): + """ + Return a collections.Sized collections.Container of paths to be + excluded for single_version_externally_managed installations. + """ + all_packages = ( + pkg + for ns_pkg in self._get_SVEM_NSPs() + for pkg in self._all_packages(ns_pkg) + ) + + excl_specs = product(all_packages, self._gen_exclusion_paths()) + return set(starmap(self._exclude_pkg_path, excl_specs)) + + def _exclude_pkg_path(self, pkg, exclusion_path): + """ + Given a package name and exclusion path within that package, + compute the full exclusion path. + """ + parts = pkg.split('.') + [exclusion_path] + return os.path.join(self.install_dir, *parts) + + @staticmethod + def _all_packages(pkg_name): + """ + >>> list(install_lib._all_packages('foo.bar.baz')) + ['foo.bar.baz', 'foo.bar', 'foo'] + """ + while pkg_name: + yield pkg_name + pkg_name, sep, child = pkg_name.rpartition('.') + + def _get_SVEM_NSPs(self): + """ + Get namespace packages (list) but only for + single_version_externally_managed installations and empty otherwise. + """ + # TODO: is it necessary to short-circuit here? i.e. what's the cost + # if get_finalized_command is called even when namespace_packages is + # False? + if not self.distribution.namespace_packages: + return [] + + install_cmd = self.get_finalized_command('install') + svem = install_cmd.single_version_externally_managed + + return self.distribution.namespace_packages if svem else [] + + @staticmethod + def _gen_exclusion_paths(): + """ + Generate file paths to be excluded for namespace packages (bytecode + cache files). + """ + # always exclude the package module itself + yield '__init__.py' + + yield '__init__.pyc' + yield '__init__.pyo' + + if not hasattr(sys, 'implementation'): + return + + base = os.path.join('__pycache__', '__init__.' + sys.implementation.cache_tag) + yield base + '.pyc' + yield base + '.pyo' + yield base + '.opt-1.pyc' + yield base + '.opt-2.pyc' + + def copy_tree( + self, + infile: StrPath, + outfile: str, + preserve_mode=True, + preserve_times=True, + preserve_symlinks=False, + level=1, + ): + assert preserve_mode and preserve_times and not preserve_symlinks + exclude = self.get_exclusions() + + if not exclude: + return orig.install_lib.copy_tree(self, infile, outfile) # type: ignore[arg-type] # Fixed upstream + + # Exclude namespace package __init__.py* files from the output + + from setuptools.archive_util import unpack_directory + from distutils import log + + outfiles = [] + + def pf(src, dst): + if dst in exclude: + log.warn("Skipping installation of %s (namespace package)", dst) + return False + + log.info("copying %s -> %s", src, os.path.dirname(dst)) + outfiles.append(dst) + return dst + + unpack_directory(infile, outfile, pf) + return outfiles + + def get_outputs(self): + outputs = orig.install_lib.get_outputs(self) + exclude = self.get_exclusions() + if exclude: + return [f for f in outputs if f not in exclude] + return outputs diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/install_scripts.py b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/install_scripts.py new file mode 100644 index 0000000000000000000000000000000000000000..f44281b49b679ea5e88b716af348e1851fa4bc3b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/install_scripts.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from distutils import log +import distutils.command.install_scripts as orig +import os +import sys + +from .._path import ensure_directory + + +class install_scripts(orig.install_scripts): + """Do normal script install, plus any egg_info wrapper scripts""" + + def initialize_options(self): + orig.install_scripts.initialize_options(self) + self.no_ep = False + + def run(self) -> None: + self.run_command("egg_info") + if self.distribution.scripts: + orig.install_scripts.run(self) # run first to set up self.outfiles + else: + self.outfiles: list[str] = [] + if self.no_ep: + # don't install entry point scripts into .egg file! + return + self._install_ep_scripts() + + def _install_ep_scripts(self): + # Delay import side-effects + from pkg_resources import Distribution, PathMetadata + from . import easy_install as ei + + ei_cmd = self.get_finalized_command("egg_info") + dist = Distribution( + ei_cmd.egg_base, + PathMetadata(ei_cmd.egg_base, ei_cmd.egg_info), + ei_cmd.egg_name, + ei_cmd.egg_version, + ) + bs_cmd = self.get_finalized_command('build_scripts') + exec_param = getattr(bs_cmd, 'executable', None) + writer = ei.ScriptWriter + if exec_param == sys.executable: + # In case the path to the Python executable contains a space, wrap + # it so it's not split up. + exec_param = [exec_param] + # resolve the writer to the environment + writer = writer.best() + cmd = writer.command_spec_class.best().from_param(exec_param) + for args in writer.get_args(dist, cmd.as_header()): + self.write_script(*args) + + def write_script(self, script_name, contents, mode="t", *ignored): + """Write an executable file to the scripts directory""" + from setuptools.command.easy_install import chmod, current_umask + + log.info("Installing %s script to %s", script_name, self.install_dir) + target = os.path.join(self.install_dir, script_name) + self.outfiles.append(target) + + encoding = None if "b" in mode else "utf-8" + mask = current_umask() + if not self.dry_run: + ensure_directory(target) + with open(target, "w" + mode, encoding=encoding) as f: + f.write(contents) + chmod(target, 0o777 - mask) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/launcher manifest.xml b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/launcher manifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..5972a96d8ded85cc14147ffc1400ec67c3b5a578 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/launcher manifest.xml @@ -0,0 +1,15 @@ + + + + + + + + + + + + diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/register.py b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/register.py new file mode 100644 index 0000000000000000000000000000000000000000..beee9782e7b9310cbe940286b2e1ffa41a184535 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/register.py @@ -0,0 +1,18 @@ +from distutils import log +import distutils.command.register as orig + +from setuptools.errors import RemovedCommandError + + +class register(orig.register): + """Formerly used to register packages on PyPI.""" + + def run(self): + msg = ( + "The register command has been removed, use twine to upload " + "instead (https://pypi.org/p/twine)" + ) + + self.announce("ERROR: " + msg, log.ERROR) + + raise RemovedCommandError(msg) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/rotate.py b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/rotate.py new file mode 100644 index 0000000000000000000000000000000000000000..064d7959ff6364ea56c78466988b2654eef02a44 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/rotate.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from distutils.util import convert_path +from distutils import log +from distutils.errors import DistutilsOptionError +import os +import shutil + +from setuptools import Command + + +class rotate(Command): + """Delete older distributions""" + + description = "delete older distributions, keeping N newest files" + user_options = [ + ('match=', 'm', "patterns to match (required)"), + ('dist-dir=', 'd', "directory where the distributions are"), + ('keep=', 'k', "number of matching distributions to keep"), + ] + + boolean_options: list[str] = [] + + def initialize_options(self): + self.match = None + self.dist_dir = None + self.keep = None + + def finalize_options(self): + if self.match is None: + raise DistutilsOptionError( + "Must specify one or more (comma-separated) match patterns " + "(e.g. '.zip' or '.egg')" + ) + if self.keep is None: + raise DistutilsOptionError("Must specify number of files to keep") + try: + self.keep = int(self.keep) + except ValueError as e: + raise DistutilsOptionError("--keep must be an integer") from e + if isinstance(self.match, str): + self.match = [convert_path(p.strip()) for p in self.match.split(',')] + self.set_undefined_options('bdist', ('dist_dir', 'dist_dir')) + + def run(self): + self.run_command("egg_info") + from glob import glob + + for pattern in self.match: + pattern = self.distribution.get_name() + '*' + pattern + files = glob(os.path.join(self.dist_dir, pattern)) + files = [(os.path.getmtime(f), f) for f in files] + files.sort() + files.reverse() + + log.info("%d file(s) matching %s", len(files), pattern) + files = files[self.keep :] + for t, f in files: + log.info("Deleting %s", f) + if not self.dry_run: + if os.path.isdir(f): + shutil.rmtree(f) + else: + os.unlink(f) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/sdist.py b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/sdist.py new file mode 100644 index 0000000000000000000000000000000000000000..a834ba4a783987ae70c3a34a8fd1f11244833012 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/sdist.py @@ -0,0 +1,204 @@ +from distutils import log +import distutils.command.sdist as orig +import os +import contextlib +from itertools import chain + +from .._importlib import metadata +from .build import _ORIGINAL_SUBCOMMANDS + +_default_revctrl = list + + +def walk_revctrl(dirname=''): + """Find all files under revision control""" + for ep in metadata.entry_points(group='setuptools.file_finders'): + yield from ep.load()(dirname) + + +class sdist(orig.sdist): + """Smart sdist that finds anything supported by revision control""" + + user_options = [ + ('formats=', None, "formats for source distribution (comma-separated list)"), + ( + 'keep-temp', + 'k', + "keep the distribution tree around after creating " + "archive file(s)", + ), + ( + 'dist-dir=', + 'd', + "directory to put the source distribution archive(s) in [default: dist]", + ), + ( + 'owner=', + 'u', + "Owner name used when creating a tar file [default: current user]", + ), + ( + 'group=', + 'g', + "Group name used when creating a tar file [default: current group]", + ), + ] + + negative_opt = {} + + README_EXTENSIONS = ['', '.rst', '.txt', '.md'] + READMES = tuple('README{0}'.format(ext) for ext in README_EXTENSIONS) + + def run(self): + self.run_command('egg_info') + ei_cmd = self.get_finalized_command('egg_info') + self.filelist = ei_cmd.filelist + self.filelist.append(os.path.join(ei_cmd.egg_info, 'SOURCES.txt')) + self.check_readme() + + # Run sub commands + for cmd_name in self.get_sub_commands(): + self.run_command(cmd_name) + + self.make_distribution() + + dist_files = getattr(self.distribution, 'dist_files', []) + for file in self.archive_files: + data = ('sdist', '', file) + if data not in dist_files: + dist_files.append(data) + + def initialize_options(self): + orig.sdist.initialize_options(self) + + def make_distribution(self): + """ + Workaround for #516 + """ + with self._remove_os_link(): + orig.sdist.make_distribution(self) + + @staticmethod + @contextlib.contextmanager + def _remove_os_link(): + """ + In a context, remove and restore os.link if it exists + """ + + class NoValue: + pass + + orig_val = getattr(os, 'link', NoValue) + try: + del os.link + except Exception: + pass + try: + yield + finally: + if orig_val is not NoValue: + os.link = orig_val + + def add_defaults(self): + super().add_defaults() + self._add_defaults_build_sub_commands() + + def _add_defaults_optional(self): + super()._add_defaults_optional() + if os.path.isfile('pyproject.toml'): + self.filelist.append('pyproject.toml') + + def _add_defaults_python(self): + """getting python files""" + if self.distribution.has_pure_modules(): + build_py = self.get_finalized_command('build_py') + self.filelist.extend(build_py.get_source_files()) + self._add_data_files(self._safe_data_files(build_py)) + + def _add_defaults_build_sub_commands(self): + build = self.get_finalized_command("build") + missing_cmds = set(build.get_sub_commands()) - _ORIGINAL_SUBCOMMANDS + # ^-- the original built-in sub-commands are already handled by default. + cmds = (self.get_finalized_command(c) for c in missing_cmds) + files = (c.get_source_files() for c in cmds if hasattr(c, "get_source_files")) + self.filelist.extend(chain.from_iterable(files)) + + def _safe_data_files(self, build_py): + """ + Since the ``sdist`` class is also used to compute the MANIFEST + (via :obj:`setuptools.command.egg_info.manifest_maker`), + there might be recursion problems when trying to obtain the list of + data_files and ``include_package_data=True`` (which in turn depends on + the files included in the MANIFEST). + + To avoid that, ``manifest_maker`` should be able to overwrite this + method and avoid recursive attempts to build/analyze the MANIFEST. + """ + return build_py.data_files + + def _add_data_files(self, data_files): + """ + Add data files as found in build_py.data_files. + """ + self.filelist.extend( + os.path.join(src_dir, name) + for _, src_dir, _, filenames in data_files + for name in filenames + ) + + def _add_defaults_data_files(self): + try: + super()._add_defaults_data_files() + except TypeError: + log.warn("data_files contains unexpected objects") + + def check_readme(self): + for f in self.READMES: + if os.path.exists(f): + return + else: + self.warn( + "standard file not found: should have one of " + ', '.join(self.READMES) + ) + + def make_release_tree(self, base_dir, files): + orig.sdist.make_release_tree(self, base_dir, files) + + # Save any egg_info command line options used to create this sdist + dest = os.path.join(base_dir, 'setup.cfg') + if hasattr(os, 'link') and os.path.exists(dest): + # unlink and re-copy, since it might be hard-linked, and + # we don't want to change the source version + os.unlink(dest) + self.copy_file('setup.cfg', dest) + + self.get_finalized_command('egg_info').save_version_info(dest) + + def _manifest_is_not_generated(self): + # check for special comment used in 2.7.1 and higher + if not os.path.isfile(self.manifest): + return False + + with open(self.manifest, 'rb') as fp: + first_line = fp.readline() + return first_line != b'# file GENERATED by distutils, do NOT edit\n' + + def read_manifest(self): + """Read the manifest file (named by 'self.manifest') and use it to + fill in 'self.filelist', the list of files to include in the source + distribution. + """ + log.info("reading manifest file '%s'", self.manifest) + manifest = open(self.manifest, 'rb') + for line in manifest: + # The manifest must contain UTF-8. See #303. + try: + line = line.decode('UTF-8') + except UnicodeDecodeError: + log.warn("%r not UTF-8 decodable -- skipping" % line) + continue + # ignore comments and blank lines + line = line.strip() + if line.startswith('#') or not line: + continue + self.filelist.append(line) + manifest.close() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/setopt.py b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/setopt.py new file mode 100644 index 0000000000000000000000000000000000000000..b78d845e60aafe24eb032d2df8fc5745dd6dd50b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/setopt.py @@ -0,0 +1,140 @@ +from distutils.util import convert_path +from distutils import log +from distutils.errors import DistutilsOptionError +import distutils +import os +import configparser + +from .. import Command +from ..unicode_utils import _cfg_read_utf8_with_fallback + +__all__ = ['config_file', 'edit_config', 'option_base', 'setopt'] + + +def config_file(kind="local"): + """Get the filename of the distutils, local, global, or per-user config + + `kind` must be one of "local", "global", or "user" + """ + if kind == 'local': + return 'setup.cfg' + if kind == 'global': + return os.path.join(os.path.dirname(distutils.__file__), 'distutils.cfg') + if kind == 'user': + dot = os.name == 'posix' and '.' or '' + return os.path.expanduser(convert_path("~/%spydistutils.cfg" % dot)) + raise ValueError("config_file() type must be 'local', 'global', or 'user'", kind) + + +def edit_config(filename, settings, dry_run=False): + """Edit a configuration file to include `settings` + + `settings` is a dictionary of dictionaries or ``None`` values, keyed by + command/section name. A ``None`` value means to delete the entire section, + while a dictionary lists settings to be changed or deleted in that section. + A setting of ``None`` means to delete that setting. + """ + log.debug("Reading configuration from %s", filename) + opts = configparser.RawConfigParser() + opts.optionxform = lambda x: x + _cfg_read_utf8_with_fallback(opts, filename) + + for section, options in settings.items(): + if options is None: + log.info("Deleting section [%s] from %s", section, filename) + opts.remove_section(section) + else: + if not opts.has_section(section): + log.debug("Adding new section [%s] to %s", section, filename) + opts.add_section(section) + for option, value in options.items(): + if value is None: + log.debug("Deleting %s.%s from %s", section, option, filename) + opts.remove_option(section, option) + if not opts.options(section): + log.info( + "Deleting empty [%s] section from %s", section, filename + ) + opts.remove_section(section) + else: + log.debug( + "Setting %s.%s to %r in %s", section, option, value, filename + ) + opts.set(section, option, value) + + log.info("Writing %s", filename) + if not dry_run: + with open(filename, 'w', encoding="utf-8") as f: + opts.write(f) + + +class option_base(Command): + """Abstract base class for commands that mess with config files""" + + user_options = [ + ('global-config', 'g', "save options to the site-wide distutils.cfg file"), + ('user-config', 'u', "save options to the current user's pydistutils.cfg file"), + ('filename=', 'f', "configuration file to use (default=setup.cfg)"), + ] + + boolean_options = [ + 'global-config', + 'user-config', + ] + + def initialize_options(self): + self.global_config = None + self.user_config = None + self.filename = None + + def finalize_options(self): + filenames = [] + if self.global_config: + filenames.append(config_file('global')) + if self.user_config: + filenames.append(config_file('user')) + if self.filename is not None: + filenames.append(self.filename) + if not filenames: + filenames.append(config_file('local')) + if len(filenames) > 1: + raise DistutilsOptionError( + "Must specify only one configuration file option", filenames + ) + (self.filename,) = filenames + + +class setopt(option_base): + """Save command-line options to a file""" + + description = "set an option in setup.cfg or another config file" + + user_options = [ + ('command=', 'c', 'command to set an option for'), + ('option=', 'o', 'option to set'), + ('set-value=', 's', 'value of the option'), + ('remove', 'r', 'remove (unset) the value'), + ] + option_base.user_options + + boolean_options = option_base.boolean_options + ['remove'] + + def initialize_options(self): + option_base.initialize_options(self) + self.command = None + self.option = None + self.set_value = None + self.remove = None + + def finalize_options(self): + option_base.finalize_options(self) + if self.command is None or self.option is None: + raise DistutilsOptionError("Must specify --command *and* --option") + if self.set_value is None and not self.remove: + raise DistutilsOptionError("Must specify --set-value or --remove") + + def run(self): + edit_config( + self.filename, + {self.command: {self.option.replace('-', '_'): self.set_value}}, + self.dry_run, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/test.py b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/test.py new file mode 100644 index 0000000000000000000000000000000000000000..af1349e1c6e3d943d965338c1dcaedc2d153f8b8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/setuptools/command/test.py @@ -0,0 +1,250 @@ +import os +import operator +import sys +import contextlib +import itertools +import unittest +from distutils.errors import DistutilsError, DistutilsOptionError +from distutils import log +from unittest import TestLoader + +from pkg_resources import ( + resource_listdir, + resource_exists, + normalize_path, + working_set, + evaluate_marker, + add_activation_listener, + require, +) +from .._importlib import metadata +from setuptools import Command +from setuptools.extern.more_itertools import unique_everseen +from setuptools.extern.jaraco.functools import pass_none + + +class ScanningLoader(TestLoader): + def __init__(self): + TestLoader.__init__(self) + self._visited = set() + + def loadTestsFromModule(self, module, pattern=None): + """Return a suite of all tests cases contained in the given module + + If the module is a package, load tests from all the modules in it. + If the module has an ``additional_tests`` function, call it and add + the return value to the tests. + """ + if module in self._visited: + return None + self._visited.add(module) + + tests = [] + tests.append(TestLoader.loadTestsFromModule(self, module)) + + if hasattr(module, "additional_tests"): + tests.append(module.additional_tests()) + + if hasattr(module, '__path__'): + for file in resource_listdir(module.__name__, ''): + if file.endswith('.py') and file != '__init__.py': + submodule = module.__name__ + '.' + file[:-3] + else: + if resource_exists(module.__name__, file + '/__init__.py'): + submodule = module.__name__ + '.' + file + else: + continue + tests.append(self.loadTestsFromName(submodule)) + + if len(tests) != 1: + return self.suiteClass(tests) + else: + return tests[0] # don't create a nested suite for only one return + + +# adapted from jaraco.classes.properties:NonDataProperty +class NonDataProperty: + def __init__(self, fget): + self.fget = fget + + def __get__(self, obj, objtype=None): + if obj is None: + return self + return self.fget(obj) + + +class test(Command): + """Command to run unit tests after in-place build""" + + description = "run unit tests after in-place build (deprecated)" + + user_options = [ + ('test-module=', 'm', "Run 'test_suite' in specified module"), + ( + 'test-suite=', + 's', + "Run single test, case or suite (e.g. 'module.test_suite')", + ), + ('test-runner=', 'r', "Test runner to use"), + ] + + def initialize_options(self): + self.test_suite = None + self.test_module = None + self.test_loader = None + self.test_runner = None + + def finalize_options(self): + if self.test_suite and self.test_module: + msg = "You may specify a module or a suite, but not both" + raise DistutilsOptionError(msg) + + if self.test_suite is None: + if self.test_module is None: + self.test_suite = self.distribution.test_suite + else: + self.test_suite = self.test_module + ".test_suite" + + if self.test_loader is None: + self.test_loader = getattr(self.distribution, 'test_loader', None) + if self.test_loader is None: + self.test_loader = "setuptools.command.test:ScanningLoader" + if self.test_runner is None: + self.test_runner = getattr(self.distribution, 'test_runner', None) + + @NonDataProperty + def test_args(self): + return list(self._test_args()) + + def _test_args(self): + if not self.test_suite: + yield 'discover' + if self.verbose: + yield '--verbose' + if self.test_suite: + yield self.test_suite + + def with_project_on_sys_path(self, func): + """ + Backward compatibility for project_on_sys_path context. + """ + with self.project_on_sys_path(): + func() + + @contextlib.contextmanager + def project_on_sys_path(self, include_dists=()): + self.run_command('egg_info') + + # Build extensions in-place + self.reinitialize_command('build_ext', inplace=True) + self.run_command('build_ext') + + ei_cmd = self.get_finalized_command("egg_info") + + old_path = sys.path[:] + old_modules = sys.modules.copy() + + try: + project_path = normalize_path(ei_cmd.egg_base) + sys.path.insert(0, project_path) + working_set.__init__() + add_activation_listener(lambda dist: dist.activate()) + require('%s==%s' % (ei_cmd.egg_name, ei_cmd.egg_version)) + with self.paths_on_pythonpath([project_path]): + yield + finally: + sys.path[:] = old_path + sys.modules.clear() + sys.modules.update(old_modules) + working_set.__init__() + + @staticmethod + @contextlib.contextmanager + def paths_on_pythonpath(paths): + """ + Add the indicated paths to the head of the PYTHONPATH environment + variable so that subprocesses will also see the packages at + these paths. + + Do this in a context that restores the value on exit. + """ + nothing = object() + orig_pythonpath = os.environ.get('PYTHONPATH', nothing) + current_pythonpath = os.environ.get('PYTHONPATH', '') + try: + prefix = os.pathsep.join(unique_everseen(paths)) + to_join = filter(None, [prefix, current_pythonpath]) + new_path = os.pathsep.join(to_join) + if new_path: + os.environ['PYTHONPATH'] = new_path + yield + finally: + if orig_pythonpath is nothing: + os.environ.pop('PYTHONPATH', None) + else: + os.environ['PYTHONPATH'] = orig_pythonpath + + @staticmethod + def install_dists(dist): + """ + Install the requirements indicated by self.distribution and + return an iterable of the dists that were built. + """ + ir_d = dist.fetch_build_eggs(dist.install_requires) + tr_d = dist.fetch_build_eggs(dist.tests_require or []) + er_d = dist.fetch_build_eggs( + v + for k, v in dist.extras_require.items() + if k.startswith(':') and evaluate_marker(k[1:]) + ) + return itertools.chain(ir_d, tr_d, er_d) + + def run(self): + self.announce( + "WARNING: Testing via this command is deprecated and will be " + "removed in a future version. Users looking for a generic test " + "entry point independent of test runner are encouraged to use " + "tox.", + log.WARN, + ) + + installed_dists = self.install_dists(self.distribution) + + cmd = ' '.join(self._argv) + if self.dry_run: + self.announce('skipping "%s" (dry run)' % cmd) + return + + self.announce('running "%s"' % cmd) + + paths = map(operator.attrgetter('location'), installed_dists) + with self.paths_on_pythonpath(paths): + with self.project_on_sys_path(): + self.run_tests() + + def run_tests(self): + test = unittest.main( + None, + None, + self._argv, + testLoader=self._resolve_as_ep(self.test_loader), + testRunner=self._resolve_as_ep(self.test_runner), + exit=False, + ) + if not test.result.wasSuccessful(): + msg = 'Test failed: %s' % test.result + self.announce(msg, log.ERROR) + raise DistutilsError(msg) + + @property + def _argv(self): + return ['unittest'] + self.test_args + + @staticmethod + @pass_none + def _resolve_as_ep(val): + """ + Load the indicated attribute value, called, as a as if it were + specified as an entry point. + """ + return metadata.EntryPoint(value=val, name=None, group=None).load()() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/__autotune_main__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/__autotune_main__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3d0e7ade9feed3d4daddf7cb885a68211416b93 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/__autotune_main__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7be0dc94fd8d2c7722606b5af4e7d5cd77fbddb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/aoti_eager.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/aoti_eager.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..080f9242c2daba4489c3d231835e9d0c1a85ad7a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/aoti_eager.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25656a37d92b5502ff754eb9b670fb374e83e0a2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/await_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/await_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03bc1c9d37704f96c856c6912ea2ed21cb923bef Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/await_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/bounds.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/bounds.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f91a163013fa679b8e05d0c53603b12fdb2a69b2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/bounds.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/cache.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d3c159c8cc67aa72fa2528cc167577bb4e7610c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/cache.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12bc43848ff46e7f3179773be8aa72bad8d1ce3b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/comm_lowering.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/comm_lowering.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ecad8e8d222b4b650341276b061932729488e30 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/comm_lowering.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/comms.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/comms.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19947fb2c9378275a58f44617c2b8a8ee2fe2b7a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/comms.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/comms_debug.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/comms_debug.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea5d4292c0068137142e5b1c1c929d91dcafbf08 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/comms_debug.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/compile_fx_subproc.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/compile_fx_subproc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a808c26df89fcfb27c1a4fe567d4252226e1c5a4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/compile_fx_subproc.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/config_comms.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/config_comms.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30cd4baa36a84836fb9f8f44d280ef79360ecbbb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/config_comms.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/constant_folding.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/constant_folding.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cc44fae7d2c61743094d07290c4bb2b9ee84b49 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/constant_folding.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/cpp_builder.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/cpp_builder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52660852a63b41e972cc67d6c4e629ec66d3ff0a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/cpp_builder.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/cpu_vec_isa.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/cpu_vec_isa.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29569e659810ade523ec9b26ea5b2c6ca1c52db4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/cpu_vec_isa.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/custom_graph_pass.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/custom_graph_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..017846a0d64097721ed04d5ec3b5c031577cd72a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/custom_graph_pass.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/debug.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/debug.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..666f2d1419fdbcf26b9dc42b6ae5c4837837adf0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/debug.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/decomposition.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/decomposition.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5d1f63dde84baafd6030b00e36e2001a3ecf9da Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/decomposition.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/dependencies.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/dependencies.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51a9d84a5459d3f90e75b66824186e9a828dd914 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/dependencies.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/distributed_autotune.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/distributed_autotune.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17baac66eb6579f625dee92c298e1ccb32d79484 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/distributed_autotune.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/exc.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/exc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d371888019d496685e17910ffa0217806697a0c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/exc.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/extern_node_serializer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/extern_node_serializer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbedd49b036f72a50ccbf7378e2391cda0a89997 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/extern_node_serializer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/freezing.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/freezing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5220a471c5e7f2b953b0d337446be183ddec3e49 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/freezing.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/fuzzer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/fuzzer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6574712f14ddccf28b398592136faa382969495 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/fuzzer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/hooks.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/hooks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35b07c1c477b47ca2feab727492b76bc0bda55ac Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/hooks.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/invert_expr_analysis.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/invert_expr_analysis.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e4e10bb038c05784e770261832dbf8322c4066e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/invert_expr_analysis.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/jagged_lowerings.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/jagged_lowerings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..782f01bbddb9123065e1938ab38418e06069d7c1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/jagged_lowerings.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/kernel_inputs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/kernel_inputs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8e09e2c998a2cd8f7c7f97d9aca775a7046eb70 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/kernel_inputs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/kernel_template_choice.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/kernel_template_choice.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ab621d9132b450b937fdc699c3d05c06c5280c0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/kernel_template_choice.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/loop_body.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/loop_body.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..870aebd48a469f6ef5bedfa53c79a5a53e90cfba Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/loop_body.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/memory.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/memory.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2911afa64b8cbf34f16a31549244af1ec1b4c594 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/memory.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/metrics.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/metrics.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..892bd5d1dee57d5b1ef38dc054594472e8e4c7ea Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/metrics.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/mkldnn_ir.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/mkldnn_ir.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2538cfc599f72d81f2dcafb6e9826c667cf4ae33 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/mkldnn_ir.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/mkldnn_lowerings.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/mkldnn_lowerings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c788f0cccc2c8e2e7e010f8d38a548a4d5186aa5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/mkldnn_lowerings.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/mock_cache.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/mock_cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5c0d976efece54fe38055b3ad6f5bbb84f21a32 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/mock_cache.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bf3091c8b353ae563c0f9c0367f6809d3ace3b0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1f3fabef56cc263fdfaa4db6af92a1cd7864119 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/output_code.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/output_code.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a28d935a6ee3700c4fd40387206fd27a1c090a18 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/output_code.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..743ff530f324e8bbd1e72a41587baaac6b3853e4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/rocm_multiarch_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/rocm_multiarch_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..709e2f41a66b2a874016a78c8efbeb11f144e0d4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/rocm_multiarch_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/shape_propagation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/shape_propagation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afbdf34cd93f1d029f6697521643c9cd4d845e7b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/shape_propagation.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/test_case.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/test_case.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..377e35cf0ed6fbdaf366f1d40de495c60d7c6aeb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/test_case.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/test_operators.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/test_operators.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..370d869ca4ed86f7eda50a6d25690855f8347331 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/test_operators.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/tiling_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/tiling_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfc0da0b53b1b959e020c720c8b4c21be86c1663 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/tiling_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/virtualized.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/virtualized.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78f2e1fee5abc2fa315c2106b6fe73422cdb3d2a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/virtualized.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef9b6785205316a792fe7018b5cb7d3f7d632e07 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/analysis/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/analysis/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aef1e4e2764aaab421ac60465da16456eef4bebc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/analysis/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/analysis/__pycache__/device_info.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/analysis/__pycache__/device_info.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88221f156f6fb953f94be70f04637dc53bdb7405 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/analysis/__pycache__/device_info.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/analysis/__pycache__/profile_analysis.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/analysis/__pycache__/profile_analysis.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82d606eb98ee4e4870e15efa79291f3ecaf0935f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/analysis/__pycache__/profile_analysis.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50739f405f81518e0d9a88b91783f964204a535d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9c1cea40aef76491f792f597f4392934bb8298f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20ddb45fa151b9df433df80a3bc30ebfb45dfdc1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/learned_heuristic_controller.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/learned_heuristic_controller.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6604e5d46c85357ab6b3dc1240938dd186b36ccd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/learned_heuristic_controller.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/learnedheuristic_interface.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/learnedheuristic_interface.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08db6b3a231a923b60ffd6465fb6bf14a6622dbb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__pycache__/learnedheuristic_interface.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py new file mode 100644 index 0000000000000000000000000000000000000000..7ebf134c83d7c597dae05f572f6f6f7f702c9f6e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py @@ -0,0 +1,296 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MMRankingA100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: list[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 166912 + and str(metadata.device_capa) == "(8, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + + def get_name(self) -> str: + return 'mm' + + def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]: + if context.get_value('arith_intensity') <= 52.6245059967041: + if context.get_value('n') <= 34.0: + if context.get_value('n') <= 18.0: + if context.get_value('k*n') <= 312.0: + return [(0.093, 12), (0.081, 16), (0.081, 148), (0.070, 10), (0.070, 17), (0.070, 149), (0.070, 151), (0.070, 150), (0.070, 14), (0.058, 11), (0.058, 15), (0.058, 13), (0.058, 122), (0.047, 121), (0.035, 123), (0.012, 92)] + else: + if context.get_value('k') <= 40.0: + return [(0.083, 42), (0.083, 46), (0.083, 44), (0.083, 40), (0.083, 128), (0.067, 45), (0.067, 43), (0.067, 41), (0.067, 169), (0.067, 171), (0.067, 168), (0.067, 129), (0.067, 170), (0.033, 103), (0.017, 121)] + else: + return [(0.112, 137), (0.104, 136), (0.101, 0), (0.081, 1), (0.073, 135), (0.069, 67), (0.066, 187), (0.058, 41), (0.050, 71), (0.046, 68), (0.046, 70), (0.031, 44), (0.027, 43), (0.027, 170), (0.019, 189), (0.019, 188), (0.015, 169), (0.015, 171), (0.012, 115), (0.012, 168), (0.012, 69), (0.004, 103)] + else: + if context.get_value('mat1_stride_0') <= 20.0: + return [(0.069, 0), (0.059, 157), (0.059, 22), (0.059, 153), (0.059, 155), (0.059, 25), (0.059, 23), (0.059, 19), (0.044, 21), (0.044, 18), (0.044, 152), (0.044, 158), (0.044, 154), (0.044, 156), (0.044, 20), (0.044, 124), (0.044, 24), (0.030, 125), (0.029, 126), (0.015, 97), (0.015, 95), (0.015, 96), (0.010, 2), (0.010, 75)] + else: + if context.get_value('k') <= 68.0: + return [(0.087, 72), (0.087, 74), (0.087, 73), (0.086, 76), (0.077, 75), (0.067, 192), (0.058, 190), (0.048, 47), (0.048, 193), (0.048, 49), (0.048, 51), (0.048, 191), (0.038, 53), (0.019, 133), (0.019, 50), (0.019, 175), (0.019, 172), (0.019, 48), (0.019, 174), (0.010, 173), (0.010, 177), (0.010, 52), (0.010, 54), (0.010, 178), (0.010, 176)] + else: + return [(0.154, 52), (0.154, 72), (0.102, 75), (0.087, 49), (0.087, 73), (0.086, 51), (0.057, 176), (0.045, 2), (0.038, 191), (0.038, 178), (0.038, 190), (0.029, 173), (0.029, 76), (0.026, 138), (0.013, 139), (0.013, 140), (0.003, 0)] + else: + if context.get_value('k') <= 35.0: + if context.get_value('k') <= 18.0: + if context.get_value('m*n') <= 19505152.0: + return [(0.151, 159), (0.140, 160), (0.129, 164), (0.055, 127), (0.051, 29), (0.044, 161), (0.044, 147), (0.040, 146), (0.040, 31), (0.037, 145), (0.026, 28), (0.022, 90), (0.022, 93), (0.022, 94), (0.022, 100), (0.022, 125), (0.022, 158), (0.022, 157), (0.011, 87), (0.011, 88), (0.011, 89), (0.011, 91), (0.011, 95), (0.011, 96), (0.011, 98), (0.011, 99)] + else: + return [(0.069, 7), (0.069, 5), (0.067, 147), (0.066, 8), (0.061, 145), (0.058, 146), (0.052, 124), (0.049, 29), (0.049, 159), (0.046, 31), (0.043, 157), (0.041, 9), (0.041, 4), (0.040, 6), (0.035, 164), (0.035, 160), (0.026, 158), (0.017, 125), (0.017, 28), (0.017, 32), (0.017, 162), (0.017, 27), (0.017, 30), (0.017, 161), (0.009, 33), (0.009, 26), (0.009, 163), (0.006, 0)] + else: + if context.get_value('n') <= 68.0: + return [(0.101, 182), (0.101, 59), (0.088, 57), (0.076, 184), (0.076, 61), (0.076, 179), (0.076, 62), (0.076, 58), (0.063, 180), (0.063, 60), (0.051, 56), (0.050, 181), (0.025, 130), (0.025, 177), (0.025, 183), (0.013, 178), (0.013, 55)] + else: + return [(0.089, 180), (0.079, 60), (0.066, 35), (0.066, 181), (0.066, 38), (0.066, 58), (0.066, 179), (0.066, 57), (0.062, 184), (0.053, 37), (0.044, 166), (0.040, 55), (0.040, 39), (0.040, 36), (0.040, 165), (0.040, 167), (0.027, 177), (0.027, 34), (0.022, 159)] + else: + if context.get_value('m*n') <= 309760.0: + return [(0.298, 0), (0.097, 140), (0.080, 83), (0.072, 86), (0.044, 84), (0.036, 178), (0.036, 117), (0.036, 82), (0.032, 120), (0.032, 85), (0.028, 119), (0.024, 130), (0.024, 109), (0.020, 108), (0.020, 118), (0.012, 104), (0.012, 116), (0.012, 141), (0.012, 144), (0.008, 105), (0.008, 106), (0.008, 111), (0.008, 114), (0.008, 107), (0.008, 132), (0.004, 101), (0.004, 102), (0.004, 110), (0.004, 112), (0.004, 113), (0.004, 131)] + else: + if context.get_value('n') <= 72.0: + return [(0.227, 77), (0.118, 78), (0.102, 194), (0.086, 80), (0.059, 57), (0.054, 81), (0.049, 196), (0.048, 197), (0.048, 59), (0.043, 79), (0.032, 195), (0.027, 180), (0.022, 3), (0.021, 141), (0.016, 60), (0.016, 142), (0.011, 183), (0.011, 0), (0.011, 144)] + else: + return [(0.140, 186), (0.132, 185), (0.109, 63), (0.085, 65), (0.078, 37), (0.077, 35), (0.062, 197), (0.047, 194), (0.046, 165), (0.046, 57), (0.039, 78), (0.039, 79), (0.039, 66), (0.039, 64), (0.016, 195), (0.008, 159)] + else: + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*n') <= 815360.0: + if context.get_value('k') <= 1184.0: + return [(0.218, 140), (0.205, 0), (0.154, 144), (0.115, 141), (0.051, 185), (0.051, 104), (0.039, 78), (0.038, 116), (0.026, 165), (0.026, 130), (0.026, 178), (0.013, 57), (0.013, 195), (0.013, 167), (0.013, 186)] + else: + return [(0.901, 0), (0.030, 144), (0.030, 134), (0.016, 3), (0.006, 78), (0.006, 77), (0.002, 57), (0.002, 194), (0.002, 59), (0.002, 60), (0.002, 143)] + else: + if context.get_value('arith_intensity') <= 187.23922729492188: + if context.get_value('mat1_stride_0') <= 198.0: + return [(0.273, 63), (0.158, 37), (0.152, 35), (0.127, 57), (0.097, 165), (0.053, 185), (0.031, 0), (0.028, 64), (0.014, 60), (0.014, 78), (0.009, 55), (0.008, 134), (0.005, 34), (0.005, 167), (0.005, 179), (0.005, 65), (0.005, 66), (0.005, 186), (0.005, 194), (0.002, 166)] + else: + return [(0.296, 63), (0.235, 0), (0.132, 64), (0.074, 37), (0.069, 78), (0.051, 185), (0.051, 35), (0.030, 57), (0.020, 77), (0.016, 194), (0.008, 66), (0.007, 65), (0.003, 3), (0.003, 165), (0.003, 141), (0.001, 134), (0.001, 166)] + else: + return [(0.405, 0), (0.246, 37), (0.177, 63), (0.145, 35), (0.005, 185), (0.005, 65), (0.005, 64), (0.004, 57), (0.003, 66), (0.002, 165), (0.001, 78), (0.001, 55)] + else: + return [(0.357, 0), (0.112, 165), (0.101, 57), (0.094, 179), (0.086, 64), (0.074, 167), (0.067, 60), (0.064, 159), (0.033, 35), (0.007, 195), (0.002, 180), (0.001, 34), (0.001, 166), (0.001, 78)] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py new file mode 100644 index 0000000000000000000000000000000000000000..6201acc4213aa153cc73971946d3a241d2063793 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py @@ -0,0 +1,321 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MMRankingH100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: list[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 232448 + and str(metadata.device_capa) == "(9, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + + def get_name(self) -> str: + return 'mm' + + def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]: + if context.get_value('arith_intensity') <= 29.89772129058838: + if context.get_value('n') <= 34.0: + if context.get_value('n') <= 18.0: + if context.get_value('k*n') <= 432.0: + if context.get_value('arith_intensity') <= 7.8700292110443115: + return [(0.098, 128), (0.098, 129), (0.098, 127), (0.073, 14), (0.073, 16), (0.073, 12), (0.073, 154), (0.073, 156), (0.073, 157), (0.073, 155), (0.049, 10), (0.049, 94), (0.049, 95), (0.048, 96)] + else: + return [(0.091, 154), (0.073, 10), (0.073, 15), (0.073, 13), (0.073, 11), (0.073, 17), (0.073, 16), (0.073, 14), (0.073, 12), (0.055, 127), (0.054, 157), (0.054, 156), (0.054, 155), (0.036, 129), (0.036, 128), (0.018, 41), (0.018, 43)] + else: + if context.get_value('k') <= 40.0: + return [(0.070, 39), (0.069, 45), (0.069, 41), (0.069, 43), (0.069, 111), (0.069, 112), (0.056, 38), (0.056, 40), (0.056, 42), (0.056, 44), (0.056, 174), (0.056, 173), (0.056, 175), (0.056, 134), (0.056, 172), (0.056, 135), (0.014, 154), (0.014, 127)] + else: + return [(0.147, 144), (0.119, 143), (0.087, 142), (0.083, 0), (0.073, 191), (0.059, 69), (0.050, 67), (0.046, 70), (0.041, 1), (0.036, 174), (0.032, 43), (0.032, 123), (0.028, 40), (0.027, 42), (0.027, 173), (0.023, 175), (0.018, 66), (0.014, 192), (0.014, 193), (0.014, 139), (0.014, 68), (0.014, 127)] + else: + if context.get_value('mat1_stride_0') <= 40.0: + if context.get_value('mat1_stride_0') <= 20.0: + return [(0.109, 23), (0.109, 21), (0.109, 20), (0.088, 0), (0.087, 131), (0.066, 18), (0.065, 130), (0.065, 132), (0.065, 159), (0.065, 160), (0.065, 161), (0.065, 158), (0.022, 22), (0.022, 19)] + else: + return [(0.065, 46), (0.064, 52), (0.064, 50), (0.064, 48), (0.064, 51), (0.064, 49), (0.064, 47), (0.064, 53), (0.064, 181), (0.064, 177), (0.064, 179), (0.064, 176), (0.038, 130), (0.038, 136), (0.026, 182), (0.026, 178), (0.026, 180), (0.026, 137), (0.025, 158), (0.013, 114), (0.013, 113)] + else: + if context.get_value('mat1_stride_0') <= 68.0: + return [(0.138, 140), (0.125, 195), (0.100, 71), (0.100, 74), (0.100, 196), (0.100, 194), (0.100, 197), (0.075, 75), (0.062, 72), (0.062, 73), (0.012, 180), (0.012, 51), (0.012, 182)] + else: + return [(0.124, 180), (0.124, 182), (0.114, 75), (0.103, 74), (0.093, 51), (0.093, 71), (0.072, 72), (0.062, 194), (0.052, 145), (0.052, 195), (0.021, 48), (0.021, 50), (0.021, 47), (0.020, 124), (0.010, 147), (0.010, 146), (0.010, 46)] + else: + if context.get_value('k') <= 18.0: + if context.get_value('m*k') <= 528.0: + return [(0.097, 88), (0.087, 92), (0.077, 90), (0.058, 105), (0.058, 103), (0.058, 104), (0.058, 99), (0.058, 100), (0.058, 106), (0.058, 93), (0.057, 91), (0.057, 97), (0.057, 98), (0.057, 101), (0.048, 102), (0.029, 87), (0.029, 89)] + else: + if context.get_value('n') <= 80.0: + return [(0.057, 161), (0.057, 130), (0.057, 24), (0.056, 164), (0.056, 163), (0.056, 166), (0.056, 168), (0.056, 30), (0.056, 28), (0.056, 26), (0.056, 25), (0.056, 27), (0.056, 29), (0.056, 31), (0.042, 131), (0.028, 99), (0.028, 101), (0.028, 100), (0.028, 167), (0.028, 165), (0.028, 133)] + else: + return [(0.110, 164), (0.108, 163), (0.106, 168), (0.069, 161), (0.066, 151), (0.060, 152), (0.055, 165), (0.050, 27), (0.050, 29), (0.048, 131), (0.043, 153), (0.037, 133), (0.037, 130), (0.028, 8), (0.028, 5), (0.027, 7), (0.026, 26), (0.016, 162), (0.012, 9), (0.007, 4), (0.005, 100), (0.005, 6), (0.005, 24)] + else: + if context.get_value('k') <= 36.0: + if context.get_value('n') <= 68.0: + return [(0.097, 184), (0.097, 56), (0.086, 186), (0.086, 183), (0.086, 188), (0.086, 58), (0.086, 60), (0.065, 54), (0.043, 187), (0.043, 185), (0.043, 57), (0.043, 61), (0.032, 55), (0.032, 130), (0.032, 59), (0.011, 181), (0.011, 163), (0.011, 136), (0.011, 138)] + else: + return [(0.117, 184), (0.117, 170), (0.117, 169), (0.107, 183), (0.106, 188), (0.075, 181), (0.064, 130), (0.064, 56), (0.053, 171), (0.032, 57), (0.032, 59), (0.032, 185), (0.011, 163), (0.011, 32), (0.011, 37), (0.011, 34), (0.011, 33), (0.011, 35), (0.011, 36), (0.011, 54)] + else: + if context.get_value('mat2_stride_0') <= 384.0: + return [(0.244, 0), (0.061, 76), (0.061, 79), (0.030, 3), (0.030, 183), (0.030, 189), (0.030, 187), (0.030, 64), (0.030, 190), (0.030, 62), (0.030, 198), (0.030, 201), (0.030, 77), (0.030, 200), (0.030, 80), (0.030, 199), (0.030, 78), (0.030, 184), (0.020, 86), (0.020, 84), (0.020, 120), (0.020, 81), (0.020, 121), (0.020, 85), (0.020, 122), (0.010, 83), (0.010, 118), (0.010, 119), (0.010, 82)] + else: + return [(0.274, 83), (0.171, 86), (0.152, 0), (0.071, 85), (0.061, 125), (0.050, 84), (0.020, 109), (0.020, 117), (0.020, 81), (0.020, 118), (0.020, 121), (0.020, 108), (0.020, 115), (0.020, 116), (0.010, 110), (0.010, 120), (0.010, 103), (0.010, 107), (0.010, 119), (0.010, 122)] + else: + if context.get_value('arith_intensity') <= 56.995582580566406: + if context.get_value('n') <= 68.0: + if context.get_value('k*n') <= 4448.0: + if context.get_value('m*n') <= 29626368.0: + return [(0.107, 198), (0.107, 200), (0.107, 201), (0.107, 199), (0.106, 76), (0.106, 79), (0.064, 197), (0.063, 56), (0.043, 184), (0.043, 187), (0.042, 80), (0.042, 77), (0.042, 183), (0.021, 78)] + else: + return [(0.073, 201), (0.073, 198), (0.073, 200), (0.073, 199), (0.073, 197), (0.073, 56), (0.073, 58), (0.073, 79), (0.073, 76), (0.072, 59), (0.072, 78), (0.072, 77), (0.072, 80), (0.018, 184), (0.018, 55), (0.018, 54)] + else: + if context.get_value('k') <= 348.0: + return [(0.206, 76), (0.183, 77), (0.169, 198), (0.160, 199), (0.053, 59), (0.046, 56), (0.038, 3), (0.030, 148), (0.030, 58), (0.030, 187), (0.023, 184), (0.015, 0), (0.008, 55), (0.008, 54)] + else: + return [(0.146, 198), (0.145, 199), (0.145, 148), (0.126, 0), (0.084, 76), (0.084, 77), (0.042, 80), (0.042, 79), (0.021, 149), (0.021, 150), (0.021, 3), (0.014, 46), (0.014, 74), (0.014, 75), (0.014, 124), (0.014, 194), (0.014, 195), (0.007, 145), (0.007, 146), (0.007, 2), (0.007, 72), (0.007, 147), (0.007, 71)] + else: + if context.get_value('m') <= 3264.0: + return [(0.247, 147), (0.115, 197), (0.066, 199), (0.066, 201), (0.066, 198), (0.049, 0), (0.049, 169), (0.049, 171), (0.033, 140), (0.033, 125), (0.033, 114), (0.016, 126), (0.016, 183), (0.016, 184), (0.016, 185), (0.016, 182), (0.016, 188), (0.016, 78), (0.016, 148), (0.016, 138), (0.016, 77), (0.016, 56), (0.016, 59)] + else: + if context.get_value('k') <= 62.5: + return [(0.226, 190), (0.226, 189), (0.122, 62), (0.122, 64), (0.055, 77), (0.055, 78), (0.037, 198), (0.036, 201), (0.036, 33), (0.024, 163), (0.018, 56), (0.018, 35), (0.018, 169), (0.006, 171)] + else: + return [(0.162, 35), (0.118, 33), (0.096, 189), (0.096, 190), (0.088, 169), (0.074, 62), (0.073, 56), (0.066, 171), (0.051, 198), (0.051, 201), (0.044, 59), (0.037, 64), (0.029, 63), (0.007, 0), (0.007, 77)] + else: + if context.get_value('m*n') <= 1097728.0: + return [(0.403, 0), (0.179, 141), (0.134, 150), (0.086, 147), (0.051, 148), (0.048, 3), (0.024, 189), (0.020, 199), (0.017, 64), (0.010, 65), (0.010, 77), (0.007, 114), (0.003, 138), (0.003, 59), (0.003, 182)] + else: + if context.get_value('m*n') <= 3244032.0: + return [(0.295, 189), (0.176, 64), (0.157, 65), (0.090, 0), (0.069, 62), (0.059, 63), (0.046, 77), (0.039, 169), (0.023, 199), (0.020, 35), (0.013, 33), (0.010, 171), (0.003, 141)] + else: + if context.get_value('n') <= 136.0: + return [(0.197, 189), (0.197, 63), (0.161, 77), (0.157, 62), (0.061, 33), (0.044, 65), (0.039, 35), (0.039, 64), (0.030, 169), (0.026, 0), (0.017, 199), (0.017, 148), (0.009, 56), (0.004, 3)] + else: + return [(0.460, 0), (0.145, 62), (0.138, 63), (0.081, 35), (0.047, 33), (0.043, 189), (0.023, 64), (0.018, 77), (0.013, 169), (0.009, 65), (0.009, 56), (0.005, 32), (0.005, 59), (0.002, 183), (0.002, 163)] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py new file mode 100644 index 0000000000000000000000000000000000000000..1ba7cbaf90275d1bb2cb50e8fd27fbd331173bbb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py @@ -0,0 +1,150 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MixedMMA100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: list[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 166912 + and str(metadata.device_capa) == "(8, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_fallback_mixed_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + + def get_name(self) -> str: + return 'mixed_mm' + + def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]: + if str(context.get_value('1LEQmLEQ16')) != 'True': + if context.get_value('m') <= 32.5: + if context.get_value('n') <= 6976.0: + if context.get_value('n') <= 3520.0: + if context.get_value('m*n') <= 37632.0: + return None + else: + return [(1.000, 13)] + else: + if context.get_value('m*k') <= 452352.0: + return [(0.590, 13), (0.256, 8), (0.103, 7), (0.051, 11)] + else: + return [(0.778, 8), (0.222, 13)] + else: + if context.get_value('k*n') <= 102776832.0: + if context.get_value('n') <= 14656.0: + return [(1.000, 11)] + else: + return [(0.889, 11), (0.111, 13)] + else: + return [(1.000, 11)] + else: + if context.get_value('m*n') <= 446464.0: + if context.get_value('m*n') <= 223424.0: + if context.get_value('mat1_stride_0') <= 3968.0: + return None + else: + return None + else: + if context.get_value('m*n') <= 346112.0: + return [(0.960, 16), (0.040, 7)] + else: + return [(0.750, 16), (0.136, 14), (0.114, 7)] + else: + if str(context.get_value('33LEQmLEQ64')) != 'True': + if context.get_value('n') <= 6976.0: + return [(1.000, 14)] + else: + return [(0.753, 2), (0.222, 1), (0.015, 7), (0.007, 16), (0.004, 12)] + else: + if context.get_value('n') <= 13888.0: + return [(0.710, 14), (0.275, 21), (0.014, 12)] + else: + return [(0.374, 19), (0.339, 20), (0.106, 21), (0.101, 16), (0.066, 17), (0.009, 14), (0.004, 18)] + else: + if context.get_value('n') <= 3520.0: + if context.get_value('arith_intensity') <= 3.994754433631897: + if str(context.get_value('mat2_dtype')) != 'torch.uint8': + if context.get_value('m*k') <= 18944.0: + return [(0.577, 5), (0.423, 6)] + else: + return [(0.988, 5), (0.012, 6)] + else: + if context.get_value('arith_intensity') <= 2.9899919033050537: + return None + else: + return None + else: + if context.get_value('arith_intensity') <= 7.956453561782837: + if context.get_value('k*n') <= 9244032.0: + return [(0.822, 5), (0.178, 6)] + else: + return [(0.977, 5), (0.023, 0)] + else: + if context.get_value('m*k') <= 978944.0: + return [(1.000, 5)] + else: + return [(0.971, 5), (0.029, 0)] + else: + if context.get_value('n') <= 13632.0: + if context.get_value('n') <= 6976.0: + return [(1.000, 6)] + else: + if context.get_value('k') <= 3968.0: + return [(0.617, 3), (0.111, 5), (0.099, 7), (0.086, 9), (0.062, 6), (0.025, 8)] + else: + return [(0.779, 8), (0.119, 5), (0.053, 7), (0.035, 6), (0.013, 3)] + else: + if context.get_value('k*n') <= 39518208.0: + return [(0.385, 4), (0.327, 3), (0.192, 6), (0.038, 7), (0.038, 10), (0.019, 5)] + else: + if context.get_value('n') <= 20800.0: + return [(0.821, 6), (0.121, 7), (0.029, 4), (0.014, 5), (0.007, 3), (0.007, 8)] + else: + return [(0.530, 7), (0.386, 6), (0.046, 8), (0.021, 3), (0.015, 4), (0.002, 5)] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py new file mode 100644 index 0000000000000000000000000000000000000000..8fe46cf75d8c63fab36eef728edf34788d6e3b22 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py @@ -0,0 +1,149 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/ +from typing import Optional + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MixedMMH100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: list[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 232448 + and str(metadata.device_capa) == "(9, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_fallback_mixed_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + + def get_name(self) -> str: + return 'mixed_mm' + + def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]: + if context.get_value('arith_intensity') <= 15.988086223602295: + if context.get_value('n') <= 25280.0: + if context.get_value('n') <= 1344.0: + if context.get_value('mat1_stride_0') <= 7808.0: + return [(0.581, 7), (0.419, 6)] + else: + if context.get_value('m*n') <= 7680.0: + return [(0.875, 0), (0.125, 6)] + else: + return [(0.833, 0), (0.167, 7)] + else: + if context.get_value('n') <= 8512.0: + if str(context.get_value('mat2_dtype')) != 'torch.int8': + return [(0.763, 6), (0.237, 7)] + else: + return [(0.725, 7), (0.275, 6)] + else: + if str(context.get_value('mat1_dtype')) != 'torch.bfloat16': + return [(0.736, 7), (0.197, 9), (0.048, 6), (0.014, 8), (0.005, 10)] + else: + return [(0.473, 7), (0.398, 6), (0.097, 9), (0.032, 10)] + else: + if context.get_value('n') <= 42254.0: + if context.get_value('n') <= 33856.0: + if context.get_value('k*n') <= 68157440.0: + return [(0.370, 4), (0.370, 5), (0.074, 7), (0.074, 8), (0.074, 11), (0.037, 6)] + else: + return [(0.916, 8), (0.036, 7), (0.036, 9), (0.012, 4)] + else: + return [(0.659, 5), (0.341, 6)] + else: + if context.get_value('k*n') <= 326052992.0: + if context.get_value('n') <= 55232.0: + return [(0.571, 6), (0.321, 7), (0.036, 4), (0.036, 8), (0.036, 9)] + else: + return [(0.506, 6), (0.325, 8), (0.104, 7), (0.039, 5), (0.026, 9)] + else: + if context.get_value('n') <= 57024.0: + return [(0.462, 9), (0.385, 7), (0.115, 6), (0.038, 8)] + else: + return [(0.598, 8), (0.223, 9), (0.107, 6), (0.071, 7)] + else: + if context.get_value('m*n') <= 543936.0: + if str(context.get_value('17LEQmLEQ32')) != 'True': + if context.get_value('m*n') <= 262272.0: + if context.get_value('n') <= 1592.5: + return [(0.860, 0), (0.140, 9)] + else: + return None + else: + if context.get_value('m*k') <= 1294336.0: + return [(0.833, 17), (0.150, 18), (0.017, 15)] + else: + return [(0.917, 17), (0.083, 8)] + else: + if context.get_value('n') <= 12416.0: + if context.get_value('m*n') <= 43008.0: + return None + else: + return [(0.853, 14), (0.147, 9)] + else: + return [(0.625, 12), (0.375, 14)] + else: + if context.get_value('m') <= 32.5: + if context.get_value('mat2_stride_1') <= 6656.0: + if context.get_value('n') <= 69184.0: + return [(0.611, 12), (0.361, 14), (0.028, 13)] + else: + return [(1.000, 12)] + else: + if context.get_value('mat2_stride_1') <= 20864.0: + return [(1.000, 12)] + else: + return [(0.958, 12), (0.042, 9)] + else: + if context.get_value('m*n') <= 1085440.0: + if context.get_value('n') <= 9152.0: + return [(1.000, 18)] + else: + return [(0.780, 18), (0.160, 16), (0.060, 20)] + else: + if context.get_value('m') <= 67.0: + return [(0.650, 16), (0.203, 19), (0.122, 18), (0.016, 20), (0.008, 1)] + else: + return [(0.561, 3), (0.185, 16), (0.096, 20), (0.083, 19), (0.076, 2)] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py new file mode 100644 index 0000000000000000000000000000000000000000..b61f8a9dd1e99056864a9dddc663b090f6971214 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py @@ -0,0 +1,109 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/pad_mm/ +from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice, CHOICE_COL +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicRegression, +) + + +class PadMMA100(LearnedHeuristicRegression): + + def __init__(self) -> None: + pass + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 166912 + and str(metadata.device_capa) == "(8, 0)" + ) + + def get_feedback(self, context: AHContext, choice: Choice) -> float: + context.context_dict[CHOICE_COL] = choice + return self.predict(context) + + def get_confidence_threshold(self) -> float: + return 1.7025303314066 + + def get_name(self) -> str: + return 'pad_mm' + + def predict(self, context: AHContext) -> float: + if str(context.get_value('choice')) != 'pad': + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*n') <= 4171264.0: + if context.get_value('m*k') <= 3999308.0: + return 1.8751469764071178 + else: + if str(context.get_value('n_multiple_32')) != 'True': + return 0.9117231355626345 + else: + return 1.1607689608873861 + else: + if str(context.get_value('n_multiple_2')) != 'True': + if str(context.get_value('using_tf32')) != 'True': + return 0.7430382200435992 + else: + return 0.8531269794448678 + else: + if str(context.get_value('k_multiple_2')) != 'True': + return 0.7577181972719917 + else: + return 0.8977349440424219 + else: + if context.get_value('m*n') <= 1299712.0: + return 1.1669723418995592 + else: + if context.get_value('mat2_stride_1') <= 45217.5: + if context.get_value('m*n') <= 55884158.0: + return 1.0262769936909601 + else: + return 1.0022677428470845 + else: + if context.get_value('m') <= 18478.0: + return 1.1127066261894312 + else: + return 1.0337740659894263 + else: + if str(context.get_value('mat1_dtype')) != 'torch.float32': + if str(context.get_value('n_multiple_2')) != 'False': + if str(context.get_value('k_multiple_2')) != 'True': + if context.get_value('mat1_stride_0') <= 561.0: + return 1.2900382135142956 + else: + return 1.5761737616057887 + else: + if context.get_value('num_dims_needs_padding') <= 1.5: + return 1.0472263310239422 + else: + return 1.1727673465762514 + else: + if context.get_value('k') <= 28238.5: + if context.get_value('k/(m*n)') <= 0.00026227018679492176: + return 1.6770542505397175 + else: + return 1.3974785435105923 + else: + if str(context.get_value('mat1_dtype')) != 'torch.bfloat16': + return 1.3952699800111992 + else: + return 1.5759286511628336 + else: + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*n') <= 14119424.0: + return 0.8875772670422478 + else: + if str(context.get_value('mat2_innermost_needs_padding')) != 'True': + return 1.1467728924377265 + else: + return 1.215842963532998 + else: + if context.get_value('arith_intensity') <= 396.8774871826172: + return 0.89940161869551 + else: + if context.get_value('mat2_stride_1') <= 45217.5: + return 0.9964328169353532 + else: + return 0.9493479238294826 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingA100.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingA100.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1c67394a7a6fb7dfd32c0b6b2f29e3ce81e0840 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingA100.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingH100.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingH100.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2f63965b8d4ed324e4fbfb4d1d3009a1138afb5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingH100.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMA100.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMA100.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6342cb029d2f49de51b62ba50802782fd940a13d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMA100.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMH100.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMH100.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..222d87d49a64a63140732bdd68be6eacfa7a91d6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMH100.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_PadMMA100.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_PadMMA100.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..456ca0bc9bc8049d9f1320bbb19591d441a08f44 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_PadMMA100.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5911e35399d93798ecc628bafe010cab76e604e3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..970d98b3a46c3681785d8c0f58172919671d4e47 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_cache.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..757e53e0b41303e776d19acac4915599145da827 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_cache.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78f0ba09054713f8ec7955fd6041c46dc3b02797 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7429ef51daa97596ac8abd6bac1ba9232471ca09 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31fade13992094bd2b0b854af0b85ccf6577b360 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/evt_extensions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/evt_extensions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3b65d895dd25666cae215493c8dda1bbc124a28 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/evt_extensions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48b04bb7de950c4ba113a40fd67cb133e51dbf0d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e1faf15989a62ff77e110ff17f459c16a9e2ce8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e12a86af8ab0ab8d7d7b2d8bf37ec6dec861e0ff --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__init__.py @@ -0,0 +1,6 @@ +import torch + + +__version__ = torch.version.cuda + +from .cuda import * # noqa: F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14a47e994d546f2891187004d2fd2200eccd6179 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__pycache__/cuda.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__pycache__/cuda.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6057346b3c192c99153996c2b66c6842f73fb6dc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__pycache__/cuda.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__pycache__/cudart.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__pycache__/cudart.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c73997715b0f785902f0473f1d6132953913f9f5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__pycache__/cudart.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cuda.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..ad41f04fc897e33f4530eb42c76a104def58f413 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cuda.py @@ -0,0 +1,24 @@ +# mypy: disable-error-code="no-untyped-def" +# flake8: noqa +import torch + + +class CUdeviceptr: + pass + + +class CUstream: + def __init__(self, v): + pass + + +class CUresult: + CUDA_SUCCESS = True + + +class nvrtc: + pass + + +def cuDeviceGetCount(): + return (CUresult.CUDA_SUCCESS, torch.cuda.device_count()) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cudart.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cudart.py new file mode 100644 index 0000000000000000000000000000000000000000..ca2ee5f1f6163d7b20336d6102ce5d8f97880c87 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cudart.py @@ -0,0 +1,17 @@ +# mypy: disable-error-code="no-untyped-def" +import torch.cuda + + +class cudaError_t: + cudaSuccess = True + + +def cudaFree(n): + return (cudaError_t.cudaSuccess,) + + +def cudaGetDeviceProperties(d): + class DummyError: + value = False + + return (DummyError(), torch.cuda.get_device_properties(d)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/pydot/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/pydot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8aefb6171b682f062cfe57a1876f51b280f120cc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/pydot/__init__.py @@ -0,0 +1,2 @@ +# mypy: disable-error-code="var-annotated" +Dot = None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/pydot/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/pydot/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d7c33463597bfea7ec4ba715ea5392acbe57e59 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/pydot/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0378d35a9c442559373f035e45de19b2be927cd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__init__.py @@ -0,0 +1,3 @@ +# typing: ignore +# flake8: noqa +from .special import * diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b8335ebbb32bb197743d2c1763f1c69202e51fe Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__pycache__/special.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__pycache__/special.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4df49c5923bd3c4984c9f2bb21dba334ead5a18 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__pycache__/special.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/special.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/special.py new file mode 100644 index 0000000000000000000000000000000000000000..79af3029aa0b18d0ad55633f8cca8af8b76b520b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/special.py @@ -0,0 +1,2 @@ +# mypy: disable-error-code="var-annotated" +erf = None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..469fa51630da6b8bfb188dddf3024e32baa25000 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/_cutedsl_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/_cutedsl_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92deb06d4baf31ba76ef4efcd52717e00c440503 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/_cutedsl_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_kernel.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_kernel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ee712c9d007d710c9ef1c510b2374f1749d784d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_kernel.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_op_overrides.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_op_overrides.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5008c37266b28830e77ee8dc30abac230b6c6c55 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_op_overrides.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_scheduling.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_scheduling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b48a1a409217dc2a4fd22c20dbf417fa5c794e4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_scheduling.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_template.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a93278a7d0925e1b4f93a1892d2b867c4bb1949c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__pycache__/cutedsl_template.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..737f5e02475c989af03895e5d014ac8172238f16 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/__pycache__/device_op_overrides.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/__pycache__/device_op_overrides.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abf9a2e597bd7543e57c130077e53f6bf67a79dc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/__pycache__/device_op_overrides.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8469dfd1e934b3b4a1a530f32b04927638d94b9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_conv_template.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_conv_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..974adf2c554d20fddac41d81976660954f53cbb4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_conv_template.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_template.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fc554ff4d53328401ac33265a06db8f10923aba Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_template.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_tile_template.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_tile_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fdaf82b62059de70ed3bcb6b590c234958dd7f6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_tile_template.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_tile_universal_gemm_template.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_tile_universal_gemm_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18757aef640c8d83bba3b14788a4c07c37fc7c0c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_tile_universal_gemm_template.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_universal_gemm_template.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_universal_gemm_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38b20ccf69ce77bc6009d673d0d7ee65d2452c0d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_universal_gemm_template.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/compile_command.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/compile_command.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92de3d37bd9041a666d33420799c5c7c8adc3a74 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/compile_command.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_benchmark_request.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_benchmark_request.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e3083496a1038a7ba424399395e287e174b64ee Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_benchmark_request.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_cpp_scheduling.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_cpp_scheduling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3869a7d4f96375445bee0852d6a1c8448ea2b3b9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_cpp_scheduling.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_kernel.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_kernel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8275842bc9b14367f8155e747559247b51a3ba9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_kernel.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a93753932ee68ab93aa6643586d4b292e8561a26 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template_buffer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template_buffer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c079f334dfc8379f37b09577704b2046cb9c6080 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template_buffer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7083f6bc9d6bf2ef7c715bbc7bbea061e72a98bd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5d35b6054f54445b00202c5dda7dfc6ed2a84de Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/__pycache__/device_op_overrides.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/__pycache__/device_op_overrides.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2320e274af48f1742770529818a84f1effa4316e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/__pycache__/device_op_overrides.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__main__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ca0f1e5a4fb2a6aeb1224285d76e78a05a0f499 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__main__.py @@ -0,0 +1,80 @@ +# mypy: allow-untyped-defs +import argparse +import base64 +import functools +import importlib +import logging +import os +import sys +from typing import TypeVar + +from torch._inductor.async_compile import pre_fork_setup +from torch._inductor.codecache import torch_key +from torch._inductor.compile_worker.subproc_pool import ( + SubprocKind, + SubprocMain, + SubprocPickler, +) +from torch._inductor.compile_worker.utils import _async_compile_initializer +from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path + + +_T = TypeVar("_T") + + +log = logging.getLogger(__name__) + +_set_triton_ptxas_path() + +try: + import triton + + assert triton is not None # preload in parent +except ImportError: + pass + + +def _lookup_and_create_type(base: type[_T], qname: str) -> _T: + """ + Given a base type and qualified name: import & lookup that name, check + that it's of the given type and then instantiate it. + """ + pkg, name = qname.rsplit(".", 1) + mod = importlib.import_module(pkg) + ty = getattr(mod, name) + if not issubclass(ty, base): + raise TypeError(f"Type {ty} is not a subtype of {base}") + return ty() + + +def main(): + try: + parser = argparse.ArgumentParser() + parser.add_argument( + "--pickler", type=functools.partial(_lookup_and_create_type, SubprocPickler) + ) + parser.add_argument("--kind", type=SubprocKind) + parser.add_argument("--workers", type=int) + parser.add_argument("--parent", type=int) + parser.add_argument("--read-fd", type=int) + parser.add_argument("--write-fd", type=int) + parser.add_argument("--torch-key", type=str) + args = parser.parse_args() + if os.getppid() != args.parent: + sys.exit(0) + read_fd = os.fdopen(args.read_fd, "rb") + write_fd = os.fdopen(args.write_fd, "wb") + + pre_fork_setup() + + torch_key.set(base64.b64decode(args.torch_key.encode("utf-8"))) # type: ignore[attr-defined] + + _async_compile_initializer(args.parent) + + SubprocMain(args.pickler, args.kind, args.workers, read_fd, write_fd).main() + except Exception: + log.exception("Uncaught exception in compile_worker subprocess") + + +if __name__ == "__main__": + main() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..addd158d43c46cbf9b33aad7cd33d26e4379d7b0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/__main__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/__main__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfcb2a204e558f793f8173dae798d0ca42795e6f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/__main__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/subproc_pool.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/subproc_pool.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54337354bb4f8fc227969843cba6d30aa2cf0bb5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/subproc_pool.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/timer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/timer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19ae35f44ad9be4bba59db55938e722e1aa2e8fe Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/timer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/tracked_process_pool.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/tracked_process_pool.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..044a0d49d2b49f918a444bb752b108267bd89e0a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/tracked_process_pool.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d05db750661326e301808aed07e16fdeffa7ff8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/subproc_pool.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/subproc_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..07c59b8cbb860fd1ed0e1ff1ba6df34979abdf4f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/subproc_pool.py @@ -0,0 +1,496 @@ +import base64 +import functools +import itertools +import logging +import multiprocessing +import os +import pickle +import struct +import subprocess +import sys +import threading +import traceback +import typing +from collections.abc import Callable +from concurrent.futures import Future, ProcessPoolExecutor +from concurrent.futures.process import BrokenProcessPool +from enum import Enum, IntEnum +from typing import Any, IO, Optional, TypeVar +from typing_extensions import Never, ParamSpec + +# _thread_safe_fork is needed because the subprocesses in the pool can read +# justknobs, e.g., in the Triton compiler. For internal, the import installs +# functionality to destroy singletons before forking and re-enable them after. +import torch._thread_safe_fork # noqa: F401 +from torch._inductor import config +from torch._inductor.codecache import torch_key +from torch._inductor.compile_worker.timer import Timer +from torch._inductor.compile_worker.tracked_process_pool import ( + TrackedProcessPoolExecutor, +) +from torch._inductor.compile_worker.utils import _async_compile_initializer +from torch._inductor.utils import get_ld_library_path, python_subprocess_env +from torch._utils_internal import find_compile_subproc_binary +from torch.monitor import _WaitCounter, _WaitCounterTracker + + +log = logging.getLogger(__name__) + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +class MsgHeader(IntEnum): + ERROR = 0 + SHUTDOWN = 1 + QUIESCE = 2 + WAKEUP = 3 + JOB = 4 + + +def _pack_msg(msg_header: MsgHeader, job_id: int, length: int) -> bytes: + return struct.pack("nnn", int(msg_header), job_id, length) + + +def _unpack_msg(data: bytes) -> tuple[MsgHeader, int, int]: + if not data: + return MsgHeader.ERROR, -1, -1 + msg_header, job_id, length = struct.unpack("nnn", data) + return MsgHeader(msg_header), job_id, length + + +msg_bytes = len(_pack_msg(MsgHeader.JOB, 0, 0)) + + +def _send_msg( + write_pipe: IO[bytes], msg_header: MsgHeader, job_id: int = -1, data: bytes = b"" +) -> None: + length = len(data) + write_pipe.write(_pack_msg(msg_header, job_id, length)) + if length > 0: + write_pipe.write(data) + write_pipe.flush() + + +def _recv_msg(read_pipe: IO[bytes]) -> tuple[MsgHeader, int, bytes]: + msg_header, job_id, length = _unpack_msg(read_pipe.read(msg_bytes)) + data = read_pipe.read(length) if length > 0 else b"" + return msg_header, job_id, data + + +class _SubprocExceptionInfo: + """ + Carries exception info from subprocesses across the wire. traceback + objects are not pickleable, so we store the trace as a string and + use it for the message in the exception thrown in the main process. + """ + + def __init__(self, details: str) -> None: + self.details = details + + +class SubprocException(Exception): + """ + Thrown when a job in a subprocess raises an Exception. + """ + + def __init__(self, details: str, name: str = "") -> None: + self.details = details + super().__init__( + f"An exception occurred in a subprocess:\n\nName={name}\n{details}" + ) + + def with_name(self, name: str) -> "SubprocException": + return SubprocException(self.details, name) + + +class SubprocPickler: + """ + Allows a caller to provide a custom pickler for passing data with the + subprocess. + """ + + def dumps(self, obj: object) -> bytes: + return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) + + def loads(self, data: bytes) -> object: + return pickle.loads(data) + + +class SubprocKind(Enum): + FORK = "fork" + SPAWN = "spawn" + + +class SubprocPool: + """ + Mimic a concurrent.futures.ProcessPoolExecutor, but wrap it in + a subprocess.Popen() to try to avoid issues with forking/spawning + """ + + def __init__( + self, + nprocs: int, + pickler: Optional[SubprocPickler] = None, + kind: SubprocKind = SubprocKind.FORK, + quiesce: bool = False, + ) -> None: + entry = os.path.join(os.path.dirname(__file__), "__main__.py") + self.pickler = pickler or SubprocPickler() + self.kind = kind + + subproc_read_fd, write_fd = os.pipe() + read_fd, subproc_write_fd = os.pipe() + self.write_pipe = os.fdopen(write_fd, "wb") + self.read_pipe = os.fdopen(read_fd, "rb") + torch_key_str = base64.b64encode(torch_key()).decode("utf-8") + + cmd = [ + sys.executable, + entry, + ] + if (binary := find_compile_subproc_binary()) is not None: + cmd = [binary] + + args = [ + f"--pickler={self.pickler.__class__.__module__}.{self.pickler.__class__.__name__}", + f"--kind={self.kind.value}", + f"--workers={nprocs}", + f"--parent={os.getpid()}", + f"--read-fd={str(subproc_read_fd)}", + f"--write-fd={str(subproc_write_fd)}", + f"--torch-key={torch_key_str}", + ] + cmd.extend(args) + log_path = None + self.log_file = None + + if config.worker_suppress_logging: + log_path = os.devnull + log.info("Suppressing compile worker output due to config") + else: + log_path = config.torchinductor_worker_logpath + if not log_path: + log_path = config.get_worker_log_path() + + if log_path: + # pyrefly: ignore [bad-assignment] + self.log_file = open(log_path, "w") # noqa:SIM115 + + self.process = subprocess.Popen( + cmd, + env={ + **python_subprocess_env(), + # Safeguard against creating a SubprocPool in the subprocess. + "TORCH_WARM_POOL": "0", + # Some internal usages need a modified LD_LIBRARY_PATH. + "LD_LIBRARY_PATH": get_ld_library_path(), + }, + pass_fds=(subproc_read_fd, subproc_write_fd), + stdout=self.log_file, + stderr=self.log_file, + ) + self.write_lock = threading.Lock() + self.read_thread = threading.Thread( + target=self._read_thread, name="InductorSubproc", daemon=True + ) + + self.futures_lock = threading.Lock() + self.pending_futures: dict[int, Future[Any]] = {} + # The pending waitcounter, is used to indicate the time when we have any specific job running. + self.pending_waitcounters: dict[int, Any] = {} + self.job_id_count = itertools.count() + + # The running waitcounter indicates the time when the SubProcPool object exists. + self.running = True + self.running_waitcounter = _WaitCounter( + "pytorch.wait_counter.subproc_pool.running" + ).guard() + self.running_waitcounter.__enter__() + + # The quiesce waitcounter indicates when the job is in a quiesced state. + self.quiesce_waitcounter: Optional[_WaitCounterTracker] = None + + # Firstjob is used to capture the time from when the firstjob is queued, to when the first job is done. + self.firstjob = True + self.firstjob_id: Optional[int] = None + self.firstjob_waitcounter = _WaitCounter( + "pytorch.wait_counter.subproc_pool.first_job" + ).guard() + + if quiesce: + self.timer: Optional[Timer] = Timer( + config.quiesce_async_compile_time, self.quiesce + ) + else: + self.timer = None + + # Start thread last to ensure all member variables are initialized + # before any access. + self.read_thread.start() + + def submit( + self, job_fn: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs + ) -> Future[_T]: + if args or kwargs: + # pyrefly: ignore [bad-assignment] + job_fn = functools.partial(job_fn, *args, **kwargs) + job_data = self.pickler.dumps(job_fn) + future: Future[_T] + with self.futures_lock: + job_id = next(self.job_id_count) + self.pending_futures[job_id] = future = Future() + self.pending_waitcounters[job_id] = _WaitCounter( + "pytorch.wait_counter.subproc_pool.job" + ).guard() + self.pending_waitcounters[job_id].__enter__() + if self.quiesce_waitcounter: + self.firstjob = True + self.quiesce_waitcounter.__exit__() + self.quiesce_waitcounter = None + # This can be entered from either quiesce wakeup, or from startup. + if self.firstjob: + self.firstjob_id = job_id + self.firstjob_waitcounter.__enter__() + self.firstjob = False + future.set_running_or_notify_cancel() + self._send(MsgHeader.JOB, job_id, job_data) + return future + + def _send(self, msg_header: MsgHeader, job_id: int = -1, data: bytes = b"") -> None: + with self.write_lock: + if not self.running: + raise RuntimeError("Attempting to use a closed pool") + _send_msg(self.write_pipe, msg_header, job_id, data) + + def _read_thread(self) -> None: + while True: + data = b"" + job_id = -1 + try: + msg_header, job_id, data = _recv_msg(self.read_pipe) + except Exception: + # Something went wrong during the read. There's no way we have a + # valid msg. + log.exception("failure in subproc_pool._recv_msg") + msg_header = MsgHeader.ERROR + + if msg_header != MsgHeader.JOB: + # read_pipe returned None or got exception + if self.running: + log.warning("SubprocPool unclean exit") + self.running = False + self.running_waitcounter.__exit__() + self.read_pipe.close() + # Cancel all the pending futures. + self.shutdown() + return + + try: + result = self.pickler.loads(data) + except Exception as e: + # Something went wrong unpickling. We have a job_id so just + # notify that particular future and continue on. + log.exception("unpickle failure in SubprocPool._read_thread") + result = e + + with self.futures_lock: + if not self.running: + return + if self.timer: + self.timer.record_call() + if isinstance(result, _SubprocExceptionInfo): + # An exception occurred in the submitted job + self.pending_futures[job_id].set_exception( + SubprocException(result.details) + ) + elif isinstance(result, Exception): + # An exception occurred in some of our subprocess machinery. + self.pending_futures[job_id].set_exception(result) + else: + self.pending_futures[job_id].set_result(result) + + self.pending_waitcounters[job_id].__exit__() + del self.pending_waitcounters[job_id] + if self.firstjob_id == job_id: + self.firstjob_waitcounter.__exit__() + + del self.pending_futures[job_id] + + def quiesce(self) -> None: + self._send(MsgHeader.QUIESCE) + if self.quiesce_waitcounter is None: + self.quiesce_waitcounter = _WaitCounter( + "pytorch.wait_counter.subproc_pool.quiesced" + ).guard() + self.quiesce_waitcounter.__enter__() + + def wakeup(self) -> None: + self._send(MsgHeader.WAKEUP) + + def shutdown(self) -> None: + try: + with self.write_lock: + if not self.running: + return + if self.timer: + self.timer.quit() + self.running = False + self.running_waitcounter.__exit__() + _send_msg(self.write_pipe, MsgHeader.SHUTDOWN) + self.write_pipe.close() + self.process.wait(300) + if self.log_file: + self.log_file.close() + except OSError: + log.warning("Ignored OSError in pool shutdown", exc_info=True) + finally: + with self.futures_lock: + for future in self.pending_futures.values(): + if not future.cancel(): + future.set_exception(RuntimeError("SubprocPool closed")) + self.pending_futures.clear() + + +class SubprocMain: + """Communicates with a SubprocPool in the parent process, called by __main__.py""" + + def __init__( + self, + pickler: SubprocPickler, + kind: SubprocKind, + nprocs: int, + read_pipe: IO[bytes], + write_pipe: IO[bytes], + ) -> None: + self.pickler = pickler + self.kind = kind + self.read_pipe = read_pipe + self.write_pipe = write_pipe + self.write_lock = threading.Lock() + self.nprocs = nprocs + self.pool: Optional[ProcessPoolExecutor] = None + self.running = True + + def main(self) -> None: + while True: + msg_header, job_id, data = _recv_msg(self.read_pipe) + if msg_header == MsgHeader.JOB: + self.submit(job_id, data) + elif msg_header == MsgHeader.WAKEUP: + self._start_pool() + elif msg_header == MsgHeader.QUIESCE: + self._quiesce() + else: + return self._shutdown() + + def _quiesce(self) -> None: + if self.pool is not None: + self.pool.shutdown(wait=False) + self.pool = None + + def _shutdown(self) -> None: + with self.write_lock: + self.running = False + try: + _send_msg(self.write_pipe, MsgHeader.SHUTDOWN) + self.write_pipe.close() + except BrokenPipeError: + pass # parent process already shutdown + self.read_pipe.close() + self._quiesce() + + def submit(self, job_id: int, data: bytes) -> None: + while self.running: + try: + self._submit_inner(job_id, data) + return + except BrokenProcessPool: + # If any subprocess in the pool crashes, we get a BrokenProcessPool + # exception and the whole pool becomes unusable. Handle crashes by + # recreating the pool and resubmitting. + self.pool = None + + def _submit_inner(self, job_id: int, data: bytes) -> None: + def callback(fut: Future[Any]) -> None: + if not self.running: + return + try: + result = fut.result() + except Exception as e: + log.exception("Error in subprocess") + result = self.pickler.dumps(e) + assert isinstance(result, bytes) + with self.write_lock: + if self.running: + _send_msg(self.write_pipe, MsgHeader.JOB, job_id, result) + return + + self._start_pool() + assert self.pool is not None + + future = self.pool.submit( + functools.partial(SubprocMain.do_job, self.pickler, data) + ) + future.add_done_callback(callback) + + def _start_pool(self) -> None: + if self.pool is not None: + return + + self.pool = TrackedProcessPoolExecutor( + self.nprocs, + mp_context=multiprocessing.get_context(self.kind.value), + initializer=functools.partial(_async_compile_initializer, os.getpid()), + ) + multiprocessing.util.Finalize( + None, self.pool.shutdown, exitpriority=sys.maxsize + ) + _warm_process_pool(self.pool, self.nprocs) + + @staticmethod + def do_job(pickler: SubprocPickler, data: bytes) -> bytes: + # do the pickle/unpickle in the sub-subproc + job = typing.cast(Callable[[], object], pickler.loads(data)) + + try: + result = job() + except Exception: + result = _SubprocExceptionInfo(traceback.format_exc()) + return pickler.dumps(result) + + +AnyPool = typing.Union[ProcessPoolExecutor, SubprocPool] + + +def _warm_process_pool(pool: ProcessPoolExecutor, n: int) -> None: + # We have to fork processes for compiler workers, but the more memory and other resources that are loaded, the + # slower the os.fork time is, quite drastically. It also holds the GIL so we can't put it on another thread. + + # Examples: + # A simple x + x + x script: 10ms seconds in the middle of the program, 2ms at startup + # tf_efficientnet_b0 benchmark: 50ms! in the middle of the program , 3ms at startup + + # So we want to start the workers early when it is still cheap, and also to allow the workers to get + # ready before we have work for them. + + # ProcessPoolExecutor also does not launch the workers until it finds a point when all the workers are idle. + # But if we waited until then fork time will be long and we will be waiting for the processes to initialize. + + # We force them to start here with some YOLOing of the internal methods. + + if hasattr(pool, "_start_queue_management_thread"): + pool._start_queue_management_thread() + else: + for _ in range(n): + pool._adjust_process_count() + if hasattr(pool, "_start_executor_manager_thread"): + pool._start_executor_manager_thread() + + +class TestException(RuntimeError): + pass + + +def raise_testexc() -> Never: + raise TestException diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/timer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..7c495403b3a55ef8858bd6661607d7bcf25674e8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/timer.py @@ -0,0 +1,55 @@ +from collections.abc import Callable +from threading import Lock, Thread +from time import monotonic, sleep +from typing import Optional, Union + + +class Timer: + """ + This measures how long we have gone since last receiving an event and if it is greater than a set interval, calls a function. + """ + + def __init__( + self, + duration: Union[int, float], # Duration in seconds + call: Callable[[], None], # Function to call when we expire + ) -> None: + # We don't start the background thread until we actually get an event. + self.background_thread: Optional[Thread] = None + self.last_called: Optional[float] = None + self.duration = duration + self.sleep_time = duration / 2 + self.call = call + self.exit = False + + self.lock = Lock() + + def record_call(self) -> None: + with self.lock: + if self.background_thread is None: + self.background_thread = Thread( + target=self.check, daemon=True, name="subproc_worker_timer" + ) + self.background_thread.start() + self.last_called = monotonic() + + def quit(self) -> None: + with self.lock: + self.exit = True + + def check(self) -> None: + while True: + # We have to be sensitive on checking here, to avoid too much impact on cpu + sleep(self.sleep_time) + with self.lock: + if self.exit: + return + assert self.last_called is not None + if self.last_called + self.duration >= monotonic(): + continue + self.last_called = None + self.background_thread = None + + # Releasing lock in case self.call() takes a very long time or is reentrant + self.call() + return diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/tracked_process_pool.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/tracked_process_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..546a5cbc6395a104cede30dd94054cfb12193a1b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/tracked_process_pool.py @@ -0,0 +1,113 @@ +import atexit +import concurrent +import dataclasses +import logging +import threading +from collections.abc import Callable +from concurrent.futures import Future, ProcessPoolExecutor +from dataclasses import dataclass +from multiprocessing.context import BaseContext +from time import time +from typing import Any, Optional, TypeVar +from typing_extensions import ParamSpec + +# _thread_safe_fork is needed because the subprocesses in the pool can read +# justknobs, e.g., in the Triton compiler. For internal, the import installs +# functionality to destroy singletons before forking and re-enable them after. +import torch._thread_safe_fork # noqa: F401 + + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +log = logging.getLogger(__name__) + + +@dataclass +class _QueueStats: + # Mapping from id(future) -> start time + pending: dict[int, float] = dataclasses.field(default_factory=dict) + timing: list[float] = dataclasses.field(default_factory=list) + enqueue_count: int = 0 + dequeue_count: int = 0 + max_queue_depth: int = 0 + pool_count: int = 0 + + +# The queue statistics tracked by TrackedProcessPoolExecutor. Always grab +# _queue_stats_lock before touching. +_queue_stats = _QueueStats() +_queue_stats_lock = threading.Lock() + + +class TrackedProcessPoolExecutor(ProcessPoolExecutor): + def __init__( + self, + max_workers: Optional[int] = None, + mp_context: Optional[BaseContext] = None, + initializer: Optional[Callable[[], object]] = None, + ) -> None: + with _queue_stats_lock: + _queue_stats.pool_count += 1 + super().__init__(max_workers, mp_context, initializer) + + def _record_dequeue(self, f: Future[Any]) -> None: + now = time() + with _queue_stats_lock: + stats = _queue_stats + if (start_time := stats.pending.pop(id(f), None)) is None: + return + stats.dequeue_count += 1 + duration = now - start_time + stats.timing.append(duration) + + def _record_enqueue(self, f: Future[Any]) -> None: + # Monkeypatch the set_running_or_notify_cancel so we can track when the Future moves out of PENDING. + saved_running_or_notify_cancel = f.set_running_or_notify_cancel + + def set_running_or_notify_cancel() -> Any: + self._record_dequeue(f) + return saved_running_or_notify_cancel() + + now = time() + with _queue_stats_lock: + stats = _queue_stats + stats.pending[id(f)] = now + stats.enqueue_count += 1 + stats.max_queue_depth = max(stats.max_queue_depth, len(stats.pending)) + f.set_running_or_notify_cancel = set_running_or_notify_cancel # type: ignore[method-assign] + + if f._state != concurrent.futures._base.PENDING: + self._record_dequeue(f) + + def submit( + self, fn: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs + ) -> Future[_R]: + # pyrefly: ignore [bad-argument-type] + f = super().submit(fn, *args, **kwargs) + self._record_enqueue(f) + return f + + +@atexit.register +def _queue_stats_report() -> None: + stats = _queue_stats + if stats.pool_count == 0: + return + + timing = stats.timing + timing.sort() + + log.info("AsyncCompile Metrics:") + log.info(" Pools %s", stats.pool_count) + log.info( + " Items %d enqueued / %d dequeued", stats.enqueue_count, stats.dequeue_count + ) + log.info(" Max Queue Depth: %d", stats.max_queue_depth) + n = len(timing) + if n > 0: + log.info(" Longest queue time: %0.2fs", timing[-1]) + log.info(" P50: %0.2fs", timing[n // 2]) + if n >= 20: + log.info(" P95: %0.2fs", timing[n * 95 // 100]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4b5e21630c270ada0f45a1f3ff318620fa2deba --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/compile_worker/utils.py @@ -0,0 +1,54 @@ +import os +import signal +from threading import Thread +from time import sleep +from typing import Optional + + +_IN_TOPLEVEL_PROCESS = True + + +def in_toplevel_process() -> bool: + global _IN_TOPLEVEL_PROCESS + return _IN_TOPLEVEL_PROCESS + + +# If this process dies abnormally (e.g. segfault) +# it will not shut down the workers. Instead, +# the workers will have their parent reassigned to the +# init process. This launches a separate thread to +# watch for the worker getting reassigned, +# and cleans it up in this case. +# +# This function cannot be an inner function since otherwise mp_context="spawn" would +# not work for ProcessPoolExecutor since inner functions cannot be pickled. +def _async_compile_initializer(orig_ppid: int) -> None: + import torch._C + + def run() -> None: + while True: + sleep(60) + if orig_ppid != os.getppid(): + os.kill(os.getpid(), signal.SIGKILL) + + global _watchdog_thread, _original_parent + _original_parent = orig_ppid + _watchdog_thread = Thread(target=run, daemon=True) + _watchdog_thread.start() + # Ignore Ctrl-C (i.e. SIGINT) sent to pool workers to avoid meaningless log spam. + signal.signal(signal.SIGINT, signal.SIG_IGN) + + # Install a crash handler to print out the stacktrace for SEGV + torch._C._initCrashHandler() + + # Set a bit to distinguish async_compile subprocesses from the toplevel process. + global _IN_TOPLEVEL_PROCESS + _IN_TOPLEVEL_PROCESS = False + + +_watchdog_thread: Optional[Thread] = None +_original_parent: Optional[int] = None + + +def has_parent_changed() -> bool: + return _original_parent != os.getppid() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e81660cea5aaf469d11ec98ac981007fd811762e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/b2b_gemm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/b2b_gemm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29a5124ac46dc0907f3047996d20491ae280ceaf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/b2b_gemm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4eaecec3d7d8930ae44b2c658d93b4e185eec319 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/bucketing.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/bucketing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b799d830978d1e120a151f131a5c347f738e8899 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/bucketing.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/control_dependencies.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/control_dependencies.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5dac4b2b550e5ee7ac2965b0f28e7753574d5de Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/control_dependencies.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/ddp_fusion.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/ddp_fusion.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f31d350642a280ab897104826c6d5471336d472f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/ddp_fusion.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10f894ab85b82d05f42699a3033570117a5ee1a3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e64da8d17e05435c5ae6a015d71a9937b51cae3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae1a4a4d4a4124a72ae677669d978ec7b2679d1d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5f560201b3641947ee2228936038b4eb6a0ebc2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/fsdp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/fsdp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a393117e1a09f701bc50777bddfbb9de6f8d4bea Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/fsdp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f32333b477fa4bc64ffad9d2e908d6834a0bf3a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/graph_view.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/graph_view.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa5a720822c43eabe70799d11dc7729cb61dc984 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/graph_view.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e87c12e5ff1e8cba67276fe12ac36df93927f6bd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42e42e78defab12b7ff0a8ea3e448c2bcde12e18 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/memory_estimator.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/memory_estimator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53b4588f71fbc02393cef58f96a95b6237249c73 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/memory_estimator.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/micro_pipeline_tp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/micro_pipeline_tp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9602a4c13b1bd0a420311675159a8e5a80228e9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/micro_pipeline_tp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75154f2d5412416e8455e0fcab05151021d15e99 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..416b694ea5139eddbbfc694ca8b40cd94cf812ca Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/node_runtime_estimation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/node_runtime_estimation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8891c7a764d0552c9795a4864c379d2d78d9229 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/node_runtime_estimation.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1b68a76d820a411ed6ffb2704c459b76183fa69 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/overlap_manual_scheduling.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/overlap_manual_scheduling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98dd99d9ddc918cf79f009932823d653c100f1e8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/overlap_manual_scheduling.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/overlap_preserving_bucketer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/overlap_preserving_bucketer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ce2b9f839b7ae0c9cf3c62f49f69dfca742e7c4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/overlap_preserving_bucketer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/overlap_scheduling.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/overlap_scheduling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99adb47c9e9f5359ef4af317e146d99e338fb106 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/overlap_scheduling.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b8e770508e1ae8796af4d8fbe2bf7de5e852c7b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7620bc8194fa01cc5d76ded4d6c739f648ab0ace Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1265e31611b5c0c84e40f84f2208d842a2d286a5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31ec89152d2a747c26ccd275a62d229fbdc52a9c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53f477c26ddf429acb400f780559f2b706bb37df Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71932f0bf4cbc454ec6748a23fe0785596d9d14a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7df8b0899f7f121f84e4e231200c66e35d82ce06 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..424bd3c9c580770a6ca7f4ebd573b5aa4eb7ae62 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a31d7d07f1c278902e3c22e843939a79b9f779c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca979427c7a8e9a9fb0d6c37fc87a03630842051 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a629c779ef1cbc8cc413f8b1bc442f9d5c2efb4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14fb4eec8237acc0b9cf1bd17110c3dbed11b0c3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3a788919936c555e00a566fda69cc80786fa730 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fed6dc4ce20a79d44f2a6130ffe74013928b0a9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c9199f0eb513d238e916066331527e8724856f6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_18.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_18.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a54e4ba5f79eb0c2dd726f813bb0017f2e2e15de Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_18.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_19.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_19.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6c4c827b40f3d9a12ca869a5f69b4f0639a863e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_19.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c429120c5f9f429aab2e517b554ea1273108d22 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_20.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_20.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8acf51befc8b2685ca66045c062769f6836c85c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_20.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_21.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_21.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecb994a6ef4109fab32b9b35265451884db228e1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_21.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_22.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_22.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08cd9ad3539a74b92f2741db64f29fee658576c6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_22.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_23.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_23.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50e4bc22cc6cccac588bdf02dad8b19aaf6ea367 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_23.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_24.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_24.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..466ac2cc9d7c526431734fad46a28653260ff735 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_24.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca3a6dc3545bf1000456428c8f807f180c07d2fb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4179273ebf4c739f6fe53339ded21506dcc10213 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac343e954a08762845c41da610571e0d6c48eb03 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84d024fdf9891b7a60568544feac960590460d7e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a7e3c4cb2e57f22da09596ebf77e4e433e4196c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1395ca0e3d720342e8da169d91cc1048c007ed4f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3232af5ad92f5435ddedff907a46b44e75079687 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/addmm_pattern.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/addmm_pattern.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f67eb0392d29ff4a0f07d686500fe84c5ea84a2b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/addmm_pattern.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/bmm_pattern.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/bmm_pattern.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb3ba6df4a807862752c894b0437566fd6588928 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/bmm_pattern.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/mm_pattern.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/mm_pattern.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd02c157d9c57ac91f20ea34eddc4f7d15a2768a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/mm_pattern.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py new file mode 100644 index 0000000000000000000000000000000000000000..6c8e6de3ff3cba5f0ebcec729c33061b04319d6a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py @@ -0,0 +1,174 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_1_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_1_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_1_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py new file mode 100644 index 0000000000000000000000000000000000000000..567390838ede7dc4d4181f601f020e8066cb07b7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py @@ -0,0 +1,205 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_10_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_10_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, convert_element_type_default_3, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_10_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_10_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py new file mode 100644 index 0000000000000000000000000000000000000000..6aa39474c67dd677008c8e7e9266cc875a153196 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py @@ -0,0 +1,204 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_11_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_11_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_11_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_11_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py new file mode 100644 index 0000000000000000000000000000000000000000..87302d1bab3694a33eac14e263dca86c9f702c75 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py @@ -0,0 +1,220 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_12_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_12_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_12_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_12_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py new file mode 100644 index 0000000000000000000000000000000000000000..d465c1cb4e22b14bfbbdd35d5ed28a43af891523 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py @@ -0,0 +1,130 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2) +amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) +neg_default = CallFunction(aten.neg.default, div_Tensor) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4, _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, fma_default, permute_default_2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, fma_default) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1')) +_sfdp_pattern_13_training = MultiOutputPattern([bmm_default_1, + bmm_default_3, + permute_default_4, + bmm_default_5, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2) +amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +_sfdp_pattern_13_inference = CallFunction(aten.bmm.default, div_Tensor, KeywordArg('value'), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default) +convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, convert_element_type_default_5, permute_default_2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, convert_element_type_default_5) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1')) +_sfdp_pattern_13_half_training = MultiOutputPattern([bmm_default_1, + bmm_default_3, + permute_default_4, + bmm_default_5, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default) +convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +_sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, convert_element_type_default_1, KeywordArg('value'), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py new file mode 100644 index 0000000000000000000000000000000000000000..f102038e82c6d5858b8b334e956d87c7e86a9d22 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py @@ -0,0 +1,210 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_14_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_14_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_14_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_14_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py new file mode 100644 index 0000000000000000000000000000000000000000..e1cbb0df340bab14188ec9d5f04a29035dba8d84 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py @@ -0,0 +1,230 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_8, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_9 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_15_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_15_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_8, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_9 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_15_half_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_15_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py new file mode 100644 index 0000000000000000000000000000000000000000..3a15abb9088ff5ffe8cd9af43df11ccc0d5bc143 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py @@ -0,0 +1,599 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_mask_fp32_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_mask_fp32_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_mask_fp32_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_mask_fp32_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py new file mode 100644 index 0000000000000000000000000000000000000000..812708907b3414e2c864ed36b98f2199e63ae5d2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py @@ -0,0 +1,246 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_8, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_9 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_17_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_17_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_8, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_5) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_9 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_17_half_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_17_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py new file mode 100644 index 0000000000000000000000000000000000000000..567d898ed204257e23a2002479afc5d26cba623b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py @@ -0,0 +1,453 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_bs1_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_bs1_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_5, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_half_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_half_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_5, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_half_bs1_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_half_bs1_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py new file mode 100644 index 0000000000000000000000000000000000000000..5c6d316351b8595f75fdfd262a4cb2171a8a6b1e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py @@ -0,0 +1,209 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_19_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_19_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_3, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_19_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_19_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py new file mode 100644 index 0000000000000000000000000000000000000000..f28da434ef0c85ca3d80095e68c052e8dc19dd2d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py @@ -0,0 +1,174 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_2_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_2_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_2_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_2_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py new file mode 100644 index 0000000000000000000000000000000000000000..9185aa3b1e3305cfa28f8080be04350beb17c065 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py @@ -0,0 +1,244 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_3, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_8, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) +view_default_9 = CallFunction(aten.view.default, where_self_1, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_10, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_20_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_3, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_20_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_8, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_5) +view_default_9 = CallFunction(aten.view.default, where_self_1, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_10, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_20_half_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_20_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py new file mode 100644 index 0000000000000000000000000000000000000000..4ebd4a4e14e48439eaa0a8b50e9fcf72145dc1a8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py @@ -0,0 +1,391 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_21_training = MultiOutputPattern([view_default_7, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +_sfdp_pattern_21_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_21_bs1_training = MultiOutputPattern([view_default_7, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +_sfdp_pattern_21_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_21_half_training = MultiOutputPattern([view_default_7, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +_sfdp_pattern_21_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_21_half_bs1_training = MultiOutputPattern([view_default_7, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +_sfdp_pattern_21_half_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py new file mode 100644 index 0000000000000000000000000000000000000000..0971c09ad972f2bc07ac6ee9f548255a3760faa2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py @@ -0,0 +1,415 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_22_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_22_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_22_bs1_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_22_bs1_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_22_half_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_22_half_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_22_half_bs1_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_22_half_bs1_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py new file mode 100644 index 0000000000000000000000000000000000000000..2be036c2e8ae7922b51690da782c5565656d7998 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py @@ -0,0 +1,407 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_23_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_23_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_23_bs1_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_23_bs1_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_4, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_5, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_7 = CallFunction(prims.convert_element_type.default, convert_element_type_default_6, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_7, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_23_half_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_23_half_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_4, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_5, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_7 = CallFunction(prims.convert_element_type.default, convert_element_type_default_6, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_7, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_23_half_bs1_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_6 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_23_half_bs1_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_24.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_24.py new file mode 100644 index 0000000000000000000000000000000000000000..72f23373c143e4f113f04d5228966e5e79c448a0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_24.py @@ -0,0 +1,153 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +view_default = CallFunction(aten.view.default, KeywordArg('query'), Ignored(), _users=2) +view_default_1 = CallFunction(aten.view.default, KeywordArg('key'), Ignored()) +permute_default = CallFunction(aten.permute.default, view_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, permute_default) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attention_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_3, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_3, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=4) +view_default_4 = CallFunction(aten.view.default, KeywordArg('value'), Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, div_Tensor, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +mul_Tensor = CallFunction(aten.mul.Tensor, bmm_default_2, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_7 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, view_default_7, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +view_default_10 = CallFunction(aten.view.default, permute_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, div_Tensor, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_24_training = MultiOutputPattern([view_default_5, + view_default_9, + view_default_10, + view_default_11, + None +]) + + +view_default = CallFunction(aten.view.default, KeywordArg('query'), Ignored()) +view_default_1 = CallFunction(aten.view.default, KeywordArg('key'), Ignored()) +permute_default = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, permute_default) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attention_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_3, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_3, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +view_default_4 = CallFunction(aten.view.default, KeywordArg('value'), Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, div_Tensor, view_default_4) +_sfdp_pattern_24_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +view_default = CallFunction(aten.view.default, KeywordArg('query'), Ignored(), _users=2) +view_default_1 = CallFunction(aten.view.default, KeywordArg('key'), Ignored()) +permute_default = CallFunction(aten.permute.default, view_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, permute_default) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attention_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_3, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_3, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +view_default_4 = CallFunction(aten.view.default, KeywordArg('value'), Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, convert_element_type_default, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_1, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_7 = CallFunction(aten.view.default, fma_default, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +view_default_8 = CallFunction(aten.view.default, convert_element_type_default_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +view_default_10 = CallFunction(aten.view.default, permute_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, convert_element_type_default, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_24_half_training = MultiOutputPattern([view_default_5, + view_default_9, + view_default_10, + view_default_11, + None +]) + + +view_default = CallFunction(aten.view.default, KeywordArg('query'), Ignored()) +view_default_1 = CallFunction(aten.view.default, KeywordArg('key'), Ignored()) +permute_default = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, permute_default) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attention_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_3, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_3, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, KeywordArg('value'), Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, convert_element_type_default, view_default_4) +_sfdp_pattern_24_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py new file mode 100644 index 0000000000000000000000000000000000000000..2c7f7519ad0570d2c2f700d4081c9b7253d16657 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py @@ -0,0 +1,190 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_3_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_3_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_3_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_3_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py new file mode 100644 index 0000000000000000000000000000000000000000..dc9cfd506f950415f4f90b49edf83815432a641c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py @@ -0,0 +1,190 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor_4, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5) +mul_Tensor_6 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_6, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_4_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_4_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_4, Ignored()) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +mul_Tensor_6 = CallFunction(aten.mul.Tensor, convert_element_type_default_5, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_6, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_4_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_4_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py new file mode 100644 index 0000000000000000000000000000000000000000..f211e56b17a0a19c05bcb0efc681ed2623f4edf7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py @@ -0,0 +1,178 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_5_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_5_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_5_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_5_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py new file mode 100644 index 0000000000000000000000000000000000000000..01304bf415163909c5ec5b03064ce064697e1de9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py @@ -0,0 +1,194 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_6_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_6_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_6_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_6_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py new file mode 100644 index 0000000000000000000000000000000000000000..b463c7e64a6130dd85063f5fb88c2317c392c8f2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py @@ -0,0 +1,221 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_7_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_7_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_7_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_7_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py new file mode 100644 index 0000000000000000000000000000000000000000..3faff67089b17ad370d4d7642539c7ce3fd5d235 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py @@ -0,0 +1,205 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_8_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_8_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_8_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_8_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf77120e836a5b577ea8a335f00bd63fd27163a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py @@ -0,0 +1,221 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_9_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_9_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_9_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_9_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py new file mode 100644 index 0000000000000000000000000000000000000000..70d672442170905a411de63187a5b579b286bf73 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py @@ -0,0 +1,53 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +addmm_default = CallFunction(aten.addmm.default, KeywordArg('input'), KeywordArg('mat1'), KeywordArg('mat2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha')) +mul_Scalar = CallFunction(aten.mul.Scalar, KeywordArg('tangents_1'), KeywordArg('beta')) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, mul_Scalar, Ignored(), True) +view_default = CallFunction(aten.view.default, sum_dim_IntList, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored()) +mm_default = CallFunction(aten.mm.default, KeywordArg('tangents_1'), permute_default) +mul_Scalar_1 = CallFunction(aten.mul.Scalar, mm_default, KeywordArg('alpha')) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored()) +mm_default_1 = CallFunction(aten.mm.default, permute_default_1, KeywordArg('tangents_1')) +mul_Scalar_2 = CallFunction(aten.mul.Scalar, mm_default_1, KeywordArg('alpha')) +addmm_pattern_training = MultiOutputPattern([addmm_default, + view_default, + mul_Scalar_1, + mul_Scalar_2, + None, + None +]) + + +addmm_pattern_inference = CallFunction(aten.addmm.default, KeywordArg('input'), KeywordArg('mat1'), KeywordArg('mat2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha'), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py new file mode 100644 index 0000000000000000000000000000000000000000..7b5ac59d6f06c97523e071e9b3ea78516ff09c0e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py @@ -0,0 +1,45 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('mat1'), KeywordArg('mat2')) +permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, permute_default_1, KeywordArg('tangents_1')) +bmm_pattern_training = MultiOutputPattern([bmm_default, + bmm_default_1, + bmm_default_2 +]) + + +bmm_pattern_inference = CallFunction(aten.bmm.default, KeywordArg('mat1'), KeywordArg('mat2'), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py new file mode 100644 index 0000000000000000000000000000000000000000..058a2f881e3a52cb147cfd3fa0ef2bbd0a25945a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py @@ -0,0 +1,45 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2')) +permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored()) +mm_default_1 = CallFunction(aten.mm.default, KeywordArg('tangents_1'), permute_default) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored()) +mm_default_2 = CallFunction(aten.mm.default, permute_default_1, KeywordArg('tangents_1')) +mm_pattern_training = MultiOutputPattern([mm_default, + mm_default_1, + mm_default_2 +]) + + +mm_pattern_inference = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'), _users=0) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/custom_op.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/custom_op.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2160df127557aea94cf5902dc373d11d5ebc9002 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/custom_op.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..162686a99a54357c229cf2d09ba73aae3001cafd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffa02b38b67cfa475a208e5238dbe8e40888a498 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm_grouped.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm_grouped.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2a1920af3436e7e8d7f25ec7a3a3bfca3a2f7ad Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm_grouped.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e301867c702c12aba4b1deaea53d0a67a979596 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..67a604adcb1e6057015f7fa1833d766b37d7c61b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__init__.py @@ -0,0 +1,3 @@ +# mypy: allow-untyped-defs +# Import so here and then reimport above so that register_lowering gets triggered +from . import flex_attention, flex_decoding diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/common.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/common.py new file mode 100644 index 0000000000000000000000000000000000000000..b604514f30d1436de9db6433e00fea28a621e8fc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/common.py @@ -0,0 +1,356 @@ +# mypy: allow-untyped-defs +"""Common utilities and functions for flex attention kernels""" + +import math +from collections.abc import Sequence +from functools import partial +from pathlib import Path +from typing import Any, Optional, TYPE_CHECKING, Union + +import sympy + +import torch +from torch._inductor.virtualized import V +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_map, tree_map_only + + +if TYPE_CHECKING: + from torch._inductor.codegen.cuda_combined_scheduling import _IntLike +else: + _IntLike = Union[int, sympy.Expr] + + +from ...ir import ( + ComputedBuffer, + ExternKernel, + FixedLayout, + FlexibleLayout, + get_fill_order, + InputBuffer, + IRNode, + MutationLayoutSHOULDREMOVE, + Scatter, + ShapeAsConstantBuffer, + StorageBox, + Subgraph, + TensorBox, +) +from ...lowering import ( + _full, + check_and_broadcast_indices, + expand, + index_output_size_and_inner_fn, + to_dtype, +) +from ...select_algorithm import realize_inputs +from ...utils import load_template + + +SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]] + + +def zeros_and_scatter_lowering(shape: list[int], indices, values): + """To support backwards on captured buffers we register a specific lowering for our specific custom up""" + # Always accumulate into fp32 then cast + grad = _full(0, values.get_device(), torch.float32, shape) + assert isinstance(grad, TensorBox) + grad.realize() + x_size = grad.get_size() + values = to_dtype(values, grad.get_dtype()) + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + indices, tensor_indices = check_and_broadcast_indices(indices, grad.get_device()) + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + indexed_size = [x_size[i] for i in range(len(indices))] + + expected_vals_size, inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + None, + check=True, + ) + + values = expand(values, expected_vals_size) + device = grad.get_device() + assert device is not None + scatter = Scatter( + device=device, + dtype=grad.get_dtype(), + inner_fn=values.make_loader(), + ranges=expected_vals_size, # iter_ranges, + output_indexer=inner_fn, + scatter_mode="atomic_add", + ) + + buffer = ComputedBuffer( + name=grad.data.data.name, # type: ignore[attr-defined] + layout=MutationLayoutSHOULDREMOVE(grad), + data=scatter, + ) + return buffer + + +def get_fwd_subgraph_outputs( + subgraph_buffer: SubgraphResults, mask_graph_buffer: SubgraphResults +) -> list[Optional[ComputedBuffer]]: + subgraph_buffer = ( + # pyrefly: ignore [bad-assignment] + subgraph_buffer if isinstance(subgraph_buffer, Sequence) else [subgraph_buffer] + ) + mask_graph_buffer = ( + # pyrefly: ignore [bad-assignment] + mask_graph_buffer + if isinstance(mask_graph_buffer, Sequence) + else [mask_graph_buffer] + ) + # pyrefly: ignore [not-iterable] + return [*subgraph_buffer, *mask_graph_buffer] + + +def build_subgraph_module_buffer( + args: list[Union[TensorBox, ShapeAsConstantBuffer]], + graph_module: torch.fx.GraphModule, +) -> SubgraphResults: + """This function's goal is to take in the required args and produce the subgraph buffer + The subgraph buffer is a ComputedBuffer that will be inlined into the triton template + + Args: + args: The args that are passed into the subgraph. Contains both fixed and lifted inputs. + subgraph: The Subgraph ir for which to produce the output node + """ + # This one we gotta keep lazy + from ...subgraph_lowering import PointwiseSubgraphLowering + + pw_subgraph = PointwiseSubgraphLowering( + graph_module, + root_graph_lowering=V.graph, + allowed_mutations=OrderedSet([torch.ops.flex_lib.zeros_and_scatter.default]), + additional_lowerings={ + torch.ops.flex_lib.zeros_and_scatter.default: zeros_and_scatter_lowering + }, + ) + with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type] + pw_subgraph.run(*args) + + def convert_output_node_to_buffer(output_buffer) -> Optional[ComputedBuffer]: + if output_buffer is None: + return None + if isinstance(output_buffer, ComputedBuffer): + # These nodes are coming from the output of zeros_and_scatter + return output_buffer + assert isinstance(output_buffer, TensorBox), ( + "The output node for flex attention's subgraph must be a TensorBox, but got: ", + type(output_buffer), + ) + assert isinstance(output_buffer.data, StorageBox), ( + "The output node for the flex attention subgraph must be a StorageBox, but got: ", + type(output_buffer), + ) + device = output_buffer.data.get_device() + assert device is not None + subgraph_buffer = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=device, + dtype=output_buffer.data.get_dtype(), + size=output_buffer.data.get_size(), + ), + data=output_buffer.data.data, # type: ignore[arg-type] + ) + return subgraph_buffer + + return tree_map(convert_output_node_to_buffer, pw_subgraph.graph_outputs) + + +def build_subgraph_buffer( + args: list[Union[TensorBox, ShapeAsConstantBuffer]], subgraph: Subgraph +) -> SubgraphResults: + return build_subgraph_module_buffer(args, subgraph.graph_module) + + +def maybe_realize(args: list[Optional[IRNode]]): + """Accepts a list of optional IRNodes and returns a list of realized IRNodes""" + return tree_map( + lambda x: ( + realize_inputs(x) + if x is not None and not isinstance(x, sympy.Symbol) + else x + ), + args, + ) + + +def freeze_irnodes(tree: Any) -> Any: + """Freeze layouts for every IRNode contained in a pytree.""" + + if tree is None: + return None + + def _freeze(node: IRNode) -> IRNode: + try: + node.freeze_layout() + except NotImplementedError: + pass + return node + + return tree_map_only(IRNode, _freeze, tree) + + +def create_placeholder( + name: str, + dtype: torch.dtype, + device: torch.device, + size: Optional[list[int]] = None, +) -> Union[TensorBox, ShapeAsConstantBuffer]: + """Creates a placeholder input buffers for producing subgraph_output.""" + input_buffer = InputBuffer( + name=name, + layout=FixedLayout( + device, + dtype, + size if size else [], + FlexibleLayout.contiguous_strides(size) if size else [], + ), + ) + return TensorBox.create(input_buffer) + + +def construct_strides( + sizes: Sequence[_IntLike], + fill_order: Sequence[int], +) -> Sequence[_IntLike]: + """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" + # Initialize strides + assert len(sizes) == len(fill_order), ( + "Length of sizes must match the length of the fill order" + ) + strides: list[_IntLike] = [0] * len(sizes) + + # Start with stride 1 for the innermost dimension + current_stride: _IntLike = 1 + + # Iterate through the fill order populating strides + for dim in fill_order: + strides[dim] = current_stride + current_stride *= sizes[dim] + + return strides + + +def infer_dense_strides( + size: Sequence[_IntLike], + orig_strides: Sequence[_IntLike], +): + """This is a mirror of the same function in aten/src/ATen/ExpandUtils.cpp + + Args: + size: The size of the output tensor + orig_strides: The strides of the input tensor + Returns: + List[int]: Dense non-overlapping strides that preserve the input tensor's layout permutation. + The returned strides follow the same stride propagation rules as TensorIterator. This matches + The behavior of empty_like() + """ + fill_order = get_fill_order(orig_strides, V.graph.sizevars.shape_env) + return construct_strides(size, fill_order) + + +def create_indices_fake(x) -> torch.Tensor: + """Create a fake indices that is used for autotuning.""" + size = [V.graph.sizevars.size_hint(i) for i in x.get_size()] + indices = torch.arange(0, size[-1], dtype=x.get_dtype(), device=x.get_device()) + indices = indices.expand(size).contiguous() + return indices + + +def create_num_blocks_fake_generator(sparse_indices): + """Create a fake num_blocks that is used for autotuning. + + The idea here is that we need to create a real tensor with real data + that's representative for benchmarking. + For example, returning all zeros for the `kv_num_blocks` input would mean + that we are computing 0 blocks for each row, which would provide bogus + autotuning results. + + In this case, we choose to use min(16, max_block) blocks, because I + (Horace) think it'll probably result in pretty representative performance. + If it's too short then prefetching won't help. If it's too long then + autotuning will take longer for no good reason. + """ + + def create_num_blocks_fake(x) -> torch.Tensor: + num_blocks_for_autotuning = V.graph.sizevars.size_hint(sparse_indices.shape[-1]) + size = [V.graph.sizevars.size_hint(i) for i in x.get_size()] + return torch.full( + size, + num_blocks_for_autotuning, + dtype=x.get_dtype(), + device=x.get_device(), + ) + + return create_num_blocks_fake + + +def contiguous_last_dim(x): + """Ensure that realized IR node has a contiguous stride in the last dimension.""" + strides = x.maybe_get_stride() + if strides and strides[-1] != 1: + contiguous_stride_order = list(reversed(range(len(x.get_size())))) + return ExternKernel.require_stride_order(x, contiguous_stride_order) + return x + + +def set_head_dim_values( + kernel_options: dict[str, Any], qk_head_dim, v_head_dim, graph_sizevars +): + """ + Mutates kernel options, adding head dimension calculations. + + Args: + kernel_options: Dictionary to populate with options + qk_head_dim: Query/Key head dimension + v_head_dim: Value head dimension + graph_sizevars: Graph size variables object with guard_int method + + """ + # QK dimensions + qk_head_dim_static = graph_sizevars.guard_int(qk_head_dim) + kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim_static) + kernel_options.setdefault( + "QK_HEAD_DIM_ROUNDED", next_power_of_two(qk_head_dim_static) + ) + + # V dimensions + v_head_dim_static = graph_sizevars.guard_int(v_head_dim) + kernel_options.setdefault("V_HEAD_DIM", v_head_dim_static) + kernel_options.setdefault( + "V_HEAD_DIM_ROUNDED", next_power_of_two(v_head_dim_static) + ) + + # Safety flag + kernel_options.setdefault( + "SAFE_HEAD_DIM", + is_power_of_2(qk_head_dim_static) and is_power_of_2(v_head_dim_static), + ) + + +def is_power_of_2(n): + return n != 0 and ((n & (n - 1)) == 0) + + +def next_power_of_two(n): + if n <= 0: + return 1 + return 2 ** math.ceil(math.log2(n)) + + +_FLEX_TEMPLATE_DIR = Path(__file__).parent / "templates" +load_flex_template = partial(load_template, template_dir=_FLEX_TEMPLATE_DIR) + + +# Template strings have been moved to templates/common.py.jinja diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_attention.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..d36b8d56cc711504dad6f9071453e887e23e1a83 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_attention.py @@ -0,0 +1,977 @@ +# mypy: allow-untyped-defs +"""Triton Implementation of the flex_attention Kernel""" + +from __future__ import annotations + +import logging +import math +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, cast, Optional, TYPE_CHECKING, Union + +import sympy + +import torch +from torch._inductor.virtualized import V +from torch.nn.attention.flex_attention import _Backend + +from ...ir import ComputedBuffer, ExternKernel, FixedLayout, TensorBox +from ...lowering import empty, empty_strided, lowerings, register_lowering +from ...select_algorithm import ( + autotune_select_algorithm, + SymbolicGridFn, + TritonTemplate, +) +from .common import ( + build_subgraph_buffer, + create_indices_fake, + create_num_blocks_fake_generator, + create_placeholder, + freeze_irnodes, + get_fwd_subgraph_outputs, + infer_dense_strides, + load_flex_template, + maybe_realize, + set_head_dim_values, + SubgraphResults, +) +from .flex_cpu import lower_cpu +from .flex_decoding import _use_flex_decoding, create_flex_decoding_kernel +from .flex_flash_attention import ( + _use_flex_flash_attention, + _use_flex_flash_attention_backward, + create_flex_flash_attention_backward_kernel, + create_flex_flash_attention_kernel, +) + + +if TYPE_CHECKING: + from ...template_heuristics.triton import FlexBwDConfig, FlexConfig + + +log = logging.getLogger(__name__) +aten = torch.ops.aten +Expr = sympy.Expr + + +def _sanitize_kernel_options_for_triton( + kernel_options: dict[str, Any], +) -> tuple[dict[str, Any], _Backend]: + """We always strip quotes around str values, we only need this in lowering, so we pop it here + to avoid passing to triton constexpr dict + """ + sanitized = dict(kernel_options) + backend = cast(_Backend, sanitized.pop("BACKEND", "AUTO")) + return sanitized, backend + + +@SymbolicGridFn +def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta, *, cdiv): + """How is this kernel parallelized? + We create a grid of (ceil_div(n_queries, query_block_size), batch_size, num_heads) + Each block is responsible for iterating over blocks of keys and values calculating + the final attention output. + """ + return (cdiv(num_queries, meta["BLOCK_M"]), batch_size, q_heads) + + +def get_float32_precision(): + if ( + ( + torch.backends.cuda.matmul.fp32_precision == "ieee" + if torch.backends.cuda.matmul.fp32_precision != "none" + else torch.get_float32_matmul_precision() == "highest" + ) + or torch.version.hip + or torch.mtia.is_available() + ): + return "'ieee'" + else: + return "'tf32'" + + +flex_attention_template = TritonTemplate( + name="flex_attention", + grid=flex_attention_grid, + source=load_flex_template("flex_attention") + + load_flex_template("utilities") + + load_flex_template("common"), +) + + +@register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None) +def flex_attention( + query, + key, + value, + subgraph, + block_mask, + scale, + kernel_options: dict[str, Any], + score_mod_other_buffers, + mask_mod_other_buffers, +): + """The main lowering for the flex_attention hop + This can currently lower to one of 3 templates: + 1. Base Triton Template + 2. Flex Decode Triton Template + 3. Cpu specific CPP template + """ + if query.get_device().type == "cpu": + return lower_cpu( + query, + key, + value, + subgraph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + # below is cuda path if device is not cpu + # tl.dot does not support embedding size less than 16 + small_dqk = V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-1], 16)) + small_dv = V.graph.sizevars.evaluate_expr(sympy.Lt(value.get_size()[-1], 16)) + if small_dqk or small_dv: + raise NotImplementedError( + f"NYI: embedding dimension of the query, key, and value must be " + f"at least 16 but got E={query.get_size()[-1]} and Ev={value.get_size()[-1]}" + ) + + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + mask_graph, + ) = block_mask + + placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("score", query.get_dtype()), + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + subgraph_buffer = build_subgraph_buffer( + placeholder_inps + list(score_mod_other_buffers), subgraph + ) + freeze_irnodes(subgraph_buffer) + + mask_graph_placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + mask_graph_buffer = build_subgraph_buffer( + mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph + ) + freeze_irnodes(mask_graph_buffer) + + kernel_options, backend = _sanitize_kernel_options_for_triton(kernel_options) + # Mark symbols in custom kernel options as static shapes and add guards. + kernel_options = { + k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v + for k, v in kernel_options.items() + } + kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) + enable_gqa = V.graph.sizevars.evaluate_expr( + sympy.Ne(query.get_size()[1], key.get_size()[1]), + ) + + can_use_decode = _use_flex_decoding( + query, kv_indices, value, kernel_options, enable_gqa + ) + use_decode = (backend == "TRITON_DECODE") or (backend == "AUTO" and can_use_decode) + + if backend == "TRITON_DECODE" and not can_use_decode: + raise RuntimeError( + "BACKEND='TRITON_DECODE' was specified but flex_decoding cannot be used for this input. " + "flex_decoding is only available for short sequence lengths with specific configurations." + ) + + if use_decode: + return create_flex_decoding_kernel( + query, + key, + value, + block_mask, + scale, + kernel_options, + subgraph_buffer, + mask_graph_buffer, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + ( + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ) = maybe_realize( + [ + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ] + ) + + if _use_flex_flash_attention( + subgraph, + mask_graph, + kernel_options, + num_score_mod_placeholders=len(placeholder_inps), + backend=backend, + ): + return create_flex_flash_attention_kernel( + query, + key, + value, + block_mask, + scale, + kernel_options, + subgraph_buffer, + mask_graph_buffer, + score_mod_other_buffers, + mask_mod_other_buffers, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + mask_graph=mask_graph, + subgraph=subgraph, + ) + + score_mod_other_buffers = maybe_realize(score_mod_other_buffers) + mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) + + freeze_irnodes(score_mod_other_buffers) + freeze_irnodes(mask_mod_other_buffers) + + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + ) + assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_q, 0)), ( + "Query length must be greater than 0" + ) + assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_kv, 0)), ( + "Key length must be greater than 0" + ) + + B = Bq + + if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: + kernel_options.setdefault("IS_DIVISIBLE", False) + else: + kernel_options.setdefault("IS_DIVISIBLE", True) + + # NB it is okay that the v_head_dim is different + # We are using these to match fill order of the output. + q_strides = query.get_stride() + # Construct output layout with strides matching the query. + out_size = [B, Hq, seq_len_q, v_head_dim] + out_strides = infer_dense_strides(out_size, q_strides) + + layout = FixedLayout( + query.get_device(), + query.get_dtype(), + [B, Hq, seq_len_q, v_head_dim], + stride=[sympy.sympify(s) for s in out_strides], + ) + # see NOTE:[TritonTemplates with multiple outputs] + logsumexp_shape = [B, Hq, seq_len_q] + logsumexp = empty_strided( + logsumexp_shape, + None, + dtype=torch.float32, # The logsumexp is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + max_scores = empty_strided( + logsumexp_shape, # Same shape as logsumexp + None, + dtype=torch.float32, # The max scores are always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + kernel_options.setdefault("SM_SCALE", scale) + + # Determine GQA broadcast factor. + gqa_shared_heads = Hq // Hkv + kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) + + # Inside of Triton kernel, only apply partial masking if partial blocks are computed. + # full_kv_num_blocks is None if partial blocks are not computed + has_full_blocks = full_kv_num_blocks is not None + kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks) + if not has_full_blocks: + full_kv_num_blocks, full_kv_indices = ( + empty(0, device=query.get_device()) for _ in range(2) + ) + + set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars) + + choices: list[Any] = [] + + dtype = query.get_dtype() + head_dim = V.graph.sizevars.guard_int(query.get_size()[-1]) + configs: list[FlexConfig] = V.choices.get_flex_attention_fwd_configs( + head_dim, dtype, query.get_device().type + ) + + # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE) + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE) + + # Note, we don't need to pass in the captured buffers explicitly + # because they're implicitly added by the score_mod function + # We do need to explicitly pass it in for autotuning though. + original_kernel_options = kernel_options.copy() + # Default config for warp specialization + num_consumer_groups, num_buffers_warp_spec = 0, 0 + + for conf in configs: + cur_kernel_options = original_kernel_options.copy() + # Performance tuning + # Triton parameters + # Remove prefix for forward kernels options and delete backward kernel options. + for k in list(cur_kernel_options.keys()): + if k.startswith("fwd_"): + v = cur_kernel_options.pop(k) + cur_kernel_options[k[4:]] = v + if k.startswith("bwd_"): + cur_kernel_options.pop(k) + cur_kernel_options.setdefault("num_stages", conf.num_stages) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + if cur_kernel_options.get("num_consumer_groups", False): + cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) + cur_kernel_options.setdefault( + "num_buffers_warp_spec", num_buffers_warp_spec + ) + + # USE TMA = false by default + cur_kernel_options.setdefault("USE_TMA", False) + + cur_kernel_options.setdefault("BLOCK_M", conf.block_m) + cur_kernel_options.setdefault("BLOCK_N", conf.block_n) + # Blocksparse options + cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) + cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + + if ( + cur_kernel_options["SPARSE_KV_BLOCK_SIZE"] % cur_kernel_options["BLOCK_N"] + != 0 + or cur_kernel_options["SPARSE_Q_BLOCK_SIZE"] % cur_kernel_options["BLOCK_M"] + != 0 + ): + if len(configs) == 1: + raise ValueError( + f"Q and KV block size must be divisible by BLOCK_M and BLOCK_N. We " + f"got Q_BLOCK_SIZE={cur_kernel_options['SPARSE_Q_BLOCK_SIZE']} and " + f"KV_BLOCK_SIZE={cur_kernel_options['SPARSE_KV_BLOCK_SIZE']}." + ) + continue + + # ROCm specific kernargs + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) + + error = flex_attention_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + max_scores, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ], + layout=layout, + subgraphs=[ + subgraph_buffer, + mask_graph_buffer, + ], + mutated_inputs=[ + logsumexp, + max_scores, + ], + call_sizes=query.get_size(), + **cur_kernel_options, + ) + if error is not None and len(configs) == 1: + raise error + inputs_for_autotuning = ( + [ + query, + key, + value, + logsumexp, + max_scores, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ] + + list(score_mod_other_buffers) + + list(mask_mod_other_buffers) + ) + input_gen_fns = { + 5: create_num_blocks_fake_generator(kv_indices), + 6: create_indices_fake, + 7: create_num_blocks_fake_generator(full_kv_indices), + 8: create_indices_fake, + } + + out = autotune_select_algorithm( + "flex_attention", + choices, + # Need to filter out symbols since there is an invariant + # that all input_nodes are of type IRNode + [x for x in inputs_for_autotuning if isinstance(x, torch._inductor.ir.IRNode)], + layout, + input_gen_fns=input_gen_fns, + ) + + # need subgraph inputs and outputs to analyze all symints used in flex attention + out.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + out.data.data.subgraph_outs = get_fwd_subgraph_outputs( + subgraph_buffer, mask_graph_buffer + ) + + return (out, logsumexp, max_scores) + + +# ---------------------------- Backward HOP Implementation ---------------------------- + + +@SymbolicGridFn +def flex_attention_backward_grid( + batch_size, q_heads, num_queries, d_model, kv_heads, num_key_value, meta, *, cdiv +): + """How is this kernel parallelized? + We create a grid of (ceil_div(n_queries, query_block_size) * heads_ratio + ceil_div(n_kv, kv_block_size), batch_size, kv_heads) + Currently this is only parallelizing over batch* kv_heads, but we can, and want to + parallelize over ceil_div(q_heads//kv_heads * num_key_value, key_value_block_size). + To do this will either require atomic updates to some grad values or to have a two pass kernel design. + """ + return ( + cdiv(num_queries, meta["BLOCK_M2"]) * (q_heads // kv_heads) + + cdiv(num_key_value, meta["BLOCK_N1"]), + batch_size, + kv_heads, + ) + + +flex_attention_backward_template = TritonTemplate( + name="flex_attention_backward", + grid=flex_attention_backward_grid, + source=load_flex_template("flex_backwards") + load_flex_template("utilities"), +) + + +def validate_joint_graph(joint_graph: torch.fx.Graph): + """We do some pre lowering graph checks in order to raise nicer error messages""" + for node in joint_graph.nodes: + if ( + node.op == "call_function" + and node.target is torch.ops.flex_lib.zeros_and_scatter.default + ): + for user in node.users: + if user.op != "output": + raise NotImplementedError( + "Using multiple indexing operations on the same tensor that requires gradients " + "in a score_mod function is not currently supported. " + "This typically happens when indexing the same tensor multiple times, like:\n\n" + " def score_mod(score, b, h, q_idx, kv_idx):\n" + " return score + bias[q_idx] + bias[kv_idx] # bias used twice!\n\n" + "A valid workaround is to clone() the tensors that will be indexed multiple times. For example:\n\n" + " bias1 = bias.clone()\n" + " def score_mod(score, b, h, q_idx, kv_idx):\n" + " return score + bias[q_idx] + bias1[kv_idx]\n\n" + "Note that this solution will use additional memory." + ) + return + + +@dataclass(frozen=True) +class JointOutputResult: + """Results from processing joint outputs.""" + + grad_input: ComputedBuffer + captured_grads_compute: list[ComputedBuffer] + captured_grads: list[Optional[TensorBox]] + mutated_grads: list[TensorBox] + + +def process_joint_outputs( + all_joint_outputs: SubgraphResults, num_placeholders: int +) -> JointOutputResult: + """Process joint outputs and extract various buffers needed for lowering + + Args: + all_joint_outputs: List of all the outputs from build_subgraphs + num_placeholders: The number of placeholder inputs, used to skip over unused backward compute buffers + + Returns: + JointOutputResult containing processed buffers and gradients + """ + assert isinstance(all_joint_outputs, list) + assert all_joint_outputs[0] is not None, ( + "joint_subgraph_buffer is None - this is a bug!" + ) + + joint_buffer = all_joint_outputs[0] + other_grads = all_joint_outputs[num_placeholders - 1 :] + + # outer_grads has the structure: Len(other_buffer_grads) if buffer doesn't require grad than it will be None + # We only grab the buffers that require grad for inlining into kernel + grads_compute = [buf for buf in other_grads if buf is not None] + + def get_out(buf): + if buf is None: + return None + assert isinstance(buf, ComputedBuffer) + assert buf.name is not None + return TensorBox.create(V.graph.get_buffer(buf.name)) + + grads_out = [get_out(x) for x in other_grads] + mutated_grads = [buf for buf in grads_out if buf is not None] + + return JointOutputResult( + grad_input=joint_buffer, + captured_grads_compute=grads_compute, + captured_grads=grads_out, + mutated_grads=mutated_grads, + ) + + +# TODO: We probably also need a layout constraint? +@register_lowering( + torch.ops.higher_order.flex_attention_backward, type_promotion_kind=None +) +def flex_attention_backward(*args, **kwargs): + """Lowering for the flex_attention_backward op in triton""" + ( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) = args + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + mask_graph, + ) = block_mask + + ( + query, + key, + value, + logsumexp, + grad_out, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ) = maybe_realize( + [ + query, + key, + value, + logsumexp, + grad_out, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ] + ) + + device = query.get_device() + dtype = query.get_dtype() + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + + assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + ) + + kernel_options, backend = _sanitize_kernel_options_for_triton(kernel_options) + # Mark symbols in custom kernel options as static shapes and add guards. + kernel_options = { + k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v + for k, v in kernel_options.items() + } + kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) + seq_q_divisible = V.graph.sizevars.statically_known_true(seq_len_q % 128 == 0) + seq_kv_divisible = V.graph.sizevars.statically_known_true(seq_len_kv % 128 == 0) + if seq_q_divisible and seq_kv_divisible: + kernel_options.setdefault("IS_DIVISIBLE", True) + else: + kernel_options.setdefault("IS_DIVISIBLE", False) + + fwd_placeholder_inps = [ + create_placeholder(name, dtype, device) + for name, dtype in [ + ("score", dtype), + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + fw_subgraph_buffer = build_subgraph_buffer( + fwd_placeholder_inps + list(score_mod_other_buffers), fw_graph + ) + freeze_irnodes(fw_subgraph_buffer) + + joint_placeholder_inps = fwd_placeholder_inps + [ + create_placeholder("grad_score_mod", dtype, device) + ] + # Sometimes we have weird unused nodes here + joint_graph.graph_module.graph.eliminate_dead_code() + + # It is hard to raise nice errors for some joint graphs during subgraph lowering + # This lets us do some checks before attempting to lower + validate_joint_graph(joint_graph.graph_module.graph) + + all_joint_outputs = build_subgraph_buffer( + joint_placeholder_inps + list(score_mod_other_buffers), + joint_graph, + ) + freeze_irnodes(all_joint_outputs) + + joint_outputs = process_joint_outputs( + all_joint_outputs, len(joint_placeholder_inps) + ) + + mask_graph_placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + mask_graph_buffer = build_subgraph_buffer( + mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph + ) + freeze_irnodes(mask_graph_buffer) + + if _use_flex_flash_attention_backward( + fw_graph, + mask_graph, + backend=backend, + ): + return create_flex_flash_attention_backward_kernel( + query, key, value, out, logsumexp, grad_out, scale, kernel_options + ) + + # Construct layout with stride order matching K + key_size = [Bq, Hkv, seq_len_kv, qk_head_dim] + key_strides = infer_dense_strides(key_size, key.get_stride()) + + layout_broadcasted_k = FixedLayout( + key.get_device(), + key.get_dtype(), + key_size, + stride=[sympy.sympify(s) for s in key_strides], + ) + + # Create delta which will is needed for the bwd's kernel + grad_lse_exp2 = lowerings[aten.mul](grad_logsumexp, 1 / math.log(2)) + mul_delta = lowerings[aten.mul](out, grad_out) + delta = lowerings[aten.sum](mul_delta, axis=-1) + delta = lowerings[aten.sub](delta, grad_lse_exp2) + delta = ExternKernel.require_contiguous(delta) + + grad_lse_exp2, delta = maybe_realize([grad_lse_exp2, delta]) + + # # see NOTE:[TritonTemplates with multiple outputs] + query_size = [Bq, Hq, seq_len_q, qk_head_dim] + grad_query_strides = infer_dense_strides(query_size, query.get_stride()) + grad_query = empty_strided( + query_size, + stride=[sympy.sympify(s) for s in grad_query_strides], + dtype=query.get_dtype(), + device=query.get_device(), + ) + + # Construct output layout with stride order matching value + value_size = [Bq, Hkv, seq_len_kv, v_head_dim] + value_strides = infer_dense_strides(value_size, value.get_stride()) + + broadcasted_grad_value = empty_strided( + value_size, + stride=[sympy.sympify(s) for s in value_strides], + dtype=value.get_dtype(), + device=value.get_device(), + ) + + kernel_options.setdefault("SM_SCALE", scale) + + # Determine GQA factor + gqa_shared_heads = Hq // Hkv + kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) + + # Inside of Triton kernel, only apply partial masking if partial blocks are computed. + # full_kv_num_blocks is torch.zeros([1, 1, 1]) if partial blocks are not computed. + has_full_blocks = full_kv_num_blocks is not None + kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks) + if not has_full_blocks: + full_kv_num_blocks, full_kv_indices, full_q_num_blocks, full_q_indices = ( + empty(0, device=query.get_device()) for _ in range(4) + ) + + set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars) + + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE) + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE) + + choices: list[Any] = [] + + dtype = query.get_dtype() + head_dim = V.graph.sizevars.guard_int(query.get_size()[-1]) + configs: list[FlexBwDConfig] = V.choices.get_flex_attention_bwd_configs( + head_dim, dtype, query.get_device().type + ) + + # Default config for warp specialization + num_consumer_groups, num_buffers_warp_spec = 0, 0 + + original_kernel_options = kernel_options.copy() + + for conf in configs: + if ( + SPARSE_KV_BLOCK_SIZE % conf.block_n1 != 0 + or SPARSE_Q_BLOCK_SIZE % conf.block_m1 != 0 + or SPARSE_KV_BLOCK_SIZE % conf.block_n2 != 0 + or SPARSE_Q_BLOCK_SIZE % conf.block_m2 != 0 + ): + continue + + # Performance tuning + # Triton heuristics + cur_kernel_options = original_kernel_options.copy() + # Remove prefix for backward kernels options and delete forward kernel options. + for k in list(cur_kernel_options.keys()): + if k.startswith("bwd_"): + v = cur_kernel_options.pop(k) + cur_kernel_options[k[4:]] = v + if k.startswith("fwd_"): + cur_kernel_options.pop(k) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + cur_kernel_options.setdefault("num_stages", conf.num_stages) + + if cur_kernel_options.get("num_consumer_groups", False): + cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) + cur_kernel_options.setdefault( + "num_buffers_warp_spec", num_buffers_warp_spec + ) + + cur_kernel_options.setdefault("BLOCK_M1", conf.block_m1) + cur_kernel_options.setdefault("BLOCK_N1", conf.block_n1) + cur_kernel_options.setdefault("BLOCK_M2", conf.block_m2) + cur_kernel_options.setdefault("BLOCK_N2", conf.block_n2) + + # Blocksparse options + cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) + cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + + # ROCm specific kernargs + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) + + flex_attention_backward_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + delta, + grad_out, + grad_query, + broadcasted_grad_value, + kv_num_blocks, + kv_indices, + q_num_blocks, + q_indices, + full_kv_num_blocks, + full_kv_indices, + full_q_num_blocks, + full_q_indices, + ], + layout=layout_broadcasted_k, # We use store_output only for grad_key + subgraphs=[ + fw_subgraph_buffer, + joint_outputs.grad_input, + mask_graph_buffer, + joint_outputs.captured_grads_compute, + ], + mutated_inputs=[ + grad_query, + broadcasted_grad_value, + *joint_outputs.mutated_grads, + ], + call_sizes=query.get_size() + key.get_size()[1:3], + **cur_kernel_options, + ) + inputs_for_autotuning = ( + # pyrefly: ignore [unsupported-operation] + [ + query, + key, + value, + logsumexp, + delta, + grad_out, + grad_query, + broadcasted_grad_value, + kv_num_blocks, + kv_indices, + q_num_blocks, + q_indices, + full_kv_num_blocks, + full_kv_indices, + full_q_num_blocks, + full_q_indices, + ] + + list(score_mod_other_buffers) + + list(mask_mod_other_buffers) + + joint_outputs.mutated_grads + ) + input_gen_fns = { + 8: create_num_blocks_fake_generator(kv_indices), # kv_num_blocks + 9: create_indices_fake, + 10: create_num_blocks_fake_generator(q_indices), # q_num_blocks + 11: create_indices_fake, + 12: create_num_blocks_fake_generator(full_kv_indices), # full_kv_num_blocks + 13: create_indices_fake, + 14: create_num_blocks_fake_generator(full_q_indices), # full_q_num_blocks + 15: create_indices_fake, + } + + broadcasted_grad_key = autotune_select_algorithm( + "flex_attention_backward", + choices, + [x for x in inputs_for_autotuning if isinstance(x, torch._inductor.ir.IRNode)], + layout_broadcasted_k, + input_gen_fns=input_gen_fns, + ) # [Bq, Hkv, seq_len_kv, k_head_dim] + + # need subgraph inputs and outputs to analyze all symints used in flex attention + broadcasted_grad_key.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + broadcasted_grad_key.data.data.subgraph_outs = get_bwd_subgraph_outputs( + fw_subgraph_buffer, mask_graph_buffer, joint_outputs + ) + + if V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv)): + grad_key = broadcasted_grad_key + grad_value = broadcasted_grad_value + else: + assert V.graph.sizevars.evaluate_expr(sympy.Gt(Bq, 1) & sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. " + f"Got Bq={V.graph.sizevars.evaluate_expr(Bq)} " + f"and Bkv={V.graph.sizevars.evaluate_expr(Bkv)}" + ) + grad_key = lowerings[aten.sum](broadcasted_grad_key, axis=0, keepdims=True) + grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True) + + return (grad_query, grad_key, grad_value, tuple(joint_outputs.captured_grads)) + + +def get_bwd_subgraph_outputs( + subgraph_buffer: SubgraphResults, + mask_graph_buffer: SubgraphResults, + joint_outputs: JointOutputResult, +) -> list[Optional[Union[ComputedBuffer, TensorBox]]]: + subgraph_buffer = ( + # pyrefly: ignore [bad-assignment] + subgraph_buffer if isinstance(subgraph_buffer, Sequence) else [subgraph_buffer] + ) + mask_graph_buffer = ( + # pyrefly: ignore [bad-assignment] + mask_graph_buffer + if isinstance(mask_graph_buffer, Sequence) + else [mask_graph_buffer] + ) + joint_output_buffers = [ + joint_outputs.grad_input, + *joint_outputs.captured_grads_compute, + *joint_outputs.captured_grads, + *joint_outputs.mutated_grads, + ] + + # pyrefly: ignore [not-iterable] + return [*subgraph_buffer, *mask_graph_buffer, *joint_output_buffers] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_cpu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..6987e64546fe3503b6a7b7e9bb1a44e72fbb2660 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_cpu.py @@ -0,0 +1,339 @@ +# mypy: allow-untyped-defs +"""CPU-specific implementations for flex attention""" + +import copy +import os +import sys +from typing import Any + +import sympy + +import torch +from torch._inductor.virtualized import V +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.value_ranges import ValueRanges + +from ...codegen.cpp_flex_attention_template import CppFlexAttentionTemplate +from ...ir import Buffer, FixedLayout, TensorBox +from ...select_algorithm import autotune_select_algorithm +from .common import ( + build_subgraph_buffer, + build_subgraph_module_buffer, + contiguous_last_dim, + create_placeholder, + get_fwd_subgraph_outputs, + infer_dense_strides, + maybe_realize, +) + + +def check_cpu_supported(): + requires_avx2_on_cpu = ( + torch.cpu._is_avx2_supported() and os.getenv("ATEN_CPU_CAPABILITY") != "default" + ) + supported = ( + requires_avx2_on_cpu + and not torch.xpu.is_available() + and sys.platform != "darwin" + ) + return supported + + +def lower_cpu( + query, + key, + value, + subgraph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, +): + """CPP based template for flex attention for x86 CPUs""" + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + mask_graph, + ) = block_mask + + if kernel_options["OUTPUT_LOGSUMEXP"]: + raise NotImplementedError( + "torch.compile on CPU only supports inference and `return_lse` is not supported yet." + ) + if not check_cpu_supported(): + raise NotImplementedError( + "torch.compile on current platform is not supported for CPU." + ) + + fake_buffers: list[Buffer] = [] # noqa: F821 + + # [Note] Handle the case where the split sizes are not statically known. + # The value of cur_qSplitSize and cur_kvSplitSize are decided during runtime. + # We use symbols to represent them during the compilation here. + # They'll be replaced by the string "cur_qSplitSize" and "cur_kvSplitSize" in + # the modification function of the CppFlexAttentionTemplate class. + cur_qSplitSize = V.graph.sizevars.shape_env.create_unbacked_symint().node.expr + cur_kvSplitSize = V.graph.sizevars.shape_env.create_unbacked_symint().node.expr + shape_env = V.graph.sizevars.shape_env + + # We don't know the concrete value of cur_qSplitSize and cur_kvSplitSize during the compilation. + # Mark symbols > 1 to ensure broadcasting is always applied. + # This avoids treating them as equal when `eq(var, 1)` is evaluated in `broadcast_symbolic_shapes`. + shape_env.var_to_range[cur_qSplitSize] = ValueRanges(2, int_oo) + shape_env.var_to_range[cur_kvSplitSize] = ValueRanges(2, int_oo) + + score_dtype = torch.float + placeholder_inps = [ + create_placeholder(name, dtype, query.get_device(), size) + for name, dtype, size in [ + ("score", score_dtype, [cur_qSplitSize, cur_kvSplitSize]), + ("b", torch.int64, []), + ("h", torch.int64, []), + ("q_idx", torch.int64, [cur_qSplitSize, 1]), + ("kv_idx", torch.int64, [1, cur_kvSplitSize]), + ] + ] + subgraph_buffer = build_subgraph_buffer( + placeholder_inps + list(score_mod_other_buffers), subgraph + ) + if subgraph_buffer is not None: + if isinstance(subgraph_buffer, list): + for _buf in subgraph_buffer: + if _buf is not None: + _buf.freeze_layout() + else: + subgraph_buffer.freeze_layout() + mask_graph_placeholder_inps = [ + create_placeholder(name, dtype, query.get_device(), size) + for name, dtype, size in [ + ("score", score_dtype, [cur_qSplitSize, cur_kvSplitSize]), + ("b", torch.int64, []), + ("h", torch.int64, []), + ("q_idx", torch.int64, [cur_qSplitSize, 1]), + ("kv_idx", torch.int64, [1, cur_kvSplitSize]), + ] + ] + + # The original mask_graph works on a scalar and only includes + # the logic of calculating the mask value. + # We need to add the logic of applying the mark to the qk_data tensor + # into the graph for the later codegen of this part. + # Example: + # mask_graph: + # def mask_fn(b, h, q_idx, kv_idx): + # mask = q_idx >= kv_idx + # return mask + # The converted_mask_graph should be: + # def converted_mask_fn(qk_data, b, h, q_idx, kv_idx): + # mask = q_idx >= kv_idx + # qk_data = torch.where(mask, qk_data, torch.full_like(qk_data, -float("inf"))) + # return qk_data + def convert_mask_graph_module(mask_graph): + gm = copy.deepcopy(mask_graph.graph_module) + graph = gm.graph + # Add qk_data as the first input + with graph.inserting_before(next(iter(graph.nodes))): + qk_data_node = graph.placeholder("qk_data") + + # Find the node that returns the mask + output_node = None + for node in graph.nodes: + if node.op == "output": + output_node = node + break + + # Get the mask node + assert output_node is not None + mask_node = output_node.args[0] + + size_node = [cur_qSplitSize, cur_kvSplitSize] + # Create a new node for torch.full + with graph.inserting_after(mask_node): + full_node = graph.call_function( + torch.full, + args=(size_node, -float("inf")), + kwargs={"dtype": score_dtype}, + ) + + # Create a new node for torch.where + with graph.inserting_after(full_node): + where_node = graph.call_function( + torch.ops.aten.where, args=(mask_node, qk_data_node, full_node) + ) + + # Update the output node to return the result of torch.where + output_node.args = (where_node,) + + graph.lint() + converted = torch.fx.GraphModule(gm, graph) + return converted + + converted_mask_graph_module = convert_mask_graph_module(mask_graph) + + mask_graph_buffer = build_subgraph_module_buffer( + mask_graph_placeholder_inps + list(mask_mod_other_buffers), + converted_mask_graph_module, + ) + + # Clear the pending fresh unbacked symbols that are created for cur_qSplitSize and cur_kvSplitSize in the current kernel. + pending = V.graph.sizevars.shape_env.pending_fresh_unbacked_symbols + V.graph.sizevars.shape_env.pending_fresh_unbacked_symbols = [ + x for x in pending if x not in (cur_qSplitSize, cur_kvSplitSize) + ] + + buffer_list = ( + placeholder_inps + + list(score_mod_other_buffers) + + mask_graph_placeholder_inps + + list(mask_mod_other_buffers) + ) + for item in buffer_list: + if isinstance(item, TensorBox): + fake_buffers.append(item.data.data) # type: ignore[attr-defined] + + # CPU kernel requires last dim to be contiguous + query, key, value = map(contiguous_last_dim, [query, key, value]) + + ( + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ) = maybe_realize( + [ + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ] + ) + + if len(OrderedSet([query.get_name(), key.get_name(), value.get_name()])) != 3: + raise NotImplementedError( + "Unsupported for now if query, key, value are the same buffer." + ) + if query.get_dtype() not in [torch.float, torch.bfloat16, torch.float16]: + raise NotImplementedError( + "`torch.float` , `torch.float16` and `torch.bfloat16` are supported in FlexAttention for CPU device. " + f"Found input tensors are `{query.get_dtype()}`." + ) + score_mod_other_buffers = maybe_realize(score_mod_other_buffers) + mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + B = Bq + + # Construct output layout with strides matching the query. + out_size = [B, Hq, seq_len_q, v_head_dim] + out_strides = infer_dense_strides(out_size, query.get_stride()) + + layout = FixedLayout( + query.get_device(), + query.get_dtype(), + [B, Hq, seq_len_q, v_head_dim], + stride=[sympy.sympify(s) for s in out_strides], + ) + _choices: list[Any] = [] + input_nodes = [query, key, value, kv_num_blocks, kv_indices] + if not full_kv_num_blocks: + no_full_kv_block = True + else: + no_full_kv_block = False + input_nodes += [full_kv_num_blocks] + input_nodes += [full_kv_indices] + has_other_buffer = False + kernel_input_name_to_buffer = {} + if score_mod_other_buffers or mask_mod_other_buffers: + has_other_buffer = True + + for prefix, buffers in [ + ("score_others", score_mod_other_buffers), + ("mask_others", mask_mod_other_buffers), + ]: + kernel_input_name_to_buffer.update( + {f"{prefix}_{i}": buf for i, buf in enumerate(buffers)} + ) + input_nodes += [ + value + for value in kernel_input_name_to_buffer.values() + if not isinstance(value, sympy.Symbol) + ] + + skip_mask_score = kernel_options.get("SKIP_MASK_SCORE", False) + # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE) + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE) + assert V.graph.sizevars.evaluate_expr( + sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE)) + ), ( + "Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask." + ) + assert V.graph.sizevars.evaluate_expr( + sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE)) + ), ( + "KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask." + ) + CppFlexAttentionTemplate.add_choices( + choices=_choices, + input_nodes=input_nodes, + layout=layout, + scale=scale, + score_mod=None if skip_mask_score else subgraph_buffer, + mask_mod=None if skip_mask_score else mask_graph_buffer, + kv_block_size=SPARSE_KV_BLOCK_SIZE, + q_block_size=SPARSE_Q_BLOCK_SIZE, + has_other_buffer=has_other_buffer, + no_full_kv_block=no_full_kv_block, + fake_buffers=fake_buffers, + len_score_other=len(score_mod_other_buffers), + len_mask_other=len(mask_mod_other_buffers), + kernel_input_name_to_buffer=kernel_input_name_to_buffer, + block_vars=(cur_qSplitSize, cur_kvSplitSize), + ) + inputs_for_autotuning = [ + query, + key, + value, + ] + res = autotune_select_algorithm( + "flex_attention", + _choices, + inputs_for_autotuning, + layout, + ) + + # need subgraph inputs and outputs to analyze all symints used in flex attention + res.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + res.data.data.subgraph_outs = get_fwd_subgraph_outputs( + subgraph_buffer, mask_graph_buffer + ) + + return (res,) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_decoding.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..37113a1d82a8455eca455b6d5e077fa06b952f5a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_decoding.py @@ -0,0 +1,436 @@ +# mypy: allow-untyped-defs +"""Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)""" + +from typing import Any + +import sympy + +import torch +from torch._inductor.virtualized import V + +from ... import ir +from ...ir import FixedLayout, FlexibleLayout +from ...lowering import empty, empty_strided, lowerings +from ...runtime.runtime_utils import is_power_of_2, next_power_of_2 +from ...select_algorithm import ( + autotune_select_algorithm, + SymbolicGridFn, + TritonTemplate, +) +from .common import ( + create_indices_fake, + create_num_blocks_fake_generator, + freeze_irnodes, + get_fwd_subgraph_outputs, + load_flex_template, + maybe_realize, + set_head_dim_values, +) + + +aten = torch.ops.aten +prims = torch.ops.prims + + +def _use_flex_decoding(query, kv_indices, value, kernel_options, enable_gqa) -> bool: + """Decide which kernel to use, return true if use flex decoding kernel. + Note: + Since the number of splits is calculated based of the number of batch and head dims + we need to ensure that the batch and head dims are statically known. Otherwise we just + use the main flex_attention kernel. + """ + force_flex = kernel_options.get("FORCE_USE_FLEX_ATTENTION", False) + short_query_length = V.graph.sizevars.evaluate_expr( + sympy.Lt(query.get_size()[-2], 128) + ) + non_zero_length = V.graph.sizevars.evaluate_expr(sympy.Gt(query.get_size()[-2], 0)) + static_batch = isinstance(query.get_size()[0], (int, sympy.Integer)) + static_num_heads = isinstance(query.get_size()[1], (int, sympy.Integer)) + if enable_gqa: + # in the current flex decoding triton kernel, grouped query heads for the + # same kv head are handled by the same block. So it's hard to support different + # kv num blocks for grouped query heads. We just fall back to main flex_attention + # kernel where each query head is handled by a separate block. + valid_block_mask_num_heads = V.graph.sizevars.evaluate_expr( + sympy.Eq(kv_indices.get_size()[1], 1) + ) + else: + valid_block_mask_num_heads = V.graph.sizevars.evaluate_expr( + sympy.Or( + sympy.Eq(kv_indices.get_size()[1], 1), + sympy.Eq(kv_indices.get_size()[1], query.get_size()[1]), + ) + ) + + Hq = query.get_size()[1] + Hkv = value.get_size()[1] + ratio = Hq // Hkv + + pw_of_two = V.graph.sizevars.guard_or_false( + sympy.And(sympy.Gt(ratio, 0), sympy.Eq(ratio & (ratio - 1), 0)) + ) + + return ( + not force_flex + and not kernel_options.get("OUTPUT_MAX", False) + and short_query_length + and static_batch + and static_num_heads + and non_zero_length + and valid_block_mask_num_heads + and pw_of_two + ) + + +@SymbolicGridFn +def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, meta): + """How is this kernel parallelized? + We create a grid of (batch_size * kv_heads, SPLIT_KV, 1) + Each block is responsible for iterating over blocks of keys and values calculating + the local output for their tile of keys and values over all full length of query. + groups of SPLIT_KV blocks then combine their output to produce the final result. + """ + + return (batch_size * kv_heads, meta["SPLIT_KV"], 1) + + +flex_decoding_template = TritonTemplate( + name="flex_decoding", + grid=flex_decoding_grid, + source=load_flex_template("flex_decode") + + load_flex_template("utilities") + + load_flex_template("common"), +) + + +def get_split_k(B: int, H: int, Mk: int) -> int: + if torch.xpu.is_available(): + num_SM = torch.xpu.get_device_properties("xpu").gpu_subslice_count + else: + num_SM = torch.cuda.get_device_properties("cuda").multi_processor_count + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + assert isinstance(bh, (int, sympy.Integer)), "B and H must be concrete integers" + split_k = num_SM // bh * 2 # Each SM should at least get one block. + # TODO: workload evening at runtime for splits fully masked out. + # Before we have runtime workload evening, assign 2 splits per SM. + split_k = max(split_k, 1) + + return split_k + + +def create_flex_decoding_kernel(*args, **kwargs): + """Flex decode lowering that is optimized for small Q_LEN and GQA packing""" + ( + query, + key, + value, + block_mask, + scale, + kernel_options, + score_mod_subgraph, + mask_mod_subgraph, + score_mod_other_buffers, + mask_mod_other_buffers, + ) = args + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, # full_kv_num_blocks, + full_kv_indices, # full_kv_indices, + _, # q_num_blocks + _, # q_indices + _, # full_q_num_blocks, + _, # full_q_indices, + _, # SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + _, + ) = block_mask + + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + + assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + ) + + B = Bq + kernel_options = dict(kernel_options) + # Mark symbols in custom kernel options as static shapes and add guards. + kernel_options = { + k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v + for k, v in kernel_options.items() + } + + seq_q_divisible = V.graph.sizevars.statically_known_true(seq_len_q % 128 == 0) + seq_kv_divisible = V.graph.sizevars.statically_known_true(seq_len_kv % 128 == 0) + if seq_q_divisible and seq_kv_divisible: + kernel_options.setdefault("IS_DIVISIBLE", True) + else: + kernel_options.setdefault("IS_DIVISIBLE", False) + + # Calculate GQA head sharing + gqa_shared_heads = Hq // Hkv + if not is_power_of_2(gqa_shared_heads): + raise ValueError( + "Number of shared query heads sharing the same KV head must be power of 2. " + ) + kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) + + # Determine if there are "full" blocks where we only need to apply score_mod, and can skip mask_mod + has_full_blocks = full_kv_num_blocks is not None + kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks) + if not has_full_blocks: + # Create a plackeholder full block list in case it is empty + full_kv_num_blocks, full_kv_indices = ( + empty(0, device=query.get_device()) for _ in range(2) + ) + + ( + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ) = maybe_realize( + [ + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ] + ) + score_mod_other_buffers = maybe_realize(score_mod_other_buffers) + mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) + + freeze_irnodes(score_mod_other_buffers) + freeze_irnodes(mask_mod_other_buffers) + + choices: list[Any] = [] + dtype = key.get_dtype() + head_dim = V.graph.sizevars.guard_int(key.get_size()[-1]) + configs = V.choices.get_flex_decode_configs( + head_dim, dtype, query.get_device().type + ) + + # TODO: fix autotuning. + + kernel_options.setdefault("SM_SCALE", scale) + kernel_options.setdefault("SPLIT_KV", get_split_k(B, Hkv, seq_len_kv)) + MAX_SPLIT_KV = kernel_options["SPLIT_KV"] + + # create config dependent intermediate buffers + buf_ACC_shape = [B, MAX_SPLIT_KV, Hq, seq_len_q, v_head_dim] + buf_ML_shape = buf_ACC_shape[:-1] + buf_M = empty_strided( + buf_ML_shape, + None, + dtype=torch.float32, # The rowmax is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + buf_L = empty_strided( + buf_ML_shape, + None, + dtype=torch.float32, # The intermediate sumexp is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + + layout_acc = FixedLayout( + query.get_device(), + torch.float32, + buf_ACC_shape, + FlexibleLayout.contiguous_strides(buf_ACC_shape), + ) + + set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars) + + kernel_options.setdefault( + "BLOCK_M", + ( + # m + # if V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 0)) + # else # Always use a BLOCK_M > 16 before Triton fix https://github.com/triton-lang/triton/pull/4061 is in pin + max( + next_power_of_2( + V.graph.sizevars.size_hint( + seq_len_q, + fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type] + ) + * gqa_shared_heads + ), + 1 if torch.xpu.is_available() else 16, + ) + ), + ) + + query = ir.ExternKernel.realize_input(query) + stride_b, stride_hq, stride_seq_len_q, stride_qk_head_dim = query.get_stride() + + # Reshape query for GQA: [B, Hq, Mq, D] -> [B, Hkv, G, Mq, D] + gqa_query_shape = (B, Hkv, gqa_shared_heads, seq_len_q, qk_head_dim) + gqa_query_stride = ( + stride_b, + stride_hq * gqa_shared_heads, + stride_hq, + stride_seq_len_q, + stride_qk_head_dim, + ) + query = lowerings[aten.as_strided](query, gqa_query_shape, gqa_query_stride) + + V.graph.sizevars.check_leq( + seq_len_q * gqa_shared_heads, sympy.Integer(kernel_options["BLOCK_M"]) + ) + + kernel_options.setdefault( + "SAFE_M_BOUNDARY", + ((seq_len_q * gqa_shared_heads) % kernel_options["BLOCK_M"]) == 0, + ) + # TODO: This feels sketchy + kernel_options.setdefault("SAFE_N_BOUNDARY", True) + # Mark SPARSE_KV_BLOCK_SIZE as static shapes and add guards. + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE) + + original_kernel_options = kernel_options.copy() + # Note, we don't need to pass in the captured buffers explicitly + # because they're implicitly added by the score_mod function + # We do need to explicitly pass it in for autotuning though. + + # Default config for warp specialization + num_consumer_groups, num_buffers_warp_spec = 0, 0 + + for conf in configs: + if SPARSE_KV_BLOCK_SIZE % conf.block_n != 0: + continue + + cur_kernel_options = original_kernel_options.copy() + # Remove prefix for forward kernels options and delete backward kernel options. + for k in list(cur_kernel_options.keys()): + if k.startswith("fwd_"): + v = cur_kernel_options.pop(k) + cur_kernel_options[k[4:]] = v + if k.startswith("bwd_"): + cur_kernel_options.pop(k) + # Performance tuning + cur_kernel_options.setdefault("BLOCK_N", conf.block_n) + cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + cur_kernel_options.setdefault("num_stages", conf.num_stages) + + if cur_kernel_options.get("num_consumer_groups", False): + cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) + cur_kernel_options.setdefault( + "num_buffers_warp_spec", num_buffers_warp_spec + ) + + # Set default to False + cur_kernel_options.setdefault("USE_TMA", False) + + # Add ROCm-specific parameters if they exist in the config + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) + + flex_decoding_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + buf_M, + buf_L, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ], + layout=layout_acc, + subgraphs=[ + score_mod_subgraph, + mask_mod_subgraph, + ], + mutated_inputs=[buf_M, buf_L], + call_sizes=query.get_size(), + **cur_kernel_options, + ) + + filtered_score_mod_buffers = [ + buf for buf in score_mod_other_buffers if not isinstance(buf, sympy.Symbol) + ] + filtered_mask_mod_buffers = [ + buf for buf in mask_mod_other_buffers if not isinstance(buf, sympy.Symbol) + ] + + inputs_for_flex_decoding = ( + # pyrefly: ignore [unsupported-operation] + [ + query, + key, + value, + buf_M, + buf_L, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ] + + filtered_score_mod_buffers + + filtered_mask_mod_buffers + ) + + input_gen_fns = { + 5: create_num_blocks_fake_generator(kv_indices), + 6: create_indices_fake, + 7: create_num_blocks_fake_generator(full_kv_indices), + 8: create_indices_fake, + } + + buf_ACC = autotune_select_algorithm( + "flex_decoding", + choices, + inputs_for_flex_decoding, + layout_acc, + input_gen_fns=input_gen_fns, + ) + + # need subgraph inputs and outputs to analyze all symints used in flex attention + buf_ACC.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + buf_ACC.data.data.subgraph_outs = get_fwd_subgraph_outputs( + score_mod_subgraph, mask_mod_subgraph + ) + + # Reduction + + g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0] + # See [Note] Handle fully masked out rows: + # g_M Is the global max among split kv blocks. + masked_rows = lowerings[aten.eq](g_M, -float("inf")) + adj_M = lowerings[aten.sub](buf_M, g_M) + adj_M = lowerings[aten.where](masked_rows, 0, adj_M) + alpha = lowerings[aten.exp2](adj_M) + + buf_L = lowerings[aten.mul](buf_L, alpha) + g_L = lowerings[aten.sum](buf_L, axis=1) + masked_rows_squeezed = lowerings[aten.squeeze](masked_rows, dim=1) + g_L = lowerings[aten.where](masked_rows_squeezed, 1.0, g_L) + logsumexp = lowerings[aten.log2](g_L) + logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=1)) + + alpha_unseq = lowerings[aten.unsqueeze](alpha, 4) + buf_ACC = lowerings[aten.mul](buf_ACC, alpha_unseq) + output = lowerings[aten.sum](buf_ACC, axis=1) + L_unseq = lowerings[aten.unsqueeze](g_L, 3) + output = lowerings[aten.div](output, L_unseq) + output = lowerings[prims.convert_element_type](output, query.get_dtype()) + + return ( + output, + logsumexp, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_flash_attention.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_flash_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..05d1290f0ab49f55dbe2b4ed331f3f408c772831 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/flex/flex_flash_attention.py @@ -0,0 +1,491 @@ +# mypy: allow-untyped-defs +"""Call into flash-attention 4 for flexattention""" + +import functools +import importlib +from collections.abc import Callable, Sequence +from contextlib import contextmanager +from typing import Any, Literal, Optional + +import sympy +from sympy import Expr, Integer + +import torch +from torch.fx import GraphModule +from torch.utils._sympy.functions import Identity + +from ...ir import FixedLayout, ShapeAsConstantBuffer, Subgraph, TensorBox +from ...lowering import empty_strided +from .common import infer_dense_strides, load_flex_template, SubgraphResults + + +aten = torch.ops.aten +prims = torch.ops.prims + + +@functools.lru_cache(maxsize=1) +def ensure_flash_available() -> bool: + """Check if flash-attn is importable; cache the result for reuse. + + Call ensure_flash_available.cache_clear() after installing flash-attn + in the same interpreter to retry the import. + """ + try: + return importlib.util.find_spec("flash_attn.cute") is not None # type: ignore[attr-defined] + except ImportError: + return False + + +from ...codegen.cutedsl.cutedsl_template import CuteDSLTemplate + + +flash_attention_cutedsl_template = CuteDSLTemplate( + name="flash_attention_cutedsl", source=load_flex_template("flash_attention") +) +flash_attention_backward_cutedsl_template = CuteDSLTemplate( + name="flash_attention_backward_cutedsl", + source=load_flex_template("flash_attention_backward"), +) + + +def _fixed_indexer_cute( + size: Sequence[int], + stride: Optional[Sequence[int]] = None, + offset: Expr = Integer(0), +) -> Callable[[Sequence[Expr]], Expr]: + """ + Colexicographic indexer for CuteDSL - matches CuTe's coordinate interpretation. + + CuTe interprets linear indices in colexicographic (column-major) order, + whereas Inductor's default _fixed_indexer uses lexicographic (row-major) order. + + For size=[4, 128] with index=[b, q_idx]: + - Lexicographic: b*128 + q_idx*1 + - Colexicographic: b*1 + q_idx*2 + + CuTe then applies the tensor's actual memory strides to get the correct offset. + """ + + def indexer(index: Sequence[Expr]) -> Expr: + assert offset == Integer(0), "Offset not supported for colexicographic indexing" + if not index: + return Integer(0) + + result = index[0] + runner = size[0] + + for idx, sz in zip(index[1:], size[1:], strict=True): + result = result + runner * Identity(idx) + runner = runner * sz + + return result + + return indexer + + +@contextmanager +def patch_fixed_layout_indexer_for_cutedsl(): + """ + Temporarily swap FixedLayout.make_indexer so CuteDSL sees colexicographic indexing. + + Note [CuteDSL indexer patch]: + Flex flash attention only supports a limited set of IR ops (pointwise, reads, no stores), + so temporarily changing the indexing order is safe for the kernels we emit today. + TODO(dynamic shapes): Reconfirm once flex flash attention supports dynamic shapes. + """ + original_make_indexer = FixedLayout.make_indexer + + def cutedsl_make_indexer(self): + return _fixed_indexer_cute(self.size, self.stride, self.offset) + + FixedLayout.make_indexer = cutedsl_make_indexer # type: ignore[assignment] + try: + yield + finally: + FixedLayout.make_indexer = original_make_indexer # type: ignore[assignment] + + +def wrap_choice_render_with_cutedsl_indexer(choice: Any) -> None: + """ + Wrap a template choice's kernel render to apply CuteDSL indexer patching. + + See Note [CuteDSL indexer patch]: + This wrapper allows the template to construct its closures normally, then + scopes the indexer patch to the actual render call that emits the kernel. + This ensures CuteDSL templates see colexicographic indexing while preserving + the template's setup logic. + """ + original_make_kernel_render = choice.make_kernel_render + + def make_kernel_render_with_patch(*args, **kwargs): + render_kernel, render = original_make_kernel_render(*args, **kwargs) + # Let the template construct its closures, then scope the indexer patch + # to the actual render call that emits the kernel + render_with_patch = patch_fixed_layout_indexer_for_cutedsl()(render) + return render_kernel, render_with_patch + + choice.make_kernel_render = make_kernel_render_with_patch + + +def input_buffers_require_grads(graph_module, num_score_mod_placeholders: int): + """Check if any of the input buffers (beyond the score mod placeholders) require gradients.""" + inputs = [] + for node in graph_module.graph.nodes: + if node.op == "placeholder": + inputs.append(node) + if len(inputs) <= num_score_mod_placeholders: + return False + + def requires_grad(n): + tensor_meta = n.meta.get("tensor_meta") + return tensor_meta.requires_grad if tensor_meta is not None else False + + return any(requires_grad(n) for n in inputs[num_score_mod_placeholders:]) + + +def is_trivial_score_graph(graph_module: GraphModule) -> bool: + """Backwards currently doesn't support score_mods, match against identity""" + graph = graph_module.graph + nodes = list(graph.nodes) + placeholders = [n for n in nodes if n.op == "placeholder"] + output = [n for n in nodes if n.op == "output"] + assert len(output) == 1, "Got graph w/ multiple outputs" + output_val = output[0].args[0] + # The identity graph just sends the score straight through + return output_val == placeholders[0] + + +def is_trivial_mask_graph(graph_module: GraphModule) -> bool: + """Mask graph is trivial when it only gates via the default full op.""" + graph = graph_module.graph + nodes = list(graph.nodes) + placeholders = [n for n in nodes if n.op == "placeholder"] + output = [n for n in nodes if n.op == "output"] + assert len(output) == 1, "Got graph w/ multiple outputs" + output_val = output[0].args[0] + + # mask mod graph is empty if we have 4 inputs and full_default output + return len(placeholders) == 4 and output_val.target is torch.ops.aten.full.default + + +@functools.lru_cache(maxsize=1) +def _supports_nontrivial_mask_graphs() -> bool: + """Currently only supported on Hopper (SM90) GPUs.""" + return torch.cuda.get_device_capability()[0] in [9, 10] + + +def _can_use_flex_flash_attention( + subgraph: Subgraph, mask_graph: Subgraph, num_score_mod_placeholders: int +) -> tuple[bool, str]: + """Check if flex flash attention can be used for the given inputs. + + Returns: + tuple: (can_use, reason) where reason explains why it can't be used if can_use is False + """ + if not ensure_flash_available(): + return False, "CUTE flash attention library is not available" + + if input_buffers_require_grads(subgraph.graph_module, num_score_mod_placeholders): + return ( + False, + "Input buffers require gradients (not supported by flash attention)", + ) + mask_trivial = is_trivial_mask_graph(mask_graph.graph_module) + + if mask_trivial: + return True, "" + + if not _supports_nontrivial_mask_graphs(): + return ( + False, + "NYI: Non-trivial mask graphs only supported on Hopper (SM90) for flash attention", + ) + + return True, "" + + +def _use_flex_flash_attention( + subgraph: Subgraph, + mask_graph: Subgraph, + kernel_options: dict[str, Any], + num_score_mod_placeholders: int, + backend: Literal["AUTO", "TRITON", "FLASH", "TRITON_DECODE"], +) -> bool: + """Determine if we should use flex flash attention for the given inputs. + + Args: + subgraph: The score modification subgraph + mask_graph: The mask modification subgraph + kernel_options: Kernel configuration options + num_score_mod_placeholders: Number of placeholders in score_mod + backend: Implementation selector (AUTO, TRITON, FLASH, TRITON_DECODE) + + Returns: + True if flash attention should be used, False otherwise + """ + # Flash is experimental and must be explicitly requested + if backend != "FLASH": + return False + + can_use, reason = _can_use_flex_flash_attention( + subgraph, mask_graph, num_score_mod_placeholders + ) + + if not can_use: + raise RuntimeError( + f"BACKEND='FLASH' but flash attention cannot be used: {reason}" + ) + + return True + + +def create_flex_flash_attention_kernel( + query: TensorBox, + key: TensorBox, + value: TensorBox, + block_mask: tuple[Any, ...], + scale: float, + kernel_options: dict[str, Any], + subgraph_buffer: SubgraphResults, + mask_graph_buffer: SubgraphResults, + score_mod_other_buffers: list[TensorBox], + mask_mod_other_buffers: list[TensorBox], + kv_num_blocks: TensorBox | None, + kv_indices: TensorBox | None, + full_kv_num_blocks: TensorBox | None, + full_kv_indices: TensorBox | None, + mask_graph: Subgraph, + subgraph: Subgraph | None = None, +) -> tuple[TensorBox | ShapeAsConstantBuffer, TensorBox | ShapeAsConstantBuffer]: + """Create a flex flash attention kernel using CuteDSL template.""" + if not ensure_flash_available(): + raise RuntimeError("CUTE flash attention not available") + + # Get dimensions + batch_size, num_heads, seq_len_q, head_dim = query.get_size() + v_head_dim = value.get_size()[-1] + device = query.get_device() + dtype = query.get_dtype() + assert device is not None, "Device must be specified" + + # Match stride pattern from query tensor + q_strides = query.get_stride() + out_size = [batch_size, num_heads, seq_len_q, v_head_dim] + out_strides = infer_dense_strides(out_size, q_strides) + + output = empty_strided( + size=out_size, + stride=out_strides, + dtype=dtype, + device=device, + ) + + lse = empty_strided( + size=[batch_size, num_heads, seq_len_q], + stride=None, # LSE can be contiguous + dtype=torch.float32, # LSE is always fp32 + device=device, + ) + + # Create layout for primary output + output_layout = FixedLayout( + device=device, + dtype=dtype, + size=[batch_size, num_heads, seq_len_q, v_head_dim], + stride=[sympy.sympify(s) for s in output.get_stride()], + ) + + # Used to check if we can skip block sparse impl + mask_graph_is_trivial = is_trivial_mask_graph(mask_graph.graph_module) + + needs_block_mask = not mask_graph_is_trivial + has_full_blocks = full_kv_num_blocks is not None + + choices: list[Any] = [] + assert flash_attention_cutedsl_template is not None + + input_nodes = [query, key, value, lse] + if has_full_blocks: + input_nodes.extend( + [kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices] + ) + + if needs_block_mask and not has_full_blocks: + raise NotImplementedError( + "Flash attention with block mask but without full blocks is not supported yet" + ) + + error = flash_attention_cutedsl_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=output_layout, + mutated_inputs=[lse], + subgraphs=[subgraph_buffer, mask_graph_buffer], + SM_SCALE=scale, + NEEDS_BLOCK_MASK=needs_block_mask, + ) + + for choice in choices: + wrap_choice_render_with_cutedsl_indexer(choice) + + if error or not choices: + # Fallback to original implementation + raise RuntimeError(f"CuteDSL template failed: {error}") + + # No autotune for now + template_output = choices[0].output_node() + + return (template_output, lse) + + +def _can_use_flex_flash_attention_backward( + fw_subgraph: Subgraph, + mask_graph: Subgraph, +) -> tuple[bool, str]: + if not ensure_flash_available(): + return False, "CUTE flash attention is not available" + + if not is_trivial_score_graph(fw_subgraph.graph_module): + return ( + False, + "NYI: Flex Flash Attention doesn't support score_mods in bwds yet.", + ) + + if not is_trivial_mask_graph(mask_graph.graph_module): + return False, "NYI: Flex Flash Attention doesn't support block_sparsity yet." + + return True, "" + + +def _use_flex_flash_attention_backward( + fw_subgraph: Subgraph, + mask_graph: Subgraph, + backend: Literal["AUTO", "TRITON", "FLASH", "TRITON_DECODE"], +) -> bool: + """Determine if we should use flex flash attention for the given inputs. + + Args: + subgraph: The score modification subgraph + mask_graph: The mask modification subgraph + kernel_options: Kernel configuration options + num_score_mod_placeholders: Number of placeholders in score_mod + backend: Implementation selector (AUTO, TRITON, FLASH, TRITON_DECODE) + + Returns: + True if flash attention should be used, False otherwise + """ + # Flash is experimental and must be explicitly requested + if backend != "FLASH": + return False + + can_use, reason = _can_use_flex_flash_attention_backward( + fw_subgraph, + mask_graph, + ) + + if not can_use: + raise RuntimeError( + f"BACKEND='FLASH' but flash attention cannot be used: {reason}" + ) + + return True + + +def create_flex_flash_attention_backward_kernel( + query: TensorBox, + key: TensorBox, + value: TensorBox, + out: TensorBox, + logsumexp: TensorBox, + grad_out: TensorBox, + scale: float, + kernel_options: dict[str, Any], + # TODO: will be needed + # grad_logsumexp, + # fw_graph: SubgraphResults, + # joint_graph: SubgraphResults, + # mask_graph: SubgraphResults, + # score_mod_other_buffers: list[TensorBox], + # mask_mod_other_buffers: list[TensorBox], + # kv_num_blocks: TensorBox | None, + # kv_indices: TensorBox | None, + # full_kv_num_blocks: TensorBox | None, + # full_kv_indices: TensorBox | None, +) -> tuple[TensorBox | ShapeAsConstantBuffer, TensorBox, TensorBox, tuple]: + """Create a CuteDSL flash attention backward kernel for the default mod path.""" + if not ensure_flash_available(): + raise RuntimeError("CUTE flash attention not available") + + batch_size, num_heads, seq_len_q, head_dim = query.get_size() + v_head_dim = value.get_size()[-1] + device = query.get_device() + dtype = query.get_dtype() + assert device is not None + + grad_query_strides = infer_dense_strides( + [batch_size, num_heads, seq_len_q, head_dim], query.get_stride() + ) + grad_query = empty_strided( + size=[batch_size, num_heads, seq_len_q, head_dim], + stride=grad_query_strides, + dtype=dtype, + device=device, + ) + + grad_key_strides = infer_dense_strides( + [batch_size, num_heads, value.get_size()[2], head_dim], key.get_stride() + ) + grad_key = empty_strided( + size=[batch_size, num_heads, value.get_size()[2], head_dim], + stride=grad_key_strides, + dtype=dtype, + device=device, + ) + + grad_value_strides = infer_dense_strides( + [batch_size, num_heads, value.get_size()[2], v_head_dim], value.get_stride() + ) + grad_value = empty_strided( + size=[batch_size, num_heads, value.get_size()[2], v_head_dim], + stride=grad_value_strides, + dtype=dtype, + device=device, + ) + + output_layout = FixedLayout( + device=device, + dtype=dtype, + size=[batch_size, num_heads, seq_len_q, head_dim], + stride=[sympy.sympify(s) for s in grad_query.get_stride()], + ) + + choices: list[Any] = [] + + input_nodes = [ + query, + key, + value, + out, + grad_out, + logsumexp, + grad_key, + grad_value, + ] + + error = flash_attention_backward_cutedsl_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=output_layout, + mutated_inputs=[grad_key, grad_value], + SM_SCALE=scale, + ) + + for choice in choices: + wrap_choice_render_with_cutedsl_indexer(choice) + + if error or not choices: + raise RuntimeError(f"CuteDSL template failed: {error}") + + template_output = choices[0].output_node() + + return (template_output, grad_key, grad_value, tuple()) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..989f297c5f80f4053cbc54f6299181d4722efdb2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja @@ -0,0 +1,333 @@ +import functools +from torch._inductor.runtime.runtime_utils import ceildiv +from cutlass.utils import TensorMapUpdateMode +{{gen_defines()}} +# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ---- +from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import ( + GroupedGemmKernel, +) + + +# Note about caching: +# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor +# maintains its own local caching system. At this stage, all compile-time +# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel +# name itself ({{kernel_name}}) are permanently baked into the file, so they +# do not need to be included in any cache key. +# +# The caching mechanism is split into two levels: +# +# 1. prep_cache +# Caches the compiled executor for build_group_ptrs_from_bases(). This +# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C, +# and can therefore be safely reused across runs with different group +# partitioning (`offs`). +# +# 2. gemm_cache +# Caches the compiled Grouped GEMM executor. Its key extends the prep +# cache key with hardware- and grid-specific parameters: +# (prep_cache_key, max_active_clusters, total_num_clusters). +# This is necessary because different `offs` tensors can change the +# per-group problem sizes and thus alter `total_num_clusters`, which in +# turn changes the grid shape and persistent scheduler configuration. +# Kernels compiled for one grid cannot be safely reused for another. +# +# +# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically, +# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead, +# despite depending only on the GPU type. We cache this function to mitigate +# redundant recompiles even when shape/stride/dtype cache misses force kernel +# regeneration. A follow-up study will investigate the root cause. + +prep_cache = {} +gemm_cache = {} + + +@functools.lru_cache +def get_hardware_info(): + hw = cutlass.utils.HardwareInfo() + sm_count = hw.get_max_active_clusters(1) + max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N) + + return (sm_count, max_active_clusters) + + +def get_prep_cache_key(input_a, input_b, output): + """ + Returns a tuple key for caching the preprocessing kernel executor based on kernel name, + shapes, strides, and dtypes of input/output tensors. + """ + return ( + tuple(input_a.shape), + tuple(input_a.stride()), + input_a.dtype, + tuple(input_b.shape), + tuple(input_b.stride()), + input_b.dtype, + tuple(output.shape), + tuple(output.stride()), + output.dtype, + ) + + +def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters): + """ + Returns a tuple key for caching the gemm kernel executor by extending the + prep cache key with hardware- and grid-specific parameters. + """ + return ( + prep_cache_key, + max_active_clusters, + total_num_clusters, + ) + + +@cute.kernel +def build_group_ptrs_from_bases_kernel( + base_A_u64: cutlass.Int64, # device addr of input_a (bytes) + base_B_u64: cutlass.Int64, # device addr of input_b (bytes) + base_C_u64: cutlass.Int64, # device addr of Output (bytes) + offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative + K: cutlass.Constexpr, + N: cutlass.Constexpr, + sizeof_element: cutlass.Int32, # bytes + # -------- STRIDES (in ELEMENTS) -------- + stride_A_m_elems: cutlass.Constexpr, # A.stride(0) + stride_A_k_elems: cutlass.Constexpr, # A.stride(1) + stride_B0_elems: cutlass.Constexpr, # B.stride(0) + stride_Bk_elems: cutlass.Constexpr, # B.stride(1) + stride_Bn_elems: cutlass.Constexpr, # B.stride(2) + stride_C_m_elems: cutlass.Constexpr, # C.stride(0) + stride_C_n_elems: cutlass.Constexpr, # C.stride(1) + # -------- OUTPUTS -------- + out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr) + out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1) + out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]] +): + tidx, _, _ = cute.arch.thread_idx() + g = tidx + + m_beg_i32 = 0 + if g > 0: + m_beg_i32 = offs[g - 1] + m_end_i32 = offs[g] + m_g_i32 = m_end_i32 - m_beg_i32 + + a_byte_off = ( + cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element) + ) + c_byte_off = ( + cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element) + ) + b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element) + + # ---- pointers ---- + out_ptrs[g, 0] = base_A_u64 + a_byte_off + out_ptrs[g, 1] = base_B_u64 + b_byte_off + out_ptrs[g, 2] = base_C_u64 + c_byte_off + + # ---- (m, n, k, 1) ---- + out_problem[g, 0] = m_g_i32 + out_problem[g, 1] = N + out_problem[g, 2] = K + out_problem[g, 3] = cutlass.Int32(1) + + # ---- strides ---- + out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems) + out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems) + out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems) + out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems) + out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems) + out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems) + + +@cute.jit +def launch_build_group_ptrs_from_bases( + base_A_u64: cutlass.Int64, + base_B_u64: cutlass.Int64, + base_C_u64: cutlass.Int64, + offs: cute.Tensor, + G: cutlass.Constexpr, + K: cutlass.Constexpr, + N: cutlass.Constexpr, + sizeof_element: cutlass.Constexpr, + stride_A_m_elems: cutlass.Constexpr, + stride_A_k_elems: cutlass.Constexpr, + stride_B0_elems: cutlass.Constexpr, + stride_Bk_elems: cutlass.Constexpr, + stride_Bn_elems: cutlass.Constexpr, + stride_C_m_elems: cutlass.Constexpr, + stride_C_n_elems: cutlass.Constexpr, + out_ptrs: cute.Tensor, # [G,3] cutlass.Int64 + out_problem: cute.Tensor, # [G,4] cutlass.Int32 + out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32 + stream: cuda.CUstream, +): + build_group_ptrs_from_bases_kernel( + base_A_u64, + base_B_u64, + base_C_u64, + offs, + K, + N, + sizeof_element, + stride_A_m_elems, + stride_A_k_elems, + stride_B0_elems, + stride_Bk_elems, + stride_Bn_elems, + stride_C_m_elems, + stride_C_n_elems, + out_ptrs, + out_problem, + out_strides_abc, + ).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream) + + +{{def_kernel("input_a", "input_b", "input_a_offs")}} + stream = cuda.CUstream(stream) + + input_b = input_b.transpose(1, 2) + + sumM, K = input_a.shape + G, N, Kb = input_b.shape + + dev = input_a.device + + base_A_u64 = int(input_a.data_ptr()) + base_B_u64 = int(input_b.data_ptr()) + base_C_u64 = int({{get_output()}}.data_ptr()) + + ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64) + probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32) + strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32) + ptrs = from_dlpack(ptrs_t) + probs = from_dlpack(probs_t) + strides = from_dlpack(strides_t) + + prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}}) + prep_executor = prep_cache.get(prep_cache_key) + + if prep_executor is None: + sizeof_element = int(input_a.element_size()) + sA_m, sA_k = map(int, input_a.stride()) + sB_0, sB_n, sB_k = map(int, input_b.stride()) + sC_m, sC_n = map(int, {{get_output()}}.stride()) + + prep_executor = cute.compile( + launch_build_group_ptrs_from_bases, + base_A_u64=base_A_u64, + base_B_u64=base_B_u64, + base_C_u64=base_C_u64, + offs=from_dlpack(input_a_offs), + G=int(G), + K=int(K), + N=int(N), + sizeof_element=sizeof_element, + stride_A_m_elems=sA_m, + stride_A_k_elems=sA_k, + stride_B0_elems=sB_0, + stride_Bk_elems=sB_k, + stride_Bn_elems=sB_n, + stride_C_m_elems=sC_m, + stride_C_n_elems=sC_n, + out_ptrs=ptrs, + out_problem=probs, + out_strides_abc=strides, + stream=stream, + ) + + prep_cache[prep_cache_key] = prep_executor + + prep_executor( + base_A_u64=base_A_u64, + base_B_u64=base_B_u64, + base_C_u64=base_C_u64, + offs=from_dlpack(input_a_offs), + out_ptrs=ptrs, + out_problem=probs, + out_strides_abc=strides, + stream=stream, + ) + + # --- Tensormap workspace per SM --- + num_tensormap_buffers, max_active_clusters = get_hardware_info() + tensormap_shape = ( + num_tensormap_buffers, + GroupedGemmKernel.num_tensormaps, + GroupedGemmKernel.bytes_per_tensormap // 8, + ) + tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64) + tensormap_workspace = from_dlpack(tensormap_workspace_t) + + # --- Total clusters --- + def compute_total_num_clusters( + problem_sizes_mnkl, + cluster_tile_shape_mn, + ): + total_num_clusters = 0 + for m, n, _, _ in problem_sizes_mnkl: + num_clusters_mn = tuple( + ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn) + ) + total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) + return total_num_clusters + + # Compute cluster tile shape + def compute_cluster_tile_shape( + mma_tiler_mn, + cluster_shape_mn, + use_2cta_instrs, + ): + cta_tile_shape_mn = list(mma_tiler_mn) + if use_2cta_instrs: + cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2 + return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) + + cluster_tile_shape_mn = compute_cluster_tile_shape( + (TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA) + ) + + total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn)) + + gemm_cache_key = get_gemm_cache_key( + prep_cache_key, max_active_clusters, total_num_clusters + ) + gemm_executor = gemm_cache.get(gemm_cache_key) + + if gemm_executor is None: + grouped_gemm = GroupedGemmKernel( + acc_dtype=ACC_DTYPE, + use_2cta_instrs=USE_2_CTA, + mma_tiler_mn=(TILE_M, TILE_N), + cluster_shape_mn=(CLUSTER_M, CLUSTER_N), + tensormap_update_mode=TENSORMAP_UPDATE_MODE, + ) + + gemm_executor = cute.compile( + grouped_gemm, + from_dlpack(input_a.unsqueeze(-1), assumed_align=16), + from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), + from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), + G, + probs, + strides, + ptrs, + total_num_clusters, + tensormap_workspace, + max_active_clusters, + stream, + ) + + gemm_cache[gemm_cache_key] = gemm_executor + + gemm_executor( + from_dlpack(input_a.unsqueeze(-1), assumed_align=16), + from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), + from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), + probs, + strides, + ptrs, + tensormap_workspace, + stream, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_blackwell_ws_persistent_device_tma_mm.py.jinja b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_blackwell_ws_persistent_device_tma_mm.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..34ff2d69793c004b050cfbbd939218a7ed6a255f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_blackwell_ws_persistent_device_tma_mm.py.jinja @@ -0,0 +1,107 @@ +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + start_pid = tl.program_id(0) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = grid_m * grid_n + + # Note: We require TMA_EXPERIMENTAL_API == False, which + # we will check before invoking this template. + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K] if A_ROW_MAJOR else [K, M], + strides=[stride_am, 1] if A_ROW_MAJOR else [stride_ak, 1], + block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[K, N] if B_ROW_MAJOR else [N, K], + strides=[stride_bk, 1] if B_ROW_MAJOR else [stride_bn, 1], + block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + ) + + # tile_id_c is used in the epilogue to break the dependency between + # the prologue and the epilogue + tile_id_c = start_pid - NUM_SMS + num_pid_in_group = GROUP_M * grid_n + + for tile_id in tl.range( + start_pid, num_tiles, NUM_SMS, flatten=FLATTEN, warp_specialize=WARP_SPECIALIZE + ): + pid_m, pid_n = _compute_pid( + tile_id, num_pid_in_group, grid_m, GROUP_M, NUM_SMS + ) + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_K + a = tl.load_tensor_descriptor( + a_desc, + [offs_am, offs_k] if A_ROW_MAJOR else [offs_k, offs_am], + ) + b = tl.load_tensor_descriptor( + b_desc, + [offs_k, offs_bn] if B_ROW_MAJOR else [offs_bn, offs_k], + ) + accumulator += tl.dot( + a if A_ROW_MAJOR else a.T, + b if B_ROW_MAJOR else b.T, + allow_tf32=ALLOW_TF32, + ) + + tile_id_c += NUM_SMS + pid_m, pid_n = _compute_pid( + tile_id_c, num_pid_in_group, grid_m, GROUP_M, NUM_SMS + ) + offs_cm = pid_m * BLOCK_M + offs_cn = pid_n * BLOCK_N + {%- if EPILOGUE_SUBTILE %} + tl.static_assert(BLOCK_N % 2 == 0) + acc = tl.reshape(accumulator, (BLOCK_M, 2, BLOCK_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + {{store_output( + ("offs_cm", "offs_cn"), + "acc0", + indent_width=8, + val_shape=("BLOCK_M", "BLOCK_N // 2"), + block_indexing=True + )}} + offs_cn2 = offs_cn + BLOCK_N // 2 + {{store_output( + ("offs_cm", "offs_cn2"), + "acc1", + indent_width=8, + val_shape=("BLOCK_M", "BLOCK_N // 2"), + block_indexing=True + )}} + {%- else %} + {{store_output( + ("offs_cm", "offs_cn"), + "accumulator", + indent_width=8, + val_shape=("BLOCK_M", "BLOCK_N"), + block_indexing=True + )}} + {%- endif %} + +@triton.jit +def _compute_pid(tile_id, num_pid_in_group, grid_m, GROUP_M: tl.constexpr, NUM_SMS: tl.constexpr): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + GROUP_M = min(grid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % GROUP_M) + pid_n = (tile_id % num_pid_in_group) // GROUP_M + return pid_m, pid_n diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_epilogue_scaled_mm.py.jinja b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_epilogue_scaled_mm.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..56ef18b7a91e3cea8fb49da3465082cc47162a09 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_epilogue_scaled_mm.py.jinja @@ -0,0 +1,194 @@ +{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + if SCALE_RECIPE_A == 1: # ScalingType.RowWise + stride_a_scale_m = 1 + else: + stride_a_scale_m = 0 + + if SCALE_RECIPE_B == 1: # ScalingType.RowWise + stride_b_scale_n = 1 + else: + stride_b_scale_n = 0 + + start_pid = tl.program_id(axis=0).to(INDEX_DTYPE) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = num_pid_m * num_pid_n + + {%- if TMA_EXPERIMENTAL_API %} + workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=A, + load_size=[BLOCK_M, BLOCK_K], + global_size=[M, K], + element_ty=A.dtype.element_ty, + ) + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=B, + load_size=[BLOCK_N, BLOCK_K], + global_size=[N, K], + element_ty=B.dtype.element_ty, + ) + + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + + {%- else %} + stride_am = {{stride("A", 0)}} + stride_bn = {{stride("B", 1)}} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K], + strides=[stride_am, 1], + block_shape=[BLOCK_M, BLOCK_K], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[N, K], + strides=[stride_bn, 1], + block_shape=[BLOCK_N, BLOCK_K], + ) + {%- endif %} + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_M * num_pid_n + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + a_scale = load_scales(A_inverse_scale, SCALE_RECIPE_A) + b_scale = load_scales(B_inverse_scale, SCALE_RECIPE_B) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + + offs_k = ki * BLOCK_K + + {%- if TMA_EXPERIMENTAL_API %} + a = tl._experimental_descriptor_load( + a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty + ) + {%- else %} + a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) + b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k]) + {%- endif %} + if USE_FAST_ACCUM: + accumulator = tl.dot(a, b.T, accumulator) + else: + accumulator += tl.dot(a, b.T) + + if ki == k_tiles - 1: + # Apply inverse scaling + offs_cm = offs_am + tl.arange(0, BLOCK_M) + offs_cn = offs_bn + tl.arange(0, BLOCK_N) + # Apply scaling + accumulator = apply_scaling( + accumulator, + a_scale, + b_scale, + SCALE_RECIPE_A, + SCALE_RECIPE_B, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, + ) + + # inductor generates a suffix + {%- if TMA_EXPERIMENTAL_API %} + idx_m = offs_cm[:, None] + idx_n = offs_cn[None, :] + mask = (idx_m < M) & (idx_n < N) + {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12, val_shape=("BLOCK_M", "BLOCK_N"))}} + {%- else %} + {{store_output( + ("offs_am", "offs_bn"), + "accumulator", + indent_width=12, + val_shape=("BLOCK_M", "BLOCK_N"), + block_indexing=True, + )}} + {%- endif %} + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + +@triton.jit +def load_scales(scale_ptr, SCALE_RECIPE: tl.constexpr): + if SCALE_RECIPE == 0: + return tl.load(scale_ptr) # For tensor-wise scaling, we'll load the scalar values + else: + return scale_ptr # For all other scaling recipes, we'll return the pointers + + +@triton.jit +def apply_scaling( + accumulator, + a_scale, + b_scale, + SCALE_RECIPE_A: tl.constexpr, + SCALE_RECIPE_B: tl.constexpr, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, +): + if SCALE_RECIPE_A == 1 and SCALE_RECIPE_B == 1: # (ScalingType.RowWise, ScalingType.RowWise) + # For row-wise scaling, we need to load the scales for each row/column + a_scales = tl.load( + a_scale + (offs_cm * stride_a_scale_m), + mask=offs_cm < M, + other=0.0, + ) + b_scales = tl.load( + b_scale + (offs_cn * stride_b_scale_n), + mask=offs_cn < N, + other=0.0, + ) + acc_scale = a_scales[:, None] * b_scales[None, :] + else: # (ScalingType.TensorWise, ScalingType.TensorWise) + # For per-tensor scaling, we can directly use the loaded scalar values + acc_scale = a_scale * b_scale + + return accumulator * acc_scale diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_main_loop_scaled_mm.py.jinja b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_main_loop_scaled_mm.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..171340a2c92333c3e514f560183ac746c458b9ce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_main_loop_scaled_mm.py.jinja @@ -0,0 +1,212 @@ +{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + stride_am = {{stride("A", 0)}} + stride_bn = {{stride("B", 1)}} + + start_pid = tl.program_id(axis=0).to(INDEX_DTYPE) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = num_pid_m * num_pid_n + + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K], + strides=[stride_am, 1], + block_shape=[BLOCK_M, BLOCK_K], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[N, K], + strides=[stride_bn, 1], + block_shape=[BLOCK_N, BLOCK_K], + ) + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_M * num_pid_n + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + a_scale = load_scales(A_inverse_scale, SCALE_RECIPE_A) + b_scale = load_scales(B_inverse_scale, SCALE_RECIPE_B) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + + offs_k = ki * BLOCK_K + + a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) + b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k]) + + am_blocks = tl.cdiv(M, TILE_SIZE_A) + ak_blocks = tl.cdiv(K, TILE_SIZE_A) + bn_blocks = tl.cdiv(N, TILE_SIZE_B) + bk_blocks = tl.cdiv(K, TILE_SIZE_B) + + {%- if SCALE_RECIPE_A == 5 %} # ScalingType.Blockwise128x128 + scale_a_block = blockwise128x128_scaling( + pid_m, + a_scale, + ki, + am_blocks, + ak_blocks, + BLOCK_M, + BLOCK_K, + MIN_BLOCK_TILE_AM, + MIN_BLOCK_TILE_AK, + ) + {%- else %} # ScalingType.Blockwise1xTILESIZE + scale_a_block = blockwise1xTILESIZE_scaling( + pid_m, + a_scale, + ki, + M, + am_blocks, + ak_blocks, + BLOCK_M, + BLOCK_K, + MIN_BLOCK_TILE_AK, + TILE_SIZE_A, + ) + {%- endif %} + + {%- if SCALE_RECIPE_A == 5 %} # ScalingType.Blockwise128x128 + scale_b_block = blockwise128x128_scaling( + pid_n, + b_scale, + ki, + bn_blocks, + bk_blocks, + BLOCK_N, + BLOCK_K, + MIN_BLOCK_TILE_BN, + MIN_BLOCK_TILE_BK, + ) + {%- else %} # ScalingType.Blockwise1xTILESIZE + scale_b_block = blockwise1xTILESIZE_scaling( + pid_n, + b_scale, + ki, + N, + bn_blocks, + bk_blocks, + BLOCK_N, + BLOCK_K, + MIN_BLOCK_TILE_BK, + TILE_SIZE_B, + ) + {%- endif %} + + a_scaled = a * scale_a_block + b_scaled = b * scale_b_block + accumulator = tl.dot(a_scaled, b_scaled.T, accumulator) + + if ki == k_tiles - 1: + offs_cm = offs_am + tl.arange(0, BLOCK_M) + offs_cn = offs_bn + tl.arange(0, BLOCK_N) + + # inductor generates a suffix + {{store_output( + ("offs_am", "offs_bn"), + "accumulator", + indent_width=12, + val_shape=("BLOCK_M", "BLOCK_N"), + block_indexing=True, + )}} + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + +@triton.jit +def load_scales(scale_ptr, SCALE_RECIPE: tl.constexpr): + if SCALE_RECIPE == 0: + return tl.load(scale_ptr) # For tensor-wise scaling, we'll load the scalar values + else: + return scale_ptr # For all other scaling recipes, we'll return the pointers + + +@triton.jit +def blockwise1xTILESIZE_scaling( + pid, + scale, + ki, + lhs_size, + lhs_blocks, + k_blocks, + BLOCK_lhs: tl.constexpr, + BLOCK_K: tl.constexpr, + MIN_BLOCK_TILE_K: tl.constexpr, + TILE_SIZE: tl.constexpr, +): + row_offs_scale = pid * BLOCK_lhs + tl.arange(0, BLOCK_lhs) + col_offs_scale = ki * tl.cdiv(BLOCK_K, TILE_SIZE) + tl.arange(0, (BLOCK_K + TILE_SIZE - 1) // TILE_SIZE) + ptrs = scale + row_offs_scale[:, None] * k_blocks + col_offs_scale[None, :] + mask = (row_offs_scale[:, None] < lhs_size) & (col_offs_scale[None, :] < k_blocks) + scale_block = tl.load(ptrs, mask=mask, other=1.0) + + scale_expanded = scale_block[:, :, None] + scale_expanded = tl.broadcast_to( + scale_expanded, + (BLOCK_lhs, (BLOCK_K + TILE_SIZE - 1) // TILE_SIZE, MIN_BLOCK_TILE_K) + ) + scale_expanded = scale_expanded.reshape( + BLOCK_lhs, + ((BLOCK_K + TILE_SIZE - 1) // TILE_SIZE) * MIN_BLOCK_TILE_K + ) + + return scale_expanded + + +@triton.jit +def blockwise128x128_scaling( + pid, + scale, + ki, + lhs_blocks, + k_blocks, + BLOCK_lhs: tl.constexpr, + BLOCK_K: tl.constexpr, + MIN_BLOCK_TILE_lhs: tl.constexpr, + MIN_BLOCK_TILE_K: tl.constexpr, +): + row_offs_scale = pid * tl.cdiv(BLOCK_lhs, 128) + tl.arange(0, (BLOCK_lhs + 128 - 1) // 128) + col_offs_scale = ki * tl.cdiv(BLOCK_K, 128) + tl.arange(0, (BLOCK_K + 128 - 1) // 128) + ptrs = scale + row_offs_scale[:, None] * k_blocks + col_offs_scale[None, :] + mask = (row_offs_scale[:, None] < lhs_blocks) & (col_offs_scale[None, :] < k_blocks) + scale_block = tl.load(ptrs, mask=mask, other=1.0) + + scale_expanded = scale_block[:, :, None, None] + scale_expanded = tl.broadcast_to( + scale_expanded, + ((BLOCK_lhs + 128 - 1) // 128, (BLOCK_K + 128 - 1) // 128, MIN_BLOCK_TILE_lhs, MIN_BLOCK_TILE_K) + ) + scale_expanded = scale_expanded.reshape( + ((BLOCK_lhs + 128 - 1) // 128) * MIN_BLOCK_TILE_lhs, + ((BLOCK_K + 128 - 1) // 128) * MIN_BLOCK_TILE_K + ) + + return scale_expanded diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_mm.py.jinja b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_mm.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..2da348f3e767cfbb91350ccb3831c9bf07b07528 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_mm.py.jinja @@ -0,0 +1,72 @@ +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0).to(INDEX_DTYPE) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and (M >= BLOCK_M and K > 1): + offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + offs_a_m = rm % M + if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and (N >= BLOCK_N and K > 1): + offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + offs_b_n = rn % N + offs_k = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for k_idx in range(0, tl.cdiv(K, BLOCK_K)): + {% if not EVEN_K %} + a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) + b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) + {% endif %} + a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) + b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) + + idx_m = offs_a_m[:, None] + idx_n = a_k_idx_vals + {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", + indent_width=8, index_shape=("BLOCK_M", "BLOCK_K"))}} + + idx_m = b_k_idx_vals + idx_n = offs_b_n[None, :] + {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", + indent_width=8, index_shape=("BLOCK_K", "BLOCK_N"))}} + + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% endif %} + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_mm_rocm.py.jinja b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_mm_rocm.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..42b99c70d5cbd5394c00662793b212661c48e48b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_mm_rocm.py.jinja @@ -0,0 +1,71 @@ +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0).to(INDEX_DTYPE) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): + offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + offs_a_m = rm % M + if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1): + offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + offs_b_n = rn % N + offs_k = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for k_idx in range(0, tl.cdiv(K, BLOCK_K)): + {% if not EVEN_K %} + a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) + b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) + {% endif %} + a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) + b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) + + idx_m = offs_a_m[:, None] + idx_n = a_k_idx_vals + {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", + indent_width=8, index_shape=("BLOCK_M", "BLOCK_K"))}} + + idx_m = b_k_idx_vals + idx_n = offs_b_n[None, :] + {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", + indent_width=8, index_shape=("BLOCK_K", "BLOCK_N"))}} + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% endif %} + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_persistent_tma_mm.py.jinja b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_persistent_tma_mm.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..38fe092c257803f4676092af83e40e3eeb55f8c7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/templates/triton_persistent_tma_mm.py.jinja @@ -0,0 +1,129 @@ +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + start_pid = tl.program_id(0).to(INDEX_DTYPE) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = grid_m * grid_n + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + width = GROUP_M * grid_n + rk_for_mask = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + {%- if TMA_EXPERIMENTAL_API %} + workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=A, + load_size=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + global_size=[M, K] if A_ROW_MAJOR else [K, M], + element_ty=A.dtype.element_ty, + ) + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=B, + load_size=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + global_size=[K, N] if B_ROW_MAJOR else [N, K], + element_ty=B.dtype.element_ty, + ) + + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + + {%- else %} + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K] if A_ROW_MAJOR else [K, M], + strides=[stride_am, 1] if A_ROW_MAJOR else [stride_ak, 1], + block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[K, N] if B_ROW_MAJOR else [N, K], + strides=[stride_bk, 1] if B_ROW_MAJOR else [stride_bn, 1], + block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + ) + {%- endif %} + + pid_m = 0 + pid_n = 0 + rm = 0 + rn = 0 + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + # re-order program ID for better L2 performance + group_id = tile_id // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // (group_size) + + rm = pid_m * BLOCK_M + rn = pid_n * BLOCK_N + + rk = ki * BLOCK_K + + {%- if TMA_EXPERIMENTAL_API %} + a = tl._experimental_descriptor_load( + a_desc_ptr, + [rm, rk] if A_ROW_MAJOR else [rk, rm], + [BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + A.dtype.element_ty, + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, + [rk, rn] if B_ROW_MAJOR else [rn, rk], + [BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + B.dtype.element_ty, + ) + {%- else %} + a = tl.load_tensor_descriptor( + a_desc, + [rm, rk] if A_ROW_MAJOR else [rk, rm], + ) + b = tl.load_tensor_descriptor( + b_desc, + [rk, rn] if B_ROW_MAJOR else [rn, rk], + ) + {%- endif %} + acc += tl.dot( + a if A_ROW_MAJOR else a.T, + b if B_ROW_MAJOR else b.T, + allow_tf32=ALLOW_TF32, + ) + + if ki == k_tiles - 1: + # inductor generates a suffix + {%- if TMA_EXPERIMENTAL_API %} + # rematerialize rm and rn to save registers + rcm = rm + tl.arange(0, BLOCK_M) + rcn = rn + tl.arange(0, BLOCK_N) + idx_m = rcm[:, None] + idx_n = rcn[None, :] + mask = (idx_m < M) & (idx_n < N) + {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12, val_shape=("BLOCK_M", "BLOCK_N"))}} + {%- else %} + {{store_output(("rm", "rn"), "acc", indent_width=12, val_shape=("BLOCK_M", "BLOCK_N"), block_indexing=True)}} + {%- endif %} + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..becac750003df0240b2708840bbc9fa19599ff2a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py @@ -0,0 +1,2372 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import functools +from typing import List, Type, Union +from inspect import isclass + +import torch +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.torch as cutlass_torch + +""" +A grouped GEMM example for the NVIDIA Blackwell SM100 architecture using CUTE DSL + +This example demonstrates an implementation of grouped GEMM using a TMA plus Blackwell SM100 TensorCore +warp-specialized persistent kernel. +The grouped GEMM workload computes a batch of GEMM operations with distinct problem sizes. Pointers to matrices +in global memory are passed to the kernel in an array (also held in global memory). Similarly, problem shapes and +strides are also stored in arrays in GMEM. + +This differs from "Batched Array" GEMM since the size of each GEMM problem in the grouped GEMM concept may be distinct. + +To run this example: + +.. code-block:: bash + + python examples/blackwell/grouped_gemm.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 128,64 --cluster_shape_mn 1,1 \ + --problem_sizes_mnkl "(8192,1280,32,1),(16,384,1536,1),(640,1280,16,1),(640,160,16,1)" \ + --num_groups 4 --tensormap_update_mode SMEM + +The above example command makes 4 groups of different m, n, k sizes. The Blackwell tcgen05 MMA tile shape +is specified as (128, 64) and the cluster shape is (1,1). The input, mma accumulator and output data type +are set as fp16, fp32 and fp16, respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/grouped_gemm.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 128,64 --cluster_shape_mn 1,1 \ + --problem_sizes_mnkl "(8192,1280,32,1),(16,384,1536,1),(640,1280,16,1),(640,160,16,1)" \ + --num_groups 4 --tensormap_update_mode SMEM \ + --warmup_iterations 1 --iterations 10 --skip_ref_check + +There are some constrains for this example. Besides the constrains from the Balckwell dense GEMM persistent example, +there are also the following constrains: +* Only fp16 and bf16 data types are supported as inputs. +* Output data types could be fp16, bf16 or fp32. +* The contiguous dimension of each tensor must be at least 16 bytes aligned. +* The l mode(aka, batch size) for each group must be 1. +* The majorness for A, B and C must be the same across all groups. +""" + + +class GroupedGemmKernel: + def __init__( + self, + acc_dtype: type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: tuple[int, int], + cluster_shape_mn: tuple[int, int], + tensormap_update_mode: utils.TensorMapUpdateMode = utils.TensorMapUpdateMode.SMEM, + ): + """Initializes the configuration for a Blackwell grouped GEMM kernel. + + Besides configurations for dense persistent GEMM, there is an extra config specific to grouped GEMM: + + Tensormap Update Mode: + - tensormap_update_mode: Specifies whether the tensormap is + updated in global memory(GMEM) or shared memory(SMEM). + The 2 modes are functionally equivalent and the difference are: + - We buffer 3 tensormaps in SMEM for A, B, and C tensors (each TMA descriptor takes 128B) when TMA updates performed on SMEM. + - Performance varies between modes depending on problem size; optimal choice differs across workloads. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param mma_tiler_mn: tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: tuple[int, int] + :param cluster_shape_mn: tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: tuple[int, int] + :param tensormap_update_mode: Mode for updating the tensormap (GMEM or SMEM), defaults to SMEM. + :type tensormap_update_mode: utils.TensorMapUpdateMode, optional + """ + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + self.cta_group = ( + tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.tensormap_update_mode = tensormap_update_mode + # Delegate tensormap ab initialization to MMA warp when SMEM mode is used for better latency hiding + self.delegate_tensormap_ab_init = ( + tensormap_update_mode == utils.TensorMapUpdateMode.SMEM + ) + + self.num_mcast_ctas_a = 1 + self.num_mcast_ctas_b = 1 + self.is_a_mcast = False + self.is_b_mcast = False + + self.occupancy = 1 + # Set specialized warp ids + self.epilog_warp_id = ( + 0, + 1, + 2, + 3, + ) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.threads_per_cta = 32 * len( + (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) + ) + # Set barrier for epilog sync, tmem ptr sync and tensormap update sync + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=32 * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)), + ) + # Barrier used by MMA/TMA warps to signal A/B tensormap initialization completion + self.tensormap_ab_init_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=32 * (len(self.epilog_warp_id) + 1), + ) + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + self.num_tma_load_bytes = 0 + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + Most of the implementation follows standard dense GEMM patterns, + with the key difference being additional consideration for SMEM + buffer needed for tensormap updates. + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + self.cluster_tile_shape_mnk = tuple( + x * y for x, y in zip(self.cta_tile_shape_mnk, (*self.cluster_shape_mn, 1)) + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + self.epi_tile = utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + + # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory + ( + self.num_acc_stage, + self.num_ab_stage, + self.num_epi_stage, + ) = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.smem_capacity, + self.occupancy, + ) + + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.epi_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_epi_stage, + ) + + mbar_smem_bytes = self._get_mbar_smem_bytes( + num_acc_stage=self.num_acc_stage, + num_ab_stage=self.num_ab_stage, + num_epi_stage=self.num_epi_stage, + ) + tensormap_smem_bytes = self._get_tensormap_smem_bytes( + self.tensormap_update_mode + ) + if ( + mbar_smem_bytes + + tensormap_smem_bytes + + GroupedGemmKernel.tensor_memory_management_bytes + > self.reserved_smem_bytes + ): + raise ValueError( + f"smem consumption for mbar and tensormap {mbar_smem_bytes + tensormap_smem_bytes} exceeds the " + f"reserved smem bytes {self.reserved_smem_bytes}" + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols( + tiled_mma, self.mma_tiler, self.num_acc_stage + ) + + @cute.jit + def __call__( + self, + initial_a: cute.Tensor, + initial_b: cute.Tensor, + initial_c: cute.Tensor, + group_count: cutlass.Constexpr[int], + problem_shape_mnkl: cute.Tensor, + strides_abc: cute.Tensor, + tensor_address_abc: cute.Tensor, + total_num_clusters: cutlass.Constexpr[int], + tensormap_cute_tensor: cute.Tensor, + max_active_clusters: cutlass.Constexpr[int], + stream: cuda.CUstream, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + For grouped GEMM, tensor shapes, tensor strides, and tensor address are all provided + by different tensors in global memory. The "initial" tensors only carry data type and + majorness information. + + :param initial_a: Initial tensor A, used for data type and majorness information. + :type initial_a: cute.Tensor + :param initial_b: Initial tensor B, used for data type and majorness information. + :type initial_b: cute.Tensor + :param initial_c: Initial tensor C, used for data type and majorness information. + :type initial_c: cute.Tensor + :param group_count: The number of GEMM groups. + :type group_count: cutlass.Constexpr[int] + :param problem_shape_mnkl: Tensor containing the (M, N, K, L) shape for each group. + :type problem_shape_mnkl: cute.Tensor + :param strides_abc: Tensor containing the strides for A, B, and C for each group. + :type strides_abc: cute.Tensor + :param tensor_address_abc: Tensor containing the base addresses for A, B, and C for each group. + :type tensor_address_abc: cute.Tensor + :param total_num_clusters: Total number of clusters needed for all groups. + :type total_num_clusters: cutlass.Constexpr[int] + :param tensormap_cute_tensor: Tensor for storing tensormaps. + :type tensormap_cute_tensor: cute.Tensor + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr[int] + :param stream: CUDA stream for asynchronous execution. + :type stream: cuda.CUstream + :raises TypeError: If A and B data types do not match. + """ + self.a_dtype = initial_a.element_type + self.b_dtype = initial_b.element_type + self.c_dtype = initial_c.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(initial_a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(initial_b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(initial_c) + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + initial_a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + initial_b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup TMA store for C + tma_atom_c = None + tma_tensor_c = None + epi_smem_layout = cute.slice_(self.epi_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + initial_c, + epi_smem_layout, + self.epi_tile, + ) + + self.tile_sched_params, grid = self._compute_grid( + total_num_clusters, self.cluster_shape_mn, max_active_clusters + ) + + self.buffer_align_bytes = 1024 + self.size_tensormap_in_i64 = ( + 0 + if self.tensormap_update_mode == utils.TensorMapUpdateMode.GMEM + else GroupedGemmKernel.num_tensormaps + * GroupedGemmKernel.bytes_per_tensormap + // 8 + ) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + tensormap_buffer: cute.struct.MemRange[ + cutlass.Int64, self.size_tensormap_in_i64 + ] + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + cute.cosize(self.epi_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.epi_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + group_count, + problem_shape_mnkl, + strides_abc, + tensor_address_abc, + tensormap_cute_tensor, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + epi_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + group_count: cutlass.Constexpr[int], + problem_sizes_mnkl: cute.Tensor, + strides_abc: cute.Tensor, + ptrs_abc: cute.Tensor, + tensormaps: cute.Tensor, + ): + """ + GPU device kernel performing the grouped GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coord inside cluster + bid = cute.arch.block_idx() + mma_tile_coord_v = bid[0] % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: tensormap buffer, a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tensormap_a_smem_ptr = None + tensormap_b_smem_ptr = None + tensormap_c_smem_ptr = None + if cutlass.const_expr( + self.tensormap_update_mode == utils.TensorMapUpdateMode.SMEM + ): + tensormap_smem_ptr = storage.tensormap_buffer.data_ptr() + tensormap_a_smem_ptr = tensormap_smem_ptr + tensormap_b_smem_ptr = ( + tensormap_a_smem_ptr + GroupedGemmKernel.bytes_per_tensormap // 8 + ) + tensormap_c_smem_ptr = ( + tensormap_b_smem_ptr + GroupedGemmKernel.bytes_per_tensormap // 8 + ) + ab_full_mbar_ptr = storage.ab_full_mbar_ptr.data_ptr() + ab_empty_mbar_ptr = storage.ab_empty_mbar_ptr.data_ptr() + acc_full_mbar_ptr = storage.acc_full_mbar_ptr.data_ptr() + acc_empty_mbar_ptr = storage.acc_empty_mbar_ptr.data_ptr() + + # init barrier for loading A, B with TMA + if warp_idx == self.epilog_warp_id[0]: + for k_stage in range(self.num_ab_stage): + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + with cute.arch.elect_one(): + cute.arch.mbarrier_init(ab_full_mbar_ptr + k_stage, 1) + cute.arch.mbarrier_init( + ab_empty_mbar_ptr + k_stage, num_tma_producer + ) + # Accumulator barrier init + if warp_idx == self.mma_warp_id: + for acc_stage in range(self.num_acc_stage): + with cute.arch.elect_one(): + cute.arch.mbarrier_init(acc_full_mbar_ptr + acc_stage, 1) + cute.arch.mbarrier_init( + acc_empty_mbar_ptr + acc_stage, 8 if use_2cta_instrs else 4 + ) + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + + # + # Setup smem tensor A/B/C + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor( + epi_smem_layout_staged.outer, swizzle=epi_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + + # + # Compute multicast mask for A/B buffer full and empty + # + a_full_mcast_mask = None + b_full_mcast_mask = None + ab_empty_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + ab_empty_mcast_mask = a_full_mcast_mask | b_full_mcast_mask + acc_full_mcast_mask = None + if cutlass.const_expr(use_2cta_instrs): + acc_full_mcast_mask = cute.make_layout_image_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mode=0 + ) + block_in_cluster_coord_vmnk_peer = ( + block_in_cluster_coord_vmnk[0] ^ 1, + *block_in_cluster_coord_vmnk[1:], + ) + a_full_mcast_mask_peer = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=2 + ) + b_full_mcast_mask_peer = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=1 + ) + ab_empty_mcast_mask = ( + a_full_mcast_mask_peer + | b_full_mcast_mask_peer + | cutlass.Int16( + 0 if ab_empty_mcast_mask is None else ab_empty_mcast_mask + ) + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for load A, B with TMA + # + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # + # Cluster wait before tensor memory alloc + # + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + # + # Get tensormap buffer address + # + grid_dim = cute.arch.grid_dim() + tensormap_workspace_idx = ( + bid[2] * grid_dim[1] * grid_dim[0] + bid[1] * grid_dim[0] + bid[0] + ) + + tensormap_manager = utils.TensorMapManager( + self.tensormap_update_mode, GroupedGemmKernel.bytes_per_tensormap + ) + tensormap_a_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 0, None)].iterator + ) + tensormap_b_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 1, None)].iterator + ) + tensormap_c_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 2, None)].iterator + ) + # Setup tensormap initialization pointer based on the mode + if cutlass.const_expr( + self.tensormap_update_mode == utils.TensorMapUpdateMode.SMEM + ): + tensormap_a_init_ptr = tensormap_a_smem_ptr + tensormap_b_init_ptr = tensormap_b_smem_ptr + tensormap_c_init_ptr = tensormap_c_smem_ptr + else: + tensormap_a_init_ptr = tensormap_a_ptr + tensormap_b_init_ptr = tensormap_b_ptr + tensormap_c_init_ptr = tensormap_c_ptr + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + # Initialize tensormaps for A, B + if cutlass.const_expr(self.delegate_tensormap_ab_init == False): + tensormap_manager.init_tensormap_from_atom( + tma_atom_a, tensormap_a_init_ptr, self.tma_warp_id + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_b, tensormap_b_init_ptr, self.tma_warp_id + ) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, bid, grid_dim + ) + # grouped gemm tile scheduler helper will compute the group index for the tile we're working on + group_gemm_ts_helper = utils.GroupedGemmTileSchedulerHelper( + group_count, + tile_sched_params, + self.cluster_tile_shape_mnk, + utils.create_initial_search_state(), + ) + tensormap_init_done = cutlass.Boolean(False) + # tile count we have searched + total_k_tile_cnt = cutlass.Int32(0) + # group index of last tile + last_group_idx = cutlass.Int32(-1) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + grouped_gemm_cta_tile_info = group_gemm_ts_helper.delinearize_z( + cur_tile_coord, + problem_sizes_mnkl, + ) + cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k + cur_group_idx = grouped_gemm_cta_tile_info.group_idx + is_group_changed = cur_group_idx != last_group_idx + # skip tensormap update if we're working on the same group + if is_group_changed: + real_tensor_a = self.make_tensor_for_tensormap_update( + cur_group_idx, + self.a_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + strides_abc, + ptrs_abc, + 0, # 0 for tensor A + ) + real_tensor_b = self.make_tensor_for_tensormap_update( + cur_group_idx, + self.b_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + strides_abc, + ptrs_abc, + 1, # 1 for tensor B + ) + # wait tensormap initialization complete before update + if tensormap_init_done == False: + if cutlass.const_expr(self.delegate_tensormap_ab_init): + self.tensormap_ab_init_barrier.arrive_and_wait() + tensormap_manager.fence_tensormap_initialization() + tensormap_init_done = True + + tensormap_manager.update_tensormap( + (real_tensor_a, real_tensor_b), + (tma_atom_a, tma_atom_b), + (tensormap_a_ptr, tensormap_b_ptr), + self.tma_warp_id, + (tensormap_a_smem_ptr, tensormap_b_smem_ptr), + ) + + mma_tile_coord_mnl = ( + grouped_gemm_cta_tile_info.cta_tile_idx_m + // cute.size(tiled_mma.thr_id.shape), + grouped_gemm_cta_tile_info.cta_tile_idx_n, + 0, + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA_slice = tAgA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), RestK) + tBgB_slice = tBgB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + num_prev_k_blk = total_k_tile_cnt + total_k_tile_cnt += cur_k_tile_cnt + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + tma_wr_k_tile = cutlass.Int32(0) + smem_wr_buffer = (num_prev_k_blk + tma_wr_k_tile) % self.num_ab_stage + tma_wr_ab_empty_phase = ( + num_prev_k_blk + tma_wr_k_tile + ) // self.num_ab_stage % 2 ^ 1 + peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait( + tma_wr_k_tile < cur_k_tile_cnt, + ab_empty_mbar_ptr + smem_wr_buffer, + tma_wr_ab_empty_phase, + ) + # ensure the update to tensormap has completed before using it + if is_group_changed: + tensormap_manager.fence_tensormap_update(tensormap_a_ptr) + tensormap_manager.fence_tensormap_update(tensormap_b_ptr) + # + # Tma load loop + # + for k_tile in cutlass.range(0, cur_k_tile_cnt, 1, unroll=1): + tma_wr_k_tile_next = tma_wr_k_tile + 1 + smem_wr_buffer_next = ( + num_prev_k_blk + tma_wr_k_tile_next + ) % self.num_ab_stage + tma_wr_ab_empty_phase_next = ( + tma_wr_ab_empty_phase ^ 1 + if smem_wr_buffer_next == 0 + else tma_wr_ab_empty_phase + ) + + smem_full_mbar_ptr = ab_full_mbar_ptr + smem_wr_buffer + + # Wait for AB buffer empty + if peek_ab_empty_status == 0: + cute.arch.mbarrier_wait( + ab_empty_mbar_ptr + smem_wr_buffer, tma_wr_ab_empty_phase + ) + + # Arrive AB buffer and expect full transaction bytes + if is_leader_cta: + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + smem_full_mbar_ptr, self.num_tma_load_bytes + ) + + # Load A/B with TMA + cute.copy( + tma_atom_a, + tAgA_slice[(None, tma_wr_k_tile)], + tAsA[(None, smem_wr_buffer)], + tma_bar_ptr=smem_full_mbar_ptr, + mcast_mask=a_full_mcast_mask, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_a_ptr, + cute.AddressSpace.generic, + ), + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, tma_wr_k_tile)], + tBsB[(None, smem_wr_buffer)], + tma_bar_ptr=smem_full_mbar_ptr, + mcast_mask=b_full_mcast_mask, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_b_ptr, + cute.AddressSpace.generic, + ), + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait( + tma_wr_k_tile_next < cur_k_tile_cnt, + ab_empty_mbar_ptr + smem_wr_buffer_next, + tma_wr_ab_empty_phase_next, + ) + + tma_wr_k_tile = tma_wr_k_tile_next + smem_wr_buffer = smem_wr_buffer_next + tma_wr_ab_empty_phase = tma_wr_ab_empty_phase_next + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + last_group_idx = cur_group_idx + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + # Bar sync for retrieve tmem ptr from shared mem + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, bid, grid_dim + ) + # grouped gemm tile scheduler helper will compute the group index for the tile we're working on + group_gemm_ts_helper = utils.GroupedGemmTileSchedulerHelper( + group_count, + tile_sched_params, + self.cluster_tile_shape_mnk, + utils.create_initial_search_state(), + ) + + work_tile = tile_sched.initial_work_tile_info() + # tile count we have searched + total_k_tile_cnt = cutlass.Int32(0) + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + # MMA warp is only interested in number of tiles along K dimension + ( + cur_k_tile_cnt, + cur_group_idx, + ) = group_gemm_ts_helper.search_cluster_tile_count_k( + cur_tile_coord, + problem_sizes_mnkl, + ) + # Set tensor memory buffer for current tile + acc_buf_idx = tile_sched.num_tiles_executed % self.num_acc_stage + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_buf_idx)] + + num_prev_k_blk = total_k_tile_cnt + total_k_tile_cnt += cur_k_tile_cnt + + # Peek (try_wait) AB buffer full for k_tile = 0 + mma_rd_k_tile = cutlass.Int32(0) + smem_rd_buffer = (num_prev_k_blk + mma_rd_k_tile) % self.num_ab_stage + if is_leader_cta: + need_check_rd_buffer_full = ( + mma_rd_k_tile < cur_k_tile_cnt and is_leader_cta + ) + mma_rd_ab_full_phase = ( + (num_prev_k_blk + mma_rd_k_tile) // self.num_ab_stage % 2 + ) + peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait( + need_check_rd_buffer_full, + ab_full_mbar_ptr + smem_rd_buffer, + mma_rd_ab_full_phase, + ) + + # + # Wait for accumulator buffer empty + # + acc_empty_phase = ( + tile_sched.num_tiles_executed // self.num_acc_stage % 2 ^ 1 + ) + cute.arch.mbarrier_wait( + acc_empty_mbar_ptr + acc_buf_idx, acc_empty_phase + ) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + # + # Mma mainloop + # + for k_tile in range(cur_k_tile_cnt): + mma_rd_k_tile_next = cutlass.Int32(k_tile + 1) + smem_rd_buffer_next = ( + num_prev_k_blk + mma_rd_k_tile_next + ) % self.num_ab_stage + mma_rd_ab_full_phase_next = ( + mma_rd_ab_full_phase ^ 1 + if smem_rd_buffer_next == 0 + else mma_rd_ab_full_phase + ) + # Wait for AB buffer full + if peek_ab_full_status == 0: + cute.arch.mbarrier_wait( + ab_full_mbar_ptr + smem_rd_buffer, mma_rd_ab_full_phase + ) + + # tCtAcc += tCrA * tCrB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = (None, None, kblock_idx, smem_rd_buffer) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + with cute.arch.elect_one(): + tcgen05.commit( + ab_empty_mbar_ptr + smem_rd_buffer, + ab_empty_mcast_mask, + self.cta_group, + ) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + need_check_rd_buffer_full = ( + mma_rd_k_tile_next < cur_k_tile_cnt and is_leader_cta + ) + + peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait( + need_check_rd_buffer_full, + ab_full_mbar_ptr + smem_rd_buffer_next, + mma_rd_ab_full_phase_next, + ) + + mma_rd_k_tile = mma_rd_k_tile_next + smem_rd_buffer = smem_rd_buffer_next + mma_rd_ab_full_phase = mma_rd_ab_full_phase_next + + # + # Async arrive accumulator buffer full + # + with cute.arch.elect_one(): + tcgen05.commit( + acc_full_mbar_ptr + acc_buf_idx, + acc_full_mcast_mask, + self.cta_group, + ) + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Specialized epilogue warps + # + if warp_idx < self.mma_warp_id: + # initialize tensormap A, B for TMA warp + if cutlass.const_expr(self.delegate_tensormap_ab_init): + tensormap_manager.init_tensormap_from_atom( + tma_atom_a, tensormap_a_init_ptr, self.epilog_warp_id[0] + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_b, tensormap_b_init_ptr, self.epilog_warp_id[0] + ) + # signal tensormap initialization has finished + self.tensormap_ab_init_barrier.arrive_and_wait() + # initialize tensorap for C + tensormap_manager.init_tensormap_from_atom( + tma_atom_c, + tensormap_c_init_ptr, + self.epilog_warp_id[0], + ) + # Alloc tensor memory buffer + tmem.allocate(self.num_tmem_alloc_cols) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + epi_tidx = tidx + # + # Partition for epilogue + # + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition(tma_atom_c, tCgC, epi_tile, sC) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, bid, grid_dim + ) + # grouped gemm tile scheduler helper will compute the group index for the tile we're working on + group_gemm_ts_helper = utils.GroupedGemmTileSchedulerHelper( + group_count, + tile_sched_params, + self.cluster_tile_shape_mnk, + utils.create_initial_search_state(), + ) + + work_tile = tile_sched.initial_work_tile_info() + # wait tensormap initialization complete before update + tensormap_manager.fence_tensormap_initialization() + # tile count we have searched + total_k_tile_cnt = cutlass.Int32(0) + # group index of last tile + last_group_idx = cutlass.Int32(-1) + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + grouped_gemm_cta_tile_info = group_gemm_ts_helper.delinearize_z( + cur_tile_coord, + problem_sizes_mnkl, + ) + cur_group_idx = grouped_gemm_cta_tile_info.group_idx + is_group_changed = cur_group_idx != last_group_idx + if is_group_changed: + # construct tensor C based on real address, shape and stride information + real_tensor_c = self.make_tensor_for_tensormap_update( + cur_group_idx, + self.c_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + strides_abc, + ptrs_abc, + 2, # 2 for tensor C + ) + tensormap_manager.update_tensormap( + ((real_tensor_c),), + ((tma_atom_c),), + ((tensormap_c_ptr),), + self.epilog_warp_id[0], + (tensormap_c_smem_ptr,), + ) + + mma_tile_coord_mnl = ( + grouped_gemm_cta_tile_info.cta_tile_idx_m + // cute.size(tiled_mma.thr_id.shape), + grouped_gemm_cta_tile_info.cta_tile_idx_n, + 0, + ) + cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k + total_k_tile_cnt += cur_k_tile_cnt + + # + # Slice to per mma tile index + # + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + + # Set tensor memory buffer for current tile + acc_buf_idx = tile_sched.num_tiles_executed % self.num_acc_stage + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_buf_idx)] + + # + # Wait for accumulator buffer full + # + acc_full_phase = tile_sched.num_tiles_executed // self.num_acc_stage % 2 + cute.arch.mbarrier_wait(acc_full_mbar_ptr + acc_buf_idx, acc_full_phase) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + # ensure the update to tensormap has completed before using it + if is_group_changed: + if warp_idx == self.epilog_warp_id[0]: + tensormap_manager.fence_tensormap_update(tensormap_c_ptr) + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + for subtile_idx in range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Convert to output type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + tRS_rC.store(acc_vec.to(self.c_dtype)) + # + # Store C to shared memory + # + epi_buffer = (num_prev_subtiles + subtile_idx) % self.num_epi_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, epi_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.epilog_sync_barrier.arrive_and_wait() + # + # store C to global memory with TMA + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, epi_buffer)], + bSG_gC[(None, subtile_idx)], + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_c_ptr, + cute.AddressSpace.generic, + ), + ) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group( + self.num_epi_stage - 1, read=True + ) + self.epilog_sync_barrier.arrive_and_wait() + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive( + acc_empty_mbar_ptr + acc_buf_idx, + cta_rank_in_cluster // 2 * 2 if use_2cta_instrs else None, + ) + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + last_group_idx = cur_group_idx + + # + # Dealloc the tensor memory buffer + # + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier.arrive_and_wait() + tmem.free(tmem_ptr) + + # + # Wait a/b buffer empty + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.mbarrier_wait( + (ab_empty_mbar_ptr + ((total_k_tile_cnt - 1) % self.num_ab_stage)), + (((total_k_tile_cnt - 1) // self.num_ab_stage) % 2), + ) + + @cute.jit + def make_tensor_for_tensormap_update( + self, + group_idx: cutlass.Int32, + dtype: Type[cutlass.Numeric], + problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32], + strides_abc: cute.Tensor, + tensor_address_abc: cute.Tensor, + tensor_index: int, + ): + """Extract stride and tensor address for a given group and construct a global tensor. + + This function is used within the kernel to dynamically create a CUTE tensor + representing A, B, or C for the current group being processed, using the + group-specific address, shape, and stride information. + + :param group_idx: The index of the current group within the grouped GEMM. + :type group_idx: cutlass.Int32 + :param dtype: The data type of the tensor elements (e.g., cutlass.Float16). + :type dtype: Type[cutlass.Numeric] + :param problem_shape_mnk: The (M, N, K) problem shape for the current group. + :type problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] + :param strides_abc: Tensor containing strides for A, B, C for all groups. Layout: (group_count, 3, 2). + :type strides_abc: cute.Tensor + :param tensor_address_abc: Tensor containing global memory addresses for A, B, C for all groups. Layout: (group_count, 3). + :type tensor_address_abc: cute.Tensor + :param tensor_index: Specifies which tensor to create: 0 for A, 1 for B, 2 for C. + :type tensor_index: int + :return: A CUTE tensor representing the requested global memory tensor (A, B, or C) for the specified group. + :rtype: cute.Tensor + :raises TypeError: If the provided dtype is not a subclass of cutlass.Numeric. + """ + ptr_i64 = tensor_address_abc[(group_idx, tensor_index)] + if cutlass.const_expr( + not isclass(dtype) or not issubclass(dtype, cutlass.Numeric) + ): + raise TypeError( + f"dtype must be a type of cutlass.Numeric, got {type(dtype)}" + ) + tensor_gmem_ptr = cute.make_ptr( + dtype, ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + + strides_tensor_gmem = strides_abc[(group_idx, tensor_index, None)] + strides_tensor_reg = cute.make_rmem_tensor( + cute.make_layout(2), + strides_abc.element_type, + ) + cute.autovec_copy(strides_tensor_gmem, strides_tensor_reg) + stride_mn = strides_tensor_reg[0] + stride_k = strides_tensor_reg[1] + c1 = cutlass.Int32(1) + c0 = cutlass.Int32(0) + + if cutlass.const_expr(tensor_index == 0): # tensor A + m = problem_shape_mnk[0] + k = problem_shape_mnk[2] + return cute.make_tensor( + tensor_gmem_ptr, + cute.make_layout((m, k, c1), stride=(stride_mn, stride_k, c0)), + ) + elif cutlass.const_expr(tensor_index == 1): # tensor B + n = problem_shape_mnk[1] + k = problem_shape_mnk[2] + return cute.make_tensor( + tensor_gmem_ptr, + cute.make_layout((n, k, c1), stride=(stride_mn, stride_k, c0)), + ) + else: # tensor C + m = problem_shape_mnk[0] + n = problem_shape_mnk[1] + return cute.make_tensor( + tensor_gmem_ptr, + cute.make_layout((m, n, c1), stride=(stride_mn, stride_k, c0)), + ) + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load(t2r) + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tma_atom_c: cute.CopyAtom, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to partition + shared memory (source) and global memory (destination) for TMA store version. + + :param tma_atom_c: The TMA copy atom configured for storing tensor C. + :type tma_atom_c: cute.CopyAtom + :param gC_mnl: The global memory tensor C. + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler defining the granularity of the operation. + :type epi_tile: cute.Tile + :param sC: The shared memory epilogue buffer tensor. + :type sC: cute.Tensor + :return: A tuple containing: + - tma_atom_c: The input TMA copy atom (passed through). + - bSG_sC: The source shared memory tensor partitioned for the TMA operation. + - tCgC: The destination global memory tensor partitioned for the TMA operation. + :rtype: tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: tuple[int, int, int], + a_dtype: type[cutlass.Numeric], + b_dtype: type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + smem_capacity: int, + occupancy: int, + ) -> tuple[int, int, int]: + """Computes the number of stages for accumulator, A/B operands, and epilogue based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C in global memory. + :type c_layout: utils.LayoutEnum + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (accumulator stages, A/B operand stages, epilogue stages) + :rtype: tuple[int, int, int] + """ + # Default accumulator and epilogue stages + num_acc_stage = 2 + num_epi_stage = 2 + + # Calculate smem layout and size for one stage of A, B, and Epilogue + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # stage=1 + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # stage=1 + ) + epi_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, # stage=1 + ) + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + + epi_bytes_per_stage = cute.size_in_bytes(c_dtype, epi_smem_layout_staged_one) + epi_bytes = epi_bytes_per_stage * num_epi_stage + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial epilogue bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + smem_capacity // occupancy + - GroupedGemmKernel.reserved_smem_bytes + - epi_bytes + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + remaining_smem = ( + smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (GroupedGemmKernel.reserved_smem_bytes + epi_bytes) + ) + num_epi_stage += remaining_smem // (occupancy * epi_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_epi_stage + + @staticmethod + def _compute_grid( + total_num_clusters: int, + cluster_shape_mn: tuple[int, int], + max_active_clusters: cutlass.Constexpr[int], + ) -> tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]: + """Compute tile scheduler parameters and grid shape for grouped GEMM operations. + + :param total_num_clusters: Total number of clusters to process across all groups. + :type total_num_clusters: int + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr[int] + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: tuple[utils.PersistentTileSchedulerParams, tuple[int, ...]] + """ + # Create problem shape with M, N dimensions from cluster shape + # and L dimension representing the total number of clusters. + problem_shape_ntile_mnl = ( + cluster_shape_mn[0], + cluster_shape_mn[1], + cutlass.Int32(total_num_clusters), + ) + + tile_sched_params = utils.PersistentTileSchedulerParams( + problem_shape_ntile_mnl, (*cluster_shape_mn, 1) + ) + + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def _get_mbar_smem_bytes(**kwargs_stages: int) -> int: + """Calculate shared memory consumption for memory barriers based on provided stages. + + Each stage requires 2 barriers, and each barrier consumes 8 bytes of shared memory. + The total consumption is the sum across all provided stages. This function calculates the total + shared memory needed for these barriers. + + :param kwargs_stages: Variable keyword arguments where each key is a stage name + (e.g., num_acc_stage, num_ab_stage) and each value is the + number of stages of that type. + :type kwargs_stages: int + :return: Total shared memory bytes required for all memory barriers. + :rtype: int + """ + num_barriers_per_stage = 2 + num_bytes_per_barrier = 8 + mbar_smem_consumption = sum( + [ + num_barriers_per_stage * num_bytes_per_barrier * stage + for stage in kwargs_stages.values() + ] + ) + return mbar_smem_consumption + + @staticmethod + def _get_tensormap_smem_bytes( + tensormap_update_mode: utils.TensorMapUpdateMode, + ) -> int: + """Get the SMEM consumption for the tensormap buffer based on the update mode. + + :param tensormap_update_mode: Specifies whether tensormaps are updated in GMEM or SMEM. + :type tensormap_update_mode: utils.TensorMapUpdateMode + :return: The shared memory bytes required for the tensormap buffer. Returns 0 if mode is GMEM. + :rtype: int + :raises ValueError: If an invalid tensormap update mode is provided. + """ + if tensormap_update_mode == utils.TensorMapUpdateMode.GMEM: + return 0 + elif tensormap_update_mode == utils.TensorMapUpdateMode.SMEM: + return ( + GroupedGemmKernel.bytes_per_tensormap * GroupedGemmKernel.num_tensormaps + ) + else: + raise ValueError(f"Invalid tensormap update mode: {tensormap_update_mode}") + + @staticmethod + def _compute_num_tmem_alloc_cols( + tiled_mma: cute.TiledMma, + mma_tiler: tuple[int, int, int], + num_acc_stage: int, + ) -> int: + """ + Compute the number of tensor memory allocation columns. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler: The shape (M, N, K) of the MMA tile. + :type mma_tiler: tuple[int, int, int] + :param acc_stage: The stage of the accumulator tensor. + :type acc_stage: int + + :return: The number of tensor memory allocation columns. + :rtype: int + """ + acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage)) + num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake) + + return num_tmem_alloc_cols + + # Size of smem we reserved for mbarrier, tensor memory management and tensormap update + reserved_smem_bytes = 1024 + bytes_per_tensormap = 128 + num_tensormaps = 3 + # size of smem used for tensor memory management + tensor_memory_management_bytes = 12 + + +# Create tensor and return the pointer, tensor, and stride +def create_tensor_and_stride( + l: int, + mode0: int, + mode1: int, + is_mode0_major: bool, + dtype: type[cutlass.Numeric], + is_dynamic_layout: bool = True, + torch_tensor_cpu: torch.Tensor = None, +) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]: + """Create GPU tensor from either a new or existing CPU tensor. + + :param torch_tensor_cpu: Optional existing CPU tensor to reuse. If None, creates a new one. + :type torch_tensor_cpu: torch.Tensor, optional + """ + if torch_tensor_cpu is None: + # Create new CPU tensor + torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype) + + # Create GPU tensor from CPU tensor (new or existing) + cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like( + torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16 + ) + return ( + torch_tensor.data_ptr(), + torch_tensor, + cute_tensor, + torch_tensor_cpu, + torch_tensor.stride()[:-1], + ) + + +def create_tensors_for_all_groups( + problem_sizes_mnkl: List[tuple[int, int, int, int]], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + torch_fp32_tensors_abc: List[List[torch.Tensor]] = None, +) -> tuple[ + List[List[int]], + List[List[torch.Tensor]], + List[tuple], + List[List[tuple]], + List[List[torch.Tensor]], +]: + if torch_fp32_tensors_abc is not None and len(torch_fp32_tensors_abc) != len( + problem_sizes_mnkl + ): + raise ValueError("torch_fp32_tensors_abc must have one entry per group") + + # Initialize lists to store tensors for all groups + new_torch_fp32_tensors_abc = ( + [] if torch_fp32_tensors_abc is None else torch_fp32_tensors_abc + ) + torch_tensors_abc = [] + cute_tensors_abc = [] + strides_abc = [] + ptrs_abc = [] + + # Iterate through all groups and create tensors for each group + for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl): + # Get existing CPU tensors if available, otherwise None + existing_cpu_a = ( + torch_fp32_tensors_abc[group_idx][0] if torch_fp32_tensors_abc else None + ) + existing_cpu_b = ( + torch_fp32_tensors_abc[group_idx][1] if torch_fp32_tensors_abc else None + ) + existing_cpu_c = ( + torch_fp32_tensors_abc[group_idx][2] if torch_fp32_tensors_abc else None + ) + + # Create tensors (reusing CPU tensors if provided) + ( + ptr_a, + torch_tensor_a, + cute_tensor_a, + tensor_fp32_a, + stride_mk_a, + ) = create_tensor_and_stride( + l, m, k, a_major == "m", ab_dtype, torch_tensor_cpu=existing_cpu_a + ) + ( + ptr_b, + torch_tensor_b, + cute_tensor_b, + tensor_fp32_b, + stride_nk_b, + ) = create_tensor_and_stride( + l, n, k, b_major == "n", ab_dtype, torch_tensor_cpu=existing_cpu_b + ) + ( + ptr_c, + torch_tensor_c, + cute_tensor_c, + tensor_fp32_c, + stride_mn_c, + ) = create_tensor_and_stride( + l, m, n, c_major == "m", c_dtype, torch_tensor_cpu=existing_cpu_c + ) + + # Only append to new_torch_fp32_tensors_abc if we created new CPU tensors + if torch_fp32_tensors_abc is None: + new_torch_fp32_tensors_abc.append( + [tensor_fp32_a, tensor_fp32_b, tensor_fp32_c] + ) + + ptrs_abc.append([ptr_a, ptr_b, ptr_c]) + torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c]) + strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c]) + cute_tensors_abc.append( + ( + cute_tensor_a, + cute_tensor_b, + cute_tensor_c, + ) + ) + + return ( + ptrs_abc, + torch_tensors_abc, + cute_tensors_abc, + strides_abc, + new_torch_fp32_tensors_abc, + ) + + +def run( + num_groups: int, + problem_sizes_mnkl: tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: tuple[int, int], + cluster_shape_mn: tuple[int, int], + use_2cta_instrs: bool, + tensormap_update_mode: utils.TensorMapUpdateMode, + tolerance: float, + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool = False, + **kwargs, +): + """Run grouped GEMM example with specified configurations. + + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :return: Execution time of the GEMM kernel in microseconds + :rtype: float + """ + print("Running Blackwell Grouped GEMM test with:") + print(f"{num_groups} groups") + for i, (m, n, k, l) in enumerate(problem_sizes_mnkl): + print(f"Group {i}: {m}x{n}x{k}x{l}") + print(f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}") + print(f"Tensor map update mode: {tensormap_update_mode}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") + + # Skip unsupported types + if ab_dtype not in { + cutlass.Float16, + cutlass.BFloat16, + }: + raise ValueError(f"Skip unsupported ab_dtype {ab_dtype}") + if c_dtype not in {cutlass.Float16, cutlass.BFloat16, cutlass.Float32}: + raise ValueError(f"Skip unsupported c_dtype {c_dtype}") + # Skip unsupported acc dtype + if acc_dtype not in {cutlass.Float32, cutlass.Float16}: + raise ValueError(f"Skip unsupported acc_dtype {acc_dtype}") + # Skip invalid ab_dtype and acc_dtype combination + if ab_dtype == cutlass.BFloat16 and acc_dtype == cutlass.Float16: + raise ValueError("Skip invalid ab_dtype and acc_dtype combination") + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + raise ValueError(f"Skip invalid mma tiler M {mma_tiler_mn[0]}") + if mma_tiler_mn[1] not in range(32, 257, 32): + raise ValueError(f"Skip invalid mma tiler N {mma_tiler_mn[1]}") + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + raise ValueError( + f"cluster_shape_m need align with use_2cta_instrs config {cluster_shape_mn}" + ) + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + raise ValueError(f"Skip invalid cluster shape {cluster_shape_mn}") + + # Skip illegal problem shape for load/store alignment + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + raise ValueError("Skip invalid problem alignment") + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + # Create tensors for all groups using the new function + ( + ptrs_abc, + torch_tensors_abc, + cute_tensors_abc, + strides_abc, + torch_fp32_tensors_abc, + ) = create_tensors_for_all_groups( + problem_sizes_mnkl, + ab_dtype, + c_dtype, + a_major, + b_major, + c_major, + ) + + # Choose A, B, C with the smallest size to create initial tensormaps + key_size_a = lambda item: item[1][0] * item[1][2] + key_size_b = lambda item: item[1][1] * item[1][2] + key_size_c = lambda item: item[1][0] * item[1][1] + # Find the indices of the groups with the smallest tensor sizes + min_a_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_a) + min_b_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_b) + min_c_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_c) + initial_cute_tensors_abc = [ + cute_tensors_abc[min_a_idx][0], # A with smallest (m, k) + cute_tensors_abc[min_b_idx][1], # B with smallest (n, k) + cute_tensors_abc[min_c_idx][2], # C with smallest (m, n) + ] + + hardware_info = utils.HardwareInfo() + sm_count = hardware_info.get_max_active_clusters(1) + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + # Prepare tensormap buffer for each SM + num_tensormap_buffers = sm_count + tensormap_shape = ( + num_tensormap_buffers, + GroupedGemmKernel.num_tensormaps, + GroupedGemmKernel.bytes_per_tensormap // 8, + ) + tensor_of_tensormap, tensor_of_tensormap_torch = cutlass_torch.cute_tensor_like( + torch.empty(tensormap_shape, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + ) + + grouped_gemm = GroupedGemmKernel( + acc_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + tensormap_update_mode, + ) + + # layout (num_groups, 4):(4, 1) + ( + tensor_of_dim_size_mnkl, + tensor_of_dim_size_mnkl_torch, + ) = cutlass_torch.cute_tensor_like( + torch.tensor(problem_sizes_mnkl, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, + ) + + # layout (num_groups, 3, 2):(6, 2, 1) + tensor_of_strides_abc, tensor_of_strides_abc_torch = cutlass_torch.cute_tensor_like( + torch.tensor(strides_abc, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, + ) + + # layout (num_groups,3):(3, 1) + tensor_of_ptrs_abc, tensor_of_ptrs_abc_torch = cutlass_torch.cute_tensor_like( + torch.tensor(ptrs_abc, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + assumed_align=16, + ) + + # Compute total number of cluster tiles we need to compute for given grouped GEMM problem + def compute_total_num_clusters( + problem_sizes_mnkl: List[tuple[int, int, int, int]], + cluster_tile_shape_mn: tuple[int, int], + ) -> int: + total_num_clusters = 0 + for m, n, _, _ in problem_sizes_mnkl: + num_clusters_mn = tuple( + (x + y - 1) // y for x, y in zip((m, n), cluster_tile_shape_mn) + ) + total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) + return total_num_clusters + + # Compute cluster tile shape + def compute_cluster_tile_shape( + mma_tiler_mn: tuple[int, int], + cluster_shape_mn: tuple[int, int], + use_2cta_instrs: bool, + ) -> tuple[int, int]: + cta_tile_shape_mn = list(mma_tiler_mn) + if use_2cta_instrs: + cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2 + return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) + + cluster_tile_shape_mn = compute_cluster_tile_shape( + mma_tiler_mn, cluster_shape_mn, use_2cta_instrs + ) + total_num_clusters = compute_total_num_clusters( + problem_sizes_mnkl, cluster_tile_shape_mn + ) + + # Initialize Stream + current_stream = cutlass_torch.default_stream() + + # Compile grouped GEMM kernel + compiled_grouped_gemm = cute.compile( + grouped_gemm, + initial_cute_tensors_abc[0], + initial_cute_tensors_abc[1], + initial_cute_tensors_abc[2], + num_groups, + tensor_of_dim_size_mnkl, + tensor_of_strides_abc, + tensor_of_ptrs_abc, + total_num_clusters, + tensor_of_tensormap, + max_active_clusters, + current_stream, + ) + + if not skip_ref_check: + compiled_grouped_gemm( + initial_cute_tensors_abc[0], + initial_cute_tensors_abc[1], + initial_cute_tensors_abc[2], + tensor_of_dim_size_mnkl, + tensor_of_strides_abc, + tensor_of_ptrs_abc, + tensor_of_tensormap, + current_stream, + ) + + # Compute reference result + for i, (a, b, c) in enumerate(torch_tensors_abc): + ref = torch.einsum( + "mkl,nkl->mnl", + a.cpu().to(dtype=torch.float32), + b.cpu().to(dtype=torch.float32), + ) + print(f"checking group {i}") + torch.testing.assert_close( + c.cpu(), + ref.to(cutlass_torch.dtype(c_dtype)), + atol=tolerance, + rtol=1e-05, + ) + + def generate_tensors(): + # Reuse existing CPU tensors and create new GPU tensors from them + ( + ptrs_abc_workspace, + torch_tensors_abc_workspace, + cute_tensors_abc_workspace, + strides_abc_workspace, + _, + ) = create_tensors_for_all_groups( + problem_sizes_mnkl, + ab_dtype, + c_dtype, + a_major, + b_major, + c_major, + torch_fp32_tensors_abc, + ) + + initial_cute_tensors_abc_workspace = [ + cute_tensors_abc_workspace[min_a_idx][0], # A with smallest (m, k) + cute_tensors_abc_workspace[min_b_idx][1], # B with smallest (n, k) + cute_tensors_abc_workspace[min_c_idx][2], # C with smallest (m, n) + ] + + # Create new tensors for this workspace + tensor_of_strides_abc_workspace, _ = cutlass_torch.cute_tensor_like( + torch.tensor(strides_abc_workspace, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, + ) + + tensor_of_ptrs_abc_workspace, _ = cutlass_torch.cute_tensor_like( + torch.tensor(ptrs_abc_workspace, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + assumed_align=16, + ) + + tensormap_workspace, _ = cutlass_torch.cute_tensor_like( + torch.empty(tensormap_shape, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + ) + + return testing.JitArguments( + initial_cute_tensors_abc_workspace[0], + initial_cute_tensors_abc_workspace[1], + initial_cute_tensors_abc_workspace[2], + tensor_of_dim_size_mnkl, + tensor_of_strides_abc_workspace, + tensor_of_ptrs_abc_workspace, + tensormap_workspace, + current_stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + sum( + [ + sum( + [ + torch_tensor.numel() * torch_tensor.element_size() + for torch_tensor in group_tensors + ] + ) + for group_tensors in torch_tensors_abc + ] + ) + + + # Add size of strides tensor + tensor_of_strides_abc_torch.numel() + * tensor_of_strides_abc_torch.element_size() + + + # Add size of ptrs tensor + tensor_of_ptrs_abc_torch.numel() * tensor_of_ptrs_abc_torch.element_size() + + + # Add size of tensormap tensor + tensor_of_tensormap_torch.numel() * tensor_of_tensormap_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_grouped_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + def parse_comma_separated_tuples(s: str) -> List[tuple[int, ...]]: + if s.strip().startswith("("): + # Split on ),( to separate tuples + tuples = s.strip("()").split("),(") + result = [] + tuple_len = None + + for t in tuples: + # Parse individual tuple + nums = [int(x.strip()) for x in t.split(",")] + + # Validate tuple length consistency + if tuple_len is None: + tuple_len = len(nums) + elif len(nums) != tuple_len: + raise argparse.ArgumentTypeError( + "All tuples must have the same length" + ) + + result.append(tuple(nums)) + return result + + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers or list of tuples" + ) + + parser = argparse.ArgumentParser( + description="Example of Grouped GEMM on Blackwell." + ) + parser.add_argument( + "--num_groups", + type=int, + default=2, + help="Number of groups", + ) + parser.add_argument( + "--problem_sizes_mnkl", + type=parse_comma_separated_tuples, + default=((128, 128, 128, 1), (128, 128, 128, 1)), + help="a tuple of problem sizes for each group (comma-separated tuples)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument( + "--tensormap_update_mode", + type=str, + default="SMEM", + help="Tensor map update mode", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.Float16) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float16) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument( + "--use_2cta_instrs", + action="store_true", + help="Enable 2CTA MMA instructions feature", + ) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if ( + len(args.problem_sizes_mnkl) != 0 + and len(args.problem_sizes_mnkl) != args.num_groups + ): + parser.error("--problem_sizes_mnkl must contain exactly num_groups tuples") + + # l mode must be 1 for all groups + for _, _, _, l in args.problem_sizes_mnkl: + if l != 1: + parser.error("l must be 1 for all groups") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + if args.tensormap_update_mode not in ["GMEM", "SMEM"]: + parser.error("--tensormap_update_mode must be GMEM or SMEM") + + if args.tensormap_update_mode == "GMEM": + tensormap_update_mode = utils.TensorMapUpdateMode.GMEM + else: + tensormap_update_mode = utils.TensorMapUpdateMode.SMEM + + torch.manual_seed(2025) + + run( + args.num_groups, + args.problem_sizes_mnkl, + args.ab_dtype, + args.c_dtype, + args.acc_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.use_2cta_instrs, + tensormap_update_mode, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/lookup_table/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/lookup_table/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d6156527bb1331fdd004f506fae7cc69b80e97e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/lookup_table/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/lookup_table/__pycache__/choices.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/lookup_table/__pycache__/choices.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ac0e01ff71c8fe2c0a1140b686e7834fb0c7ce2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/lookup_table/__pycache__/choices.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/package/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/package/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e8a8a7da836bec77df223fab776311758e7c852 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/package/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/package/__pycache__/build_package.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/package/__pycache__/build_package.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa3ae3d360489fa88752f2718f09143e6927948b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/package/__pycache__/build_package.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/package/__pycache__/package.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/package/__pycache__/package.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..317ac927b4a5d25b426beb83423fff607596861c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/package/__pycache__/package.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..390e7e9fa1dd5e9621545e6155b572bf39d59284 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/autotune_cache.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/autotune_cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1d72d6a02b0d20ed60c30a9ed5887a26028725a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/autotune_cache.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/benchmarking.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/benchmarking.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23cb56259d34a3a1b756ffd6a4b1804b58a27755 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/benchmarking.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/cache_dir_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/cache_dir_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c22f80358a4ea07662795544dbf6a4faa4cc077 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/cache_dir_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/compile_tasks.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/compile_tasks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9aee84397ebafebb579fa910075a90f9dac5121d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/compile_tasks.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/coordinate_descent_tuner.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/coordinate_descent_tuner.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f805bd5710e31a2ef0e478f4658c843506e5152 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/coordinate_descent_tuner.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/debug_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/debug_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a57da77de09175f9a5bfe0ffb4213952a9602f9f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/debug_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/halide_helpers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/halide_helpers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..697bf8e2957711fd5f935536e049f1bc0caa3a21 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/halide_helpers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/hints.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/hints.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d0efb369ce58334f5b3d817746eaed8472b4468 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/hints.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/runtime_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/runtime_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ed9f5db0741dc811f846ee6df5d52039a9fa0a7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/runtime_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/static_cuda_launcher.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/static_cuda_launcher.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1b323ebd895ad124585f782addf7436038bd4f8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/static_cuda_launcher.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/triton_compat.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/triton_compat.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a3eab5b1ad18b8d71e3aec070d1603bb8759595 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/triton_compat.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/triton_helpers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/triton_helpers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b220290ce88cd70c7987a7bbcaa5aba86db9e48f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/triton_helpers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb1d364eaf51e009b557e422fe0b5093fe9cfb17 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__init__.py @@ -0,0 +1,68 @@ +from threading import Lock + +from . import config, interfaces as intfs, locks +from .context import IsolationSchema, SelectedCompileContext, SelectedRuntimeContext +from .exceptions import ( + CacheError, + CustomParamsEncoderRequiredError, + CustomResultDecoderRequiredError, + CustomResultEncoderRequiredError, + DeterministicCachingDisabledError, + DeterministicCachingIMCDumpConflictError, + DeterministicCachingInvalidConfigurationError, + DeterministicCachingRequiresStrongConsistencyError, + FileLockTimeoutError, + KeyEncodingError, + KeyPicklingError, + LockTimeoutError, + StrictDeterministicCachingKeyNotFoundError, + SystemError, + UserError, + ValueDecodingError, + ValueEncodingError, + ValuePicklingError, + ValueUnPicklingError, +) + + +# fast cache; does not bother supporting deterministic caching, and is essentially +# a memoized on-disk cache. use when deterministic caching is not required +fcache: intfs._CacheIntf = intfs._FastCacheIntf() +# deterministic cache; slower than fcache but provides deterministic guarantees. +# use when deterministic caching is absolutely required, as this will raise +# an exception if use is attempted when deterministic caching is disabled +dcache: intfs._CacheIntf = intfs._DeterministicCacheIntf() +# inductor cache; defaults to the deterministic cache if deterministic caching +# is enabled, otherwise uses the fast cache. use when you would like deterministic +# caching but are okay with non-deterministic caching if deterministic caching is disabled +icache: intfs._CacheIntf = ( + dcache if config.IS_DETERMINISTIC_CACHING_ENABLED() else fcache +) + +__all__ = [ + "SelectedCompileContext", + "SelectedRuntimeContext", + "IsolationSchema", + "CacheError", + "SystemError", + "UserError", + "LockTimeoutError", + "FileLockTimeoutError", + "KeyEncodingError", + "KeyPicklingError", + "ValueEncodingError", + "ValuePicklingError", + "ValueDecodingError", + "ValueUnPicklingError", + "CustomParamsEncoderRequiredError", + "CustomResultEncoderRequiredError", + "CustomResultDecoderRequiredError", + "DeterministicCachingDisabledError", + "DeterministicCachingRequiresStrongConsistencyError", + "StrictDeterministicCachingKeyNotFoundError", + "DeterministicCachingInvalidConfigurationError", + "DeterministicCachingIMCDumpConflictError", + "fcache", + "dcache", + "icache", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59211917addc662ad020f4a67d608bc12f405efb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/config.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36bf87ee2df996ff09b831f6fa17d79534ba27bd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/config.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/context.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/context.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..994673d96811025d111c35835b4da2f109eef9d1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/context.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/exceptions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/exceptions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..021b907054c72d9e1ee819c9e7056e8eedac2514 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/exceptions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/implementations.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/implementations.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64438d19cab89fe688fa8ff19989857ebdd93e03 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/implementations.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/interfaces.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/interfaces.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28795fb01cc79177a856e1bb95f017a0a1d67ee0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/interfaces.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/locks.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/locks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5be0e47e56b7173889775e59bb29315b82af45e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/locks.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af14b69fcf497acb4edd036a4b51b544bb305132 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/config.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/config.py new file mode 100644 index 0000000000000000000000000000000000000000..14e13f937dbb75ad0b8ca0c197df3e8c2559c098 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/config.py @@ -0,0 +1,127 @@ +import os +from collections.abc import Callable +from functools import cache, partial + +import torch +from torch._environment import is_fbcode + + +@cache +def _env_var_config(env_var: str, default: bool) -> bool: + if (env_val := os.environ.get(env_var)) is not None: + return env_val == "1" + return default + + +@cache +def _versioned_config( + jk_name: str, + this_version: int, + oss_default: bool, + env_var_override: str | None = None, +) -> bool: + """ + A versioned configuration utility that determines boolean settings based on: + 1. Environment variable override (highest priority) + 2. JustKnobs version comparison in fbcode environments + 3. OSS default fallback + + This function enables gradual rollouts of features in fbcode by comparing + a local version against a JustKnobs-controlled remote version, while + allowing environment variable overrides for testing and OSS defaults + for non-fbcode environments. + + Args: + jk_name: JustKnobs key name (e.g., "pytorch/inductor:feature_version") + this_version: Local version number to compare against JustKnobs version + oss_default: Default value to use in non-fbcode environments + env_var_override: Optional environment variable name that, when set, + overrides all other logic + + Returns: + bool: Configuration value determined by the priority order above + """ + if ( + env_var_override + and (env_var_value := os.environ.get(env_var_override)) is not None + ): + return env_var_value == "1" + elif is_fbcode(): + # this method returns 0 on failure, which we should check for specifically. + # in the case of JK failure, the safe bet is to simply disable the config + jk_version: int = torch._utils_internal.justknobs_getval_int(jk_name) + return (this_version >= jk_version) and (jk_version != 0) + return oss_default + + +# toggles the entire caching module, but only when calling through the +# public facing interfaces. get/insert operations become no-ops in the sense +# that get will always miss and insert will never insert; record becomes a +# no-op in the sense that the function will always be called and the cache +# will never be accessed +_CACHING_MODULE_VERSION: int = 0 +_CACHING_MODULE_VERSION_JK: str = "pytorch/inductor:caching_module_version" +_CACHING_MODULE_OSS_DEFAULT: bool = False +_CACHING_MODULE_ENV_VAR_OVERRIDE: str = "TORCHINDUCTOR_ENABLE_CACHING_MODULE" +IS_CACHING_MODULE_ENABLED: Callable[[], bool] = partial( + _versioned_config, + _CACHING_MODULE_VERSION_JK, + _CACHING_MODULE_VERSION, + _CACHING_MODULE_OSS_DEFAULT, + _CACHING_MODULE_ENV_VAR_OVERRIDE, +) + + +# toggles the deterministic caching interface. silently disabling deterministic +# caching (i.e. by mimicking the functionality of IS_CACHING_MODULE_ENABLED) can +# be problematic if the user is directly calling the deterministic caching interface +# (for example, if they were to interface with dcache instead of icache). instead, if +# the user tries to use the deterministic caching interface while it is disabled we +# will simply throw DeterministicCachingDisabledError +_DETERMINISTIC_CACHING_VERSION: int = 0 +_DETERMINISTIC_CACHING_VERSION_JK: str = ( + "pytorch/inductor:deterministic_caching_version" +) +_DETERMINISTIC_CACHING_OSS_DEFAULT: bool = False +_DETERMINISTIC_CACHING_ENV_VAR_OVERRIDE: str = ( + "TORCHINDUCTOR_ENABLE_DETERMINISTIC_CACHING" +) +IS_DETERMINISTIC_CACHING_ENABLED: Callable[[], bool] = partial( + _versioned_config, + _DETERMINISTIC_CACHING_VERSION_JK, + _DETERMINISTIC_CACHING_VERSION, + _DETERMINISTIC_CACHING_OSS_DEFAULT, + _DETERMINISTIC_CACHING_ENV_VAR_OVERRIDE, +) + +# enabling strictly pre-populated determinism forces the deterministic caching +# interface to pull from and only from a pre-populated in-memory cache. this +# in-memory cache gets pre-populated from a file path that is specified by +# environment variable "TORCHINDUCTOR_PRE_POPULATE_DETERMINISTIC_CACHE". +# coincidentally, the deterministic caching interface will dump its in-memory +# cache to disk on program exit (check the logs for the exact file path) which +# can be used as a drop-in solution for pre-population on subsequent runs. if +# strictly pre-populated determinism is enabled and a key is encountered which +# is not covered by the pre-populated in-memory cache an exception, +# StrictDeterministicCachingKeyNotFoundError, will be raised +STRICTLY_PRE_POPULATED_DETERMINISM: bool = _env_var_config( + "TORCHINDUCTOR_STRICTLY_PRE_POPULATED_DETERMINISM", + default=False, +) +# similar to strictly pre-populated determinism, except that any key can either +# be in the pre-populated in-memory cache or the on-disk/remote cache (depending +# on whether or not local/global determinism is enabled). +STRICTLY_CACHED_DETERMINISM: bool = _env_var_config( + "TORCHINDUCTOR_STRICTLY_CACHED_DETERMINISM", + default=False, +) +# local determinism ensures that caching is deterministic on a single machine, +# hence an on-disk cache is used for synchronization of results +LOCAL_DETERMINISM: bool = _env_var_config( + "TORCHINDUCTOR_LOCAL_DETERMINISM", default=(not is_fbcode()) +) +# global determinism ensures that caching is deterministic across any/all machines, +# hence a remote cache (with strong consistency!) is used for synchronization of results +GLOBAL_DETERMINISM: bool = _env_var_config( + "TORCHINDUCTOR_GLOBAL_DETERMINISM", default=is_fbcode() +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/context.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/context.py new file mode 100644 index 0000000000000000000000000000000000000000..7f52a70ff6d70982a5626a1ff48d7078b6b4ccf8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/context.py @@ -0,0 +1,292 @@ +"""Context management for PyTorch Inductor runtime caching. + +This module provides context classes for collecting configuration and environment +information used in caching decisions for PyTorch's Inductor runtime. +""" + +import json +from abc import ABC, abstractmethod +from base64 import b64encode +from collections.abc import Sequence +from functools import cache +from hashlib import sha256 +from typing import Any +from typing_extensions import override, TypedDict + +import torch + + +class _Context(ABC): + """Abstract base class for context providers. + + Context providers collect specific configuration and environment information + that affects compilation and runtime behavior. + """ + + @staticmethod + @abstractmethod + def forms_of_context() -> Sequence[str]: + """Return a sequence of context form names provided by this context class. + + Returns: + A sequence of strings representing the available context forms. + """ + + +class _RuntimeContext(_Context): + """Context provider for runtime configuration and environment settings. + + Collects configuration settings that affect runtime behavior but not + compilation, such as Inductor configs, determinism settings, and CUDA + matmul precision configurations. + """ + + @override + @staticmethod + def forms_of_context() -> Sequence[str]: + """Return the runtime context forms provided by this class. + + Returns: + A sequence containing the available runtime context forms: + - "inductor_configs": PyTorch Inductor configuration settings + - "torch_determinism_configs": Deterministic algorithm settings + - "cuda_matmul_precision_configs": CUDA matrix multiplication precision settings + """ + return ( + "inductor_configs", + "torch_determinism_configs", + "cuda_matmul_precision_configs", + ) + + @staticmethod + def inductor_configs() -> dict[str, Any]: + """Get portable Inductor configuration settings. + + Returns: + A dictionary containing Inductor configuration settings, + including private configs. + """ + from torch._inductor import config + + return config.save_config_portable(ignore_private_configs=False) + + @staticmethod + def torch_determinism_configs() -> dict[str, Any]: + """Get PyTorch deterministic algorithm configuration settings. + + Returns: + A dictionary containing deterministic algorithm settings: + - Whether deterministic algorithms are enabled + - Whether deterministic algorithm warnings are enabled + - Fill uninitialized memory setting + """ + return { + "torch.are_deterministic_algorithms_enabled": torch.are_deterministic_algorithms_enabled(), + "torch.is_deterministic_algorithms_warn_only_enabled": ( + torch.is_deterministic_algorithms_warn_only_enabled() + ), + "torch.utils.deterministic.fill_uninitialized_memory": ( + torch.utils.deterministic.fill_uninitialized_memory # type: ignore[attr-defined] + ), + } + + @staticmethod + def cuda_matmul_precision_configs() -> dict[str, Any]: + """Get CUDA matrix multiplication precision configuration settings. + + Returns: + A dictionary containing CUDA matmul precision settings: + - FP32 precision setting + - FP16 reduced precision reduction allowance + - BF16 reduced precision reduction allowance + """ + return { + "torch.backends.cuda.matmul.fp32_precision": torch.backends.cuda.matmul.fp32_precision, + "torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction": ( + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction + ), + "torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction": ( + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction + ), + } + + +class _CompileContext(_Context): + """Context provider for compilation-related configuration and environment settings. + + Collects information that affects compilation behavior, such as PyTorch and Triton + versions, runtime environment, and accelerator properties. + """ + + @override + @staticmethod + def forms_of_context() -> Sequence[str]: + """Return the compile context forms provided by this class. + + Returns: + A sequence containing the available compile context forms: + - "torch_version_hash": PyTorch version hash + - "triton_version_hash": Triton version hash (if available) + - "runtime": Runtime type (CUDA/HIP/None) + - "runtime_version": Runtime version string + - "accelerator_properties": GPU/accelerator properties + """ + return ( + "torch_version_hash", + "triton_version_hash", + "runtime", + "runtime_version", + "accelerator_properties", + ) + + @cache + @staticmethod + def torch_version_hash() -> str: + """Get base64-encoded PyTorch version hash. + + Returns: + A base64-encoded string representing the PyTorch version hash. + """ + from torch._inductor.codecache import torch_key + + return b64encode(torch_key()).decode() + + @cache + @staticmethod + def triton_version_hash() -> str | None: + """Get Triton version key if Triton is available. + + Returns: + Triton version key if Triton is available, None otherwise. + """ + from torch._inductor.runtime.triton_compat import HAS_TRITON, triton_key + + return triton_key() if HAS_TRITON else None + + @cache + @staticmethod + def runtime() -> str | None: + """Determine the runtime type based on available backends. + + Returns: + "CUDA" if CUDA is available, "HIP" if HIP is available, None otherwise. + """ + return "CUDA" if torch.version.cuda else "HIP" if torch.version.hip else None + + @cache + @staticmethod + def runtime_version() -> str | None: + """Get the version string for the detected runtime. + + Returns: + Version string for the current runtime (CUDA or HIP), or None if + no supported runtime is detected. + """ + return { + "CUDA": torch.version.cuda, + "HIP": torch.version.hip, + }.get(_CompileContext.runtime()) # type: ignore[arg-type] + + @cache + @staticmethod + def accelerator_properties() -> str | None: + """Get string representation of CUDA device properties. + + Returns: + String representation of CUDA device properties if a runtime is + available, None otherwise. + """ + return ( + repr(torch.cuda.get_device_properties()) + if _CompileContext.runtime() and torch.cuda.is_available() + else None + ) + + +class SelectedRuntimeContext(TypedDict): + inductor_configs: bool + torch_determinism_configs: bool + cuda_matmul_precision_configs: bool + + +class SelectedCompileContext(TypedDict): + torch_version_hash: bool + triton_version_hash: bool + runtime: bool + runtime_version: bool + accelerator_properties: bool + + +class IsolationSchema(TypedDict): + """Schema for specifying which context forms to include in cache isolation. + + Attributes: + runtime_context: Either True (include all runtime context), False (exclude all), + or a SelectedRuntimeContext dict specifying which forms to include. + compile_context: Either True (include all compile context), False (exclude all), + or a SelectedCompileContext dict specifying which forms to include. + """ + + runtime_context: SelectedRuntimeContext | bool + compile_context: SelectedCompileContext | bool + + +_DEFAULT_ISOLATION_SCHEMA: IsolationSchema = IsolationSchema( + runtime_context=True, compile_context=True +) + + +def _isolation_context( + ischema: IsolationSchema = _DEFAULT_ISOLATION_SCHEMA, +) -> dict[str, Any]: + """Generate context data based on the isolation schema. + + Args: + ischema: Schema specifying which context forms to include. + Defaults to including all runtime and compile context. + + Returns: + A dictionary containing the selected context data with keys + "runtime_context" and "compile_context", where each value is + either None (if excluded) or a dict of context form data. + """ + isolation_context: dict[str, Any] = {} + for context_name, context_cls in ( + ("runtime_context", _RuntimeContext), + ("compile_context", _CompileContext), + ): + selected_context: dict[str, Any] | None = None + if ischema[context_name] is True: # type: ignore[literal-required] + selected_context = { + form_of_context: getattr(context_cls, form_of_context)() + for form_of_context in context_cls.forms_of_context() + } + elif ischema[context_name] is False: # type: ignore[literal-required] + selected_context = None + else: + selected_context = {} + for form_of_context in ischema[context_name]: # type: ignore[literal-required] + selected = ischema[context_name][form_of_context] # type: ignore[literal-required] + if selected: + selected_context[form_of_context] = getattr( + context_cls, form_of_context + )() + selected_context = selected_context or None + isolation_context[context_name] = selected_context + return isolation_context + + +def _isolation_key(ischema: IsolationSchema = _DEFAULT_ISOLATION_SCHEMA) -> str: + """Generate a unique key for the given isolation schema. + + Args: + ischema: Schema specifying which context forms to include. + Defaults to including all runtime and compile context. + + Returns: + A 32-character hexadecimal string that uniquely identifies + the context specified by the isolation schema. + """ + return sha256( + json.dumps(_isolation_context(ischema), sort_keys=True).encode() + ).hexdigest()[:32] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/exceptions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..02e47fa1e6127a44b45c61966d3aa6e3d9fb65da --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/exceptions.py @@ -0,0 +1,189 @@ +# pyre-strict + +"""Exception classes for PyTorch Inductor runtime caching. + +This module defines a hierarchy of exceptions used throughout the caching system. +All custom exceptions inherit from CacheError, with UserError serving as a base +for user-facing errors that also inherit from TypeError for compatibility. +""" + +from threading import Lock +from typing import Any + +from filelock import FileLock + + +class CacheError(Exception): + """Base class for all caching-related errors. + + This is the root exception class for all custom exceptions raised by the caching + module, providing a common interface for error handling and logging. + """ + + +class SystemError(CacheError, RuntimeError): + """Base class for system-level caching errors. + + This class represents errors that occur during cache operations, such as + storage or retrieval failures. It inherits from RuntimeError to indicate + that the error is not caused by user input. + """ + + +class LockTimeoutError(SystemError): + """Error raised when a lock operation times out. + + This exception is raised when a lock operation exceeds the specified timeout + limit, indicating that the lock could not be acquired within the allotted time. + """ + + def __init__(self, lock: Lock, timeout: float) -> None: + """Initialize the lock timeout error with detailed lock information. + + Args: + lock: The lock object that timed out. + timeout: The timeout limit that was exceeded. + """ + super().__init__(f"Failed to acquire lock {lock} within {timeout} seconds.") + + +class FileLockTimeoutError(SystemError): + """Error raised when a file lock operation times out. + + This exception is raised when a file lock operation exceeds the specified timeout + limit, indicating that the lock could not be acquired within the allotted time. + """ + + def __init__(self, flock: FileLock, timeout: float) -> None: + """Initialize the file lock timeout error with detailed lock information. + + Args: + flock: The file lock object that timed out. + timeout: The timeout limit that was exceeded. + """ + super().__init__( + f"Failed to acquire file lock {flock} within {timeout} seconds." + ) + + +class UserError(CacheError, TypeError): + """Base class for user-facing cache errors that also inherit from TypeError. + + This class combines CacheError with TypeError to provide compatibility + with existing exception handling patterns while maintaining the cache + error hierarchy. All user-facing cache errors should inherit from this class. + """ + + +class KeyEncodingError(UserError): + """Base class for errors that occur during cache key encoding operations. + + Raised when cache keys cannot be properly encoded for storage or transmission. + This includes serialization, hashing, or other encoding-related failures. + """ + + +class KeyPicklingError(KeyEncodingError): + """Error raised when a cache key cannot be pickled for serialization. + + This typically occurs when trying to cache objects with keys that contain + non-serializable components, lambda functions, or other unpickleable types. + """ + + def __init__(self, key: Any) -> None: + """Initialize the key pickling error with detailed key information. + + Args: + key: The cache key that failed to be pickled. + """ + super().__init__( + f"Failed to pickle cache key with type {type(key)} and value {key!r}." + ) + + +class ValueEncodingError(UserError): + """Base class for errors that occur during cache value encoding operations. + + Raised when cache values cannot be properly encoded for storage or transmission. + This includes serialization, compression, or other encoding-related failures. + """ + + +class ValuePicklingError(ValueEncodingError): + """Error raised when a cache value cannot be pickled for serialization. + + This occurs when trying to cache objects that contain non-serializable + components, file handles, network connections, or other unpickleable types. + """ + + def __init__(self, value: Any) -> None: + """Initialize the value pickling error with detailed value information. + + Args: + value: The cache value that failed to be pickled. + """ + super().__init__( + f"Failed to pickle cache value with type {type(value)} and value {value!r}." + ) + + +class ValueDecodingError(UserError): + """Base class for errors that occur during cache value decoding operations. + + Raised when cached values cannot be properly decoded during retrieval. + This includes deserialization, decompression, or other decoding-related failures. + """ + + +class ValueUnPicklingError(ValueDecodingError): + """Error raised when cached value data cannot be unpickled during retrieval. + + This typically indicates corruption, version incompatibility, or missing + dependencies required to reconstruct the cached object. + """ + + def __init__(self, pickled_value: bytes) -> None: + """Initialize the value unpickling error with the problematic data. + + Args: + pickled_value: The bytes that failed to be unpickled. + """ + super().__init__( + f"Failed to unpickle cache value from pickled value {pickled_value!r}." + ) + + +class CustomParamsEncoderRequiredError(UserError): + pass + + +class CustomResultEncoderRequiredError(UserError): + pass + + +class CustomResultDecoderRequiredError(UserError): + pass + + +class DeterministicCachingDisabledError(UserError): + pass + + +class DeterministicCachingRequiresStrongConsistencyError(UserError): + pass + + +class StrictDeterministicCachingKeyNotFoundError(UserError): + pass + + +class DeterministicCachingInvalidConfigurationError(UserError): + pass + + +class StrictDeterministicCachingInsertionError(UserError): + pass + + +class DeterministicCachingIMCDumpConflictError(SystemError): + pass diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/implementations.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/implementations.py new file mode 100644 index 0000000000000000000000000000000000000000..ed83e490fd316059e7d877b63adb2eeaec69ed70 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/implementations.py @@ -0,0 +1,415 @@ +"""Cache implementation classes for PyTorch Inductor runtime caching. + +This module provides concrete implementations of caching backends including +in-memory, on-disk, and remote caching strategies. Each implementation follows +the abstract _CacheImpl interface and provides thread-safe operations with +appropriate locking mechanisms. +""" + +from abc import ABC, abstractmethod +from collections.abc import Generator +from contextlib import contextmanager +from dataclasses import dataclass +from hashlib import sha256 +from io import BufferedReader, BufferedWriter +from os import PathLike +from pathlib import Path +from threading import Lock +from typing import Any +from typing_extensions import override + +from filelock import FileLock + +from . import locks, utils + + +@dataclass +class Hit: + """Result wrapper for hits on cache get operations. + + Allows distinguishing between a cache miss and a cached None value. + + Attributes: + value: The cached value. + """ + + value: Any + + +class Miss: + """Sentinel class representing a cache miss. + + Used to distinguish between a cached None value and a cache miss + when None is a valid cached value. + """ + + +# Singleton instance for cache miss sentinel +miss = Miss() + + +class _CacheImpl(ABC): + """Abstract base class for cache implementations. + + This class defines the interface that all cache implementations must follow. + It provides thread-safe operations through a locking mechanism and supports + both get and insert operations. + + Note: We don't use generics here as doing so would require that the interfaces + know which k/v types the implementation can work with. Instead, we leave that + determination up to the implementation itself and require that the interfaces + handle any potential errors from invalid k/v types being passed to the + implementation. + """ + + def __init__(self) -> None: + """Initialize the cache implementation with a threading lock.""" + self._lock: Lock = Lock() + + @property + def lock(self) -> locks._LockProtocol: + """Get a context manager for acquiring the cache lock. + + Locking of the cache is not done by the implementation itself, but by the + interface that uses it. The interface may want to hold the lock for longer + than a single cache operation, for example when dealing with multiple + cache implementations at once, so we leave that decision up to the interface. + + Args: + timeout: Optional timeout in seconds (float) for acquiring the lock. + + Returns: + A callable that returns a context manager for the lock. + """ + + def _lock_with_timeout( + timeout: float | None = None, + ) -> locks._LockContextManager: + return locks._acquire_lock_with_timeout(self._lock, timeout) + + return _lock_with_timeout + + @abstractmethod + def get(self, key: Any) -> Hit | None: + """Retrieve a value from the cache. + + Args: + key: The key to look up in the cache. + + Returns: + A Hit object on cache hit where Hit.value is the cached value, + or None on cache miss. + """ + + @abstractmethod + def insert(self, key: Any, value: Any) -> bool: + """Insert a key-value pair into the cache. + + Args: + key: The key to insert. + value: The value to associate with the key. + + Returns: + True if the insertion was successful, False if not inserted. + """ + + +class _InMemoryCacheImpl(_CacheImpl): + """In-memory cache implementation using a dictionary. + + This implementation stores key-value pairs in a Python dictionary, + with keys being pickled for consistent hashing. It provides fast + access but is limited by available memory and process lifetime. + """ + + def __init__(self) -> None: + """Initialize the in-memory cache with an empty dictionary.""" + super().__init__() + self._memory: dict[bytes, Any] = {} + + @override + def get(self, key: Any) -> Hit | None: + """Retrieve a value from the in-memory cache. + + Args: + key: The key to look up. Will be pickled for storage. + + Returns: + A Hit object on cache hit where Hit.value is the cached value, + or None on cache miss. + """ + pickled_key: bytes = utils._try_pickle_key(key) + if (value := self._memory.get(pickled_key, miss)) is not miss: + return Hit(value=value) + return None + + @override + def insert(self, key: Any, value: Any) -> bool: + """Insert a key-value pair into the in-memory cache. + + Args: + key: The key to insert. Will be pickled for storage. + value: The value to associate with the key. + + Returns: + True if the insertion was successful (key was new), + False if not inserted (key already existed). + """ + pickled_key: bytes = utils._try_pickle_key(key) + if pickled_key not in self._memory: + self._memory[pickled_key] = value + return True + return False + + +class _OnDiskCacheImpl(_CacheImpl): + """On-disk cache implementation using file system storage. + + This implementation stores cached data as files on disk, with version + headers to handle cache invalidation. It uses file locking to ensure + thread safety across processes and provides persistent storage that + survives process restarts. + + Attributes: + _version: Version number for cache format compatibility. + _version_header_length: Length of the version header in bytes. + """ + + _version: int = 0 + _version_header_length: int = 4 + + def __init__(self, sub_dir: PathLike[str] | None = None) -> None: + """Initialize the on-disk cache with a specified subdirectory. + + Args: + sub_dir: Subdirectory name within the cache directory. + Defaults to empty string if not specified. + """ + self._cache_dir: Path = self._base_dir / (sub_dir or "") + # pyrefly: ignore [bad-assignment] + self._flock: FileLock = FileLock(str(self._cache_dir / "dir.lock")) + + @property + def _base_dir(self) -> Path: + """Get the base directory for cache storage. + + Returns: + Path to the cache directory based on the default cache dir + and the specified subdirectory. + """ + from torch._inductor.runtime.runtime_utils import default_cache_dir + + return Path(default_cache_dir(), "cache") + + def _fpath_from_key(self, key: Any) -> Path: + """Generate a file path from a cache key. + + Args: + key: The cache key to convert to a file path. + + Returns: + A Path object representing the file location for this key. + """ + pickled_key: bytes = utils._try_pickle_key(key) + return self._cache_dir / sha256(pickled_key).hexdigest()[:32] + + @classmethod + def _version_header(cls) -> bytes: + """Generate the version header bytes. + + Returns: + A byte string representing the current cache version header. + """ + return sha256(str(cls._version).encode()).digest()[: cls._version_header_length] + + def _version_header_matches(self, fp: BufferedReader) -> bool: + """Check if the file's version header matches the current version. + + Args: + fp: File pointer positioned at the start of the file. + + Returns: + True if the version header matches, False otherwise. + """ + return fp.read(self._version_header_length) == self._version_header() + + def _write_version_header(self, fp: BufferedWriter) -> None: + """Write the version header to a file. + + Args: + fp: File pointer where the version header should be written. + """ + fp.write(self._version_header()) + + @override + @property + def lock(self) -> locks._LockProtocol: + """Get a context manager for acquiring the file lock. + + Uses file locking to ensure thread safety across processes. + + Args: + timeout: Optional timeout in seconds (float) for acquiring the file lock. + + Returns: + A callable that returns a context manager for the file lock. + """ + + def _lock_with_timeout( + timeout: float | None = None, + ) -> locks._LockContextManager: + return locks._acquire_flock_with_timeout(self._flock, timeout) + + return _lock_with_timeout + + @override + def get(self, key: Any) -> Hit | None: + """Retrieve a value from the on-disk cache. + + Args: + key: The key to look up in the cache. + + Returns: + A Hit object on cache hit where Hit.value is the cached value, + or None on cache miss or version mismatch. + """ + fpath: Path = self._fpath_from_key(key) + + if not fpath.is_file(): + return None + + pickled_value: bytes | None = None + with open(fpath, "rb") as fp: + if self._version_header_matches(fp): + pickled_value = fp.read() + + if not pickled_value: + # if pickled_value is still None, even though the file exists, then + # we know that the version header did not match. in this case implementation + # is up to preference, we choose to remove entries that do not match + # the version header so that the key can be re-cached later with the correct + # version header + fpath.unlink() + return None + + return Hit(value=utils._try_unpickle_value(pickled_value)) + + @override + def insert(self, key: Any, value: Any) -> bool: + """Insert a key-value pair into the on-disk cache. + + Args: + key: The key to insert. + value: The value to associate with the key. + + Returns: + True if successfully inserted, False if the key already exists + with a valid version. + """ + fpath: Path = self._fpath_from_key(key) + fpath.parent.mkdir(parents=True, exist_ok=True) + + r_fp, w_fp, inserted = None, None, False + try: + w_fp = open(fpath, "xb") # noqa: SIM115 + except FileExistsError: + is_stale: bool = False + with open(fpath, "rb") as r_fp: + is_stale = not self._version_header_matches(r_fp) + + if is_stale: + # same story as above, in this case the version header doesn't + # match so we choose to remove the old entry so that the new + # k/v pair can be cached + fpath.unlink() + w_fp = open(fpath, "xb") # noqa: SIM115 + else: + w_fp = None + finally: + if w_fp: + try: + pickled_value: bytes = utils._try_pickle_value(value) + self._write_version_header(w_fp) + w_fp.write(pickled_value) + inserted = True + finally: + w_fp.close() + + return inserted + + +try: + from .fb.implementations import _RemoteCacheImpl +except ModuleNotFoundError: + + class _RemoteCacheImpl(_CacheImpl): # type: ignore[no-redef] + """Fallback remote cache implementation for non-Facebook environments. + + This is a no-op implementation that always raises NotImplementedError. + The actual remote cache implementation is provided in the `.fb` module + for Facebook-specific environments. + + Attributes: + _version: Version number for cache format compatibility. + has_strong_consistency: Whether the remote cache provides strong + consistency guarantees. + """ + + _version: int = 0 + has_strong_consistency: bool = False + + def __init__(self) -> None: + """Initialize the fallback remote cache implementation. + + Note: We don't need to initialize any form of lock since this + implementation provides a pseudo-lock context manager. + """ + + @override + @property + def lock(self) -> locks._LockProtocol: + """Get a pseudo lock that does nothing. + + Most remote cache implementations don't have an ability to implement + any form of locking, so we provide a no-op pseudo-lock for consistency + with the interface. + + Args: + timeout: Optional timeout in seconds (float). Ignored in this + + Returns: + A callable that returns a no-op context manager. + """ + + @contextmanager + def pseudo_lock( + timeout: float | None = None, + ) -> Generator[None, None, None]: + yield + + return pseudo_lock + + @override + def get(self, key: Any) -> Hit | None: + """Raise NotImplementedError for remote cache get operations. + + Args: + key: The key to look up (ignored). + + Raises: + NotImplementedError: Always raised as this is a fallback implementation. + """ + raise NotImplementedError + + @override + def insert(self, key: Any, value: Any) -> bool: + """Raise NotImplementedError for remote cache insert operations. + + Args: + key: The key to insert (ignored). + value: The value to insert (ignored). + + Raises: + NotImplementedError: Always raised as this is a fallback implementation. + """ + raise NotImplementedError diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/interfaces.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/interfaces.py new file mode 100644 index 0000000000000000000000000000000000000000..eb4b8251bc3997c6e03e742af55ad879266eaa73 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/interfaces.py @@ -0,0 +1,818 @@ +from __future__ import annotations + +import atexit +import json +import os +from abc import ABC, abstractmethod +from ast import literal_eval +from enum import Enum +from functools import partial, wraps +from logging import DEBUG, getLogger, INFO, Logger +from os import PathLike +from pathlib import Path +from threading import Lock +from time import time +from typing import Any, TYPE_CHECKING, TypeAlias +from typing_extensions import override + +from . import config, context, exceptions, implementations as impls, locks + + +if TYPE_CHECKING: + from collections.abc import Callable + + from .utils import P, R + + +# ideally we could annotate this as tuple[P.args, P.kwargs] but +# functionally that doesn't work as P is defined in a specific +# scope and P.args/P.kwargs are only valid in that scope +Params: TypeAlias = tuple[Any, Any] + +logger: Logger = getLogger(__name__) + + +class _IntfCallbackOrigin(Enum): + RECORD = "record" + GET = "get" + INSERT = "insert" + + +class _IntfCallbackAction(Enum): + REPLAY = "replay" + RECORD_INSERTED = "record_inserted" + RECORD_NOT_INSERTED = "record_not_inserted" + RECORD_NOT_INSERTED_REPLAY = "record_not_inserted_replay" + HIT = "hit" + MISS = "miss" + INSERTED = "inserted" + NOT_INSERTED = "not_inserted" + + +def _intf_callback( + origin: _IntfCallbackOrigin, + action: _IntfCallbackAction, + dur: float, + fn: Callable[P, R], + params: Params, + *args: Any, +) -> None: + if origin == _IntfCallbackOrigin.RECORD: + result: R = args[0] + if action == _IntfCallbackAction.REPLAY: + logger.log( + DEBUG, + "[RECORD] for fn %s with params %r cached, " + "returned result %r in %f seconds.", + fn.__name__, + params, + result, + dur, + ) + elif action == _IntfCallbackAction.RECORD_INSERTED: + fn_dur: float = args[1] + logger.log( + DEBUG, + "[RECORD] for fn %s with params %r not cached, " + "calculated and cached result %r in %f seconds " + "of which %f seconds was spent on the function call.", + fn.__name__, + params, + result, + dur, + fn_dur, + ) + elif action == _IntfCallbackAction.RECORD_NOT_INSERTED: + fn_dur = args[1] + logger.log( + DEBUG, + "[RECORD] for fn %s with params %r not cached, " + "calculated result %r but was not able to " + "insert it into the cache as a matching " + "entry already exists; returned calculated result in %f seconds " + "of which %f seconds was spent on the function call.", + fn.__name__, + params, + result, + dur, + fn_dur, + ) + elif action == _IntfCallbackAction.RECORD_NOT_INSERTED_REPLAY: + fn_dur = args[1] + cached_result: R = args[2] + logger.log( + DEBUG, + "[RECORD] for fn %s with params %r not cached, " + "calculated result %r but was not able to " + "insert it into the synchronization cache as a matching " + "entry already exists; returned cached result %r in %f seconds " + "of which %f seconds was spent on the function call.", + fn.__name__, + params, + result, + cached_result, + dur, + fn_dur, + ) + else: + raise NotImplementedError + elif origin == _IntfCallbackOrigin.GET: + if action == _IntfCallbackAction.HIT: + result = args[0] + logger.log( + DEBUG, + "[GET] for fn %s with params %r cached, " + "returned result %r in %f seconds.", + fn.__name__, + params, + result, + dur, + ) + elif action == _IntfCallbackAction.MISS: + logger.log( + DEBUG, + "[GET] for fn %s with params %r not cached, " + "returned nothing in %f seconds.", + fn.__name__, + params, + dur, + ) + else: + raise NotImplementedError + elif origin == _IntfCallbackOrigin.INSERT: + result = args[0] + if action == _IntfCallbackAction.INSERTED: + logger.log( + DEBUG, + "[INSERT] for fn %s with params %r and " + "result %r inserted in %f seconds.", + fn.__name__, + params, + result, + dur, + ) + elif action == _IntfCallbackAction.NOT_INSERTED: + logger.log( + DEBUG, + "[INSERT] for fn %s with params %r and " + "result %r not inserted in %f seconds as there is " + "already has a matching entry.", + fn.__name__, + params, + result, + dur, + ) + else: + raise NotImplementedError + else: + raise NotImplementedError + + +class _CacheIntf(ABC): + def __init__(self) -> None: + self._lock: Lock = Lock() + + def _make_key( + self, + fn: Callable[P, R], + params: Params, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + ) -> Any: + callee: str = fn.__name__ + fkey: Any = ( + (callee, params) + if not custom_params_encoder + # pyrefly: ignore [invalid-param-spec] + else (callee, custom_params_encoder(*params[0], **params[1])) + ) + ikey: Any = context._isolation_key( + ischema if ischema is not None else context._DEFAULT_ISOLATION_SCHEMA + ) + return (fkey, ikey) + + def _make_dummy_record_wrapper(self, fn: Callable[P, R]) -> Callable[P, R]: + @wraps(fn) + def dummy_wrapper(*args: Any, **kwargs: Any) -> R: + # pyrefly: ignore [invalid-param-spec] + return fn(*args, **kwargs) + + # pyrefly: ignore [bad-return] + return dummy_wrapper + + @abstractmethod + def _make_record_wrapper( + self, + fn: Callable[P, R], + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_encoder: Callable[[R], Any] | None = None, + custom_result_decoder: Callable[[Any], R] | None = None, + ) -> Callable[P, R]: + pass + + @abstractmethod + def _get( + self, + fn: Callable[P, R], + params: Params, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_decoder: Callable[[Any], R] | None = None, + ) -> impls.Hit | None: + pass + + @abstractmethod + def _insert( + self, + fn: Callable[P, R], + params: Params, + result: R, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_encoder: Callable[[R], Any] | None = None, + ) -> bool: + pass + + @property + def lock(self) -> locks._LockProtocol: + """Get a context manager for acquiring the file lock. + + Uses file locking to ensure thread safety across processes. + + Args: + timeout: Optional timeout in seconds (float) for acquiring the file lock. + + Returns: + A callable that returns a context manager for the file lock. + """ + + def _lock_with_timeout( + timeout: float | None = None, + ) -> locks._LockContextManager: + return locks._acquire_lock_with_timeout(self._lock, timeout) + + return _lock_with_timeout + + def get( + self, + fn: Callable[P, R], + params: Params, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_decoder: Callable[[Any], R] | None = None, + ) -> impls.Hit | None: + if not config.IS_CACHING_MODULE_ENABLED(): + return None + + start_t: float = time() + with self.lock(): # type: ignore[call-arg] + result: impls.Hit | None = self._get( + fn, + params, + ischema=ischema, + custom_params_encoder=custom_params_encoder, + custom_result_decoder=custom_result_decoder, + ) + dur: float = time() - start_t + + _intf_callback( + _IntfCallbackOrigin.GET, + _IntfCallbackAction.HIT if result else _IntfCallbackAction.MISS, + dur, + fn, + params, + *((result.value,) if result else ()), + ) + + return result + + def insert( + self, + fn: Callable[P, R], + params: Params, + result: R, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_encoder: Callable[[R], Any] | None = None, + ) -> bool: + if not config.IS_CACHING_MODULE_ENABLED(): + return False + + start_t: float = time() + with self.lock(): # type: ignore[call-arg] + inserted: bool = self._insert( + fn, + params, + result, + ischema=ischema, + custom_params_encoder=custom_params_encoder, + custom_result_encoder=custom_result_encoder, + ) + dur: float = time() - start_t + + _intf_callback( + _IntfCallbackOrigin.INSERT, + _IntfCallbackAction.INSERTED + if inserted + else _IntfCallbackAction.NOT_INSERTED, + dur, + fn, + params, + result, + ) + + return inserted + + def record( + self, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[..., Any] | None = None, + custom_result_encoder: Callable[..., Any] | None = None, + custom_result_decoder: Callable[..., ...] | None = None, + ) -> Callable[[Callable[..., ...]], Callable[..., ...]]: + if custom_result_encoder and not custom_result_decoder: + raise exceptions.CustomResultDecoderRequiredError( + "Custom result encoder provided without custom result decoder." + ) + elif not custom_result_encoder and custom_result_decoder: + raise exceptions.CustomResultEncoderRequiredError( + "Custom result decoder provided without custom result encoder." + ) + elif not config.IS_CACHING_MODULE_ENABLED(): + return self._make_dummy_record_wrapper + else: + return partial( + self._make_record_wrapper, + ischema=ischema, + custom_params_encoder=custom_params_encoder, + custom_result_encoder=custom_result_encoder, + custom_result_decoder=custom_result_decoder, + ) + + +class _FastCacheIntf(_CacheIntf): + def __init__(self) -> None: + super().__init__() + self._imc: impls._InMemoryCacheImpl = impls._InMemoryCacheImpl() + self._callee_to_odc: dict[str, impls._OnDiskCacheImpl] = {} + + def _get_odc_from_callee(self, callee: str) -> impls._OnDiskCacheImpl: + if not (odc := self._callee_to_odc.get(callee)): + callee_sub_dir: PathLike[str] = Path(callee) + odc = impls._OnDiskCacheImpl(sub_dir=callee_sub_dir) + self._callee_to_odc[callee] = odc + # pyrefly: ignore [unbound-name] + return odc + + @override + def _make_record_wrapper( + self, + fn: Callable[P, R], + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_encoder: Callable[[R], Any] | None = None, + custom_result_decoder: Callable[[Any], R] | None = None, + ) -> Callable[P, R]: + @wraps(fn) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + start_t: float = time() + params = ( + args, + kwargs, + ) + with self.lock(): + get: impls.Hit | None = self._get( + fn, + params, + ischema=ischema, + custom_params_encoder=custom_params_encoder, + custom_result_decoder=custom_result_decoder, + ) + + if get: + dur: float = time() - start_t + _intf_callback( + _IntfCallbackOrigin.RECORD, + _IntfCallbackAction.REPLAY, + dur, + fn, + params, + get.value, + ) + return get.value + else: + fn_start_t: float = time() + result: R = fn(*args, **kwargs) + fn_dur: float = time() - fn_start_t + inserted: bool = self._insert( + fn, + params, + result, + ischema=ischema, + custom_params_encoder=custom_params_encoder, + custom_result_encoder=custom_result_encoder, + ) + dur = time() - start_t + _intf_callback( + _IntfCallbackOrigin.RECORD, + _IntfCallbackAction.RECORD_INSERTED + if inserted + else _IntfCallbackAction.RECORD_NOT_INSERTED, + dur, + fn, + params, + result, + fn_dur, + ) + return result + + return wrapper + + @override + def _get( + self, + fn: Callable[P, R], + params: Params, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_decoder: Callable[[Any], R] | None = None, + ) -> impls.Hit | None: + key: Any = self._make_key( + fn, params, ischema=ischema, custom_params_encoder=custom_params_encoder + ) + odc: impls._OnDiskCacheImpl = self._get_odc_from_callee(fn.__name__) + with locks._acquire_many_impl_locks_with_timeout(self._imc, odc): + try: + # we'll check the memoization first, since that is much faster + # than checking the on-disk cache (and the two should be consistent + # regardless) + imc_get: impls.Hit | None = self._imc.get(key) + if imc_get: + if custom_result_decoder: + return impls.Hit(value=custom_result_decoder(imc_get.value)) + else: + return imc_get + else: + odc_get: impls.Hit | None = odc.get(key) + if odc_get: + if custom_result_decoder: + return impls.Hit(value=custom_result_decoder(odc_get.value)) + return odc_get + return None + except exceptions.KeyEncodingError as err: + raise exceptions.CustomParamsEncoderRequiredError(fn, params) from err + + @override + def _insert( + self, + fn: Callable[P, R], + params: Params, + result: R, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_encoder: Callable[[R], Any] | None = None, + ) -> bool: + key: Any = self._make_key( + fn, params, ischema=ischema, custom_params_encoder=custom_params_encoder + ) + odc: impls._OnDiskCacheImpl = self._get_odc_from_callee(fn.__name__) + with locks._acquire_many_impl_locks_with_timeout(self._imc, odc): + try: + encoded_result: Any = ( + result + if not custom_result_encoder + else custom_result_encoder(result) + ) + # reverse order of get, as we don't want to memoize values + # if we haven't actually inserted them into the on-disk cache + # so that the memoization and the on-disk cache remain consistent + if odc.insert(key, encoded_result): + assert self._imc.insert(key, encoded_result) + return True + return False + except exceptions.KeyEncodingError as err: + raise exceptions.CustomParamsEncoderRequiredError(fn, params) from err + except exceptions.ValueEncodingError as err: + raise exceptions.CustomResultEncoderRequiredError( + f"Custom result encoder required for function {fn} with parameters {params} and result {result}." + ) from err + + +class _DeterministicCacheIntf(_CacheIntf): + def __init__(self) -> None: + super().__init__() + self._imc: impls._InMemoryCacheImpl = impls._InMemoryCacheImpl() + + if fpath_str := os.environ.get( + "TORCHINDUCTOR_PRE_POPULATE_DETERMINISTIC_CACHE" + ): + fpath: Path = Path(fpath_str) + fpath_parent: PathLike[str] = fpath.parent + if fpath.is_file(): + odc: impls._OnDiskCacheImpl = impls._OnDiskCacheImpl( + sub_dir=fpath_parent + ) + with odc.lock(): + with open(fpath) as fp: + dump_for_pre_population: dict[str, str] = json.load(fp) + for key_r, value_r in dump_for_pre_population.items(): + key: bytes = literal_eval(key_r) + value: bytes = literal_eval(value_r) + self._imc._memory[key] = value + + if config.STRICTLY_PRE_POPULATED_DETERMINISM: + # we'll never need a synchronization cache if we're in strictly pre-populated mode, + # as we'll only ever be checking the memoized pre-population + self._get_sc_from_callee: Callable[ + [str], None | impls._OnDiskCacheImpl | impls._RemoteCacheImpl + ] = lambda callee: None + elif config.GLOBAL_DETERMINISM: + # if we want global determinism we need to use a remote cache with strong + # consistency as the synchronization cache + self._rc: impls._RemoteCacheImpl = impls._RemoteCacheImpl() + if not self._rc.has_strong_consistency: + raise exceptions.DeterministicCachingRequiresStrongConsistencyError + self._get_sc_from_callee = lambda callee: self._rc + elif config.LOCAL_DETERMINISM: + # local determinism can use the on-disk cache as the synchronization cache, + # for cleanliness of the on-disk cache we subdir based on the callee + self._callee_to_odc: dict[str, impls._OnDiskCacheImpl] = {} + self._get_sc_from_callee = self._get_odc_from_callee + else: + raise exceptions.DeterministicCachingInvalidConfigurationError( + "Deterministic caching must specify at least one of STRICTLY_PRE_POPULATED_DETERMINISM, " + "GLOBAL_DETERMINISM, or LOCAL_DETERMINISM." + ) + + atexit.register(self._dump_imc_to_disk) + + def __del__(self) -> None: + atexit.unregister(self._dump_imc_to_disk) + del self + + def _get_odc_from_callee(self, callee: str) -> impls._OnDiskCacheImpl: + if not (odc := self._callee_to_odc.get(callee)): + callee_sub_dir: PathLike[str] = Path(callee) + odc = impls._OnDiskCacheImpl(sub_dir=callee_sub_dir) + self._callee_to_odc[callee] = odc + # pyrefly: ignore [unbound-name] + return odc + + def _dump_imc_to_disk(self) -> Path | None: + with self.lock(): # type: ignore[call-arg] + to_dump: dict[str, str] = { + repr(key): repr(value) for key, value in self._imc._memory.items() + } + if not to_dump: + return None + + odc: impls._OnDiskCacheImpl = impls._OnDiskCacheImpl( + sub_dir=Path("dcache_dump") + ) + fpath: Path = odc._cache_dir / "imc.save" + with odc.lock(): + w_fp = None + try: + w_fp = open(fpath, "x") # noqa:SIM115 + except FileExistsError: + with open(fpath) as r_fp: + existing_dump = json.load(r_fp) + + for key, value in existing_dump.items(): + if key not in to_dump: + to_dump[key] = value + elif to_dump[key] != value: + raise exceptions.DeterministicCachingIMCDumpConflictError from None + + w_fp = open(fpath, "w") # noqa:SIM115 + finally: + assert w_fp is not None + try: + json.dump(to_dump, w_fp, indent=4) + logger.log( + INFO, "Dumped deterministic cache memoization to %s", fpath + ) + finally: + w_fp.close() + + return fpath + + @override + def _make_record_wrapper( + self, + fn: Callable[P, R], + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_encoder: Callable[[R], Any] | None = None, + custom_result_decoder: Callable[[Any], R] | None = None, + ) -> Callable[P, R]: + @wraps(fn) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + if not config.IS_DETERMINISTIC_CACHING_ENABLED(): + raise exceptions.DeterministicCachingDisabledError + start_t: float = time() + params = ( + args, + kwargs, + ) + with self.lock(): + get: impls.Hit | None = self._get( + fn, + params, + ischema=ischema, + custom_params_encoder=custom_params_encoder, + custom_result_decoder=custom_result_decoder, + ) + + if get: + dur: float = time() - start_t + _intf_callback( + _IntfCallbackOrigin.RECORD, + _IntfCallbackAction.REPLAY, + dur, + fn, + params, + get.value, + ) + return get.value + else: + fn_start_t: float = time() + result: R = fn(*args, **kwargs) + fn_dur: float = time() - fn_start_t + if not self._insert( + fn, + params, + result, + ischema, + custom_params_encoder, + custom_result_encoder, + ): + # if we couldn't insert that means that some other callee has populated + # the key entry in the remote cache within the time between our first get + # and the insert attempt; in that case, to be deterministic, we should + # call get again and return that value as the assumption is that other + # compile workers will also use that value + get = self._get( + fn, + params, + ischema, + custom_params_encoder=custom_params_encoder, + custom_result_decoder=custom_result_decoder, + ) + assert get is not None, ( + "remote cache should get(key) if insert(key, _) failed" + ) + dur = time() - start_t + _intf_callback( + _IntfCallbackOrigin.RECORD, + _IntfCallbackAction.RECORD_NOT_INSERTED_REPLAY, + dur, + fn, + params, + fn_dur, + get.value, + ) + return get.value + dur = time() - start_t + _intf_callback( + _IntfCallbackOrigin.RECORD, + _IntfCallbackAction.RECORD_INSERTED, + dur, + fn, + params, + result, + fn_dur, + ) + return result + + return wrapper + + @override + def _get( + self, + fn: Callable[P, R], + params: Params, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_decoder: Callable[[Any], R] | None = None, + ) -> impls.Hit | None: + key: Any = self._make_key( + fn, params, ischema=ischema, custom_params_encoder=custom_params_encoder + ) + sc: impls._OnDiskCacheImpl | impls._RemoteCacheImpl | None = ( + self._get_sc_from_callee(fn.__name__) + ) + with locks._acquire_many_impl_locks_with_timeout( + *([self._imc, sc] if sc else [self._imc]) + ): + try: + # we'll check the memoization first, since that is much faster + # than checking the remote cache and the two should be consistent + imc_get: impls.Hit | None = self._imc.get(key) + if imc_get: + if custom_result_decoder: + return impls.Hit(value=custom_result_decoder(imc_get.value)) + else: + return imc_get + elif not sc: + raise exceptions.StrictDeterministicCachingKeyNotFoundError + else: + sc_get: impls.Hit | None = sc.get(key) + if sc_get: + if custom_result_decoder: + return impls.Hit(value=custom_result_decoder(sc_get.value)) + return sc_get + elif config.STRICTLY_CACHED_DETERMINISM: + raise exceptions.StrictDeterministicCachingKeyNotFoundError + return None + except exceptions.KeyEncodingError as err: + raise exceptions.CustomParamsEncoderRequiredError(fn, params) from err + + @override + def _insert( + self, + fn: Callable[P, R], + params: Params, + result: R, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_encoder: Callable[[R], Any] | None = None, + ) -> bool: + if ( + config.STRICTLY_PRE_POPULATED_DETERMINISM + or config.STRICTLY_CACHED_DETERMINISM + ): + raise exceptions.StrictDeterministicCachingInsertionError + + key: Any = self._make_key( + fn, params, ischema=ischema, custom_params_encoder=custom_params_encoder + ) + sc: impls._OnDiskCacheImpl | impls._RemoteCacheImpl | None = ( + self._get_sc_from_callee(fn.__name__) + ) + assert sc, ( + "sc should be either an on-disk cache or a remote cache if we're inserting" + ) + with locks._acquire_many_impl_locks_with_timeout(self._imc, sc): + try: + encoded_result: Any = ( + result + if not custom_result_encoder + else custom_result_encoder(result) + ) + # reverse order of get, as we don't want to memoize values + # if we haven't actually inserted them into the remote cache + # so that the memoization and the remote cache remain consistent + if sc.insert(key, encoded_result): + if not self._imc.insert(key, encoded_result): + # imc might have the mapping already, if pre-populated + assert self._imc.get(key) == encoded_result + return True + return False + except exceptions.KeyEncodingError as err: + raise exceptions.CustomParamsEncoderRequiredError(fn, params) from err + except exceptions.ValueEncodingError as err: + raise exceptions.CustomResultEncoderRequiredError( + f"Custom result encoder required for function {fn} with parameters {params} and result {result}." + ) from err + + @override + def get( + self, + fn: Callable[P, R], + params: Params, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_decoder: Callable[[Any], R] | None = None, + ) -> impls.Hit | None: + if not config.IS_DETERMINISTIC_CACHING_ENABLED(): + raise exceptions.DeterministicCachingDisabledError + return super().get( + fn, + params, + ischema=ischema, + custom_params_encoder=custom_params_encoder, + custom_result_decoder=custom_result_decoder, + ) + + @override + def insert( + self, + fn: Callable[P, R], + params: Params, + result: R, + ischema: context.IsolationSchema | None = None, + custom_params_encoder: Callable[P, Any] | None = None, + custom_result_encoder: Callable[[R], Any] | None = None, + ) -> bool: + if not config.IS_DETERMINISTIC_CACHING_ENABLED(): + raise exceptions.DeterministicCachingDisabledError + return super().insert( + fn, + params, + result, + ischema=ischema, + custom_params_encoder=custom_params_encoder, + custom_result_encoder=custom_result_encoder, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/locks.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/locks.py new file mode 100644 index 0000000000000000000000000000000000000000..8e8cd011e2d443a814b01842db5677cab6e70132 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/locks.py @@ -0,0 +1,202 @@ +"""Lock acquisition utilities for caching system with timeout support. + +This module provides safe and unsafe lock acquisition functions for both threading.Lock +and FileLock objects, with configurable timeout behaviors. It supports three timeout modes: +blocking (infinite wait), non-blocking (immediate), and blocking with timeout (finite wait). + +The module offers both context manager and manual acquisition patterns: +- Safe acquisition: Uses context managers that automatically handle lock release +- Unsafe acquisition: Manual acquisition that requires explicit release by the caller +""" + +from __future__ import annotations + +from contextlib import _GeneratorContextManager, contextmanager, ExitStack +from typing import TYPE_CHECKING, TypeAlias +from typing_extensions import Protocol + +from filelock import FileLock, Timeout + +from . import exceptions, implementations as impls + + +if TYPE_CHECKING: + from collections.abc import Generator + from threading import Lock + + +_LockContextManager: TypeAlias = _GeneratorContextManager[None, None, None] + + +class _LockProtocol(Protocol): # noqa: PYI046 + def __call__(self, timeout: float | None = None) -> _LockContextManager: ... + + +# Infinite timeout - blocks indefinitely until lock is acquired. +_BLOCKING: float = -1 +# No timeout - returns immediately if lock cannot be acquired. +_NON_BLOCKING: float = 0 +# Finite timeout - blocks for a specified duration before raising a timeout error. +_BLOCKING_WITH_TIMEOUT: float = 60.0 +# Default timeout for lock acquisition. +_DEFAULT_TIMEOUT: float = _BLOCKING_WITH_TIMEOUT + + +@contextmanager +def _acquire_lock_with_timeout( + lock: Lock, + timeout: float | None = None, +) -> Generator[None, None, None]: + """Context manager that safely acquires a threading.Lock with timeout and automatically releases it. + + This function provides a safe way to acquire a lock with timeout support, ensuring + the lock is always released even if an exception occurs during execution. + + Args: + lock: The threading.Lock object to acquire + timeout: Timeout in seconds. If None, uses _DEFAULT_TIMEOUT. + - Use _BLOCKING (-1.0) for infinite wait + - Use _NON_BLOCKING (0.0) for immediate return + - Use positive value for finite timeout + + Yields: + None: Yields control to the caller while holding the lock + + Raises: + LockTimeoutError: If the lock cannot be acquired within the timeout period + + Example: + with _acquire_lock_with_timeout(my_lock, timeout=30.0): + # Critical section - lock is held + perform_critical_operation() + # Lock is automatically released here + """ + _unsafe_acquire_lock_with_timeout(lock, timeout=timeout) + + try: + yield + finally: + lock.release() + + +def _unsafe_acquire_lock_with_timeout(lock: Lock, timeout: float | None = None) -> None: + """Acquire a threading.Lock with timeout without automatic release (unsafe). + + This function acquires a lock with timeout support but does NOT automatically + release it. The caller is responsible for releasing the lock explicitly. + Use this only when you need manual control over lock lifetime. + + Args: + lock: The threading.Lock object to acquire + timeout: Timeout in seconds. If None, uses _DEFAULT_TIMEOUT. + - Use _BLOCKING (-1.0) for infinite wait + - Use _NON_BLOCKING (0.0) for immediate return + - Use positive value for finite timeout + + Raises: + LockTimeoutError: If the lock cannot be acquired within the timeout period + + Warning: + This is an "unsafe" function because it does not automatically release + the lock. Always call lock.release() when done, preferably in a try/finally + block or use the safe _acquire_lock_with_timeout context manager instead. + + Example: + lock = Lock() + try: + _unsafe_acquire_lock_with_timeout(lock, timeout=30.0) + # Critical section - lock is held + perform_critical_operation() + finally: + lock.release() # Must manually release! + """ + _timeout: float = timeout if timeout is not None else _DEFAULT_TIMEOUT + if not lock.acquire(timeout=_timeout): + raise exceptions.LockTimeoutError(lock, _timeout) + + +@contextmanager +def _acquire_flock_with_timeout( + flock: FileLock, + timeout: float | None = None, +) -> Generator[None, None, None]: + """Context manager that safely acquires a FileLock with timeout and automatically releases it. + + This function provides a safe way to acquire a file lock with timeout support, ensuring + the lock is always released even if an exception occurs during execution. + + Args: + flock: The FileLock object to acquire + timeout: Timeout in seconds. If None, uses _DEFAULT_TIMEOUT. + - Use _BLOCKING (-1.0) for infinite wait + - Use _NON_BLOCKING (0.0) for immediate return + - Use positive value for finite timeout + + Yields: + None: Yields control to the caller while holding the file lock + + Raises: + FileLockTimeoutError: If the file lock cannot be acquired within the timeout period + + Example: + flock = FileLock("/tmp/my_process.lock") + with _acquire_flock_with_timeout(flock, timeout=30.0): + # Critical section - file lock is held + perform_exclusive_file_operation() + # File lock is automatically released here + """ + _unsafe_acquire_flock_with_timeout(flock, timeout=timeout) + + try: + yield + finally: + flock.release() + + +def _unsafe_acquire_flock_with_timeout(flock: FileLock, timeout: float | None) -> None: + """Acquire a FileLock with timeout without automatic release (unsafe). + + This function acquires a file lock with timeout support but does NOT automatically + release it. The caller is responsible for releasing the lock explicitly. + Use this only when you need manual control over lock lifetime. + + Args: + flock: The FileLock object to acquire + timeout: Timeout in seconds. If None, uses _DEFAULT_TIMEOUT. + - Use _BLOCKING (-1.0) for infinite wait + - Use _NON_BLOCKING (0.0) for immediate return + - Use positive value for finite timeout + + Raises: + FileLockTimeoutError: If the file lock cannot be acquired within the timeout period + + Warning: + This is an "unsafe" function because it does not automatically release + the lock. Always call flock.release() when done, preferably in a try/finally + block or use the safe _acquire_flock_with_timeout context manager instead. + + Example: + flock = FileLock("/tmp/my_process.lock") + try: + _unsafe_acquire_flock_with_timeout(flock, timeout=30.0) + # Critical section - file lock is held + perform_exclusive_file_operation() + finally: + flock.release() # Must manually release! + """ + _timeout: float = timeout if timeout is not None else _DEFAULT_TIMEOUT + try: + _ = flock.acquire(timeout=_timeout) + except Timeout as err: + raise exceptions.FileLockTimeoutError(flock, _timeout) from err + + +@contextmanager +def _acquire_many_impl_locks_with_timeout( + *impls: impls._CacheImpl, + timeout: float | None = None, +) -> Generator[None, None, None]: + with ExitStack() as stack: + for impl in impls: + stack.enter_context(impl.lock(timeout)) + yield diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7fb25573f2e37346d2f16501f4fb6ff731353cef --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/runtime/caching/utils.py @@ -0,0 +1,109 @@ +"""Utility functions for caching operations in PyTorch Inductor runtime. + +This module provides helper functions for pickling/unpickling operations +with error handling, LRU caching decorators, and type-safe serialization +utilities used throughout the caching system. +""" + +import pickle +from collections.abc import Callable +from functools import lru_cache, partial, wraps +from typing import Any +from typing_extensions import ParamSpec, TypeVar + +from . import exceptions + + +# Type specification for function parameters +P = ParamSpec("P") +# Type variable for function return values +R = TypeVar("R") + + +def _lru_cache(fn: Callable[P, R]) -> Callable[P, R]: + """LRU cache decorator with TypeError fallback. + + Provides LRU caching with a fallback mechanism that calls the original + function if caching fails due to unhashable arguments. Uses a cache + size of 64 with typed comparison. + + Args: + fn: The function to be cached. + + Returns: + A wrapper function that attempts caching with fallback to original function. + """ + cached_fn = lru_cache(maxsize=64, typed=True)(fn) + + @wraps(fn) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # type: ignore[type-var] + try: + return cached_fn(*args, **kwargs) # type: ignore[arg-type] + except TypeError: + return fn(*args, **kwargs) + + return wrapper + + +@_lru_cache +def _try_pickle(to_pickle: Any, raise_if_failed: type = exceptions.CacheError) -> bytes: + """Attempt to pickle an object with error handling. + + Tries to serialize an object using pickle.dumps with appropriate error + handling and custom exception raising. + + Args: + to_pickle: The object to be pickled. + raise_if_failed: Exception class to raise if pickling fails. + + Returns: + The pickled bytes representation of the object. + + Raises: + The exception class specified in raise_if_failed if pickling fails. + """ + try: + pickled: bytes = pickle.dumps(to_pickle) + except (pickle.PicklingError, AttributeError) as err: + raise raise_if_failed(to_pickle) from err + return pickled + + +# Specialized pickle function for cache keys with KeyPicklingError handling. +_try_pickle_key: Callable[[Any], bytes] = partial( + _try_pickle, raise_if_failed=exceptions.KeyPicklingError +) +# Specialized pickle function for cache values with ValuePicklingError handling. +_try_pickle_value: Callable[[Any], bytes] = partial( + _try_pickle, raise_if_failed=exceptions.ValuePicklingError +) + + +@_lru_cache +def _try_unpickle(pickled: bytes, raise_if_failed: type = exceptions.CacheError) -> Any: + """Attempt to unpickle bytes with error handling. + + Tries to deserialize bytes using pickle.loads with appropriate error + handling and custom exception raising. + + Args: + pickled: The bytes to be unpickled. + raise_if_failed: Exception class to raise if unpickling fails. + + Returns: + The unpickled object. + + Raises: + The exception class specified in raise_if_failed if unpickling fails. + """ + try: + unpickled: Any = pickle.loads(pickled) + except pickle.UnpicklingError as err: + raise raise_if_failed(pickled) from err + return unpickled + + +# Specialized unpickle function for cache keys with KeyUnPicklingError handling. +_try_unpickle_value: Callable[[Any], bytes] = partial( + _try_unpickle, raise_if_failed=exceptions.ValueUnPicklingError +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb3d731525ea8d1bebac20a4b2e9ac732469cdd4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__init__.py @@ -0,0 +1,6 @@ +# NOTE: add new template heuristics here, so they get imported and registered +# TODO: write a simple glob if there are many heuristics to auto import them in the right order +from . import aten, base, contiguous_mm, decompose_k, registry, triton + +# expose the entry function +from .registry import get_template_heuristic diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..202fa8431883856f7f017dfca6c564c4c54be210 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/aten.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/aten.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eaf87bd3b05ad09f95577c944c3e6fa5084d0fc0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/aten.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/base.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb37fa86f80788dd8535c0e32852fd1d2a7021c8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/base.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/contiguous_mm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/contiguous_mm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca1945be0c4fcfafa5e099013da099254d103a0f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/contiguous_mm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/cutedsl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/cutedsl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f6cc117ada526789cc14591a62348238faf1b17 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/cutedsl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/decompose_k.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/decompose_k.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87f011cb0e6b34abf8fb3831c7cdcc9b52b4cb29 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/decompose_k.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/gemm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/gemm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b4da57687f6dd10d175b7eb6ab9f679014c2748 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/gemm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/params.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/params.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25a3223e3a9242c93bb9390c3240619895a0eb5d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/params.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/registry.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/registry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5a5aa25b36707d7a7f1659b40923d1f788db932 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/registry.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/triton_addmm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/triton_addmm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46072247f1c5e0b960894df84fc05b89a373e354 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/triton_addmm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/aten.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/aten.py new file mode 100644 index 0000000000000000000000000000000000000000..103668aa056faae96c6e65ef9a8d912ef6543c6e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/aten.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +from torch._inductor import config as inductor_config + +from ..kernel.bmm import aten_baddbmm, aten_bmm, aten_bmm_dtype +from ..kernel.mm import ( + aten__fp8_mm, + aten__int_mm, + aten_addmm, + aten_bias_addmm, + aten_mm, + aten_mm_dtype, +) +from ..kernel.mm_plus_mm import aten_mm_plus_mm +from .base import TemplateConfigHeuristics +from .gemm import GemmMaxAutotuneTemplateConfigHeuristics +from .registry import register_template_heuristic + + +if TYPE_CHECKING: + from collections.abc import Generator + + from ..kernel_inputs import KernelInputs + + +# These are all labeled as device type None to indicate that they +# are valid for all device types +@register_template_heuristic(aten_mm.uid, None) +@register_template_heuristic(aten_mm_dtype.uid, "cuda") +@register_template_heuristic(aten__fp8_mm.uid, None) +@register_template_heuristic(aten__int_mm.uid, None) +@register_template_heuristic(aten_bmm.uid, None) +@register_template_heuristic(aten_mm_plus_mm.uid, None) +# bmm dtype is only valid on cuda +@register_template_heuristic(aten_bmm_dtype.uid, "cuda") +class ATenConfigHeuristics(TemplateConfigHeuristics): + """ + Pseudo heuristic to make ATen choices go through the same flow as other templates + + This is a single choice without kwargs + + If you want to use this with an ATen choice that has kwargs, just subclass + """ + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + yield dict() + + +# None here indicates that this is valid for all device types on that op +# Note (None, op) takes precedence over (device_type, None) +@register_template_heuristic(aten_addmm.uid, None, op_name="addmm") +@register_template_heuristic(aten_baddbmm.uid, None, op_name="baddbmm") +class ATenAddMMConfigHeuristics(ATenConfigHeuristics): + def get_extra_kwargs( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> dict[str, Any]: + kwargs = super().get_extra_kwargs(kernel_inputs, op_name) + alpha = kernel_inputs.get_scalar("alpha") + beta = kernel_inputs.get_scalar("beta") + return { + **kwargs, + "alpha": alpha, + "beta": beta, + } + + +@register_template_heuristic(aten_bias_addmm.uid, None, op_name="addmm") +class ATenBiasAddMMConfigHeuristics( + ATenAddMMConfigHeuristics, GemmMaxAutotuneTemplateConfigHeuristics +): + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + nodes = kernel_inputs.nodes() + # for addmm, bias is the first input + bias = nodes[0] + if bias.get_stride()[0] == 0 and inductor_config.triton.autotune_cublasLt: + yield dict() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/base.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/base.py new file mode 100644 index 0000000000000000000000000000000000000000..0343270f3a1111de9963f2dfb4781b7aabd1d855 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/base.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +from .params import DictKernelTemplateParams, KernelTemplateParams + + +if TYPE_CHECKING: + from collections.abc import Generator + + from ..kernel_inputs import KernelInputs + + +class TemplateConfigHeuristics: + """Base class for generating sets of configs for an associated template.""" + + def should_run(self, inputs: KernelInputs) -> bool: + """ + hookup to check whether the configs are right to run at all e.g. you can check + max-autotune specific to your heuristic here or other things + If this returns False, get_template_configs will yield no configs + + Args: + inputs: KernelInputs + """ + return True + + def get_template_configs( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[KernelTemplateParams, None, None]: + """ + Get template configs for the given inputs. + + Prefer to override the _get_template_configs_impl method + to leverage things like should_run + """ + if not self.should_run(kernel_inputs): + return + + # Generate configs and fuse with extra_kwargs + for config_dict in self._get_template_configs_impl(kernel_inputs, op_name): + # Fuse extra_kwargs into config + yield DictKernelTemplateParams(config_dict) + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Get template configs for the given inputs. + This is the main entry point for template-specific logic. + """ + # base implementation yields no entries + yield from [] + + def get_extra_kwargs( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> dict[str, Any]: + """ + Get extra kwargs for the given inputs/op for the template. + + Use this to return kwargs that are needed for the template, but + do not change depending on the config/choice, but are rather + always the same, for all configs + """ + return {} + + def adjust_kernel_inputs( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> KernelInputs: + """ + Adjust kernel inputs for the given inputs/op for the template. + + override this to adjust the kernel inputs e.g. (un)squeezing + """ + return kernel_inputs diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/contiguous_mm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/contiguous_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..f7b65eba9c76cfbaa23d67cca2a6fb0d51d317dc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/contiguous_mm.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +import torch + +from ..ir import get_free_symbols +from ..kernel.mm import ( + addmm_contiguous_subgraph_template, + mm_contiguous_subgraph_template, +) +from ..kernel_inputs import KernelInputs, MMKernelInputs +from ..utils import use_contiguous +from .base import TemplateConfigHeuristics +from .gemm import GemmMaxAutotuneTemplateConfigHeuristics +from .registry import register_template_heuristic + + +if TYPE_CHECKING: + from collections.abc import Generator + + +@register_template_heuristic(mm_contiguous_subgraph_template.uid, None, op_name="mm") +@register_template_heuristic( + addmm_contiguous_subgraph_template.uid, None, op_name="addmm" +) +class EmptyContiguousMMConfigHeuristics(TemplateConfigHeuristics): + """empty heuristics to skip contiguous mm on not cuda""" + + +@register_template_heuristic( + mm_contiguous_subgraph_template.uid, + "cuda", + register=torch.version.hip is not None, + op_name="mm", +) +@register_template_heuristic( + addmm_contiguous_subgraph_template.uid, + "cuda", + register=torch.version.hip is not None, + op_name="addmm", +) +class ContiguousMMHeuristics(GemmMaxAutotuneTemplateConfigHeuristics): + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Get all the valid k_splits for the given m, n, k. + """ + assert isinstance(kernel_inputs, MMKernelInputs), ( + f"{self.__class__.__name__} requires MMKernelInputs" + ) + # Check for unbacked symbols - if found, yield nothing + unbacked_symbols = any( + len(get_free_symbols(itr, unbacked_only=True)) > 0 + for itr in ( + *kernel_inputs.shapes_symbolic(), + *kernel_inputs.strides_symbolic(), + ) + ) + if unbacked_symbols: + return + mat2 = kernel_inputs.mat1mat2()[1] + if mat2.get_layout().is_contiguous(): + # no need for contiguous decomposition + return + m, n, k = kernel_inputs.mnk_symbolic() + if not use_contiguous(m, n, k): + return + yield {} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/cutedsl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/cutedsl.py new file mode 100644 index 0000000000000000000000000000000000000000..db337b9d8a271d25f28c55a23aaa2dc91e56b0bf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/cutedsl.py @@ -0,0 +1,141 @@ +from dataclasses import dataclass +from enum import auto, Enum +from itertools import product + +import torch._inductor.config as config + + +class TensorMapUpdateMode(Enum): + """Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency.""" + + SMEM = auto() + GMEM = auto() + + +@dataclass(frozen=True) +class CuTeGemmConfig: + TILE_M: int = 128 + TILE_N: int = 192 + CLUSTER_M: int = 2 + CLUSTER_N: int = 1 + USE_2_CTA: bool = False + TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM + + +def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + For information regarding valid config sets, see: + https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py + """ + + # Tile_n is always the same regardless of 2cta + tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256] + + # Valid clusters + clusters_no_2cta = [ + (1, 1), + (1, 2), + (1, 4), + (1, 8), + (1, 16), + (2, 1), + (2, 2), + (2, 4), + (2, 8), + (4, 1), + (4, 2), + (4, 4), + (8, 1), + (8, 2), + (16, 1), + ] + clusters_2cta = [ + (2, 1), + (2, 2), + (2, 4), + (2, 8), + (4, 1), + (4, 2), + (4, 4), + (8, 1), + (8, 2), + (16, 1), + ] + + configs: list[CuTeGemmConfig] = [] + + for use_2cta, cluster_set, tile_m_range in [ + (False, clusters_no_2cta, [64, 128]), + (True, clusters_2cta, [128, 256]), + ]: + for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product( + [TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM], + tile_m_range, + tile_n_vals, + cluster_set, + ): + configs.append( + CuTeGemmConfig( + tile_m, + tile_n, + cluster_m, + cluster_n, + USE_2_CTA=use_2cta, + TENSORMAP_UPDATE_MODE=tensormap_update_mode, + ) + ) + + return configs + + +def get_default_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + """ + + config_tuples = [ + (128, 256, 2, 1, False, TensorMapUpdateMode.SMEM), + (256, 160, 2, 1, True, TensorMapUpdateMode.GMEM), + (256, 256, 2, 1, True, TensorMapUpdateMode.GMEM), + (64, 32, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 256, 1, 2, False, TensorMapUpdateMode.SMEM), + (128, 256, 1, 2, False, TensorMapUpdateMode.SMEM), + (256, 256, 2, 2, True, TensorMapUpdateMode.GMEM), + (128, 256, 1, 2, False, TensorMapUpdateMode.GMEM), + (64, 32, 1, 1, False, TensorMapUpdateMode.SMEM), + (256, 256, 2, 1, True, TensorMapUpdateMode.SMEM), + (128, 256, 1, 1, False, TensorMapUpdateMode.GMEM), + (256, 256, 8, 1, True, TensorMapUpdateMode.GMEM), + (64, 32, 1, 2, False, TensorMapUpdateMode.SMEM), + (256, 192, 2, 1, True, TensorMapUpdateMode.GMEM), + (256, 256, 2, 2, True, TensorMapUpdateMode.SMEM), + (128, 96, 1, 2, False, TensorMapUpdateMode.SMEM), + (64, 192, 1, 1, False, TensorMapUpdateMode.SMEM), + (64, 64, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 192, 1, 1, False, TensorMapUpdateMode.GMEM), + (128, 64, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 160, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 256, 1, 1, False, TensorMapUpdateMode.GMEM), + ] + + return [CuTeGemmConfig(*args) for args in config_tuples] + + +def get_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + + Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures + or unstable results. By default, autotuning is disabled and we return only + a single baseline config. + """ + if ( + config.cutedsl_enable_autotuning + and config.max_autotune_gemm_search_space == "EXHAUSTIVE" + ): + return get_exhaustive_groupgemm_configs() + elif config.cutedsl_enable_autotuning: + return get_default_groupgemm_configs() + else: + return [get_default_groupgemm_configs()[0]] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/decompose_k.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/decompose_k.py new file mode 100644 index 0000000000000000000000000000000000000000..7954396a10861b39748ad73075b343286551a102 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/decompose_k.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +import sympy + +import torch + +from ..ir import get_free_symbols +from ..kernel.mm import decompose_k_subgraph_template +from ..kernel_inputs import KernelInputs, MMKernelInputs +from ..utils import get_k_splits +from ..virtualized import V +from .base import TemplateConfigHeuristics +from .gemm import GemmMaxAutotuneTemplateConfigHeuristics +from .registry import register_template_heuristic + + +if TYPE_CHECKING: + from collections.abc import Generator + + +@register_template_heuristic(decompose_k_subgraph_template.uid, None, op_name="mm") +class EmptyDecomposeKConfigHeuristics(TemplateConfigHeuristics): + """empty heuristics to skip decompose k on anything not cuda""" + + +# on CUDA, we don't support hip for decompose_k yet +@register_template_heuristic( + decompose_k_subgraph_template.uid, + "cuda", + register=torch.version.hip is None, + op_name="mm", +) +# TODO(coconutruben): enable decompose k on AMD by removing the register bool +# and benchmarking it for performance and stability +# TODO(coconutruben): enable decompose k on other devices (xpu, cpu, mps, mtia) +# by either adding specific register_template_heuristic tags, or setting the +# device to None (enabled on all devices) +class DecomposeKConfigHeuristics(GemmMaxAutotuneTemplateConfigHeuristics): + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Get all the valid k_splits for the given m, n, k. + """ + assert isinstance(kernel_inputs, MMKernelInputs), ( + f"{self.__class__.__name__} requires MMKernelInputs" + ) + + # Check for unbacked symbols - if found, yield nothing + unbacked_symbols = any( + len(get_free_symbols(itr, unbacked_only=True)) > 0 + for itr in ( + *kernel_inputs.shapes_symbolic(), + *kernel_inputs.strides_symbolic(), + ) + ) + if unbacked_symbols: + return + + m, n, k = kernel_inputs.mnk_symbolic() + k_splits = get_k_splits(m, n, k) + for k_split in k_splits: + if not V.graph.sizevars.statically_known_true( + sympy.Eq(sympy.Mod(k, k_split), 0) + ): + continue + yield {"k_split": k_split} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/gemm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..2d56f4c481ccd0601d75b8867a48634c7001abc3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/gemm.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .. import config as inductor_config +from .base import TemplateConfigHeuristics + + +if TYPE_CHECKING: + from ..kernel_inputs import KernelInputs + + +class GemmMaxAutotuneTemplateConfigHeuristics(TemplateConfigHeuristics): + def should_run(self, inputs: KernelInputs) -> bool: + """ + simple base override for GEMM family templates that run only in max-autotune + """ + return inductor_config.max_autotune or inductor_config.max_autotune_gemm diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/params.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/params.py new file mode 100644 index 0000000000000000000000000000000000000000..92b130217e3d19507b51e7bd384072548c67abb4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/params.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + + +class KernelTemplateParams(ABC): + """Abstract base class for kernel template parameters.""" + + @abstractmethod + def to_kwargs(self) -> dict[str, Any]: + """Convert params to kwargs dict for template.choice_or_none()""" + + @abstractmethod + def to_serializeable_dict(self) -> dict[str, Any]: + """Convert params to serializable dict for storage/caching""" + + @classmethod + @abstractmethod + def from_dict(cls, data: dict[str, Any]) -> KernelTemplateParams: + """Create params instance from dict""" + + +class DictKernelTemplateParams(KernelTemplateParams): + """Simple implementation that wraps a kwargs dict""" + + # NOTE: this is a compatibility layer, until every template + # has time to define their own params class, with meaningful + # defaults etc. + + def __init__(self, kwargs: dict[str, Any]): + self.kwargs = kwargs + + def to_kwargs(self) -> dict[str, Any]: + return self.kwargs.copy() + + def to_serializeable_dict(self) -> dict[str, Any]: + return self.kwargs.copy() + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> DictKernelTemplateParams: + return cls(data) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/registry.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..247c78fd557580e33474c8550e645c372db49903 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/registry.py @@ -0,0 +1,175 @@ +""" +Template heuristic registry system for PyTorch Inductor. + +This module provides a centralized registration system for template heuristics, +allowing automatic registration based on device type and conditional registration +for CUDA vs ROCm based on torch.version.hip. +""" + +from __future__ import annotations + +import contextlib +import logging +from typing import Any, Optional, TYPE_CHECKING, Union + +from .base import TemplateConfigHeuristics + + +if TYPE_CHECKING: + from collections.abc import Iterator + + +# Module-wide registry for template heuristics +_TEMPLATE_HEURISTIC_REGISTRY: dict[ + tuple[Union[str, None], ...], type[TemplateConfigHeuristics] +] = {} + +# Manual cache for successful lookups only (fallback instances are not cached) +_HEURISTIC_CACHE: dict[tuple[str, str, str], TemplateConfigHeuristics] = {} + +log = logging.getLogger(__name__) + + +def register_template_heuristic( + template_name: str, + device_type: Union[str, None], + register: bool = True, + op_name: Optional[str] = None, +) -> Any: + """ + Decorator to register template heuristic classes. + + Args: + template_name: Name of the template (e.g., "mm", "bmm", "scaled_mm") + device_type: Device type ("cuda", "cpu", "xpu") + Set this to None to indicate that the heuristic is applicable to all device types. + register: Whether to register this heuristic. Caller should pass the condition directly. + op_name: Name of the operator (e.g., "mm", "bmm", "scaled_mm"). This is optional + and is only used when a template uses different heuristics for different ops + + Returns: + Decorator function that registers the class if conditions are met. + + Example: + @register_template_heuristic("mm", "cuda", register=torch.version.hip is None) + class CUDAMMTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic): + pass + """ + + def decorator( + cls: type[TemplateConfigHeuristics], + ) -> type[TemplateConfigHeuristics]: + if register: + key: tuple[Union[str, None], ...] = (template_name, device_type, op_name) + _TEMPLATE_HEURISTIC_REGISTRY[key] = cls + log.info( + f"Registered template heuristic: {cls.__name__} for '{template_name=}', '{device_type=}', '{op_name=}'" # noqa: G004 + ) + return cls + + return decorator + + +def get_template_heuristic( + template_name: str, device_type: str, op_name: str +) -> TemplateConfigHeuristics: + """ + Retrieve a template heuristic instance for the given template and device type. + + Args: + template_name: Name of the template (e.g., "mm", "bmm", "scaled_mm") + device_type: Device type ("cuda", "cpu", "xpu") + op_name: Name of the operator (e.g., "mm", "bmm", "scaled_mm") + + Returns: + Template heuristic instance. If no specific heuristic is found, + returns a fallback TemplateConfigHeuristics() instance (uncached). + """ + # Check cache first + cache_key = (template_name, device_type, op_name) + if cache_key in _HEURISTIC_CACHE: + return _HEURISTIC_CACHE[cache_key] + + keys = [ + # everything is specified + (template_name, device_type, op_name), + # heuristic is valid across all devices + (template_name, None, op_name), + # heuristic is valid across all ops for that device + (template_name, device_type, None), + # heuristic is always valid for that template + (template_name, None, None), + ] + + # Look up in registry + heuristic_class = None + for key in keys: + if key in _TEMPLATE_HEURISTIC_REGISTRY: + heuristic_class = _TEMPLATE_HEURISTIC_REGISTRY[key] + break + + if heuristic_class is None: + # Log error and return fallback instance (uncached) + log.error( + "No template heuristic found - template_name=%s, device_type=%s, op_name=%s. " + "Available combinations: %s. Using fallback TemplateConfigHeuristics instance.", + template_name, + device_type, + op_name, + list(_TEMPLATE_HEURISTIC_REGISTRY.keys()), + ) + return TemplateConfigHeuristics() + + # Cache successful lookup and return + instance = heuristic_class() + _HEURISTIC_CACHE[cache_key] = instance + return instance + + +def clear_registry() -> None: + """ + Clear all registered template heuristics. + + This is primarily useful for testing purposes to ensure a clean state. + """ + _TEMPLATE_HEURISTIC_REGISTRY.clear() + _HEURISTIC_CACHE.clear() + + +@contextlib.contextmanager +def override_template_heuristics( + device_type: str, + template_op_pairs: list[tuple[str, str]], +) -> Iterator[None]: + """ + Context manager to temporarily override template heuristics with an empty heuristic. + + This is useful for testing purposes, where we want to ensure a specific template/op pair + is not used + + Args: + device_type: Device type ("cuda", "cpu", "xpu") + template_op_pairs: List of (template_name, op_name) pairs to override. + """ + # Save original entries to restore later + original_entries = {} + new_keys = [] + _HEURISTIC_CACHE.clear() + try: + for template_name, op_name in template_op_pairs: + assert op_name is not None + key = (device_type, template_name, op_name) + if key in _TEMPLATE_HEURISTIC_REGISTRY: + original_entries[key] = _TEMPLATE_HEURISTIC_REGISTRY[key] + # TemplateConfigHeuristics base class returns no entries + # so we use it for overriding + _TEMPLATE_HEURISTIC_REGISTRY[key] = TemplateConfigHeuristics + new_keys.append(key) + yield + finally: + # Restore original entries or remove if they didn't exist before + for key in new_keys: + _TEMPLATE_HEURISTIC_REGISTRY.pop(key, None) + if key in original_entries: + _TEMPLATE_HEURISTIC_REGISTRY[key] = original_entries[key] + _HEURISTIC_CACHE.clear() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/triton.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/triton.py new file mode 100644 index 0000000000000000000000000000000000000000..21deda557346b8adda8668699120854e705e524e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/triton.py @@ -0,0 +1,2649 @@ +from __future__ import annotations + +import dataclasses +import itertools +import math +import os +from functools import partial +from threading import Lock +from typing import Any, Optional, TYPE_CHECKING + +import sympy + +import torch +from torch._inductor.template_heuristics.triton_addmm import AddMMConfigMixin +from torch.utils._ordered_set import OrderedSet +from torch.utils._triton import has_triton_stable_tma_api + +from .. import config, config as inductor_config +from ..kernel.bmm import bmm_template +from ..kernel.mm import ( + blackwell_ws_persistent_device_tma_mm_template, + get_scaling_options, + get_tile_size, + mm_template, + persistent_tma_mm_template, + scaled_mm_device_tma_epilogue_scaling_template, + scaled_mm_device_tma_main_loop_scaling_template, +) +from ..kernel.mm_plus_mm import mm_plus_mm_template +from ..kernel_inputs import KernelInputs, MMKernelInputs +from ..utils import ( + get_backend_num_stages, + get_num_sms, + get_tma_workspace_arg, + TMA_DESCRIPTOR_SIZE, + using_b200, +) +from ..virtualized import V +from .gemm import GemmMaxAutotuneTemplateConfigHeuristics +from .registry import register_template_heuristic + + +if TYPE_CHECKING: + from collections.abc import Callable, Generator + + from triton import Config as TritonConfig + + +# Gemm Configs +@dataclasses.dataclass +class BaseConfig: + """ + Base Gemm configuration used for most backends (CPU, CUDA) + """ + + block_m: int + block_n: int + block_k: int + num_stages: int + num_warps: int + hint_override: Optional[int] = dataclasses.field(kw_only=True, default=None) + + +@dataclasses.dataclass +class GemmConfig(BaseConfig): + """ + Gemm configuration used for most backends (CPU, CUDA) + """ + + group_m: int = dataclasses.field(kw_only=True, default=8) + + +ConvConfig = BaseConfig + + +# FlexAttention Configs +@dataclasses.dataclass +class FlexConfig: + """ + Base Config class for flex attention + - FlexAttn forward and backward will use this. For flex decoding, + please use FlexDecodingConfig. + + NOTE: + For flex_attn bwd block_m and block_n are reused for block_m1, block_m2, block_n1, block_n2 + + """ + + block_m: int + block_n: int + num_stages: int + num_warps: int + + +@dataclasses.dataclass +class FlexBwDConfig: + """ + Base Config class for flex attention backward + - FlexAttn backward will use this. + + Note: flex bwd configs + + Kernel Constraints: + * BLOCK_N1 % BLOCK_M1 == 0 + * BLOCK_M2 % BLOCK_N2 == 0 + + Pattern 1 - Symmetric Pairing (M, N, N, M): + - Used in autotune configs + - block_m1=M, block_n1=N, block_m2=N, block_n2=M + - Only requires checking BLOCK_N % BLOCK_M == 0 + - Second constraint (BLOCK_M2 % BLOCK_N2) automatically satisfied + + Pattern 2 - Independent Parameters (M1, N1, M2, N2): + - Used in exhaustive search for maximum flexibility + - All four parameters can be set independently + - Requires checking both constraints + + """ + + block_m1: int + block_n1: int + block_m2: int + block_n2: int + num_stages: int + num_warps: int + + +@dataclasses.dataclass +class FlexDecodeConfig: + """ + Config class for flex decoding + """ + + block_n: int + num_stages: int + num_warps: int + + +# ROCm classes +@dataclasses.dataclass +class ROCmGemmConfig(GemmConfig): + """ + ROCm subclass for GEMMs, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 16 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmConvConfig(ConvConfig): + """ + ROCm subclass for Conv, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 16 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmFlexConfig(FlexConfig): + """ + ROCm subclass for FlexAttn, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 0 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmFlexBwDConfig(FlexBwDConfig): + """ + ROCm subclass for FlexAttn backward, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 0 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmFlexDecodeConfig(FlexDecodeConfig): + """ + ROCm subclass for FlexDecode, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 0 + waves_per_eu: int = 0 + kpack: int = 2 + + +class BaseHeuristicSingleton(type): + """ + Thread-safe implementation of single to be used in the config heuristic subclasses + to ensure heavy __init__ calls are not repeatedly run + """ + + _instances: dict[type[Any], Any] = {} + _lock: Lock = Lock() + + def __call__( + cls: BaseHeuristicSingleton, *args: Any, **kwargs: Any + ) -> BaseConfigHeuristic: + with cls._lock: + if cls not in cls._instances: + instance = super().__call__() + cls._instances[cls] = instance + return cls._instances[cls] + + +class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton): + """ + Base class for mm_configs, device specific triton kernels config inherit from here + """ + + def __init__(self) -> None: + # Whether the heuristic is used for int8. Use this when the heuristic is int8 exclusive + # but prefer the preprocess_mm_configs argument when it's used for both + self.has_int8_tensor: bool = False + # Whether to scale configs at all + # TODO(coconutruben): remove this once mm_plus_mm and tests support scaling + self.should_scale_configs: bool = True + # List of dictionaries to store the kernel configs. Configs that evaluate to true + # will be utilised on the target platform. The configs are as follows: + # (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) + self.mm_configs: list[BaseConfig] = [ + GemmConfig(32, 32, 16, 1, 2), + GemmConfig(32, 32, 128, 2, 4), + GemmConfig(32, 64, 32, 5, 8), + GemmConfig(64, 32, 32, 5, 8), + GemmConfig(64, 32, 128, 5, 4), + GemmConfig(64, 64, 16, 2, 4), + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 64, 64, 3, 8), + GemmConfig(64, 64, 128, 5, 4), + GemmConfig(64, 128, 32, 3, 4), + GemmConfig(64, 128, 32, 4, 8), + GemmConfig(64, 128, 64, 3, 4), + GemmConfig(64, 128, 128, 4, 4), + GemmConfig(128, 64, 32, 3, 4), + GemmConfig(128, 64, 32, 4, 8), + GemmConfig(128, 128, 32, 2, 8), + GemmConfig(128, 128, 32, 3, 4), + GemmConfig(128, 128, 64, 3, 4), + GemmConfig(128, 128, 64, 5, 8), + GemmConfig(128, 128, 128, 4, 8), + ] + + # Exhaustive search for mm configs + self.exhaustive_configs: list[BaseConfig] = [ + GemmConfig( + BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, group_m=group_m + ) + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( + [16, 32, 64, 128, 256], repeat=3 + ) + for num_stages in [1, 2, 3, 4, 5] + for num_warps in [2, 4, 8] + for group_m in [8] + ] + + # these are only used in tuned_mm when AutoHeuristic is enabled + # the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned + # when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10 + # which saves compilation time (since less configs are autotuned) and potentially increase performance + # because the learned heuristic might predict a config that is not part mm_configs + self.extra_mm_configs: list[BaseConfig] = [ + GemmConfig(16, 32, 16, 3, 2), + GemmConfig(16, 32, 32, 4, 2), + GemmConfig(16, 32, 32, 5, 2), + GemmConfig(64, 64, 128, 3, 4), + GemmConfig(128, 64, 32, 2, 2), + GemmConfig(128, 64, 64, 3, 8), + GemmConfig(128, 64, 128, 4, 8), + GemmConfig(128, 128, 32, 4, 4), + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 64, 5, 4), + ] + + self.int8_mm_configs: list[BaseConfig] = [ + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 128, 32, 3, 4), + GemmConfig(128, 64, 32, 3, 4), + GemmConfig(64, 128, 32, 4, 8), + GemmConfig(128, 64, 32, 4, 8), + GemmConfig(64, 32, 32, 5, 8), + GemmConfig(32, 64, 32, 5, 8), + GemmConfig(128, 128, 32, 2, 8), + GemmConfig(64, 64, 64, 3, 8), + GemmConfig(128, 256, 128, 3, 8), + GemmConfig(256, 128, 128, 3, 8), + ] + + self.mixed_mm_configs: list[BaseConfig] = [ + GemmConfig(16, 128, 256, 3, 4), + GemmConfig(16, 128, 256, 5, 8), + ] + + self.persistent_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 256, 64, 3, 8), + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 128, 3, 8), + GemmConfig(128, 128, 128, 3, 4), + GemmConfig(128, 128, 64, 4, 8), + GemmConfig(128, 128, 64, 5, 8), + GemmConfig(256, 128, 64, 4, 8), + GemmConfig(128, 128, 64, 5, 4), + ] + + self.blackwell_persistent_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 256, 64, 4, 8), + GemmConfig(256, 128, 64, 3, 8), + GemmConfig(128, 256, 128, 2, 8), + GemmConfig(128, 256, 64, 3, 8), + GemmConfig(128, 128, 128, 3, 4), + GemmConfig(256, 128, 64, 3, 8), + GemmConfig(128, 128, 128, 3, 8), + ] + + self.blackwell_persistent_addmm_configs: list[BaseConfig] = [ + GemmConfig(256, 128, 64, 2, 4), + ] + + self.scaled_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 256, 32, 3, 8), + GemmConfig(256, 128, 32, 3, 8), + GemmConfig(256, 64, 32, 4, 4), + GemmConfig(64, 256, 32, 4, 4), + GemmConfig(128, 128, 32, 4, 4), + GemmConfig(128, 64, 32, 4, 4), + GemmConfig(64, 128, 32, 4, 4), + GemmConfig(128, 32, 32, 4, 4), + GemmConfig(64, 32, 32, 5, 2), + GemmConfig(256, 128, 128, 3, 8), + GemmConfig(256, 64, 128, 4, 4), + GemmConfig(64, 256, 128, 4, 4), + GemmConfig(128, 128, 128, 4, 4), + GemmConfig(128, 64, 64, 4, 4), + GemmConfig(64, 128, 64, 4, 4), + GemmConfig(128, 32, 64, 4, 4), + GemmConfig(64, 32, 64, 5, 2), + GemmConfig(16, 32, 32, 2, 2), + GemmConfig(16, 64, 32, 2, 2), + GemmConfig(16, 128, 32, 2, 4), + GemmConfig(16, 256, 32, 2, 4), + GemmConfig(16, 32, 64, 2, 2), + GemmConfig(16, 64, 64, 2, 2), + GemmConfig(16, 128, 64, 2, 4), + GemmConfig(16, 256, 64, 2, 4), + GemmConfig(32, 32, 32, 2, 2), + GemmConfig(32, 64, 32, 2, 2), + GemmConfig(32, 128, 32, 2, 4), + GemmConfig(32, 256, 32, 2, 4), + GemmConfig(32, 32, 64, 2, 2), + GemmConfig(32, 64, 64, 2, 2), + GemmConfig(32, 128, 64, 2, 4), + GemmConfig(32, 256, 64, 2, 4), + GemmConfig(16, 32, 32, 3, 2), + GemmConfig(16, 64, 32, 3, 2), + GemmConfig(16, 128, 32, 3, 4), + GemmConfig(16, 256, 32, 3, 4), + GemmConfig(16, 32, 64, 3, 2), + GemmConfig(16, 64, 64, 3, 2), + GemmConfig(16, 128, 64, 3, 4), + GemmConfig(16, 256, 64, 3, 4), + GemmConfig(32, 32, 32, 3, 2), + GemmConfig(32, 64, 32, 3, 2), + GemmConfig(32, 128, 32, 3, 4), + GemmConfig(32, 256, 32, 3, 4), + GemmConfig(32, 32, 64, 3, 2), + GemmConfig(32, 64, 64, 3, 2), + GemmConfig(32, 128, 64, 3, 4), + GemmConfig(32, 256, 64, 3, 4), + GemmConfig(16, 32, 32, 4, 2), + GemmConfig(16, 64, 32, 4, 2), + GemmConfig(16, 128, 32, 4, 4), + GemmConfig(16, 256, 32, 4, 4), + GemmConfig(16, 32, 64, 4, 2), + GemmConfig(16, 64, 64, 4, 2), + GemmConfig(16, 128, 64, 4, 4), + GemmConfig(16, 256, 64, 4, 4), + GemmConfig(32, 32, 32, 4, 2), + GemmConfig(32, 64, 32, 4, 2), + GemmConfig(32, 128, 32, 4, 4), + GemmConfig(32, 256, 32, 4, 4), + GemmConfig(32, 32, 64, 4, 2), + GemmConfig(32, 64, 64, 4, 2), + GemmConfig(32, 128, 64, 4, 4), + GemmConfig(32, 256, 64, 4, 4), + GemmConfig(16, 32, 32, 5, 2), + GemmConfig(16, 64, 32, 5, 2), + GemmConfig(16, 128, 32, 5, 4), + GemmConfig(16, 256, 32, 5, 4), + GemmConfig(16, 32, 64, 5, 2), + GemmConfig(16, 64, 64, 5, 2), + GemmConfig(16, 128, 64, 5, 4), + GemmConfig(16, 256, 64, 5, 4), + GemmConfig(32, 32, 32, 5, 2), + GemmConfig(32, 64, 32, 5, 2), + GemmConfig(32, 128, 32, 5, 4), + GemmConfig(32, 256, 32, 5, 4), + GemmConfig(32, 32, 64, 5, 2), + GemmConfig(32, 64, 64, 5, 2), + GemmConfig(32, 128, 64, 5, 4), + GemmConfig(32, 256, 64, 5, 4), + GemmConfig(16, 32, 32, 6, 2), + GemmConfig(16, 64, 32, 6, 2), + GemmConfig(16, 128, 32, 6, 4), + GemmConfig(16, 256, 32, 6, 4), + GemmConfig(16, 32, 64, 6, 2), + GemmConfig(16, 64, 64, 6, 2), + GemmConfig(16, 128, 64, 6, 4), + GemmConfig(16, 256, 64, 6, 4), + GemmConfig(32, 32, 32, 6, 2), + GemmConfig(32, 64, 32, 6, 2), + GemmConfig(32, 128, 32, 6, 4), + GemmConfig(32, 256, 32, 6, 4), + GemmConfig(32, 32, 64, 6, 2), + GemmConfig(32, 64, 64, 6, 2), + GemmConfig(32, 128, 64, 6, 4), + GemmConfig(32, 256, 64, 6, 4), + GemmConfig(64, 16, 256, 5, 4), + GemmConfig(64, 32, 256, 5, 4), + GemmConfig(64, 128, 128, 2, 4), + GemmConfig(64, 128, 128, 3, 4), + GemmConfig(128, 128, 128, 2, 4), + GemmConfig(128, 256, 128, 4, 8), + GemmConfig(256, 128, 128, 2, 4), + GemmConfig(256, 128, 128, 2, 8), + ] + + self.scaled_persistent_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 128, 3, 8), + GemmConfig(128, 128, 128, 4, 8), + GemmConfig(128, 128, 128, 4, 4), + GemmConfig(128, 128, 128, 3, 4), + GemmConfig(128, 128, 128, 5, 4), + GemmConfig(128, 128, 128, 5, 8), + GemmConfig(128, 128, 128, 6, 8), + GemmConfig(128, 128, 64, 4, 8), + GemmConfig(64, 32, 256, 5, 4), + GemmConfig(128, 256, 128, 3, 8), + GemmConfig(64, 128, 256, 4, 4), + GemmConfig(64, 256, 128, 4, 4), + ] + + # TODO: Unify with other gemm patterns, mm_plus_mm currently follows + # slightly different pattern than rest + self.mm_plus_mm_configs: list[BaseConfig] = [ + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 64, 32, 3, 8), + GemmConfig(64, 64, 32, 4, 16), + GemmConfig(64, 32, 32, 4, 8), + GemmConfig(32, 64, 32, 4, 8), + GemmConfig(128, 128, 32, 1, 8), + GemmConfig(64, 64, 64, 1, 8), + GemmConfig(32, 32, 128, 1, 8), + GemmConfig(64, 64, 16, 2, 4), + GemmConfig(32, 32, 16, 1, 2), + ] + + self.conv_configs: list[BaseConfig] = [ + ConvConfig(64, 256, 16, 2, 4), + ConvConfig(256, 64, 16, 2, 4), + ConvConfig(1024, 16, 16, 1, 8), + ConvConfig(128, 128, 32, 2, 8), + ConvConfig(64, 64, 32, 2, 4), + ConvConfig(64, 256, 32, 2, 8), + ConvConfig(256, 64, 32, 2, 8), + ] + + self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [ + FlexConfig(128, 64, 3, 4), + FlexConfig(128, 128, 3, 4), + FlexConfig(128, 128, 2, 8), + FlexConfig(128, 128, 1, 8), + FlexConfig(64, 128, 3, 4), + FlexConfig(64, 64, 3, 4), + ] + + self.flex_attn_bwd_autotune_configs: list[FlexBwDConfig] = [ + # See Note: flex bwd configs + FlexBwDConfig(BLOCK_M, BLOCK_N, BLOCK_N, BLOCK_M, s, w) + for BLOCK_M in [32, 64] + for BLOCK_N in [32, 64, 128] + for s in [1, 3, 4, 5] # num_stages + for w in ([4, 8] if BLOCK_M >= 128 or BLOCK_N >= 128 else [4]) + if BLOCK_N % BLOCK_M == 0 + ] + + self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [ + FlexDecodeConfig(64, 3, 2), + FlexDecodeConfig(32, 3, 2), + FlexDecodeConfig(128, 3, 2), + ] + + self.exhaustive_flex_attn_fwd_configs: list[FlexConfig] = [ + FlexConfig(BLOCK_M, BLOCK_N, num_stages, num_warps) + for BLOCK_M in [16, 32, 64, 128] + for BLOCK_N in [32, 64, 128] + for num_stages in [1, 3, 4, 5] + for num_warps in [2, 4, 8] + ] + + self.exhaustive_flex_attn_bwd_configs: list[FlexBwDConfig] = [ + # See Note: flex bwd configs + FlexBwDConfig(BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2, num_stages, num_warps) + for BLOCK_M1 in [16, 32, 64, 128] + for BLOCK_N1 in [16, 32, 64, 128] + for BLOCK_M2 in [16, 32, 64, 128] + for BLOCK_N2 in [16, 32, 64, 128] + for num_stages in [1, 3, 4] + for num_warps in [2, 4, 8] + if BLOCK_N1 % BLOCK_M1 == 0 + and BLOCK_M2 % BLOCK_N2 == 0 # kernel static assertions + ] + + self.exhaustive_flex_decode_configs: list[FlexDecodeConfig] = [ + FlexDecodeConfig(block_n, num_stages, num_warps) + for block_n in [16, 32, 64, 128] + for num_stages in [1, 3, 4, 5] + for num_warps in [2, 4, 8] + ] + + def _finalize_mm_configs( + self, + configs: list[BaseConfig], + ) -> Generator[TritonConfig, None, None]: + """ + Finalizes configs after scaling, applying additional constraints. + """ + used: OrderedSet[tuple[Optional[int], ...]] = OrderedSet() + + max_mm_configs = config.test_configs.max_mm_configs + + for conf in configs: + # Each warp computes a 16x16 tile = 256 elements + num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256) + + # Construct key for finding duplicate configs + key: tuple[Optional[int], ...] = ( + conf.block_m, + conf.block_n, + conf.block_k, + conf.num_stages, + conf.hint_override, + num_warps, + ) + + # Check if gemm specific arg exists - add to key if does + group_m = getattr(conf, "group_m", None) + if group_m is not None: + key += (group_m,) + + if key not in used and ( + max_mm_configs is None or len(used) < max_mm_configs + ): + used.add(key) + kwargs = { + "BLOCK_M": conf.block_m, + "BLOCK_N": conf.block_n, + "BLOCK_K": conf.block_k, + "hint_override": conf.hint_override, + } + if group_m is not None: + kwargs["GROUP_M"] = group_m + yield self.triton_config(conf.num_stages, num_warps, **kwargs) + + def _scale_mm_configs( + self, + m: int, + n: int, + k: int, + configs: list[BaseConfig], + scale: float, + has_int8_tensor: bool, + exclude: Callable[[sympy.Integer, sympy.Integer, sympy.Integer], bool], + hint_override: Optional[int] = None, + ) -> list[BaseConfig]: + """ + Scales and filters matrix multiplication configs based on input size. + """ + if not self.should_scale_configs: + return configs + from ..runtime.runtime_utils import next_power_of_2 + + min_block_size = 16 + min_block_size_k = 32 if (has_int8_tensor or self.has_int8_tensor) else 16 + + scaled_configs = [] + for hint_override in [None] + config.multi_kernel_hints: + m_hint = max( + next_power_of_2( + V.graph.sizevars.size_hint( + m, + fallback=config.unbacked_symint_fallback, # type: ignore[arg-type] + hint_override=hint_override, + ) + ), + min_block_size, + ) + n_hint = max( + next_power_of_2( + V.graph.sizevars.size_hint( + n, + fallback=config.unbacked_symint_fallback, # type: ignore[arg-type] + hint_override=hint_override, + ) + ), + min_block_size, + ) + k_hint = max( + next_power_of_2( + V.graph.sizevars.size_hint( + k, + fallback=config.unbacked_symint_fallback, # type: ignore[arg-type] + hint_override=hint_override, + ) + ), + min_block_size_k, + ) + + for c in configs: + scaled_config = dataclasses.replace( + c, + block_m=max(min(int(c.block_m * scale), m_hint), min_block_size), + block_n=max(min(int(c.block_n * scale), n_hint), min_block_size), + block_k=max(min(int(c.block_k * scale), k_hint), min_block_size_k), + hint_override=hint_override, + ) + + if not exclude( + scaled_config.block_m, scaled_config.block_n, scaled_config.block_k + ): + scaled_configs.append(scaled_config) + + return scaled_configs + + def _get_exceeding_shared_memory_checker( + self, + ) -> Optional[Callable[[BaseConfig, int], bool]]: + """ + Returns a function that checks whether a given configuration exceeds the available shared memory for the device. + If the device does not report available shared memory, returns None. + """ + + try: + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + if hasattr(props, "shared_memory_per_block_optin"): # for NVidia GPUs + sm_available = int(props.shared_memory_per_block_optin) + elif hasattr(props, "shared_memory_per_block"): # for ROCm + sm_available = int(props.shared_memory_per_block) + else: + return None + + except Exception: + # If CUDA is not available or properties cannot be queried, return None + return None + + # TODO make a BaseDeviceConfigHeuristics to handle different device configuration in its own implementation. + def exceeds(gemm_config: BaseConfig, dtype_size: int) -> bool: + shared_mem_accum = dtype_size * ( + gemm_config.block_m * gemm_config.block_k + + gemm_config.block_n * gemm_config.block_k + ) + return shared_mem_accum * gemm_config.num_stages > sm_available + + return exceeds + + def _prune_exceeding_max_shared_mem_configs( + self, + configs: list[BaseConfig], + dtype_size: int, + ) -> list[BaseConfig]: + if dtype_size <= 0: + return configs + + is_exceeding_shared_memory = self._get_exceeding_shared_memory_checker() + if is_exceeding_shared_memory is None: + return configs + + return [c for c in configs if not is_exceeding_shared_memory(c, dtype_size)] + + def _prune_exhaustive_configs( + self, + configs: list[BaseConfig], + dtype_size: int, + ) -> list[BaseConfig]: + is_exceeding_shared_memory = self._get_exceeding_shared_memory_checker() + + pruned_configs = [] + for gemm_config in configs: + # Will use more shared memory than available + if is_exceeding_shared_memory and is_exceeding_shared_memory( + gemm_config, dtype_size + ): + continue + + NUM_REG = 255 + acc_regs = math.ceil( + gemm_config.block_m * gemm_config.block_n / (gemm_config.num_warps * 32) + ) + # Lower bound for register spillage, if exceeds the kernel will certainly spill + if acc_regs > NUM_REG: + continue + + pruned_configs.append(gemm_config) + + return pruned_configs + + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + """ + Filter configs based on specific requirements. + Subclasses can override this to implement custom filtering logic. + """ + return configs + + def preprocess_mm_configs( + self, + m: int, + n: int, + k: int, + configs: list[BaseConfig], + has_int8_tensor: bool = False, + scale: float = 1.0, + exclude: Callable[ + [sympy.Integer, sympy.Integer, sympy.Integer], bool + ] = lambda m, n, k: False, + dtype_size: int = 0, + op_name: str = "mm", # For preprocessing overrides e.g. on CPU + ) -> Generator[TritonConfig, None, None]: + configs = self._filter_configs(configs) + scaled_configs = self._scale_mm_configs( + m, n, k, configs, scale, has_int8_tensor, exclude + ) + + # Filter out configs that require more shared memory than is available. + if config.max_autotune_prune_choices_based_on_shared_mem: + scaled_configs = self._prune_exceeding_max_shared_mem_configs( + scaled_configs, dtype_size + ) + + if config.max_autotune_gemm_search_space == "EXHAUSTIVE": + assert dtype_size > 0, "dtype_size must be provided for exhaustive search" + scaled_configs = self._prune_exhaustive_configs(scaled_configs, dtype_size) + return self._finalize_mm_configs(scaled_configs) + + def triton_config( + self, num_stages: int, num_warps: int, **kwargs: Any + ) -> TritonConfig: + from triton import Config as TritonConfig # type: ignore[attr-defined] + + return TritonConfig(kwargs, num_stages=num_stages, num_warps=num_warps) + + def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.mm_configs) + + def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.exhaustive_configs) + + def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial( + self.preprocess_mm_configs, configs=self.conv_configs, op_name="conv" + ) + + # Flex attn helpers + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = FlexConfig(64, 64, 3, 4) + else: + default_config = FlexConfig(128, 64, 3, 4) + else: + if dtype == torch.float32: + default_config = FlexConfig(32, 16, 3, 4) + else: + default_config = FlexConfig(64, 32, 3, 4) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexBwDConfig]: + flex_attn_bwd_configs: list[FlexBwDConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + default_config = FlexBwDConfig(16, 16, 16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + default_config = FlexDecodeConfig(block_n=64, num_stages=1, num_warps=2) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs + + +class CPUConfigHeuristic(BaseConfigHeuristic): + """ + CPU-specific config heuristic with CPU-specific optimizations. + """ + + def _get_cpu_exclude_function( + self, method: str = "bmm" + ) -> Callable[[sympy.Integer, sympy.Integer, sympy.Integer], bool]: + """ + Get CPU-specific exclude function based on method type. + Returns a function that can be used as exclude condition. + Moved from mm_common._is_large_block_for_cpu and refactored to return a function. + """ + if method in ("conv"): + + def exclude_conv( + m: sympy.Integer, n: sympy.Integer, k: sympy.Integer + ) -> bool: + # Thresholds are experimentally determined to reduce Triton CPU compile times + if m > 256 or n > 256 or k > 256: + return True + return m * n * k > 2**17 + + return exclude_conv + elif method in ("mm", "addmm", "int_mm"): + + def exclude_mm( + m: sympy.Integer, n: sympy.Integer, k: sympy.Integer + ) -> bool: + return m * n > 2**13 + + return exclude_mm + else: # Default to bmm implementation for unknown methods + + def exclude_bmm( + m: sympy.Integer, n: sympy.Integer, k: sympy.Integer + ) -> bool: + if m > 128 or n > 128 or k > 128: + return True + return m * n > 2**12 + + return exclude_bmm + + def preprocess_mm_configs( + self, + m: int, + n: int, + k: int, + configs: list[BaseConfig], + has_int8_tensor: bool = False, + scale: float = 1.0, + exclude: Callable[ + [sympy.Integer, sympy.Integer, sympy.Integer], bool + ] = lambda m, n, k: False, + dtype_size: int = 0, + op_name: str = "mm", # For preprocessing overrides e.g. on CPU + ) -> Generator[TritonConfig, None, None]: + """ + CPU-specific preprocessing that applies CPU-specific scaling (0.5) and exclusion logic. + """ + # Get CPU-specific exclude function based on operation type + cpu_exclude_fn = self._get_cpu_exclude_function(op_name) + + # Apply CPU-specific scaling (0.5) and exclusion logic + return super().preprocess_mm_configs( + m, + n, + k, + configs=configs, + has_int8_tensor=has_int8_tensor, + scale=0.5, + exclude=cpu_exclude_fn, + dtype_size=dtype_size, + op_name=op_name, + ) + + +class CUDAConfigHeuristic(BaseConfigHeuristic): + """ + Child class for CUDA device specific gemm/flex attention/conv/ configs. + """ + + def __init__(self) -> None: + super().__init__() + self.sm_120_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 2, 4), + (torch.float32, 128): FlexConfig(128, 32, 2, 4), + (torch.float32, 256): FlexConfig(64, 16, 2, 4), + (torch.bfloat16, 64): FlexConfig(128, 64, 2, 4), + (torch.bfloat16, 128): FlexConfig(128, 64, 2, 8), + (torch.bfloat16, 256): FlexConfig(32, 64, 2, 4), + (torch.float16, 64): FlexConfig(128, 64, 2, 4), + (torch.float16, 128): FlexConfig(128, 64, 2, 8), + (torch.float16, 256): FlexConfig(32, 64, 2, 4), + } + + self.sm_100_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 3, 4), + (torch.float32, 128): FlexConfig(32, 64, 3, 4), + (torch.float32, 192): FlexConfig(32, 64, 2, 4), + (torch.float32, 256): FlexConfig(32, 32, 3, 4), + (torch.bfloat16, 64): FlexConfig(128, 128, 3, 4), + (torch.bfloat16, 128): FlexConfig(128, 64, 3, 8), + (torch.bfloat16, 192): FlexConfig(128, 128, 1, 8), + (torch.bfloat16, 256): FlexConfig(64, 32, 3, 4), + (torch.float16, 64): FlexConfig(128, 128, 3, 4), + (torch.float16, 128): FlexConfig(128, 64, 3, 8), + (torch.float16, 192): FlexConfig(128, 128, 1, 8), + (torch.float16, 256): FlexConfig(64, 32, 3, 4), + } + + self.h100_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 3, 4), + (torch.float32, 128): FlexConfig(32, 64, 3, 4), + (torch.float32, 256): FlexConfig(32, 32, 3, 4), + (torch.bfloat16, 64): FlexConfig(128, 128, 3, 4), + (torch.bfloat16, 128): FlexConfig(128, 64, 3, 8), + (torch.bfloat16, 256): FlexConfig(64, 32, 3, 4), + (torch.float16, 64): FlexConfig(128, 128, 3, 4), + (torch.float16, 128): FlexConfig(128, 64, 3, 8), + (torch.float16, 256): FlexConfig(64, 32, 3, 4), + } + + self.a100_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 3, 4), + (torch.float32, 128): FlexConfig(128, 32, 3, 4), + (torch.float32, 256): FlexConfig(64, 16, 3, 4), + (torch.bfloat16, 64): FlexConfig(128, 64, 3, 4), + (torch.bfloat16, 128): FlexConfig(128, 64, 3, 8), + (torch.bfloat16, 256): FlexConfig(32, 64, 3, 4), + (torch.float16, 64): FlexConfig(128, 64, 3, 4), + (torch.float16, 128): FlexConfig(128, 64, 3, 8), + (torch.float16, 256): FlexConfig(32, 64, 3, 4), + } + + # Overwriting the configs omitting BLOCK_N of size 128 that cause ULFs + self.flex_attn_bwd_autotune_configs: list[FlexBwDConfig] = [ + # See Note: flex bwd configs + FlexBwDConfig(BLOCK_M, BLOCK_N, BLOCK_N, BLOCK_M, s, 4) + for BLOCK_M in [32, 64] + for BLOCK_N in [32, 64] + for s in [1, 3, 4, 5] # num_stages + if BLOCK_N % BLOCK_M == 0 + ] + + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + capability = torch.cuda.get_device_capability() + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = FlexConfig(64, 64, 3, 4) + else: + default_config = FlexConfig(64, 64, 3, 4) + if capability >= (12, 0): + default_config = self.sm_120_default_flex_config.get( + (dtype, head_dim), default_config + ) + elif capability >= (10, 0): + default_config = self.sm_100_default_flex_config.get( + (dtype, head_dim), default_config + ) + elif capability == (9, 0): + default_config = self.h100_default_flex_config.get( + (dtype, head_dim), default_config + ) + elif capability >= (8, 0): + default_config = self.a100_default_flex_config.get( + (dtype, head_dim), default_config + ) + else: + if dtype == torch.float32: + default_config = FlexConfig(32, 16, 3, 4) + else: + default_config = FlexConfig(64, 32, 3, 4) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexBwDConfig]: + capability = torch.cuda.get_device_capability() + flex_attn_bwd_configs: list[FlexBwDConfig] = [] + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + major, minor = capability + if dtype == torch.float32: + capability_class = "float32" + elif major == 12: + capability_class = "sm12x" + elif major >= 10: + capability_class = "sm10x" + elif capability == (9, 0): + capability_class = "sm90" + elif major >= 8: + capability_class = "sm8x" + else: + capability_class = "baseline" + + # fmt: off + config_map = { + "float32": lambda h: FlexBwDConfig(16, 16, 16, 16, 1, 4), + "baseline": lambda h: FlexBwDConfig(16, 16, 16, 16, 1, 4), + "sm90": lambda h: ( + FlexBwDConfig(64, 64, 64, 64, 3, 4) if h < 64 else + FlexBwDConfig(64, 128, 128, 64, 3, 8) if h <= 128 else + FlexBwDConfig(64, 64, 64, 64, 2, 4) + ), + "sm10x": lambda h: ( + FlexBwDConfig(64, 128, 128, 64, 3, 4) if h <= 128 else + FlexBwDConfig(64, 64, 64, 64, 1, 8) if h <= 192 else + FlexBwDConfig(64, 64, 64, 64, 1, 4) + ), + "sm8x": lambda h: ( + FlexBwDConfig(32, 128, 128, 32, 3, 4) + if h < 64 + else FlexBwDConfig( + 64, 64, 64, 64, 3 if minor == 6 and h == 128 else 2, 4 + ) + ), + "sm12x": lambda h: ( + FlexBwDConfig(32, 128, 128, 32, 3, 4) + if h < 64 + else FlexBwDConfig( + 64, 64, 64, 64, 3 if minor == 6 and h == 128 else 2, 4 + ) + ), + } + # fmt: on + + if head_dim <= 256: + default_config = config_map[capability_class](head_dim) + else: + default_config = FlexBwDConfig(16, 16, 16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + capability = torch.cuda.get_device_capability() + + default_config = FlexDecodeConfig(64, 1, 2) + + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + if capability in [(9, 0), (10, 0), (10, 3)]: # sm_90, sm_100, sm_103 + if head_dim > 128 and dtype == torch.float32: + default_config = FlexDecodeConfig(64, 1, 2) + else: + default_config = FlexDecodeConfig(64, 3, 2) + else: + default_config = FlexDecodeConfig(64, 1, 2) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs + + +class ROCmConfigHeuristic(BaseConfigHeuristic): + """ + Child class for ROCm specific gemm/flex attention/conv/ configs. + """ + + def __init__(self) -> None: + super().__init__() + + self.default_num_stages = get_backend_num_stages() + + self.mm_configs: list[BaseConfig] = [ + ROCmGemmConfig( + 16, 16, 256, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(32, 16, 256, self.default_num_stages, 4, group_m=4), + ROCmGemmConfig( + 32, 32, 16, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(32, 32, 128, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(32, 64, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig( + 64, 16, 128, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(64, 32, 32, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 32, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 32, 64, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(64, 32, 128, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 64, 16, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 64, 64, self.default_num_stages, 4, group_m=4), + ROCmGemmConfig(64, 64, 128, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig(64, 64, 256, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 64, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(64, 128, 32, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(64, 128, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(64, 128, 128, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(128, 32, 32, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(128, 32, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig( + 128, 64, 32, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(128, 64, 64, self.default_num_stages, 4, group_m=16), + ROCmGemmConfig(128, 64, 128, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 128, 128, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 128, 32, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig( + 128, 128, 32, self.default_num_stages, 8, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 128, 64, self.default_num_stages, 4, group_m=16), + ROCmGemmConfig(128, 128, 64, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(128, 128, 128, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig( + 128, 256, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 256, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(256, 64, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 256, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(256, 128, 32, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig(256, 128, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(256, 256, 64, self.default_num_stages, 8, group_m=4), + ] + + # Exhaustive search for mm configs + self.exhaustive_configs: list[BaseConfig] = [ + ROCmGemmConfig( + BLOCK_M, + BLOCK_N, + BLOCK_K, + num_stages, + num_warps, + group_m=group_m, + matrix_instr_nonkdim=matrix_instr_nonkdim, + waves_per_eu=waves_per_eu, + kpack=kpack, + ) + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( + [16, 32, 64, 128, 256], repeat=3 + ) + for num_stages in [1, self.default_num_stages] + for num_warps in [4, 8] + for group_m in [4, 8, 16] + for matrix_instr_nonkdim in [0, 16] + for waves_per_eu in [0, 2] + for kpack in [2] + ] + + self.default_flex_config = { + (torch.float32, 64): ROCmFlexConfig(128, 32, 1, 4), + (torch.float32, 128): ROCmFlexConfig(128, 32, 1, 4), + (torch.float32, 256): ROCmFlexConfig(64, 16, 1, 4), + (torch.bfloat16, 64): ROCmFlexConfig(128, 64, 1, 8), + (torch.bfloat16, 128): ROCmFlexConfig(128, 64, 1, 8), + (torch.bfloat16, 256): ROCmFlexConfig(32, 64, 1, 8), + (torch.float16, 64): ROCmFlexConfig(128, 64, 1, 8), + (torch.float16, 128): ROCmFlexConfig(128, 64, 1, 8), + (torch.float16, 256): ROCmFlexConfig(32, 64, 1, 4), + } + + self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK1, BLOCK2, 1, w) + for BLOCK1 in [16, 64, 128] + for BLOCK2 in [16, 32, 64, 128] + for w in [4, 8] + ] + + self.flex_attn_bwd_autotune_configs: list[FlexBwDConfig] = [ + # See Note: flex bwd configs + ROCmFlexBwDConfig(BLOCK1, BLOCK2, BLOCK2, BLOCK1, 1, w, mfma) + for BLOCK1 in [16, 32, 64] + for BLOCK2 in [32, 64, 128] + for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4]) + for mfma in [0, 16] + if BLOCK2 % BLOCK1 == 0 + ] + + self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [ + ROCmFlexDecodeConfig(32, 1, 4), + ROCmFlexDecodeConfig(64, 1, 4), + ROCmFlexDecodeConfig(128, 1, 4), + ROCmFlexDecodeConfig(32, 1, 8), + ROCmFlexDecodeConfig(64, 1, 8), + ROCmFlexDecodeConfig(128, 1, 8), + ] + + self.exhaustive_flex_attn_fwd_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK_M, BLOCK_N, num_stages, num_warps, mfma, wpeu) + for BLOCK_M in [16, 32, 64, 128] + for BLOCK_N in [32, 64, 128] + for num_stages in [1, 2] + for num_warps in [2, 4, 8] + for mfma in [0, 16] + for wpeu in [0, int(8 // num_warps)] + ] + + self.exhaustive_flex_attn_bwd_configs: list[FlexBwDConfig] = [ + # See Note: flex bwd configs + ROCmFlexBwDConfig( + BLOCK_M1, + BLOCK_N1, + BLOCK_M2, + BLOCK_N2, + num_stages, + num_warps, + mfma, + wpeu, + ) + for BLOCK_M1 in [16, 32, 64, 128] + for BLOCK_N1 in [16, 32, 64, 128] + for BLOCK_M2 in [16, 32, 64, 128] + for BLOCK_N2 in [16, 32, 64, 128] + for num_stages in [1, 2] + for num_warps in [2, 4, 8] + for mfma in [0, 16] + for wpeu in [0, int(8 // num_warps)] + if BLOCK_N1 % BLOCK_M1 == 0 + and BLOCK_M2 % BLOCK_N2 == 0 # kernel static assertions + ] + + self.exhaustive_flex_decode_configs: list[FlexDecodeConfig] = [ + ROCmFlexDecodeConfig(block_n, num_stages, num_warps, mfma, wpeu, kpack=2) + for block_n in [16, 32, 64, 128] + for num_stages in [1, 2] + for num_warps in [2, 4, 8] + for mfma in [0, 16] + for wpeu in [0, int(8 // num_warps)] + ] + + def _prune_exhaustive_configs( + self, + configs: list[BaseConfig], + dtype_size: int, + ) -> list[BaseConfig]: + # these cause AMD compile to crash + pruned_configs = [ + c + for c in configs + if not ( + getattr(c, "matrix_instr_nonkdim", 0) == 2 + and getattr(c, "kpack", 0) == 2 + ) + ] + return pruned_configs + + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + """ + ROCm specific filtering + """ + for c in configs: + c.num_stages = self.default_num_stages + return super()._filter_configs(configs) + + def _finalize_mm_configs( + self, + configs: list[BaseConfig], + ) -> Generator[TritonConfig, None, None]: + """ + Finalizes configs after scaling, applying additional constraints. + """ + used: OrderedSet[tuple[int, ...]] = OrderedSet() + + max_mm_configs = config.test_configs.max_mm_configs + + for conf in configs: + # Each warp computes a 16x16 tile = 256 elements + conf.num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256) + + # Defaults for AMD triton backend kern args if not set + matrix_instr_nonkdim = getattr(conf, "matrix_instr_nonkdim", 16) + waves_per_eu = getattr(conf, "waves_per_eu", 0) + kpack = getattr(conf, "kpack", 2) + + if matrix_instr_nonkdim != 0 and ( + conf.block_m % matrix_instr_nonkdim != 0 + or conf.block_n % matrix_instr_nonkdim != 0 + ): + # block_m and block_n must be a multiple of matrix_instr_nonkdim + continue + + # Construct key for finding duplicate configs + key: tuple[int, ...] = ( + conf.block_m, + conf.block_n, + conf.block_k, + conf.num_stages, + conf.num_warps, + waves_per_eu, + matrix_instr_nonkdim, + kpack, + ) + + # Check if gemm specific arg exists - add to key if does + group_m = getattr(conf, "group_m", None) + # AMD GPU crashes if group_m = 0 + if group_m is not None and group_m <= 0: + group_m = 8 + if group_m is not None: + key += (group_m,) + + if waves_per_eu != 0: + waves_per_eu = int(8 // conf.num_warps) + + if key not in used and ( + max_mm_configs is None or len(used) < max_mm_configs + ): + used.add(key) + kwargs = { + "BLOCK_M": conf.block_m, + "BLOCK_N": conf.block_n, + "BLOCK_K": conf.block_k, + "num_stages": conf.num_stages, + "num_warps": conf.num_warps, + "matrix_instr_nonkdim": matrix_instr_nonkdim, + "waves_per_eu": waves_per_eu, + "kpack": kpack, + } + if group_m is not None: + kwargs["GROUP_M"] = group_m + yield self.triton_config(**kwargs) + + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = ROCmFlexConfig(64, 64, 1, 4) + else: + default_config = ROCmFlexConfig(128, 64, 1, 8) + default_config = self.default_flex_config.get( + (dtype, head_dim), default_config + ) + else: + if dtype == torch.float32: + default_config = ROCmFlexConfig(32, 16, 1, 4) + else: + default_config = ROCmFlexConfig(64, 32, 1, 4) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexBwDConfig]: + flex_attn_bwd_configs: list[FlexBwDConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + if dtype == torch.float32: + default_config = ROCmFlexBwDConfig(16, 16, 16, 16, 1, 4) + elif head_dim <= 256: + if head_dim == 64: + default_config = ROCmFlexBwDConfig(64, 64, 64, 64, 1, 4) + elif head_dim == 128: + default_config = ROCmFlexBwDConfig(64, 128, 128, 64, 1, 8) + else: + default_config = ROCmFlexBwDConfig(64, 64, 64, 64, 1, 4) + else: + default_config = ROCmFlexBwDConfig(16, 16, 16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + default_config = ROCmFlexDecodeConfig(64, 1, 4) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs + + +class XPUConfigHeuristic(BaseConfigHeuristic): + """ + Placeholder child class for Intel GPU specific overrides. + """ + + def __init__(self) -> None: + super().__init__() + self.xpu_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 1, 16), + (torch.float32, 128): FlexConfig(128, 32, 1, 16), + (torch.float32, 256): FlexConfig(64, 16, 1, 8), + (torch.bfloat16, 64): FlexConfig(128, 64, 1, 16), + (torch.bfloat16, 128): FlexConfig(128, 64, 1, 16), + (torch.bfloat16, 256): FlexConfig(32, 64, 1, 4), + (torch.float16, 64): FlexConfig(128, 64, 1, 16), + (torch.float16, 128): FlexConfig(128, 64, 1, 16), + (torch.float16, 256): FlexConfig(32, 64, 1, 4), + } + self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [ + FlexConfig(32, 16, 2, 4), + FlexConfig(128, 64, 2, 16), + FlexConfig(128, 64, 2, 8), + FlexConfig(128, 32, 2, 16), + FlexConfig(128, 32, 2, 8), + ] + self.flex_attn_bwd_autotune_configs: list[FlexBwDConfig] = [] + self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [] + + if not bool(os.getenv("CI")): + self.flex_attn_bwd_autotune_configs += [ + # See Note: flex bwd configs + FlexBwDConfig(BLOCK1, BLOCK2, BLOCK2, BLOCK1, s, w) + for BLOCK1 in [32, 64] + for BLOCK2 in [32, 64, 128] + for s in [1, 3, 4, 5] # num_stages + for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4]) + if BLOCK2 % BLOCK1 == 0 + ] + self.flex_decode_autotune_configs += [ + FlexDecodeConfig(32, 1, 2), + FlexDecodeConfig(32, 1, 1), + FlexDecodeConfig(32, 2, 2), + FlexDecodeConfig(32, 2, 1), + FlexDecodeConfig(64, 1, 2), + FlexDecodeConfig(64, 1, 1), + FlexDecodeConfig(64, 2, 2), + FlexDecodeConfig(64, 2, 1), + ] + + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = FlexConfig(64, 64, 1, 8) + else: + default_config = FlexConfig(128, 64, 1, 16) + default_config = self.xpu_default_flex_config.get( + (dtype, head_dim), default_config + ) + else: + if dtype == torch.float32: + default_config = FlexConfig(32, 16, 1, 4) + else: + default_config = FlexConfig(64, 32, 1, 8) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexBwDConfig]: + flex_attn_bwd_configs: list[FlexBwDConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + if dtype == torch.float32: + default_config = FlexBwDConfig(16, 16, 16, 16, 1, 4) + elif head_dim <= 256: + if head_dim == 64: + default_config = FlexBwDConfig(64, 64, 64, 64, 1, 8) + elif head_dim == 128: + default_config = FlexBwDConfig(64, 128, 64, 128, 1, 8) + else: + default_config = FlexBwDConfig(64, 64, 64, 64, 1, 8) + else: # modest hardware or extremely large head_dim + default_config = FlexBwDConfig(16, 16, 16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + default_config = FlexDecodeConfig(64, 1, 2) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs + + def _prune_exhaustive_configs( + self, + configs: list[BaseConfig], + dtype_size: int, + ) -> list[BaseConfig]: + return configs + + +class MTIAConfigHeuristic(BaseConfigHeuristic): + """ + Placeholder child class for MTIA specific overrides. + """ + + +# Template-specific mixin classes +class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics): + """ + Mixin class that converts config lists to template kwargs. + This handles the logic that was previously in choices.get_mm_configs. + + This mixin expects to be used with BaseConfigHeuristic or its subclasses. + """ + + # Type annotations to ensure the mixin works with BaseConfigHeuristic + get_mm_configs: Callable[[], partial[Generator[TritonConfig, None, None]]] + get_exhaustive_mm_configs: Callable[ + [], partial[Generator[TritonConfig, None, None]] + ] + _filter_configs: Callable[[list[BaseConfig]], list[BaseConfig]] + + def get_extra_kwargs( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> dict[str, Any]: + assert isinstance(kernel_inputs, MMKernelInputs) + m, n, k = kernel_inputs.mnk_symbolic() + # Calculate allow_tf32 + allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and ( + not inductor_config.force_same_precision + or ((m % 16) == 0 and (n % 16) == 0 and (k % 8) == 0) + ) + + return { + "ALLOW_TF32": allow_tf32, + } + + def _valid(self, kernel_inputs: KernelInputs) -> bool: + return True + + def _get_config_generator( + self, + ) -> partial[Generator[TritonConfig, None, None]]: + """ + Get the appropriate config generator based on search space. + Can be overridden by subclasses for template-specific behavior. + """ + # Handle exhaustive search case + if config.max_autotune_gemm_search_space == "EXHAUSTIVE": + return self.get_exhaustive_mm_configs() + else: + return self.get_mm_configs() + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Convert config lists to template kwargs. + This replaces the logic from choices.get_mm_configs and inlines mm_options. + """ + assert isinstance(kernel_inputs, MMKernelInputs), ( + f"{self.__class__.__name__} requires MMKernelInputs" + ) + input_nodes = kernel_inputs.nodes() + if len(input_nodes) < 2: + raise ValueError(f"Need at least 2 input tensors, got {len(input_nodes)}") + if not self._valid(kernel_inputs): + return + + # Extract M, N, K from kernel_inputs + m, n, k = kernel_inputs.mnk_symbolic() + + # Extract dtype and device_type from kernel_inputs + dtype = kernel_inputs.dtype() + + # Get the appropriate config generator + configs = self._get_config_generator() + + # Generate and process configs + for c in configs(m, n, k, dtype_size=dtype.itemsize, op_name=op_name): + template_kwargs = self._convert_config_to_template_kwargs( + c, + m, + n, + k, + kernel_inputs.out_dtype(), + ) + yield template_kwargs + + def _convert_config_to_template_kwargs( + self, + triton_config: TritonConfig, + m: sympy.Integer, + n: sympy.Integer, + k: sympy.Integer, + out_dtype: torch.dtype, + ) -> dict[str, Any]: + """ + Convert triton config to template kwargs. + Moved from mm_common.mm_options. + """ + # Calculate EVEN_K symbolic + even_k_symbolic = ( + # it isn't worth guarding on this + sympy.gcd(k, triton_config.kwargs["BLOCK_K"]) + == triton_config.kwargs["BLOCK_K"] + ) + + # Build options dict + + options_dict = dict( + EVEN_K=even_k_symbolic, + USE_FAST_ACCUM=False, # Option for _scaled_mm + ACC_TYPE=self._get_acc_type(out_dtype), + num_stages=triton_config.num_stages, + num_warps=triton_config.num_warps, + **triton_config.kwargs, + ) + + # If GROUP_M not specified then default to 8 + if "GROUP_M" not in triton_config.kwargs: + group_m = triton_config.kwargs.get("GROUP_M", 8) + options_dict["GROUP_M"] = group_m + + return options_dict + + def _get_acc_type(self, dtype: torch.dtype) -> str: + """ + Get accumulator type for the given dtype. + Moved from mm_common.acc_type. + """ + if dtype in (torch.float16, torch.bfloat16): + return "tl.float32" + return f"tl.{dtype}".replace("torch.", "") + + +# INT8 specific mixin to filter correctly +class INT8MMTemplateConfigMixin(MMTemplateConfigMixin): + """ + Ensure that we feed in has_int8_tensor=True + """ + + def __init__(self) -> None: + super().__init__() + self.has_int8_tensor = True + + +# MMPlusMM specific mixin to avoid running _scale_mm_configs +class MMPlusMMTemplateConfigMixin(MMTemplateConfigMixin): + """ + Ensure that _should_scale_configs is False + """ + + # TODO(coconutruben): remove this once all tests work + # with proper scaling on mm_plus_mm + def __init__(self) -> None: + super().__init__() + self.should_scale_configs = False + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + assert isinstance(kernel_inputs, MMKernelInputs), "Expect MMKernelInputs" + m, n, k = kernel_inputs.mnk_symbolic() + for kwargs in super()._get_template_configs_impl(kernel_inputs, op_name): + # Apply BLOCK_K constraint specific to mm_plus_mm + # see https://github.com/triton-lang/triton/issues/1298 + # BLOCK_K = K causes llvm error + if V.graph.sizevars.statically_known_lt(kwargs.get("BLOCK_K", k), k): + yield kwargs + + +class TMAWorkspaceMixin(MMTemplateConfigMixin): + """ + Small mixin to ensure that the workspace arg is correct for TMA + and TMA specific filtering can happen. + """ + + def get_extra_kwargs( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> dict[str, Any]: + kwargs = super().get_extra_kwargs(kernel_inputs, op_name) + kwargs["workspace_arg"] = get_tma_workspace_arg( + num_tma_descriptors=2, + device=kernel_inputs.device(), + ) + return kwargs + + # pyrefly: ignore [bad-override] + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + """ + TMA specific filtering, as num_warps=2 not safe for TMA + """ + configs = [c for c in configs if c.num_warps != 2] + return super()._filter_configs(configs) + + +# TMA-specific mixin for TMA templates +class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin): + """ + TMA-specific mixin that uses persistent configs and adds TMA options. + This inherits from MMTemplateConfigMixin and overrides config generation. + """ + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Generate TMA template configs by calling super and adding TMA-specific options. + """ + assert isinstance(kernel_inputs, MMKernelInputs), ( + "TMATemplateConfigMixin requires MMKernelInputs" + ) + mat1, mat2 = kernel_inputs.mat1mat2() + tma_opts = { + "A_ROW_MAJOR": not mat1.layout.is_transposed(), + "B_ROW_MAJOR": not mat2.layout.is_transposed(), + "NUM_SMS": get_num_sms(), + "TMA_SIZE": TMA_DESCRIPTOR_SIZE, + "TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api(), + "tma_store": config.triton.enable_template_tma_store, + "transpose_discontiguous_tensor_descriptors_override": True, + } + # Get base template configs from superclass + for template_kwargs in super()._get_template_configs_impl( + kernel_inputs, + op_name, + ): + yield {**template_kwargs, **tma_opts} + + +# TMA mixins for Blackwell templates +class BlackwellTMATemplateConfigMixin(TMATemplateConfigMixin): + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Generate TMA template configs by calling super and adding TMA-specific options. + """ + base_ops = { + "NUM_SMS": get_num_sms(), + # TODO: Consider making this tunable. + "FLATTEN": True, + } + # Get base template configs from superclass + for template_kwargs in super()._get_template_configs_impl( + kernel_inputs, + op_name, + ): + # Some Triton versions requires num_warps >= 4 for WS + # to avoid compilation issues. Triton disables WS if num_warps < 4 + # or num_stages < 2. Similar issues have been seen with num_stages=1 + ws = ( + template_kwargs["num_warps"] >= 4 and template_kwargs["num_stages"] >= 2 + ) + yield { + **template_kwargs, + **base_ops, + "WARP_SPECIALIZE": ws, + "EPILOGUE_SUBTILE": config.triton.enable_epilogue_subtiling, + } + + +# Scaled MM-specific mixin for scaled MM templates +class BaseScaledMMConfigMixin(MMTemplateConfigMixin): + """ + This is a base that handles the common case for ScaledMM + + The TMA and non-TMA should build on top of this + """ + + def adjust_kernel_inputs( + self, kernel_inputs: KernelInputs, op_name: str + ) -> KernelInputs: + """ + for scaled_mm, we need to unsqueeze scale tensors, and bias + """ + assert isinstance(kernel_inputs, MMKernelInputs), ( + "Expect MMKernelInputs for scaled MM" + ) + inputs = super().adjust_kernel_inputs(kernel_inputs, op_name) + nodes = inputs.nodes() + mat_a, mat_b, scale_a, scale_b, *bias = nodes + bias = bias[0] if bias else None + # Prepare triton input nodes and create kernel_inputs at the top + from ..lowering import lowerings as L + + aten = torch.ops.aten + if bias and len(mat_b.get_size()) == len(bias.get_size()) + 1: + # Need to unsqueeze bias from [N] -> [1, N] + bias = L[aten.unsqueeze](bias, 0) + + if len(scale_a.get_size()) == 0 or len(scale_b.get_size()) == 0: + assert len(scale_a.get_size()) == len(scale_b.get_size()) + # Need to unsqueeze scale from [] -> [1, 1] + scale_a = L[aten.unsqueeze](L[aten.unsqueeze](scale_a, 0), 1) + scale_b = L[aten.unsqueeze](L[aten.unsqueeze](scale_b, 0), 1) + nodes = [mat_a, mat_b, scale_a, scale_b] + if bias: + nodes.append(bias) + return MMKernelInputs( + nodes, mat1_idx=kernel_inputs._mat1_idx, mat2_idx=kernel_inputs._mat2_idx + ) + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Generate scaled MM template configs with scaled MM-specific options. + Handles the remaining logic from mm_common, including assertions. + """ + kernel_inputs = self.adjust_kernel_inputs(kernel_inputs, op_name) + input_nodes = kernel_inputs.nodes() + # Initial assertion from mm_common.scaled_mm_options + assert len(input_nodes) >= 4, ( + f"scaled_mm requires at least 4 inputs, got {len(input_nodes)}" + ) + + # Extract scale tensors (typically scale_a and scale_b are input_nodes[2] and input_nodes[3]) + scale_a = input_nodes[2] + scale_b = input_nodes[3] + + # Scale compatibility assertion from mm_common.scaled_mm_options + def are_compatible_scales(size_a: Any, size_b: Any) -> bool: + # Same sized scales are compatible + if len(size_a) == len(size_b): + return True + + # Both need to be scalars or len(1) tensors + if len(size_a) <= 1 and len(size_b) <= 1: + return True + + return False + + size_a, size_b = scale_a.get_size(), scale_b.get_size() + assert are_compatible_scales(size_a, size_b), ( + "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " + f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." + ) + + assert isinstance(kernel_inputs, MMKernelInputs), ( + f"{self.__class__.__name__} requires MMKernelInputs" + ) + + if not self._valid(kernel_inputs): + return + + # Get base template configs from superclass + for template_kwargs in super()._get_template_configs_impl( + kernel_inputs, op_name + ): + # Add scaled MM-specific options (moved from mm_common.scaled_mm_options) + # Override accumulator type for scaled MM + template_kwargs["ACC_TYPE"] = "tl.float32" + + yield template_kwargs + + +class ScaledMMConfigMixin(BaseScaledMMConfigMixin): + """Mixing for scaled mm with the regular mm template""" + + def get_extra_kwargs( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> dict[str, Any]: + kwargs = super().get_extra_kwargs(kernel_inputs, op_name) + from ..kernel.mm_common import scale_mm_epilogue + + return { + **kwargs, + "suffix_args": kernel_inputs.count - 2, + "epilogue_fn": scale_mm_epilogue(), + "epilogue_fn_hash": "scale_mm_epilogue", + } + + def _valid(self, kernel_inputs: KernelInputs) -> bool: + assert isinstance(kernel_inputs, MMKernelInputs), ( + "Expect MMKernelInputs for ScaledMMConfigMixin" + ) + _, _, k = kernel_inputs.mnk_symbolic() + if V.graph.sizevars.guard_or_false(sympy.Le(k, 16)): + # Triton crashes however uncommon for real workloads + return False + + # On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid + # source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape + if using_b200() and V.graph.sizevars.guard_or_false(sympy.Lt(k, 32)): + return False + return True + + # pyrefly: ignore [bad-override] + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + """ + Filter out bad configs for specific hardware. + On AMD MI350X (GFX 9.5+), skip configs with BLOCK_K<=64 due to lack of corresponding MFMA instructions. + """ + + def should_skip_mi350x_config(config: BaseConfig) -> bool: + """Skip config if BLOCK_K<=64 on MI350X (GFX 9.5+)""" + try: + return ( + config.block_k <= 64 + and torch.version.hip is not None + and torch.cuda.get_device_capability() >= (9, 5) + ) + except RuntimeError: + # If no HIP GPUs are available, we can't check device capability + # so we don't skip any configs + return False + + filtered_configs = [c for c in configs if not should_skip_mi350x_config(c)] + return super()._filter_configs(filtered_configs) + + +# Scaled TMA-specific mixin for scaled MM templates with TMA +class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin): + """ + Scaled TMA-specific mixin that extends BaseScaledMMConfigMixin with TMA functionality. + This is for scaled MM templates that use device TMA. + This inherits from BaseScaledMMConfigMixin and adds TMA-specific options. + """ + + # pyrefly: ignore [bad-override] + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + """ + TMA specific filtering: + - num_warps=2 not safe for TMA + - block_k >= 32 required for TMA (requires inner-most dimension >= 32) + """ + configs = [c for c in configs if c.num_warps != 2 and c.block_k >= 32] + return super()._filter_configs(configs) + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Generate scaled TMA template configs with both scaled MM and TMA-specific options. + """ + # Get base scaled MM template configs from superclass + for template_kwargs in super()._get_template_configs_impl( + kernel_inputs, + op_name, + ): + # Add TMA-specific options for device TMA scaled MM + template_kwargs["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE + template_kwargs["NUM_SMS"] = get_num_sms() + template_kwargs["TMA_EXPERIMENTAL_API"] = not has_triton_stable_tma_api() + + yield template_kwargs + + +# Scaled Blackwell TMA-specific mixin for scaled MM templates with TMA +class ScaledBlackwellTMAConfigMixin( + BlackwellTMATemplateConfigMixin, ScaledMMConfigMixin +): + """ + Scaled Blackwell TMA-specific mixin that extends ScaledMMConfigMixin with TMA functionality. + This is for scaled MM templates that use device TMA on Blackwell. + This inherits from ScaledMMConfigMixin, which inherits the scale_mm_epilogue, and adds TMA-specific options. + """ + + # pyrefly: ignore [bad-override] + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + """ + Warp specialization-specific filtering (BlackwellTMATemplateConfigMixin) + (compilation issues occur in some versions of Triton) + - num_warps < 4 unsafe for warpspec + - num_stages < 2 unsafe for warpspec + + TMA-specific filtering: + - block_k >= 32 required for TMA (requires inner-most dimension >= 32) + """ + configs = [c for c in configs if c.block_k >= 32] + return super()._filter_configs(configs) + + +# Template-specific heuristic classes using multiple inheritance + + +@register_template_heuristic( + mm_template.uid, + "cuda", + register=torch.version.hip is None, +) +@register_template_heuristic( + bmm_template.uid, + "cuda", + register=torch.version.hip is None, +) +class CUDAMMTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic): + """Standard MM template heuristic for CUDA""" + + +@register_template_heuristic( + mm_template.uid, "cuda", register=torch.version.hip is None, op_name="addmm" +) +@register_template_heuristic( + bmm_template.uid, "cuda", register=torch.version.hip is None, op_name="baddbmm" +) +class CUDAAddMMTemplateConfigHeuristic(AddMMConfigMixin, CUDAMMTemplateConfigHeuristic): + """Addmm specific mixin for CUDA""" + + +# TODO(coconutruben): deprecate once autoheuristic is deprecated +@register_template_heuristic( + mm_template.uid, + "cuda", + register=torch.version.hip is None, + op_name="mm-ah", +) +class CUDAMMAHTemplateConfigHeuristic(MMTemplateConfigMixin, CUDAConfigHeuristic): + """Standard MM template heuristic for CUDA using the extra mm configs only (for autoheuristic)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.extra_mm_configs + self.exhaustive_configs = self.extra_mm_configs + + +@register_template_heuristic( + persistent_tma_mm_template.uid, + "cuda", + register=torch.version.hip is None, +) +class CUDAPersistentTMATemplateConfigHeuristic( + TMATemplateConfigMixin, CUDAConfigHeuristic +): + """Persistent TMA template heuristic for CUDA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use persistent_mm_configs + self.mm_configs = self.persistent_mm_configs + + +@register_template_heuristic( + blackwell_ws_persistent_device_tma_mm_template.uid, + "cuda", + register=torch.version.hip is None, +) +class CUDABlackwellPersistentTMATemplateConfigHeuristic( + BlackwellTMATemplateConfigMixin, CUDAConfigHeuristic +): + """Blackwell Persistent TMA template""" + + def __init__(self) -> None: + super().__init__() + self.mm_configs = self.blackwell_persistent_mm_configs + + +@register_template_heuristic( + persistent_tma_mm_template.uid, + "cuda", + register=torch.version.hip is None, + op_name="addmm", +) +class CUDAAddmmPersistentTMATemplateConfigHeuristic( + AddMMConfigMixin, CUDAPersistentTMATemplateConfigHeuristic +): + """Addmm specific mixin for CUDA""" + + +@register_template_heuristic( + blackwell_ws_persistent_device_tma_mm_template.uid, + "cuda", + register=torch.version.hip is None, + op_name="addmm", +) +class CUDABlackwellAddmmPersistentTMATemplateConfigHeuristic( + AddMMConfigMixin, CUDABlackwellPersistentTMATemplateConfigHeuristic +): + """Addmm extension for DataCenter Blackwell Templates""" + + def __init__(self) -> None: + super().__init__() + # NOTE: to ensure that we pass tests, addmm needs a small config + self.mm_configs = ( + self.blackwell_persistent_mm_configs + + self.blackwell_persistent_addmm_configs + ) + + +@register_template_heuristic( + mm_template.uid, "cuda", register=torch.version.hip is None, op_name="scaled_mm" +) +class CUDAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CUDAConfigHeuristic): + """Scaled MM template heuristic for CUDA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.scaled_mm_configs + + # pyrefly: ignore [bad-override] + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + configs = [c for c in configs if c.block_k >= 32] + return super()._filter_configs(configs) + + +@register_template_heuristic( + scaled_mm_device_tma_epilogue_scaling_template.uid, + "cuda", + register=torch.version.hip is None, + op_name="scaled_mm", +) +class CUDAScaledTMAEpilogueScalingTemplateConfigHeuristic( + ScaledTMAConfigMixin, CUDAConfigHeuristic +): + """Scaled TMA template heuristic for CUDA: epilogue scaling variants (TensorWise, RowWise)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_persistent_mm_configs for TMA + self.mm_configs = self.scaled_persistent_mm_configs + + +@register_template_heuristic( + scaled_mm_device_tma_main_loop_scaling_template.uid, + "cuda", + register=torch.version.hip is None, + op_name="scaled_mm", +) +class CUDAScaledTMAMainLoopScalingTemplateConfigHeuristic( + ScaledTMAConfigMixin, CUDAConfigHeuristic +): + """ + Scaled TMA template heuristic for CUDA: + main loop scaling variants (BlockWise1x128, BlockWise1x32, BlockWise1x16, BlockWise128x128) + """ + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_persistent_mm_configs for TMA + self.mm_configs = self.scaled_persistent_mm_configs + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Generate main loop scaling kernel inputs. + """ + mat_a, mat_b, scale_a, scale_b = kernel_inputs._input_nodes + scale_a_size, scale_b_size = scale_a.get_size(), scale_b.get_size() + + scale_option_a, scale_option_b = get_scaling_options( + mat_a, mat_b, scale_a_size, scale_b_size + ) + tile_size_a = get_tile_size(scale_option_a) + tile_size_b = get_tile_size(scale_option_b) + + # Get base scaled MM template configs from superclass + for template_kwargs in super()._get_template_configs_impl( + kernel_inputs, + op_name, + ): + # Add scaling-specific options for main loop scaling variants + + # Inductor templates require compile-time constants passed in as tl.constexpr values. + # In cases in which the block size (BLOCK_*) is smaller than the tile size (128, 32, 16), + # scales must be broadcasted to BLOCK_* (rather than to a tile_sizextile_size chunk). + + template_kwargs["TILE_SIZE_A"] = tile_size_a + template_kwargs["TILE_SIZE_B"] = tile_size_b + + template_kwargs["MIN_BLOCK_TILE_AM"] = min( + template_kwargs["BLOCK_M"], tile_size_a + ) + template_kwargs["MIN_BLOCK_TILE_AK"] = min( + template_kwargs["BLOCK_K"], tile_size_a + ) + template_kwargs["MIN_BLOCK_TILE_BK"] = min( + template_kwargs["BLOCK_K"], tile_size_b + ) + template_kwargs["MIN_BLOCK_TILE_BN"] = min( + template_kwargs["BLOCK_N"], tile_size_b + ) + + yield template_kwargs + + +@register_template_heuristic( + blackwell_ws_persistent_device_tma_mm_template.uid, # regular Blackwell MM template + scaling epilogue from ScaledMMConfigMixin + "cuda", + register=torch.version.hip is None, + op_name="scaled_mm", +) +class CUDAScaledBlackwellTMATemplateConfigHeuristic( + ScaledBlackwellTMAConfigMixin, CUDAConfigHeuristic +): + """Scaled Blackwell TMA template heuristic for CUDA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_persistent_mm_configs for TMA + # TODO: Tune scaled_persistent_mm_configs for Blackwell + self.mm_configs = self.scaled_persistent_mm_configs + + +@register_template_heuristic( + mm_plus_mm_template.uid, + "cuda", + register=torch.version.hip is None, +) +class CUDAMMPlusMMTemplateConfigHeuristic( + MMPlusMMTemplateConfigMixin, CUDAConfigHeuristic +): + """MM Plus MM template heuristic for CUDA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use mm_plus_mm_configs + self.mm_configs = self.mm_plus_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.mm_plus_mm_configs + + +@register_template_heuristic( + mm_template.uid, + "cuda", + register=torch.version.hip is None, + op_name="int_mm", +) +class CUDAInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, CUDAConfigHeuristic): + """Int8 MM template heuristic for CUDA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use int8_mm_configs + self.mm_configs = self.int8_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.int8_mm_configs + + +# ROCm template-specific classes + + +@register_template_heuristic( + mm_template.uid, + "cuda", + register=torch.version.hip is not None, +) +@register_template_heuristic( + bmm_template.uid, + "cuda", + register=torch.version.hip is not None, +) +class ROCmMMTemplateConfigHeuristic(MMTemplateConfigMixin, ROCmConfigHeuristic): + """Standard MM template heuristic for ROCm""" + + +# TODO(coconutruben): replace with template.name once templates are importable +@register_template_heuristic( + mm_template.uid, "cuda", register=torch.version.hip is not None, op_name="addmm" +) +# TODO(coconutruben): replace with template.name once templates are importable +@register_template_heuristic( + bmm_template.uid, "cuda", register=torch.version.hip is not None, op_name="baddbmm" +) +class ROCmAddMMTemplateConfigHeuristic(AddMMConfigMixin, ROCmMMTemplateConfigHeuristic): + """Addmm specific mixin for ROCm""" + + +# TODO(coconutruben): deprecate once autoheuristic is deprecated +@register_template_heuristic("mm-ah", "cuda", register=torch.version.hip is not None) +class ROCmMMAHTemplateConfigHeuristic(MMTemplateConfigMixin, ROCmConfigHeuristic): + """Standard MM template heuristic for ROCm using the extra mm configs only (for autoheuristic)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.extra_mm_configs + self.exhaustive_configs = self.extra_mm_configs + + +@register_template_heuristic( + mm_template.uid, + "cuda", + register=torch.version.hip is not None, + op_name="scaled_mm", +) +class ROCmScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, ROCmConfigHeuristic): + """Scaled MM template heuristic for ROCm (non-TMA)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.scaled_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.scaled_mm_configs + + +@register_template_heuristic( + mm_template.uid, + "cuda", + register=torch.version.hip is not None, + op_name="int_mm", +) +class ROCmInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, ROCmConfigHeuristic): + """Int8 MM template heuristic for ROCm""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use int8_mm_configs + self.mm_configs = self.int8_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.int8_mm_configs + + +@register_template_heuristic( + mm_plus_mm_template.uid, + "cuda", + register=torch.version.hip is not None, +) +class ROCmMMPlusMMTemplateConfigHeuristic( + MMPlusMMTemplateConfigMixin, ROCmConfigHeuristic +): + """MM Plus MM template heuristic for ROCm""" + + def __init__(self) -> None: + super().__init__() + # self.default_num_stages is used to make sure all configs have that in ROCm land + # for mm_plus_mm, we actually just want stages = 1, as pipelining brings no benefits + self.default_num_stages = 1 + # Override mm_configs to use mm_plus_mm_configs + self.mm_configs = self.mm_plus_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.mm_plus_mm_configs + + +# CPU template-specific classes + + +@register_template_heuristic(mm_template.uid, "cpu") +@register_template_heuristic(bmm_template.uid, "cpu") +class CPUMMTemplateConfigHeuristic(MMTemplateConfigMixin, CPUConfigHeuristic): + """Standard MM template heuristic for CPU""" + + +@register_template_heuristic(mm_template.uid, "cpu", op_name="addmm") +@register_template_heuristic(bmm_template.uid, "cpu", op_name="baddbmm") +class CPUAddmmTemplateConfigHeuristic(AddMMConfigMixin, CPUMMTemplateConfigHeuristic): + """Addmm specific mixin for CPU""" + + +@register_template_heuristic(mm_template.uid, "cpu", op_name="scaled_mm") +class CPUScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CPUConfigHeuristic): + """Scaled MM template heuristic for CPU (non-TMA)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.scaled_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.scaled_mm_configs + + +@register_template_heuristic(mm_template.uid, "cpu", op_name="int_mm") +class CPUInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, CPUConfigHeuristic): + """Int8 MM template heuristic for CPU""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use int8_mm_configs + self.mm_configs = self.int8_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.int8_mm_configs + + +@register_template_heuristic(mm_plus_mm_template.uid, "cpu") +class CPUMMPlusMMTemplateConfigHeuristic( + MMPlusMMTemplateConfigMixin, CPUConfigHeuristic +): + """MM Plus MM template heuristic for CPU""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use mm_plus_mm_configs + self.mm_configs = self.mm_plus_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.mm_plus_mm_configs + + +# XPU template-specific classes + + +@register_template_heuristic(mm_template.uid, "xpu") +@register_template_heuristic(bmm_template.uid, "xpu") +class XPUMMTemplateConfigHeuristic(MMTemplateConfigMixin, XPUConfigHeuristic): + """Standard MM template heuristic for XPU""" + + def __init__(self) -> None: + super().__init__() + + # TODO(etaf): Design proper exhaustive search space for XPU. + self.exhaustive_configs = self.mm_configs + + +@register_template_heuristic(mm_template.uid, "xpu", op_name="addmm") +@register_template_heuristic(bmm_template.uid, "xpu", op_name="baddbmm") +class XPUAddmmTemplateConfigHeuristic(AddMMConfigMixin, XPUMMTemplateConfigHeuristic): + """Addmm specific mixin for XPU""" + + +@register_template_heuristic( + persistent_tma_mm_template.uid, + "xpu", +) +class XPUPersistentTMATemplateConfigHeuristic( + TMATemplateConfigMixin, XPUConfigHeuristic +): + """Persistent TMA template heuristic for XPU""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use persistent_mm_configs + self.mm_configs = self.persistent_mm_configs + + +@register_template_heuristic(persistent_tma_mm_template.uid, "xpu", op_name="addmm") +class XPUAddmmPersistentTMATemplateConfigHeuristic( + AddMMConfigMixin, XPUPersistentTMATemplateConfigHeuristic +): + """Addmm specific mixin for XPU""" + + +@register_template_heuristic(mm_template.uid, "xpu", op_name="scaled_mm") +class XPUScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, XPUConfigHeuristic): + """Scaled MM template heuristic for XPU (non-TMA)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.scaled_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.scaled_mm_configs + + +@register_template_heuristic(mm_template.uid, "xpu", op_name="int_mm") +class XPUInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, XPUConfigHeuristic): + """Int8 MM template heuristic for XPU""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use int8_mm_configs + self.mm_configs = self.int8_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.int8_mm_configs + + +@register_template_heuristic(mm_plus_mm_template.uid, "xpu") +class XPUMMPlusMMTemplateConfigHeuristic( + MMPlusMMTemplateConfigMixin, XPUConfigHeuristic +): + """MM Plus MM template heuristic for XPU""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use mm_plus_mm_configs + self.mm_configs = self.mm_plus_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.mm_plus_mm_configs + + +# MTIA template-specific classes + + +@register_template_heuristic(mm_template.uid, "mtia") +@register_template_heuristic(bmm_template.uid, "mtia") +class MTIAMMTemplateConfigHeuristic(MMTemplateConfigMixin, MTIAConfigHeuristic): + """Standard MM template heuristic for MTIA""" + + +@register_template_heuristic(mm_template.uid, "mtia", op_name="addmm") +@register_template_heuristic(bmm_template.uid, "mtia", op_name="baddbmm") +class MTIAAddMMTemplateConfigHeuristic(AddMMConfigMixin, MTIAMMTemplateConfigHeuristic): + """Addmm specific mixin for MTIA""" + + +@register_template_heuristic(mm_template.uid, "mtia", op_name="scaled_mm") +class MTIAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, MTIAConfigHeuristic): + """Scaled MM template heuristic for MTIA (non-TMA)""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use scaled_mm_configs + self.mm_configs = self.scaled_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.scaled_mm_configs + + +@register_template_heuristic(mm_template.uid, "mtia", op_name="int_mm") +class MTIAInt8MMTemplateConfigHeuristic(INT8MMTemplateConfigMixin, MTIAConfigHeuristic): + """Int8 MM template heuristic for MTIA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use int8_mm_configs + self.mm_configs = self.int8_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.int8_mm_configs + + +@register_template_heuristic(mm_plus_mm_template.uid, "mtia") +class MTIAMMPlusMMTemplateConfigHeuristic( + MMPlusMMTemplateConfigMixin, MTIAConfigHeuristic +): + """MM Plus MM template heuristic for MTIA""" + + def __init__(self) -> None: + super().__init__() + # Override mm_configs to use mm_plus_mm_configs + self.mm_configs = self.mm_plus_mm_configs + # NOTE: overriding exhaustive configs here to be the same as mm_configs + # as we haven't validated exhaustive support here yet + # TODO(coconutruben): remove this once we have validated exhaustive support + # for scaled_mm + self.exhaustive_configs = self.mm_plus_mm_configs diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/triton_addmm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/triton_addmm.py new file mode 100644 index 0000000000000000000000000000000000000000..a6643d1ce2a90de0f31ef07e6f20d689d9d101b7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_inductor/template_heuristics/triton_addmm.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +from ..kernel.mm_common import addmm_epilogue +from .base import TemplateConfigHeuristics + + +if TYPE_CHECKING: + from ..kernel_inputs import KernelInputs + + +class AddMMConfigMixin(TemplateConfigHeuristics): + """ + Simple mixin to handle scalars for addmm like operators (addmm, baddbmm) + """ + + def get_extra_kwargs( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> dict[str, Any]: + kwargs = super().get_extra_kwargs(kernel_inputs, op_name) + assert op_name in [ + "addmm", + "baddbmm", + ], f"op_name={op_name} invalid for AddMMConfigMixin" + alpha = kernel_inputs.get_scalar("alpha") + beta = kernel_inputs.get_scalar("beta") + return { + **kwargs, + "epilogue_fn": addmm_epilogue(kernel_inputs.out_dtype(), alpha, beta), + "epilogue_fn_hash": str( + ["addmm_epilogue", kernel_inputs.out_dtype(), alpha, beta] + ), + "prefix_args": 1, + } diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3ebcb34cb875884dd67afcc1b95301e5c102992 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/autograd.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/autograd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..642687975e6503efad288f033e16c1cf6dc30495 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/autograd.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/custom_ops.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/custom_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a7b9c8bedb242bdb0927b0d3bf6d279214070fc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/custom_ops.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/effects.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/effects.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37d35e0f82be1756c0481bfd276964b4394704aa Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/effects.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/fake_class_registry.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/fake_class_registry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64d5a6807254e6ee2b17d6c0e890603f2b54d5eb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/fake_class_registry.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/fake_impl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/fake_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4200aeb383d559c2d20fa163cabe251dfc2b42d0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/fake_impl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/fake_profile.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/fake_profile.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aeea06375b312ac50e27e0bdf5dafd3172a5ced9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/fake_profile.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/infer_schema.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/infer_schema.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2798d534761a643f818b06bf857ba2e2bd9595fd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/infer_schema.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/opaque_object.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/opaque_object.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d125fe4b546c53227f29baa50106cda30d05788 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/opaque_object.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/simple_registry.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/simple_registry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21cbaaf7a00d10c6cea444e19e159efc98e0ab1b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/simple_registry.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/triton.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/triton.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56fd85d9986d4b5cb0460a21ee67e4db3ec5268e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/triton.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6fcf2e5943c3819d984e2e3826ff491f544347f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_library/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/__pycache__/_conversions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/__pycache__/_conversions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..faf5998bfd7b5c262e1bbe4b9c4fae2f2eaccdf8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/__pycache__/_conversions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/__pycache__/fft.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/__pycache__/fft.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44aa4fbf0c5343f991a5a34b911224883d91b37a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/__pycache__/fft.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/linalg/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/linalg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..393e42b06d15cf4736c23a03e87d05468ee0ab35 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/linalg/__init__.py @@ -0,0 +1,435 @@ +# mypy: allow-untyped-defs +import math +from functools import partial +from typing import Optional, Union + +import torch +import torch._prims as prims +import torch._prims_common as utils +import torch._refs as refs +import torch._refs.linalg as linalg +from torch import Tensor +from torch._prims_common import ( + check_fp_or_complex, + check_is_matrix, + Dim, + DimsType, + ELEMENTWISE_TYPE_PROMOTION_KIND, + IntLike, + TensorLikeType, +) +from torch._prims_common.wrappers import ( + _maybe_convert_to_dtype, + elementwise_type_promotion_wrapper, + out_wrapper, +) + + +__all__ = [ + "diagonal", + "matrix_norm", + "norm", + "svd", + "svdvals", + "vector_norm", + "vecdot", + "cross", +] + + +def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name: str): + """ + Checks related to the dtype kwarg in `linalg.*norm` functions + """ + if dtype is not None: + torch._check( + utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), + lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}", + ) + torch._check( + utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype), + lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format( + fn_name=fn_name, + d="complex" if utils.is_complex_dtype(x_dtype) else "real", + dtype=dtype, + ), + ) + torch._check( + utils.get_higher_dtype(dtype, x_dtype) == dtype, + lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible " + f"without narrowing to the specified dtype ({dtype})", + ) + + +import operator + +# Utilities should come BEFORE this import +from torch._decomp import register_decomposition +from torch._decomp.decompositions import pw_cast_for_opmath + + +@register_decomposition(torch._ops.ops.aten.linalg_cross) +@out_wrapper() +@pw_cast_for_opmath +def cross(a: Tensor, b: Tensor, dim: int = -1): + torch._check( + a.ndim == b.ndim, + lambda: "linalg.cross: inputs must have the same number of dimensions.", + ) + torch._check( + a.size(dim) == 3 and b.size(dim) == 3, + lambda: f"linalg.cross: inputs dim {dim} must have length 3, got {a.size(dim)} and {b.size(dim)}", + ) + a, b = torch.broadcast_tensors(a, b) + dim = utils.canonicalize_dim(a.ndim, dim) + idx = torch.arange(3, device=a.device) + return a.index_select(dim, (idx + 1) % 3) * b.index_select( + dim, (idx + 2) % 3 + ) - a.index_select(dim, (idx + 2) % 3) * b.index_select(dim, (idx + 1) % 3) + + +def diagonal( + input: TensorLikeType, + *, + offset: int = 0, + dim1: int = -2, + dim2: int = -1, +) -> TensorLikeType: + return torch.diagonal(input, offset=offset, dim1=dim1, dim2=dim2) + + +def _check_vector_norm_args( + x: TensorLikeType, ord: Union[float, int] = 2, dim: Optional[DimsType] = None +): + from torch.fx.experimental.symbolic_shapes import sym_or + + if not (ord < 0.0 or ord == float("inf")): + return + + torch._check( + sym_or( + x.numel() != 0, + not isinstance(dim, IntLike) and dim is not None and len(dim) != 0, + ), + lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor " + "because the operation does not have an identity", + ) + + shape = x.shape + if dim is not None and not isinstance(dim, IntLike): + for d in dim: + torch._check( + sym_or(x.numel() != 0, d < len(shape) and d >= 0 and shape[d] != 0), + lambda: f"linalg.vector_norm cannot compute the {ord} norm on the " + f"dimension {d} because this dimension is empty and the " + "operation does not have an identity", + ) + + +@register_decomposition(torch._ops.ops.aten.linalg_vector_norm) +@out_wrapper(exact_dtype=True) +def vector_norm( + x: TensorLikeType, + ord: Union[float, int] = 2, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> Tensor: + from torch.fx.experimental.symbolic_shapes import guard_or_false + + check_fp_or_complex(x.dtype, "linalg.vector_norm") + + if isinstance(dim, Dim): + dim = [dim] # type: ignore[assignment] + + _check_vector_norm_args(x, ord, dim) + + _check_norm_dtype(dtype, x.dtype, "linalg.vector_norm") + + computation_dtype, result_dtype = utils.reduction_dtypes( + x, utils.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype + ) + + to_result_dtype = partial(_maybe_convert_to_dtype, dtype=result_dtype) + + # Implementation + if ord == 0.0: + return torch.sum(torch.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype) + elif ord == float("inf"): + return to_result_dtype(torch.amax(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type] + elif ord == float("-inf"): + return to_result_dtype(torch.amin(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type] + else: + # From here on the computation dtype is important as the reduction is non-trivial + x = _maybe_convert_to_dtype(x, computation_dtype) # type: ignore[assignment] + reduce_sum = partial(torch.sum, dim=dim, keepdim=keepdim) + + is_ord_even = ord % 2 == 0 if isinstance(ord, IntLike) else ord % 2.0 == 0.0 + if dim == []: + dim = None + + if (dim is None and x.numel() == 1) or ( + dim is not None + and (x.ndim > 0 and all(guard_or_false(x.shape[d] == 1) for d in dim)) + ): + if x.ndim > 64: + raise RuntimeError( + f"Received a tensor with {x.ndim} dimensions, but only tensors with up to 64 dims are supported!" + ) + x = torch.abs(x) + if keepdim or x.ndim == 0: + return to_result_dtype(x).contiguous() + elif dim is None: + return to_result_dtype(x).flatten()[0] + else: + new_shape = [s for d, s in enumerate(x.shape) if d not in dim] + return to_result_dtype(x.view(new_shape)).contiguous() + + if not (is_ord_even and utils.is_float_dtype(x.dtype)): + x = torch.abs(x) + return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord)) # type: ignore[return-value] + + +def _backshift_permutation(dim0, dim1, ndim): + # Auxiliary function for matrix_norm + # Computes the permutation that moves the two given dimensions to the back + ret = [i for i in range(ndim) if i != dim0 and i != dim1] + ret.extend((dim0, dim1)) + return ret + + +def _inverse_permutation(perm): + # Given a permutation, returns its inverse. It's equivalent to argsort on an array + return [i for i, j in sorted(enumerate(perm), key=operator.itemgetter(1))] + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def matrix_norm( + A: TensorLikeType, + ord: Union[float, str] = "fro", + dim: DimsType = (-2, -1), + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # shape + check_is_matrix(A, "linalg.matrix_norm") + # dim + + dim = utils.canonicalize_dims(A.ndim, dim) + if isinstance(dim, Dim): + dim = (dim,) # type: ignore[assignment] + torch._check( + len(dim) == 2, lambda: f"linalg.matrix_norm: dim must be a 2-tuple. Got {dim}" + ) + torch._check( + # pyrefly: ignore [index-error] + dim[0] != dim[1], + # pyrefly: ignore [index-error] + lambda: f"linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})", + ) + # dtype arg + _check_norm_dtype(dtype, A.dtype, "linalg.matrix_norm") + + if isinstance(ord, str): + # ord + torch._check( + ord in ("fro", "nuc"), + lambda: f"linalg.matrix_norm: Order {ord} not supported.", + ) + # dtype + check_fp_or_complex( + A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != "nuc" + ) + + if ord == "fro": + return vector_norm(A, 2, dim, keepdim, dtype=dtype) + else: # ord == "nuc" + if dtype is not None: + A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment] + # pyrefly: ignore [index-error] + perm = _backshift_permutation(dim[0], dim[1], A.ndim) + result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim) + if keepdim: + inv_perm = _inverse_permutation(perm) + result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) + return result + else: + # ord + abs_ord = abs(ord) + torch._check( + abs_ord in (2, 1, float("inf")), + lambda: f"linalg.matrix_norm: Order {ord} not supported.", + ) + # dtype + check_fp_or_complex( + A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != 2 + ) + + max_min = partial(torch.amax if ord > 0.0 else torch.amin, keepdim=keepdim) + + def _max_min_wrapper(A, dim): + # pyrefly: ignore [unsupported-operation] + if A.size(dim) == 0 and ord > 0.0: + new_size = list(A.size()) + if keepdim: + new_size[dim] = 1 + else: + del new_size[dim] + return torch.zeros(new_size, dtype=A.dtype, device=A.device) + else: + return max_min(A, dim) + + if abs_ord == 2.0: + if dtype is not None: + A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment] + # pyrefly: ignore [index-error] + perm = _backshift_permutation(dim[0], dim[1], A.ndim) + result = _max_min_wrapper(svdvals(prims.transpose(A, perm)), dim=-1) + if keepdim: + inv_perm = _inverse_permutation(perm) + result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) + return result + else: # 1, -1, inf, -inf + # pyrefly: ignore [bad-unpacking] + dim0, dim1 = dim + if abs_ord == float("inf"): + dim0, dim1 = dim1, dim0 + if not keepdim and (dim0 < dim1): + dim1 -= 1 + return _max_min_wrapper( + vector_norm(A, 1.0, dim=dim0, keepdim=keepdim, dtype=dtype), dim1 + ) + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def norm( + A: TensorLikeType, + ord: Optional[Union[float, str]] = None, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + if dim is not None: + if isinstance(dim, Dim): + dim = (dim,) # type: ignore[assignment] + torch._check( + len(dim) in (1, 2), + lambda: f"linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}", + ) + elif ord is not None: + torch._check( + A.ndim in (1, 2), + lambda: f"linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D", + ) + + if ord is not None and ( + (dim is not None and len(dim) == 2) or (dim is None and A.ndim == 2) + ): + if dim is None: + dim = (0, 1) + return matrix_norm(A, ord, dim, keepdim, dtype=dtype) + else: + if ord is None: + ord = 2.0 + return vector_norm(A, ord, dim, keepdim, dtype=dtype) # type: ignore[arg-type] + + +# CompositeImplicitAutograd +@out_wrapper("U", "S", "Vh", exact_dtype=True) +def svd(A: TensorLikeType, full_matrices: bool = True) -> tuple[Tensor, Tensor, Tensor]: + return prims.svd(A, full_matrices=full_matrices) + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def svdvals(A: TensorLikeType) -> Tensor: + return svd(A, full_matrices=False)[1] + + +# CompositeImplicitAutograd +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("x", "y"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def vecdot(x: Tensor, y: Tensor, dim: int = -1) -> Tensor: + check_fp_or_complex(x.dtype, "linalg.vecdot") + return (x.conj() * y).sum(dim=dim) + + +def _pivots_to_permutation(pivots, shape, *, inverse=False): + perm = torch.empty(shape, dtype=torch.int32, device=pivots.device) + perm[..., :] = torch.arange(shape[-1], dtype=torch.int32, device=pivots.device) + indices = range(shape[-1]) + if inverse: + indices = reversed(indices) + + if len(shape) > 1: + for i in indices: + j_s = pivots[..., i] + perm_i = perm[..., i].clone() + j_idx = torch.meshgrid( + *[torch.arange(s, device=perm.device) for s in j_s.shape], indexing="ij" + ) + (j_s,) + perm_j = perm[j_idx] + perm.index_put_(j_idx, perm_i) + perm[..., i].copy_(perm_j) + + else: + for i in indices: + j = pivots[i] + perm_i = perm[i].clone() + perm_j = perm[j].clone() + perm[i].copy_(perm_j) + perm[j].copy_(perm_i) + + return perm + + +def _apply_pivots(a, pivots, shape, *, inverse=False): + perm = _pivots_to_permutation(pivots - 1, shape, inverse=inverse) + + if len(shape) == 1: + return a[perm, :] + else: + idx = torch.meshgrid( + *[torch.arange(s, device=a.device) for s in perm.shape], indexing="ij" + )[:-1] + (perm, slice(None)) + return a[idx] + + +def linalg_lu_solve_out_mps(LU, pivots, B, *, left=True, adjoint=False, out): + if out.numel() == 0: + return + + if not left: + adjoint = not adjoint + B = B.mH + + if adjoint: + lu_ = LU.mH + x = torch.linalg.solve_triangular(lu_, B, left=True, upper=False) + x = torch.linalg.solve_triangular( + lu_, x, left=True, upper=True, unitriangular=True + ) + x = _apply_pivots(x, pivots, LU.shape[:-1], inverse=True) + else: + x = _apply_pivots(B, pivots, LU.shape[:-1]) + x = torch.linalg.solve_triangular( + LU, x, left=True, upper=False, unitriangular=True + ) + x = torch.linalg.solve_triangular(LU, x, left=True, upper=True) + + if not left: + x = x.mH + + out.copy_(x) + + +mps_lib = torch.library.Library("aten", "IMPL", "MPS") # noqa: TOR901 +mps_lib.impl("aten::linalg_lu_solve.out", linalg_lu_solve_out_mps) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84d73cf6aaae1965c87ba650c45f77b8da6532bf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/nn/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9c2ef67bd9d44a21f9d3673ba631c0840740ced --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/nn/__init__.py @@ -0,0 +1 @@ +__all__: list[str] = [] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..beb28ca525e8c210d84b6824f484c2c3068f5c87 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/nn/functional/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/nn/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..135788a439de5cf7882f659133ce63649d8308e2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/nn/functional/__init__.py @@ -0,0 +1,1293 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import math +from collections.abc import Callable +from functools import wraps +from typing import Concatenate, Optional, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +import torch._prims as prims +import torch._prims_common as utils +import torch._refs as refs +from torch._decomp import register_decomposition +from torch._prims_common import ( + ELEMENTWISE_TYPE_PROMOTION_KIND, + NumberType, + ShapeType, + TensorLike, + TensorLikeType, +) +from torch._prims_common.wrappers import ( + elementwise_type_promotion_wrapper, + elementwise_unary_scalar_wrapper, + out_wrapper, +) +from torch._refs import _make_inplace + + +__all__ = [ + "alpha_dropout", + "celu", + "celu_", + "channel_shuffle", + "dropout", + "elu", + "elu_", + "gelu", + "glu", + "group_norm", + "hardshrink", + "hardtanh", + "hinge_embedding_loss", + "huber_loss", + "l1_loss", + "layer_norm", + "leaky_relu", + "log_softmax", + "margin_ranking_loss", + "mish", + "mish_", + "mse_loss", + "nll_loss", + "pairwise_distance", + "pdist", + "poisson_nll_loss", + "prelu", + "relu", + "relu6", + "selu", + "selu_", + "smooth_l1_loss", + "softmax", + "softmin", + "softplus", + "softshrink", + "tanhshrink", + "threshold", + "threshold_", + "triplet_margin_loss", +] + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +Tensor = torch.Tensor +aten = torch._ops.ops.aten +DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] + + +def _dropout_helper( + self: TensorLikeType, + val: float, +) -> TensorLikeType: + """ + Helper function for all dropout-type operators. During training, + some of the elements of the input tensor are randomly masked. + + Returns the masked tensor of the boolean values. + + """ + + return ( + refs._uniform_helper( + self.shape, low=0.0, high=1.0, dtype=torch.float32, device=self.device + ) + < val + ) + + +@register_decomposition(aten.alpha_dropout) +def alpha_dropout( + self: TensorLikeType, p: float = 0.5, training: bool = False, inplace: bool = False +) -> TensorLikeType: + if inplace: + raise NotImplementedError + + if not training: + return self + + torch._check( + p <= 1 and p >= 0, + lambda: f"dropout probability has to be between 0 and 1, but got, {p}", + ) + + if p == 1: + return torch.zeros_like(self) + + if p == 0: + return self + + dropout_mask = _dropout_helper(self, 1 - p) + + # From paper: Self-Normalizing Neural Networks (https://arxiv.org/pdf/1706.02515.pdf) + # alpha = - SELU.alpha * SELU.scale, here + # SELU.alpha = 1.6732632423543772848170429916717 and + # SELU.scale = 1.0507009873554804934193349852946 + alpha = -1.7580993408473766 + + a = 1.0 / math.sqrt((alpha * alpha * p + 1) * (1 - p)) + b = torch.logical_not(dropout_mask) + b = b * (alpha * a) + alpha * a * p + dropout_mask = a * dropout_mask + + return self * dropout_mask + b + + +def _inplace_wrapper(fn: Callable[_P, _T]) -> Callable[_P, _T]: + """ + Given a nn.functional non-linearity, implements its `inplace: bool` argument + """ + + # nb. We use the name of the first argument used in the unary references + @wraps(fn) + def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: + # pyrefly: ignore [unsupported-operation] + a = args[0] + if "inplace" not in kwargs: + kwargs["inplace"] = False + # pyrefly: ignore [unsupported-operation] + if kwargs["inplace"]: + torch._check( + "out" not in kwargs, + lambda: "Cannot set inplace=True and pass out= at the same time", + ) + kwargs["inplace"] = False + kwargs["out"] = a + return fn(*args, **kwargs) + else: + return fn(*args, **kwargs) + + return _fn + + +# celu is implemented specially because it has an alpha argument +# celu is very similar to elu +@register_decomposition(aten.celu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def celu( + a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.celu + """ + + if inplace: + raise NotImplementedError + + rhs: TensorLikeType + if alpha is not None: + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(alpha), python_type): + msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + rhs = alpha * torch.expm1(torch.true_divide(a, alpha)) # type: ignore[arg-type] + else: + rhs = torch.expm1(a) + + return torch.where(a > 0, a, rhs) + + +@_inplace_wrapper +@out_wrapper() +def dropout( + a: TensorLikeType, p: float = 0.5, training: bool = True, inplace: bool = False +) -> TensorLikeType: + if inplace: + raise NotImplementedError + + if not training: + return a + + torch._check( + p <= 1 and p >= 0, + lambda: f"dropout probability has to be between 0 and 1, but got, {p}", + ) + + if p == 1: + return torch.zeros_like(a) + + if p == 0: + return a + + scale = 1 / (1 - p) + dropout_mask = _dropout_helper(a, 1 - p) + + return a * dropout_mask * scale + + +@register_decomposition(aten.elu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def elu( + a: TensorLikeType, + alpha: NumberType = 1.0, + scale: NumberType = 1.0, + input_scale: NumberType = 1.0, + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.elu + """ + if inplace: + raise NotImplementedError + + # nb. This should be factored out into a can_cast aux function + python_type = utils.dtype_to_type(a.dtype) + torch._check( + utils.is_weakly_lesser_type(type(input_scale), python_type), + lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!", + ) + torch._check( + utils.is_weakly_lesser_type(type(scale), python_type), + lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!", + ) + torch._check( + utils.is_weakly_lesser_type(type(alpha), python_type), + lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", + ) + + return torch.where(a > 0, scale * a, (alpha * scale) * torch.expm1(a * input_scale)) + + +@register_decomposition(aten.relu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def relu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.relu + """ + + if inplace: + raise NotImplementedError + + return torch.where(torch.le(a, 0), 0, a) + + +@register_decomposition(aten.channel_shuffle) +@out_wrapper() +def channel_shuffle(input: TensorLikeType, groups: int) -> TensorLikeType: + """ + Reference implementation of :func:`torch.nn.functional.channel_shuffle`. + """ + from torch._meta_registrations import device_hint + + torch._check( + input.dim() > 2, + lambda: f"channel_shuffle expects input with > 2 dims, but got input with sizes {list(input.size())}", + ) + c = input.shape[1] + torch._check( + groups > 0, + lambda: f"Number of groups to divide channels in must be positive. Value of groups:{groups}", + ) + torch._check( + (c % groups) == 0, + lambda: f"Number of channels must be divisible by groups. Got {c} channels and {groups} groups.", + ) + n = input.shape[0] + cg = c // groups + dhw = input.shape[2:] + + if input.numel() == 0 or ( + device_hint(input) == "cuda" and (groups == 1 or groups == c) + ): + return input.view(input.shape) + + return ( + input.reshape(n, groups, cg, *dhw) + .transpose(1, 2) + .reshape(input.shape) + .contiguous() + ) + + +def group_norm( + input: Tensor, + num_groups: int, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + """ + Reference implementation of :func:`torch.nn.functional.group_norm`. + """ + torch._check( + input.ndim >= 2, + lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", + ) + + batch_size = input.shape[0] + num_channels = input.shape[1] + torch._check( + num_channels % num_groups == 0, + lambda: "Expected number of channels in input to be divisible by num_groups, " + + f"but got input of shape {input.shape} and num_groups = {num_groups}", + ) + + # input shape is (N, C, *), so we flatten all inner dimensions except (N, C) + flattened_inner_size = 1 + for dim_length in input.shape[2:]: + flattened_inner_size *= dim_length + + return torch.native_group_norm( + input, + weight, + bias, + batch_size, + num_channels, + flattened_inner_size, + num_groups, + eps, + )[0] + + +def layer_norm( + input: Tensor, + normalized_shape: ShapeType, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + """ + Reference implementation of :func:`torch.nn.functional.layer_norm`. + """ + return torch.native_layer_norm(input, normalized_shape, weight, bias, eps)[0] + + +@register_decomposition(aten.leaky_relu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def leaky_relu( + a: TensorLikeType, negative_slope: float = 0.01, inplace: bool = False +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.leaky_relu + """ + + if inplace: + raise NotImplementedError + + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(negative_slope), python_type): + msg = f"negative_slope argument of type {type(negative_slope)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + return torch.where(torch.gt(a, 0), a, torch.mul(a, negative_slope)) + + +@register_decomposition(aten.mish) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def mish(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.mish + """ + + if inplace: + raise NotImplementedError + return a * torch.tanh(torch.nn.functional.softplus(a)) + + +@register_decomposition(aten.selu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def selu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.selu + """ + if inplace: + raise NotImplementedError + + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + + rhs = alpha * torch.expm1(a) + + return scale * torch.where(a > 0, a, rhs) + + +# Forwarding alias: the functional variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def softmax( + a: TensorLikeType, + dim: Optional[int] = None, + _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # The error is for compat with regular PyTorch, which has this behavior + # deprecated. For PrimTorch, it's fine to drop support for deprecated + # behavior because it requires explicit opt in. This error is to inform + # users how to update their calls. + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") + return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +# CompositeImplicitAutograd - don't register decomp +def softmin( + a: TensorLikeType, + dim: Optional[int] = None, + _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # The error is for compat with regular PyTorch, which has this behavior + # deprecated. For PrimTorch, it's fine to drop support for deprecated + # behavior because it requires explicit opt in. This error is to inform + # users how to update their calls. + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") + return torch.softmax(a=-a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +# softplus is implemented specially because it has beta and threshold arguments +@register_decomposition(aten.softplus) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def softplus( + a: TensorLikeType, + beta: Optional[NumberType] = None, + threshold: NumberType = 20, + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.softplus + """ + + if inplace: + raise NotImplementedError + + rhs: TensorLikeType + if beta is not None: + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(beta), python_type): + msg = f"beta argument of type {type(beta)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + scaled_input = a * beta + rhs = torch.true_divide(torch.log1p(torch.exp(scaled_input)), beta) # type: ignore[arg-type] + + else: + scaled_input = a + rhs = torch.log1p(torch.exp(scaled_input)) + + return torch.where(scaled_input > threshold, a, rhs) + + +@aten.hardshrink.default.py_impl(DispatchKey.Autograd) +@register_decomposition(aten.hardshrink) +@out_wrapper() +def hardshrink(a: TensorLikeType, lambd: float = 0.5): + # Formula for reference, + # hardshrink(x) = x if x > lambd + # = x if x < -lambd + # = 0 otherwise + return torch.where(torch.abs(a) <= lambd, 0, a) + + +@aten.softshrink.default.py_impl(DispatchKey.Autograd) +@register_decomposition(aten.softshrink) +@out_wrapper() +def softshrink(a: TensorLikeType, lambd: float = 0.5): + # Formula for reference, + # softshrink(x) = x - lambd if x > lambd + # = x + lambd if x < -lambd + # = 0 otherwise + torch._check( + 0 <= lambd <= torch.finfo(a.dtype).max, + lambda: f"lambda must be in range [0, {torch.finfo(a.dtype).max}] for input dtype {a.dtype}, but found {lambd}", + ) + # We implement this in one torch.where to generate better code in the backward + # see https://github.com/pytorch/pytorch/pull/107052#discussion_r1293748211 + # We multiply by 0 for dealing with nans + return torch.where(torch.abs(a) > lambd, a - torch.sign(a) * lambd, a * 0) + + +# Losses +def _reduction_int_to_str(reduction: int) -> str: + from torch._decomp.decompositions import Reduction + + if reduction == Reduction.NONE.value: + return "none" + elif reduction == Reduction.MEAN.value: + return "mean" + elif reduction == Reduction.SUM.value: + return "sum" + else: + raise ValueError(f"{reduction} is not a valid value for reduction") + + +def _apply_loss_reduction(loss: TensorLikeType, reduction: str) -> TensorLikeType: + if reduction == "sum": + return torch.sum(loss) + elif reduction == "mean": + return torch.mean(loss) + else: # reduction == "none" + return loss + + +def _check_reduction_value(reduction: str): + if reduction not in ("mean", "sum", "none"): + raise ValueError(f"{reduction} is not a valid value for reduction") + + +# This helper function maps depreciated arguments, "size_average" and "reduce" +# to their corresponding "reduction" string argument +def _get_string_reduction_arg( + *, size_average: Optional[bool], reduce: Optional[bool] +) -> str: + if size_average is None: + size_average = True + if reduce is None: + reduce = True + if size_average and reduce: + ret = "mean" + elif reduce: + ret = "sum" + else: + ret = "none" + return ret + + +# CompositeImplicitAutograd - don't register decomp +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) +def l1_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.l1_loss + """ + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + loss = torch.abs(input - target) + return _apply_loss_reduction(loss, reduction) + + +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) +def smooth_l1_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + beta: float = 1.0, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.smooth_l1_loss + """ + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + + if beta == 0.0: + return torch.nn.functional.l1_loss( + input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) + else: + loss = torch.abs(input - target) + # pyrefly: ignore [unsupported-operation] + loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta) + return _apply_loss_reduction(loss, reduction) + + +# Forwarding alias: the functional variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def log_softmax( + a: TensorLikeType, + dim: Optional[int] = None, + _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # The error is for compat with regular PyTorch, which has this behavior + # deprecated. For PrimTorch, it's fine to drop support for deprecated + # behavior because it requires explicit opt in. This error is to inform + # users how to update their calls. + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") + return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +@register_decomposition(aten.margin_ranking_loss) +def margin_ranking_loss( + input1: TensorLikeType, + input2: TensorLikeType, + target: TensorLikeType, + margin: float = 0.0, + reduction: str = "mean", +) -> TensorLikeType: + # loss_without_reduction = max(0, -target * (input1 - input2) + margin) + if input1.ndim != input2.ndim or input1.ndim != target.ndim: + raise RuntimeError( + "margin_ranking_loss : All input tensors should have same dimension but got sizes: " + f"input1: {input1.shape}, input2: {input2.shape}, target: {target.shape} " + ) + _check_reduction_value(reduction) + loss = torch.clamp_min(-target * (input1 - input2) + margin, 0) + return _apply_loss_reduction(loss, reduction) + + +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) +def mse_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + loss = torch.pow(input - target, 2) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.hinge_embedding_loss) +def hinge_embedding_loss( + input: TensorLikeType, + target: TensorLikeType, + margin: float = 1.0, + reduction: str = "mean", +) -> TensorLikeType: + # loss_without_reduction = input if y == 1 + # = max(0, margin - input) if y == -1 + _check_reduction_value(reduction) + margin_clamp = torch.clamp_min(margin - input, 0) + output_margin = torch.where(target != 1, margin_clamp, 0) + output_self = torch.where(target != -1, input, 0) + loss = output_margin + output_self + return _apply_loss_reduction(loss, reduction) + + +def _nll_loss_nd( + input: TensorLikeType, + target: TensorLikeType, + weight: Optional[TensorLikeType], + reduction: str, + ignore_index: int, +) -> TensorLikeType: + torch._check( + input.ndim > 0 and input.ndim <= 3, + lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.", + ) + + torch._check( + (input.ndim == 1) or (input.shape[0] == target.shape[0]), + lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.", + ) + + _check_reduction_value(reduction) + + flat_target = torch.flatten(target) + ignore_classes_mask = torch.eq(flat_target, ignore_index) + + # TODO: Enable data-dependent checks with debug mode + # TODO: This check does not work with FakeTensor inputs; See Issue #85834 + # Explicit cast for class_check to bool; See Issue #78071 + """ + from torch._subclasses.fake_tensor import FakeTensor + num_classes = input.shape[1] if input.ndim > 1 else input.shape[0] + valid_classes_mask = torch.logical_and( + (flat_target >= 0), (flat_target < num_classes) + ) + class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask)) + torch._check( + isinstance(target, FakeTensor) or bool(class_check.item()), + lambda: "A target class is out-of-bounds and not the ignore index.", + ) + """ + + ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device) + class_weight = ( + torch.scalar_tensor(1, dtype=input.dtype, device=input.device) + if weight is None + else weight[flat_target] + ) + current_weight = torch.where( + ignore_classes_mask, + ignore_class_weight, + class_weight, + ) + + if input.ndim == 1: + # implicit batch size = 1 + # input (1 batch size, C classes) + loss = -input[target] * current_weight + elif input.ndim == 2: + # input (N batch size, C classes) + batch_size = input.shape[0] + loss = -input[torch.arange(batch_size), target] * current_weight + else: + # 3D case (N batch size, C classes, K dimensions) + # input (N batch size, C classes, K) + batch_size = input.shape[0] + extent = input.shape[2] + numel = batch_size * extent + indices = torch.arange(numel) + bdx = indices // extent + kdx = indices % extent + loss = -input[bdx, flat_target, kdx] * current_weight + loss = torch.reshape(loss, target.shape) + + if reduction == "none": + return loss + elif reduction == "sum": + return torch.sum(loss) + else: + # calculate weighted mean of the loss function + return torch.sum(loss) / torch.sum(current_weight) + + +@register_decomposition(aten.nll_loss) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("input",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def nll_loss( + input: TensorLikeType, + target: TensorLikeType, + weight: Optional[TensorLikeType] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.nll_loss + """ + torch._check( + input.ndim > 0, + lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})", + ) + + # TODO: raise exception instead of converting value + # msg = "size_average and reduce args are deprecated, please use reduction argument." + # Convert these options for consistency with the eager mode + if size_average is not None or reduce is not None: + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + + # The expected behavior when the target and input have zero elements: + # reduction = 'none' --- tensor([]) + # reduction = 'sum' --- tensor(0.) + # reduction = 'mean' --- tensor(nan) + # Mean reduction on empty tensors produces NaN. See the discussion in + # https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162 + if input.numel() == 0 and target.numel() == 0: + if reduction == "none": + return torch.zeros_like(target) + elif reduction == "sum": + return torch.empty_like(target) + else: + return torch.full_like(target, float("nan")) + + # The _nll_loss_nd helper function handles the most common cases. + # ndim == 1 (Single Example) + # => Batch Size: 1, Input: (C), Target: () + # ndim == 2 (k = 1) + # => Batch Size: N, Input: (N, C), Target: (N) + # ndim == 3 (k > 1) + # => Batch Size: N, Input: (N, C, K), Target: (N, K) + if input.ndim <= 3: + return _nll_loss_nd(input, target, weight, reduction, ignore_index) + + # For ndim > 3, we reshape the input and target to 3-D case. + # Input (N batch-size, C classes, k-dimensions) + # Target (N batch-size, k-dimensions) + torch._check( + input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:], + lambda: ( + "Expected input and target to both have ndim > 0 and " + "target.shape[1:] == input.shape[2:], but got " + f"target.shape {target.shape} and input.shape {input.shape}" + ), + ) + + batch_size = input.shape[0] + num_classes = input.shape[1] + out_size = [batch_size] + list(target.shape[1:]) + + input = torch.reshape(input, [batch_size, num_classes, -1]) + target = torch.reshape(target, [batch_size, -1]) + if reduction != "none": + return _nll_loss_nd(input, target, weight, reduction, ignore_index) + else: + result = _nll_loss_nd(input, target, weight, reduction, ignore_index) + # reshape flattened inner-dim to original k-dimensions + return torch.reshape(result, out_size) + + +# TODO: This ref supports int reduction and out kwarg to be compatible with ATen: +# https://github.com/pytorch/pytorch/issues/83931 +# TODO: Could be rewritten to support complex: +# https://github.com/pytorch/pytorch/pull/85041 +@register_decomposition(aten.huber_loss) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def huber_loss( + input: TensorLikeType, + target: TensorLikeType, + reduction: Union[str, int] = "mean", + delta: float = 1.0, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.huber_loss + """ + if type(reduction) is int: + reduction = _reduction_int_to_str(reduction) + _check_reduction_value(reduction) # type: ignore[arg-type] + torch._check( + delta > 0, + lambda: "huber_loss does not support non-positive values for delta.", + ) + z = (input - target).abs() + loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta)) + return _apply_loss_reduction(loss, reduction) # type: ignore[arg-type] + + +# tanhshrink does not use _make_elementwise_unary_reference because it does not support out +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def tanhshrink(a: TensorLikeType) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.tanhshrink + """ + if not isinstance(a, TensorLike): + raise RuntimeError( + "Expected a tensor input for an elementwise unary operation!" + ) + return a - torch.tanh(a) + + +@register_decomposition(aten.threshold) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def threshold( + a: TensorLikeType, + threshold: NumberType, + value: Union[bool, int, float], + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.threshold + """ + + if inplace: + raise NotImplementedError + + return torch.where(a <= threshold, value, a) + + +# CompositeImplicitAutograd - don't register decomp +# No elementwise type promotion - core op doesn't explicitly type promote +def triplet_margin_loss( + anchor: TensorLikeType, + positive: TensorLikeType, + negative: TensorLikeType, + margin: float = 1.0, + p: float = 2, + eps: float = 1e-6, + swap: bool = False, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + + if margin <= 0: + raise ValueError(f"margin must be greater than 0, got {margin}") + + # torch.nn.functional.triplet_margin_with_distance_loss has no ref defined + # since it's a pure Python implementation. Use this helper instead. + return _triplet_margin_with_distance_loss( + anchor=anchor, + positive=positive, + negative=negative, + distance_function=lambda x, y: torch.pairwise_distance(x, y, p, eps), + margin=margin, + swap=swap, + reduction=reduction, + ) + + +# Pure Python impl - don't register decomp and don't add a ref. Defined as a +# helper here since triplet_margin_loss can be nicely implemented with it. +def _triplet_margin_with_distance_loss( + anchor: TensorLikeType, + positive: TensorLikeType, + negative: TensorLikeType, + *, + distance_function: Optional[ + Callable[[TensorLikeType, TensorLikeType], TensorLikeType] + ] = None, + margin: float = 1.0, + swap: bool = False, + reduction: str = "mean", +) -> TensorLikeType: + _check_reduction_value(reduction) + + a_dim = anchor.ndim + p_dim = positive.ndim + n_dim = negative.ndim + torch._check( + a_dim == p_dim and p_dim == n_dim, + lambda: ( + f"The anchor, positive, and negative tensors are expected to have " + f"the same number of dimensions, but got: anchor {a_dim}D, " + f"positive {p_dim}D, and negative {n_dim}D inputs" + ), + ) + + if distance_function is None: + distance_function = torch.pairwise_distance + + dist_pos = distance_function(anchor, positive) + dist_neg = distance_function(anchor, negative) + # The distance swap is described in the paper "Learning shallow + # convolutional feature descriptors with triplet losses" by V. Balntas, E. + # Riba et al. If True, and if the positive example is closer to the + # negative example than the anchor is, swaps the positive example and the + # anchor in the loss computation. + if swap: + dist_swap = distance_function(positive, negative) + dist_neg = torch.minimum(dist_neg, dist_swap) + loss = torch.clamp_min(margin + dist_pos - dist_neg, 0) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.hardtanh) +@_inplace_wrapper +@out_wrapper() +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def hardtanh( + a: TensorLikeType, + min_val: NumberType = -1, + max_val: NumberType = 1, + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.hardtanh + """ + if inplace: + raise NotImplementedError + if utils.is_boolean_dtype(a.dtype): + raise RuntimeError("Bool inputs not supported for hardtanh") + + # preserve legacy behavior of boundaries not causing type promotion + if utils.is_integer_dtype(a.dtype): + min_val = int(min_val) # type: ignore[arg-type] + max_val = int(max_val) # type: ignore[arg-type] + if not (a.dtype != torch.uint8 or (min_val >= 0 and max_val >= 0)): + raise RuntimeError( + "Cannot do hardtanh on an unsigned type with negative limits" + ) + + if min_val > max_val: # type: ignore[operator] + raise ValueError("min_val cannot be greater than max_val") + + return torch.clamp(a, min_val, max_val) # type: ignore[arg-type] + + +@register_decomposition(aten.gelu) +@out_wrapper() +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def gelu(a: TensorLikeType, approximate: str = "none") -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.gelu + """ + if not isinstance(a, TensorLike): + raise RuntimeError( + "Expected a tensor input for an elementwise unary operation!" + ) + M_SQRT2 = 1.41421356237309504880 + M_SQRT1_2 = 0.70710678118654752440 + M_2_SQRTPI = 1.12837916709551257390 + if approximate == "tanh": + kBeta = M_SQRT2 * M_2_SQRTPI * 0.5 + kKappa = 0.044715 + a_cube = a * a * a + inner = kBeta * (a + kKappa * a_cube) + return 0.5 * a * (1 + torch.tanh(inner)) + elif approximate == "none": + kAlpha = M_SQRT1_2 + return a * 0.5 * (1 + torch.erf(a * kAlpha)) + else: + raise RuntimeError("approximate argument must be either none or tanh.") + + +# CompositeImplicitAutograd - don't register decomp +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def poisson_nll_loss( + input: TensorLikeType, + target: TensorLikeType, + log_input: bool = True, + full: bool = False, + size_average: Optional[bool] = None, + eps: float = 1e-8, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.poisson_nll_loss + """ + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + if log_input: + loss = torch.exp(input) - target * input + else: + loss = input - target * torch.log(input + eps) + + if full: + stirling_term = ( + target * torch.log(target) - target + 0.5 * torch.log(2 * torch.pi * target) + ) + # avoid inplace add + loss = loss + stirling_term.masked_fill(target <= 1, 0) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.prelu) +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "weight"), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.prelu + """ + torch._check( + isinstance(a, TensorLike), + lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}", + ) + torch._check( + isinstance(weight, TensorLike), + lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}", + ) + + if weight.numel() != 1: + torch._check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.") + channel_size = a.shape[1] if a.ndim >= 2 else 1 + torch._check( + weight.numel() == channel_size, + lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers =" + f" {weight.numel()} and channel size = {channel_size}.", + ) + + torch._check( + weight.ndim == 0 or weight.ndim == 1, + lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: " + f"ndim = {weight.ndim}", + ) + if a.ndim == 0: + weight = weight[0] if weight.ndim == 1 else weight + else: + weight = prims.broadcast_in_dim( + weight, a.shape, () if weight.ndim == 0 else (0 if a.ndim == 1 else 1,) + ) + + return torch.where(a > 0, a, a * weight) + + +@register_decomposition(aten.relu6) +@_inplace_wrapper +@out_wrapper() +def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.relu6 + """ + if inplace: + raise NotImplementedError + + # See https://github.com/pytorch/pytorch/pull/81142#discussion_r918220126 + # It may be better to use clamp here, but we use hardtanh to replicate + # the behavior of the existing implementation + return torch.nn.functional.hardtanh(a, 0, 6) + + +@register_decomposition(aten.glu) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType: + dim = utils.canonicalize_dims(a.ndim, dim) + torch._check( + a.shape[dim] % 2 == 0, + lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}", + ) + b, c = torch.tensor_split(a, 2, dim) + + return b * torch.sigmoid(c) + + +@register_decomposition(aten.pairwise_distance) +@out_wrapper() +def pairwise_distance( + x1: TensorLikeType, + x2: TensorLikeType, + p: NumberType = 2.0, + eps: NumberType = 1e-6, + keepdim=False, +) -> TensorLikeType: + return torch.linalg.vector_norm(x1 - x2 + eps, ord=p, dim=-1, keepdim=keepdim) + + +@register_decomposition(aten.pdist) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType: + torch._check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D") + torch._check(p >= 0, lambda: "pdist only supports non-negative p values") + # For p == 2 we can use an efficient implementation, but other values of p + # require creating a much bigger tensor for an intermediate step + if p == 2: + aTa = torch.mm(a, a.T) + aTa_diag = torch.diag(aTa) + t = torch.sqrt(torch.clamp(aTa_diag + aTa_diag.unsqueeze(-1) - 2 * aTa, min=0)) + else: + t = torch.linalg.vector_norm(a.unsqueeze(1) - a, ord=p, dim=2) + i = torch.triu_indices(t.shape[0], t.shape[1], offset=1, device=a.device) + return t.flatten().index_select(0, i[0] * t.shape[0] + i[1]) + + +@register_decomposition(aten.pixel_shuffle) +@out_wrapper() +def pixel_shuffle(self: Tensor, upscale_factor: int): + torch._check( + self.dim() >= 3, + lambda: f"pixel_shuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)", + ) + batch = self.shape[:-3] + C_out = self.shape[-3] // upscale_factor**2 + HW_out = (self.shape[-2] * upscale_factor, self.shape[-1] * upscale_factor) + n = len(batch) + B_dims = range(n) + C_dim, r1_dim, r2_dim, H_dim, W_dim = range(n, n + 5) + return ( + self.view( + *batch, + C_out, + upscale_factor, + upscale_factor, + self.shape[-2], + self.shape[-1], + ) + .permute(*B_dims, C_dim, H_dim, r1_dim, W_dim, r2_dim) + .reshape(*batch, C_out, *HW_out) + .clone(memory_format=utils.suggest_memory_format(self)) + ) + + +@register_decomposition(aten.pixel_unshuffle) +@out_wrapper() +def pixel_unshuffle(self: Tensor, downscale_factor: int): + torch._check( + self.dim() >= 3, + lambda: f"pixel_unshuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)", + ) + batch = self.shape[:-3] + C_out = self.shape[-3] * downscale_factor**2 + HW_out = (self.shape[-2] // downscale_factor, self.shape[-1] // downscale_factor) + n = len(batch) + B_dims = range(n) + C_dim, H_dim, r1_dim, W_dim, r2_dim = range(n, n + 5) + return ( + self.view( + *batch, + self.shape[-3], + HW_out[0], + downscale_factor, + HW_out[1], + downscale_factor, + ) + .permute(*B_dims, C_dim, r1_dim, r2_dim, H_dim, W_dim) + .reshape(*batch, C_out, *HW_out) + .clone(memory_format=utils.suggest_memory_format(self)) + ) + + +# Needed as aten.{celu_,elu_...} exist (even if they don't have the in-place kwarg) +celu_ = _make_inplace(celu) +elu_ = _make_inplace(elu) +mish_ = _make_inplace(mish) +selu_ = _make_inplace(selu) +threshold_ = _make_inplace(threshold) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1db51a97f73ff14c296420767be37c9dc4a89b05 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/special/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/special/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a7351fb8f10cad27819c4ee7cf17805a3f3d37bc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/special/__init__.py @@ -0,0 +1,238 @@ +# mypy: allow-untyped-defs +import math +from typing import Optional, Union + +import torch +import torch._prims as prims +import torch._prims_common as utils +import torch._refs as refs +from torch import Tensor +from torch._decomp import register_decomposition +from torch._prims_common import ( + ELEMENTWISE_TYPE_PROMOTION_KIND, + Number, + NumberType, + TensorLike, + TensorLikeType, +) +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper, out_wrapper +from torch._refs import ( + _make_alias, + _make_elementwise_binary_reference, + _make_elementwise_unary_reference, +) + + +__all__ = [ + "bessel_j0", + "bessel_j1", + "entr", + "erfcx", + "expit", + "i0e", + "i1", + "i1e", + "log_ndtr", + "logit", + "log_softmax", + "multigammaln", + "ndtr", + "ndtri", + "softmax", + "spherical_bessel_j0", + "xlog1py", + "zeta", +] +aten = torch._ops.ops.aten + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def bessel_j0(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_j0(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def bessel_j1(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_j1(a) + + +@register_decomposition(aten.special_entr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def entr(a: TensorLikeType) -> TensorLikeType: + return torch.where( + torch.isnan(a), + a, + torch.where(a > 0, -a * torch.log(a), torch.where(a == 0, 0, -torch.inf)), + ) + + +@register_decomposition(aten.special_erfcx) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def erfcx(a: TensorLikeType) -> TensorLikeType: + return prims.erfcx(a) + + +# alias for sigmoid +expit = _make_alias(torch.sigmoid, "expit") + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def i0e(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_i0e(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def i1(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_i1(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def i1e(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_i1e(a) + + +@register_decomposition(aten.special_log_ndtr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def log_ndtr(a: TensorLikeType) -> TensorLikeType: + # Note: M_SQRT1_2 is the value of 1 / sqrt(2) + M_SQRT1_2 = 0.707106781186547524400844362104849039 + t = a * M_SQRT1_2 + return torch.where( + a < 1.0, + torch.log(torch.special.erfcx(-t) / 2) - t * t, + torch.log1p(-torch.erfc(t) / 2), + ) + + +@register_decomposition(aten.logit) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType: + if eps is None: + eps = -1.0 + lo = eps + hi = 1 - eps + self = torch.where(self < lo, lo, torch.where(self > hi, hi, self)) + return torch.log(torch.true_divide(self, torch.sub(1, self))) + + +@register_decomposition(aten.special_xlog1py) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]): + torch._check( + isinstance(a, TensorLike) or isinstance(b, TensorLike), + lambda: 'Expected either argument a or b to be a Tensor"', + ) + + # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors. + if isinstance(a, TensorLike) and isinstance(b, Number): + # pyrefly: ignore [bad-argument-type] + b = refs.scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(b, TensorLike) and isinstance(a, Number): + # pyrefly: ignore [bad-argument-type] + a = refs.scalar_tensor(a, dtype=b.dtype, device=b.device) + + # mypy: expected "Tensor" + assert isinstance(a, TensorLike) + assert isinstance(b, TensorLike) + rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log1p(b))) + return torch.where(torch.isnan(b), float("nan"), rhs) + + +@register_decomposition(aten.mvlgamma) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def multigammaln(a: TensorLikeType, p: int) -> TensorLikeType: + c = 0.25 * p * (p - 1) * math.log(math.pi) + b = 0.5 * torch.arange(start=(1 - p), end=1, step=1, dtype=a.dtype, device=a.device) + return torch.sum(torch.lgamma(a.unsqueeze(-1) + b), dim=-1) + c + + +@register_decomposition(aten.special_ndtr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def ndtr(a: TensorLikeType) -> TensorLikeType: + # Note: M_SQRT1_2 is the value of 1 / sqrt(2) + M_SQRT1_2 = 0.707106781186547524400844362104849039 + a_sqrt_2 = a * M_SQRT1_2 + return (1 + torch.erf(a_sqrt_2)) * 0.5 + + +@register_decomposition(aten.special_ndtri) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def ndtri(a: TensorLikeType) -> TensorLikeType: + return prims.ndtri(a) + + +# Forwarding alias: the special variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def log_softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +# Forwarding alias: the special variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def spherical_bessel_j0(a: TensorLikeType) -> TensorLikeType: + return prims.spherical_bessel_j0(a) + + +# TODO: add docstring +@_make_elementwise_binary_reference( + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def zeta(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.zeta(a, b) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/special/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/special/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a9052a22b8acdb84396ce9a10c1068c0a7f7376 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_refs/special/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_strobelight/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_strobelight/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..092fafe4596f1476ea4b7fc39f5763adef159c09 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_strobelight/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_strobelight/__pycache__/cli_function_profiler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_strobelight/__pycache__/cli_function_profiler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15e6783f76833ff77f61e3200ff4276ebee8864f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_strobelight/__pycache__/cli_function_profiler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_strobelight/__pycache__/compile_time_profiler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_strobelight/__pycache__/compile_time_profiler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c82a7ddb74f43b405ab638290c1e99273b81eade Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_strobelight/__pycache__/compile_time_profiler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84e6777c9e673b5686f124da1ae8815258fba5ab Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/_fake_tensor_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/_fake_tensor_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83cf6f8faf668dce91a3c3f22acb419ef1391908 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/_fake_tensor_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/fake_impls.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/fake_impls.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26775ae89bb9d0287c1c8c10bac2e6077877ce70 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/fake_impls.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/fake_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/fake_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e2c7a9cf550c50c7c521395e89381ed3534d537 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/fake_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/functional_tensor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/functional_tensor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aed80a07abc032112cd31dc2d6ee6ab5b015d36a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/functional_tensor.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/meta_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/meta_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f938a3d853a635592c5654bb64a43db677e2b302 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/meta_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/schema_check_mode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/schema_check_mode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bf2a0d467f82ada8bdd55a7396e49364d7298f9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/__pycache__/schema_check_mode.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ab4a816261dc088cb740c6dd85575de8a36a0f5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/__init__.py @@ -0,0 +1,9 @@ +from ._core import ComplexTensor +from ._ops import ComplexTensorMode, is_complex_tensor + + +__all__ = ["ComplexTensor", "ComplexTensorMode", "is_complex_tensor"] + +ComplexTensor.__module__ = __name__ +ComplexTensorMode.__module__ = __name__ +is_complex_tensor.__module__ = __name__ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a50fa249fdd855a957207916f4901620bdceeb8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/__pycache__/_core.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/__pycache__/_core.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..095febc462e662c6720ff23f2d3e93efe67184f3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/__pycache__/_core.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_core.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_core.py new file mode 100644 index 0000000000000000000000000000000000000000..edd7568b2ef06dc3f3e6e9e2f67a586aa15c984f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_core.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING +from typing_extensions import Self + +import torch +from torch import Tensor +from torch.autograd import Function + + +if TYPE_CHECKING: + from torch._ops import OpOverload + from torch._prims_common import DeviceLikeType + from torch.autograd.function import FunctionCtx + + +class ComplexTensor(Tensor): + """A class that decomposes all ops on complex Tensors into their real and imaginary parts.""" + + _re: Tensor + _im: Tensor + + def __new__(cls, real: Tensor, imag: Tensor) -> Self: + """Initialize a ComplexTensor from its real and imaginary parts.""" + from ._ops.common import REAL_TO_COMPLEX + + shape = real.shape + device = real.device + + # TODO (hameerabbasi): `torch.compile` sometimes fails here without making these + # contiguous. Why? + real = real.contiguous() + imag = imag.contiguous() + + # TODO (hameerabbasi): + # What should we do with dtype? + # We could convert to the complex type (float32 -> complex64), but we + # can't use that model for say `bfloat16` which does not have a + # corresponding complex dtype. + # If we want to support this complex rep using any float type (see + # https://github.com/pytorch/pytorch/issues/95100) + # We either need to: + # 1) add the complex types for say `complexbf32`, knowing they can't really be used anywhere + # else. + # 2) We use the real float dtype here, and it is up to the user to know + # that dtype=float here really means complex<2xSize> with dtype + # matching that of re/im parts alone + # I'm going with 1 for now, so that I can make gradcheck and some complex + # ops work properly, but might want to discuss this in the RFP. + dtype = REAL_TO_COMPLEX.get(real.dtype) + if dtype is None: + raise TypeError( + "Unsupported dtype for constituent tensors. Supported dtypes are: " + f"{set(REAL_TO_COMPLEX.keys())!r}." + ) + storage_offset = real.storage_offset() + strides = real.stride() + layout = real.layout + pin_memory = real.is_pinned() + + assert shape == imag.shape, f"Expected imag shape {shape}, got {imag.shape}" + assert device == imag.device, ( + f"Expected imag device {device}, got {imag.device}" + ) + assert real.dtype == imag.dtype, ( + f"Expected imag dtype {real.dtype}, got {imag.dtype}" + ) + assert pin_memory == imag.is_pinned(), ( + f"Expected imag pinning {pin_memory}, got {imag.is_pinned()}" + ) + + res = Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + shape, + device=device, + dtype=dtype, + storage_offset=storage_offset, + strides=strides, + pin_memory=pin_memory, + layout=layout, + requires_grad=False, + ) + res._re = real.clone().detach() + res._im = imag.clone().detach() + + return res + + @property + def re(self) -> Tensor: + return self._re + + @property + def im(self) -> Tensor: + return self._im + + @classmethod + def __torch_dispatch__( + cls, + func: OpOverload, + types: tuple[type, ...], + args: tuple = (), + kwargs: dict | None = None, + ): + from ._ops.common import lookup_complex + + kwargs = {} if kwargs is None else kwargs + + impl = lookup_complex(func, *args, **kwargs) + if impl is None: + return NotImplemented + + return impl(*args, **kwargs) + + @staticmethod + def from_interleaved(t: Tensor) -> ComplexTensor: + t_real = torch.real(t) + t_imag = torch.imag(t) if t.dtype.is_complex else torch.zeros_like(t_real) + return Complex.apply(t_real, t_imag) + + def as_interleaved(self) -> Tensor: + return torch.complex(self.real, self.imag) + + @staticmethod + def __tensor_unflatten__( + inner_tensors: dict[str, Tensor], + meta: Any, + outer_size: tuple[int, ...], + outer_stride: tuple[int, ...], + ) -> ComplexTensor: + assert meta is None + re, im = inner_tensors["re"], inner_tensors["im"] + return ComplexTensor(re, im) + + def __tensor_flatten__(self) -> tuple[list[str], Any]: + return ["re", "im"], None + + def __repr__(self, *, tensor_contents=None) -> str: + return f"ComplexTensor(real={self.re!r}, imag={self.im!r})" + + def is_pinned(self, device: DeviceLikeType | None = None) -> bool: + return self.re.is_pinned(device) + + +class Complex(Function): + @staticmethod + def forward(ctx: FunctionCtx, real: Tensor, imag: Tensor) -> ComplexTensor: # type: ignore[bad-override] + return ComplexTensor(real, imag) + + @staticmethod + def backward(ctx: FunctionCtx, grad_output: ComplexTensor) -> tuple[Tensor, Tensor]: # type: ignore[bad-override] + return grad_output.real, grad_output.imag diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c07bdf6099b65d477e45bc7e18078eb53201dc4e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__init__.py @@ -0,0 +1,5 @@ +from . import aten, prims +from .common import ComplexTensorMode, is_complex_tensor + + +__all__ = ["ComplexTensorMode", "is_complex_tensor", "aten", "prims"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4678a60dfec591a272c3769c2d0b81f9957c8842 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/aten.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/aten.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fd7cb41fad3477ac6f8acb351027cd97f682123 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/aten.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/common.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc36160a7ad61e54b4648f29976cfbcbf8d75838 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/common.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/prims.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/prims.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc1a270b957327189b39df296e0460ff9985dff3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/prims.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/aten.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/aten.py new file mode 100644 index 0000000000000000000000000000000000000000..e638e5413c2cdc4878756c7878fb700d4901c551 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/aten.py @@ -0,0 +1,934 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING + +import torch + +from .._core import ComplexTensor +from .common import ( + _get_func_name, + COMPLEX_TO_REAL, + complex_to_real_dtype, + is_complex, + OpType, + promote_tensors, + register_binary_nonlinear, + register_complex, + register_error, + register_force_test, + register_simple, + split_complex_arg, + split_complex_tensor, +) + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + from typing import Any + +aten = torch.ops.aten + + +def register_binary_linear(op: OpType): + def impl_with_alpha( + lhs: ComplexTensor, rhs: ComplexTensor, *args, alpha, **kwargs + ) -> ComplexTensor: + return op(lhs, aten.mul(rhs, alpha, *args, **kwargs), *args, **kwargs) + + def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor: + alpha = kwargs.pop("alpha", None) + if alpha is not None: + return impl_with_alpha(lhs, rhs, *args, alpha=alpha, **kwargs) + a_r, a_i = split_complex_arg(lhs) + b_r, b_i = split_complex_arg(rhs) + out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i) + u = op(a_r, b_r, *args, **kwargs) + v = op(a_i, b_i, *args, **kwargs) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + return register_complex(op, impl) + + +@register_complex(aten.real) +def real_impl(self: ComplexTensor) -> torch.Tensor: + re, _ = split_complex_tensor(self) + return re + + +@register_complex(aten.imag) +def imag_impl(self: ComplexTensor) -> torch.Tensor: + _, im = split_complex_tensor(self) + return im + + +@register_complex(aten.is_pinned) +def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> bool: + return self.is_pinned(device) + + +SIMPLE_OPS_LIST = [ + aten.slice, + aten.flatten, + aten.view, + aten.diagonal, + aten.expand, + aten.unsqueeze, + aten.unsqueeze_, + aten.mean, + aten.sum, + aten.clone, + aten.neg, + aten.flip, + aten.permute, + aten.repeat, + aten.index_select, + aten.split, + aten.split_with_sizes, + aten.cumsum, + aten.detach, + aten.select, + aten.squeeze, + aten.zero_, + aten.transpose, + aten.t, + aten.gather, +] + +for simple_op in SIMPLE_OPS_LIST: + globals()[_get_func_name(simple_op)] = register_simple(simple_op) + +# TODO (hameerabbasi): Not being tested +SIMPLE_FORCE_TESTED_OPS = [ + aten.copy, + aten.col2im, + aten.alias, + aten.lift_fresh, + aten._unsafe_view, + aten.index, + aten._neg_view, + aten.avg_pool2d, + aten.avg_pool3d, + aten.avg_pool2d_backward, + aten.avg_pool3d_backward, + aten.masked_scatter_backward, + aten.select_backward, + aten.slice_backward, + aten.embedding, +] + +for simple_op in SIMPLE_FORCE_TESTED_OPS: + globals()[_get_func_name(simple_op)] = register_force_test( + simple_op, register_simple(simple_op) + ) + +del simple_op + +# some binary ops which we can stamp out +mul_impl = register_binary_nonlinear(aten.mul) +mul__impl = register_binary_nonlinear(aten.mul_) +mm_impl = register_binary_nonlinear(aten.mm) +dot_impl = register_binary_nonlinear(aten.dot) +bmm_impl = register_binary_nonlinear(aten.bmm) + +# TODO (hameerabbasi): Not being tested +convolution_impl = register_force_test( + aten.convolution, register_binary_nonlinear(aten.convolution) +) + +slice_scatter_impl = register_force_test( + aten.slice_scatter, register_binary_linear(aten.slice_scatter) +) +select_scatter_impl = register_force_test( + aten.select_scatter, register_binary_linear(aten.select_scatter) +) + +add_impl = register_binary_linear(aten.add) +add__impl = register_binary_linear(aten.add_) +sub_impl = register_binary_linear(aten.sub) +sub__impl = register_binary_linear(aten.sub_) +diagonal_scatter_impl = register_binary_linear(aten.diagonal_scatter) +fill__impl = register_binary_linear(aten.fill_) + + +@register_complex(aten.rsub) +def rsub_impl(lhs: ComplexTensor, rhs: ComplexTensor, alpha=None) -> ComplexTensor: + if alpha is None: + return torch.sub(rhs, lhs) # type: ignore[bad-return] + return torch.sub(rhs, lhs, alpha=alpha) # type: ignore[bad-return] + + +@register_complex(aten.div) +@register_complex(aten.true_divide) +def div_impl(lhs: ComplexTensor, rhs: ComplexTensor, *, rounding_mode=None): + if rounding_mode is not None: + raise NotImplementedError( + "`rounding_mode` other than `None` not implemented for`ComplexTensor`." + ) + a_r, a_i = split_complex_arg(lhs) + if not is_complex(rhs): + return ComplexTensor(a_r / rhs, a_i / rhs) + b_r, b_i = split_complex_arg(rhs) + out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i) + num_r = a_r * b_r + a_i * b_i + num_i = a_i * b_r - a_r * b_i + den = b_r * b_r + b_i * b_i + return ComplexTensor( + (num_r / den).to(out_dt), + (num_i / den).to(out_dt), + ) + + +@register_complex(aten.reciprocal) +def reciprocal_impl(self: ComplexTensor): + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + den = self_r * self_r + self_i * self_i + return ComplexTensor( + aten.div(self_r, den).to(out_dt), + aten.div(-self_i, den).to(out_dt), + ) + + +# reductions +@register_complex(aten.prod) +def prod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + dtype = kwargs.pop("dtype", out_dt) + kwargs["dtype"] = complex_to_real_dtype(self.dtype) + + prod_r = torch.prod(torch.abs(self), *args, **kwargs) + sum_phi = torch.sum(torch.angle(self), *args, **kwargs) + u = prod_r * torch.cos(sum_phi) + v = prod_r * torch.sin(sum_phi) + return ComplexTensor(u, v).to(dtype) # type: ignore[bad-return] + + +@register_complex(aten.pow) +def pow_impl(self: ComplexTensor, exponent: ComplexTensor) -> ComplexTensor: + out_dt, (self, exponent) = promote_tensors(self, exponent) + return torch.exp(exponent * torch.log(self)).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.cumprod) +def cumprod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: + dtype = kwargs.pop("dtype", self.dtype) + kwargs["dtype"] = complex_to_real_dtype(dtype) + + prod_r = torch.cumprod(torch.abs(self), *args, **kwargs) + sum_phi = torch.cumsum(torch.angle(self), *args, **kwargs) + u = prod_r * torch.cos(sum_phi) + v = prod_r * torch.sin(sum_phi) + return ComplexTensor(u, v) + + +# unary funcs, +# most of these are simple or require some kind of identity +@register_complex(aten.abs) +def abs_impl(self: ComplexTensor) -> torch.Tensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + result = torch.hypot(x, y) + return result.to(out_dt) + + +@register_complex(aten.angle) +def angle_impl(self: ComplexTensor) -> torch.Tensor: + x, y = split_complex_tensor(self) + return torch.atan2(y, x) + + +@register_complex(aten.acos) +def acos_impl(self: ComplexTensor) -> ComplexTensor: + _, y = split_complex_tensor(self) + acosh_z = torch.acosh(self) + assert isinstance(acosh_z, ComplexTensor) + acosh_z_re, acosh_z_im = split_complex_tensor(acosh_z) + sign_im = 2 * torch.signbit(y) - 1 + return ComplexTensor(torch.abs(acosh_z_im), sign_im * torch.abs(acosh_z_re)) + + +@register_complex(aten.asin) +def asin_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + asinh_iz = torch.asinh(ComplexTensor(-y, x)) + assert isinstance(asinh_iz, ComplexTensor) + asinh_iz_re, asinh_iz_im = split_complex_tensor(asinh_iz) + return ComplexTensor(asinh_iz_im, -asinh_iz_re) + + +@register_complex(aten.atan) +def atan_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + tanh_iz = torch.atanh(ComplexTensor(-y, x)) + assert isinstance(tanh_iz, ComplexTensor) + tanh_iz_re, tanh_iz_im = split_complex_tensor(tanh_iz) + return ComplexTensor(tanh_iz_im, -tanh_iz_re) + + +@register_complex(aten.asinh) +def asinh_impl(self: ComplexTensor) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + return torch.log(self + torch.sqrt(self * self + 1)).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.acosh) +def acosh_impl(self: ComplexTensor) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + return torch.log(self + torch.sqrt(self * self - 1)).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.atanh) +def atanh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + + ret = 0.5 * ( + torch.log(ComplexTensor(1 + x, y)) - torch.log(ComplexTensor(1 - x, -y)) + ) + assert isinstance(ret, ComplexTensor) + ret_re, ret_im = split_complex_tensor(ret) + + return ComplexTensor(ret_re.to(out_dt), ret_im.to(out_dt)) + + +@register_complex(aten.cos) +def cos_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + return torch.cosh(ComplexTensor(-y, x)) # type: ignore[bad-return] + + +@register_complex(aten.cosh) +def cosh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + u = torch.cosh(x) * torch.cos(y) + v = torch.sinh(x) * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.sin) +def sin_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + sinh_iz = torch.sinh(ComplexTensor(-y, x)) + assert isinstance(sinh_iz, ComplexTensor) + sinh_iz_re, sinh_iz_im = split_complex_tensor(sinh_iz) + return ComplexTensor(sinh_iz_im, -sinh_iz_re) + + +@register_complex(aten.sinh) +def sinh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + u = torch.sinh(x) * torch.cos(y) + v = torch.cosh(x) * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.tan) +def tan_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + tanh_iz = torch.tanh(ComplexTensor(-y, x)) + assert isinstance(tanh_iz, ComplexTensor) + tanh_iz_re, tanh_iz_im = split_complex_tensor(tanh_iz) + return ComplexTensor(tanh_iz_im, -tanh_iz_re) + + +@register_complex(aten.tanh) +def tanh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + + _2x = 2 * x + _2y = 2 * y + _d = torch.cosh(_2x) + torch.cos(_2y) + _2xsh = torch.sinh(_2x) + + out_re = _2xsh / _d + out_im = torch.sin(_2y) / _d + + return ComplexTensor(out_re.to(out_dt), out_im.to(out_dt)) + + +@register_complex(aten.exp) +def exp_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + ex = torch.exp(x) + u = ex * torch.cos(y) + v = ex * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.expm1) +def expm1_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + # TODO (hameerabbasi): The two lines below may have numerical issues + ex = torch.exp(x) + u = ex * torch.cos(y) - 1 + v = ex * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.log) +def log_impl(self: ComplexTensor) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + re = torch.log(torch.abs(self)) + im = torch.angle(self) + return ComplexTensor(re, im).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.log1p) +def log1p_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + # TODO (hameerabbasi): The line below may have numerical issues + return torch.log(ComplexTensor(x + 1, y)) # type: ignore[bad-return] + + +@register_complex(aten.any) +def any_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + x, y = split_complex_tensor(self) + return torch.any(x, *args, **kwargs) | torch.any(y, *args, **kwargs) + + +@register_complex(aten.all) +def all_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + x, y = split_complex_tensor(self) + return torch.any(x, *args, **kwargs) & torch.any(y, *args, **kwargs) + + +@register_complex(aten.eq) +def eq_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor: + a_r, a_i = split_complex_arg(self) + b_r, b_i = split_complex_arg(rhs) + return torch.eq(a_r, b_r, *args, **kwargs) & torch.eq(a_i, b_i, *args, **kwargs) + + +@register_complex(aten.ne) +def ne_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor: + a_r, a_i = split_complex_tensor(self) + b_r, b_i = split_complex_arg(rhs) + return torch.ne(a_r, b_r, *args, **kwargs) | torch.ne(a_i, b_i, *args, **kwargs) + + +@register_complex(aten.isnan) +def isnan_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.isnan(re) | torch.isnan(im) + + +@register_complex(aten.isinf) +def isinf_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.isinf(re) | torch.isinf(im) + + +@register_complex(aten.isfinite) +def isfinite_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.isfinite(re) & torch.isfinite(im) + + +@register_complex(aten.isclose) +def isclose_impl( + self: ComplexTensor, + rhs: ComplexTensor, + rtol=1e-5, + atol=1e-8, + equal_nan: bool = False, +) -> torch.Tensor: + abs_diff = torch.abs(self - rhs) + abs_other = torch.abs(rhs) + basic_condition = abs_diff <= (rtol * abs_other + atol) + + # This is the nontrivial part + if equal_nan: + a_r, a_i = split_complex_tensor(self) + b_r, b_i = split_complex_arg(rhs) + + a_r_nan = torch.isnan(a_r) + b_r_nan = torch.isnan(b_r) + a_i_nan = torch.isnan(a_i) + b_i_nan = torch.isnan(b_i) + a_nan = a_r_nan | a_i_nan + + # This logical expression makes sure that the isnan of both the real and imaginary parts + # matches (so 1 + nan*i doesn't equal nan + 1*i) + equal_nan_condition = ((a_r_nan == b_r_nan) & (a_i_nan == b_i_nan)) & a_nan + return basic_condition | equal_nan_condition + + return basic_condition + + +ERROR_OPS_LIST = [ + aten.lt, + aten.le, + aten.gt, + aten.ge, + aten.amin, + aten.amax, + aten.clamp, + aten.ceil, + aten.floor, + aten.minimum, + aten.maximum, + aten.trunc, + aten.sign, + aten.argmax, + aten.argmin, + aten.sort, + aten.topk, + aten.round, + aten.fmod, +] + + +ERROR_TYPES = { + aten.minimum: RuntimeError, + aten.maximum: RuntimeError, + aten.argmax: RuntimeError, + aten.argmin: RuntimeError, + aten.sort: RuntimeError, + aten.topk: RuntimeError, +} + + +for err_op in ERROR_OPS_LIST: + globals()[_get_func_name(err_op)] = register_error( + err_op, ERROR_TYPES.get(err_op, NotImplementedError) + ) + +del err_op + + +@register_complex(aten.masked_scatter) +def masked_scatter_impl( + self: ComplexTensor, mask: torch.Tensor, source: ComplexTensor +) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + source_r, source_i = split_complex_arg(source) + ret_r = torch.masked_scatter(self_r, mask, source_r) + ret_i = torch.masked_scatter(self_i, mask, source_i) + + return ComplexTensor(ret_r, ret_i) + + +@register_complex(aten.where) +def where_impl(mask: torch.Tensor, x: ComplexTensor, y: ComplexTensor) -> ComplexTensor: + x_r, x_i = split_complex_arg(x) + y_r, y_i = split_complex_arg(y) + + ret_r = torch.where(mask, x_r, y_r) + ret_i = torch.where(mask, x_i, y_i) + + return ComplexTensor(ret_r, ret_i) + + +@register_complex(aten.full_like) +def full_like_impl( + input: ComplexTensor, + fill_value: complex, + *args, + dtype: torch.dtype | None = None, + **kwargs, +) -> torch.Tensor | ComplexTensor: + # Note: Cannot be merged with the cases below due to the `fill_value` argument + input_r, input_i = split_complex_tensor(input) + if dtype is not None and dtype not in COMPLEX_TO_REAL: + return torch.full_like(input_r, fill_value, *args, dtype=dtype, **kwargs) + + if dtype is not None: + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + + fv_r, fv_i = split_complex_arg(fill_value) + ret_r = torch.full_like(input_r, fv_r, *args, **kwargs) + ret_i = torch.full_like(input_i, fv_i, *args, **kwargs) + + return ComplexTensor(ret_r, ret_i) + + +def register_like(op: OpType) -> Callable[..., torch.Tensor | ComplexTensor]: + def impl( + self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs + ) -> torch.Tensor | ComplexTensor: + self_re, self_im = split_complex_tensor(self) + + if dtype is not None and dtype not in COMPLEX_TO_REAL: + return op(self_re, *args, dtype=dtype, **kwargs) + + if dtype is not None: + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + + ret_re = op(self_re, *args, **kwargs) + ret_im = op(self_im, *args, **kwargs) + + return ComplexTensor(ret_re, ret_im) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +LIKE_OPS_LIST = [ + aten.empty_like, + aten.zeros_like, + aten.randn_like, + aten.new_zeros, +] + +for like_op in LIKE_OPS_LIST: + globals()[_get_func_name(like_op)] = register_like(like_op) + +del like_op + + +@register_complex(aten.cat) +def cat_impl(tensors: Sequence[ComplexTensor], dim: int = 0) -> ComplexTensor: + tensors_r = [] + tensors_i = [] + + for t in tensors: + t_r, t_i = split_complex_arg(t) + tensors_r.append(t_r) + tensors_i.append(t_i) + + ret_r = torch.cat(tensors_r, dim=dim) + ret_i = torch.cat(tensors_i, dim=dim) + + return ComplexTensor(ret_r, ret_i) + + +@register_complex(aten.sgn) +def sgn_impl(self: ComplexTensor) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + abs_self = torch.abs(ComplexTensor(self_r, self_i)) + mask = (self_r != 0) | (self_i != 0) + masked_sgn = ComplexTensor( + (self_r / abs_self).to(out_dt), (self_i / abs_self).to(out_dt) + ) + return torch.where(mask, masked_sgn, 0) # type: ignore[bad-return] + + +@register_complex(aten.sqrt) +def sqrt_impl(self: ComplexTensor) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + self = ComplexTensor(self_r, self_i) + self_abs_sqrt = torch.sqrt(torch.abs(self)) + self_half_angle = 0.5 * torch.angle(self) + + ret_r = self_abs_sqrt * torch.cos(self_half_angle) + ret_i = self_abs_sqrt * torch.sin(self_half_angle) + + return ComplexTensor(ret_r.to(out_dt), ret_i.to(out_dt)) + + +@register_complex(aten.rsqrt) +def rsqrt_impl(self: ComplexTensor) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + self = ComplexTensor(self_r, self_i) + self_abs_rsqrt = torch.rsqrt(torch.abs(self)) + self_neg_half_angle = -0.5 * torch.angle(self) + + ret_r = self_abs_rsqrt * torch.cos(self_neg_half_angle) + ret_i = self_abs_rsqrt * torch.sin(self_neg_half_angle) + + return ComplexTensor(ret_r.to(out_dt), ret_i.to(out_dt)) + + +@register_complex(aten.addmm) +def addmm_impl( + input: ComplexTensor, + mat1: ComplexTensor, + mat2: ComplexTensor, + out_dtype: torch.dtype | None = None, + beta: complex = 1, + alpha: complex = 1, +) -> ComplexTensor: + ret = beta * input + alpha * torch.mm(mat1, mat2) + assert isinstance(ret, ComplexTensor) + ret_r, ret_i = split_complex_tensor(ret) + if out_dtype is not None: + out_dtype = COMPLEX_TO_REAL[out_dtype] + ret_r, ret_i = ret_r.to(out_dtype), ret_i.to(out_dtype) + return ComplexTensor(ret_r, ret_i) + + +def elemwise_nonzero(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return (re != 0) | (im != 0) + + +def register_nonzero_impl(op: OpType): + def nonzero_impl( + self: ComplexTensor, other: ComplexTensor, *args, **kwargs + ) -> torch.Tensor: + return op(elemwise_nonzero(self), elemwise_nonzero(other), *args, **kwargs) + + func_name = _get_func_name(op) + nonzero_impl.__name__ = func_name + nonzero_impl.__qualname__ = func_name + + return register_complex(op, nonzero_impl) + + +logical_and_impl = register_nonzero_impl(aten.logical_and) +logical_or_impl = register_nonzero_impl(aten.logical_or) +logical_xor_impl = register_nonzero_impl(aten.logical_xor) + + +@register_complex(aten.logical_not) +def logical_not_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + return torch.logical_not(elemwise_nonzero(self), *args, **kwargs) + + +@register_complex(aten.view_as_real) +def view_as_real_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.stack([re, im], dim=-1) + + +@register_complex(aten.linalg_vector_norm) +def linalg_vector_norm_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + return torch.linalg.vector_norm(torch.abs(self), *args, **kwargs) + + +@register_force_test(aten.copy_) +def copy__impl( + self: ComplexTensor | torch.Tensor, + src: ComplexTensor | torch.Tensor, + *args, + **kwargs, +) -> ComplexTensor | torch.Tensor: + if not self.dtype.is_complex: + warnings.warn( + "Casting complex values to real discards the imaginary part", UserWarning + ) + src_re, src_im = split_complex_arg(src) + return self.copy_(src_re) + + self_re, self_im = split_complex_arg(self) + src_re, src_im = split_complex_arg(src) + + ret_re = self_re.copy_(src_re, *args, **kwargs) + ret_im = self_im.copy_(src_im, *args, **kwargs) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten._local_scalar_dense) +def _local_scalar_dense_impl(self: ComplexTensor, *args, **kwargs) -> complex: + x, y = split_complex_tensor(self) + u = aten._local_scalar_dense(x, *args, **kwargs) + v = aten._local_scalar_dense(y, *args, **kwargs) + return complex(u, v) + + +@register_complex(aten.allclose) +def allclose_impl( + input: torch.Tensor, + other: torch.Tensor, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> bool: + return torch.all( + torch.isclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan) + ).item() # type: ignore[bad-return] + + +@register_complex(aten.stack) +def stack_impl(self: list[ComplexTensor], *args, **kwargs) -> ComplexTensor: + re_im_tuples = [split_complex_arg(self_i) for self_i in self] + u = torch.stack([c[0] for c in re_im_tuples], *args, **kwargs) + v = torch.stack([c[1] for c in re_im_tuples], *args, **kwargs) + return ComplexTensor(u, v) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten._conj_physical) +@register_complex(aten.conj_physical) +def conj_physical_impl(self: ComplexTensor) -> ComplexTensor: + re, im = split_complex_tensor(self) + return ComplexTensor(re, -im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten._conj) +def _conj_impl(self: ComplexTensor) -> ComplexTensor: + re, im = split_complex_tensor(self) + return ComplexTensor(re, torch._neg_view(im)) + + +@register_complex(aten.index_add) +def index_add_impl( + self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs +) -> ComplexTensor: + alpha = kwargs.pop("alpha", None) + if alpha is not None: + source = source * alpha + self_re, self_im = split_complex_arg(self) + source_re, source_im = split_complex_arg(source) + + ret_re = self_re.index_add(dim, index, source_re) + ret_im = self_im.index_add(dim, index, source_im) + + return ComplexTensor(ret_re, ret_im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten.index_add_) +def index_add__impl( + self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs +) -> ComplexTensor: + alpha = kwargs.pop("alpha", None) + if alpha is not None: + source = source * alpha + + self_re, self_im = split_complex_arg(self) + source_re, source_im = split_complex_arg(source) + + ret_re = self_re.index_add_(dim, index, source_re) + ret_im = self_im.index_add_(dim, index, source_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.masked_fill) +def masked_fill_impl( + self: ComplexTensor, mask: torch.Tensor, value: complex +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + value_re, value_im = split_complex_arg(value) + + ret_re = self_re.masked_fill(mask, value_re) + ret_im = self_im.masked_fill(mask, value_im) + + return ComplexTensor(ret_re, ret_im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten.masked_fill_) +def masked_fill__impl( + self: ComplexTensor, mask: torch.Tensor, value: complex +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + value_re, value_im = split_complex_arg(value) + + ret_re = self_re.masked_fill_(mask, value_re) + ret_im = self_im.masked_fill_(mask, value_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.constant_pad_nd) +def constant_pad_nd_impl( + self: ComplexTensor, pad, value: complex | None = None +) -> ComplexTensor: + self_re, self_im = split_complex_tensor(self) + if value is None: + ret_re = aten.constant_pad_nd(self_re, pad) + ret_im = aten.constant_pad_nd(self_im, pad) + else: + value_re, value_im = split_complex_arg(value) + ret_re = aten.constant_pad_nd(self_re, pad, value_re) + ret_im = aten.constant_pad_nd(self_im, pad, value_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.var) +def var_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + self_re, self_im = split_complex_tensor(self) + return torch.var(self_re, *args, **kwargs) + torch.var(self_im, *args, **kwargs) + + +@register_complex(aten.scatter_add) +def scatter_add_impl( + self: ComplexTensor, dim, index, src: ComplexTensor +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + src_re, src_im = split_complex_arg(src) + + ret_re = torch.scatter_add(self_re, dim, index, src_re) + ret_im = torch.scatter_add(self_im, dim, index, src_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.scatter_add_) +def scatter_add__impl( + self: ComplexTensor, dim, index, src: ComplexTensor +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + src_re, src_im = split_complex_arg(src) + + out_re = self_re.scatter_add_(dim, index, src_re) + out_im = self_im.scatter_add_(dim, index, src_im) + + return ComplexTensor(out_re, out_im) + + +@register_complex(aten.index_put_) +def index_put__impl( + self: ComplexTensor, + indices: tuple[torch.Tensor, ...], + values: ComplexTensor, + accumulate: bool = False, +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + values_re, values_im = split_complex_arg(values) + + out_re = self_re.index_put_(indices, values_re, accumulate=accumulate) + out_im = self_im.index_put_(indices, values_im, accumulate=accumulate) + + return ComplexTensor(out_re, out_im) + + +@register_complex(aten.tanh_backward) +def tanh_backward(out_grad: torch.Tensor, y: torch.Tensor): + return out_grad * (1.0 - y * y).conj_physical() + + +@register_complex(aten.diagonal_backward) +def diagonal_backward( + grad_output: torch.Tensor, input_sizes: list[int], offset: int, dim1: int, dim2: int +): + grad_input = grad_output.new_zeros(input_sizes) + return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2) + + +def _dt_to_real(dt: torch.dtype | Any) -> torch.dtype | Any: + if not isinstance(dt, torch.dtype): + return dt + + return COMPLEX_TO_REAL[dt] + + +def register_to_impl(op: OpType): + """Register an op similar to `aten.to`, but may have different signatures.""" + + def impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor | ComplexTensor: + x, y = split_complex_tensor(self) + try: + args = tuple(_dt_to_real(a) for a in args) + kwargs = {k: _dt_to_real(v) for k, v in kwargs.items()} + except KeyError: + return op(x, *args, **kwargs) + + return ComplexTensor(op(x, *args, **kwargs), op(y, *args, **kwargs)) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +to_impl = register_to_impl(aten.to) +_to_copy_impl = register_to_impl(aten._to_copy) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/common.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/common.py new file mode 100644 index 0000000000000000000000000000000000000000..88532efe224bba013b221000a988b594ea01b2cf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/common.py @@ -0,0 +1,317 @@ +from collections.abc import Callable +from typing import Any, overload, TypeAlias +from typing_extensions import TypeIs + +import torch +from torch import Tensor +from torch._decomp import get_decompositions +from torch._ops import OpOverload, OpOverloadPacket +from torch._refs import is_complex as _is_complex +from torch.types import Number +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + +from .._core import ComplexTensor + + +OpType: TypeAlias = OpOverloadPacket | OpOverload + +TableType: TypeAlias = dict[OpType, Callable] + +# Mapping from ops to implementations +COMPLEX_OPS_TABLE: TableType = {} + +COMPLEX_TO_REAL = { + torch.complex128: torch.float64, + torch.complex64: torch.float32, + torch.complex32: torch.float16, +} + +REAL_TO_COMPLEX = {v: k for k, v in COMPLEX_TO_REAL.items()} + +# Used to promote dtypes in `promote_real_cpu_tensors` +PROMOTE_TYPES = { + torch.float16: torch.float32, + torch.bfloat16: torch.float32, + torch.complex32: torch.complex64, +} + + +def is_complex_tensor(obj: Any, /) -> TypeIs[ComplexTensor]: + r"""Returns True if the input is a ComplexTensor, else False + + Args: + a: any input + + Examples: + + >>> # xdoctest: +SKIP + >>> from torch.complex import ComplexTensor + >>> data = torch.zeros((3, 2), dtype=torch.complex64) + >>> ct = ComplexTensor.from_interleaved(data) + >>> is_complex_tensor(ct) + True + """ + return isinstance(obj, ComplexTensor) + + +@overload +def promote_tensors( + *tensors: ComplexTensor, +) -> tuple[torch.dtype, tuple[ComplexTensor, ...]]: ... + + +@overload +def promote_tensors( + *tensors: Tensor, +) -> tuple[torch.dtype, tuple[Tensor, ...]]: ... + + +def promote_tensors( + *tensors: Tensor | ComplexTensor, +) -> tuple[torch.dtype, tuple[Tensor | ComplexTensor, ...]]: + """ + Promotes all tensors to a common dtype. + Additionally promotes CPU tensors to at least `float32`. + """ + tensor = next(t for t in tensors if isinstance(t, Tensor)) + out_dt = tensor.dtype + for t in tensors: + if isinstance(t, Tensor): + out_dt = torch.promote_types(out_dt, t.dtype) + + prom_dt = PROMOTE_TYPES.get(out_dt, out_dt) + return out_dt, tuple( + t.to(prom_dt) if isinstance(t, Tensor) else torch.asarray(t, dtype=prom_dt) + for t in tensors + ) + + +def register_complex( + op: OpType, + func_impl: Callable | None = None, +): + """Decorator to register an implementation for some ops in some dispatch tables""" + + def inner(func): + if COMPLEX_OPS_TABLE.get(op, func) is not func: + raise RuntimeError(f"Attempted to register multiple functions for {op}") + COMPLEX_OPS_TABLE[op] = func + return func + + if func_impl is None: + return inner + + return inner(func_impl) + + +FORCE_TEST_LIST: list[OpType] = [] + + +def register_force_test(op: OpType, *args, **kwargs): + """Will attempt to test these ops even if they err on "normal" inputs""" + FORCE_TEST_LIST.append(op) + return register_complex(op, *args, **kwargs) + + +DECOMPOSITIONS = get_decompositions(list(torch.ops.aten)) # type: ignore[no-matching-overload] + + +def lookup_complex(func: OpOverload, *args, **kwargs) -> Callable | None: + """ + Lookup an impl from the table. + + Try the particular overload first, then the overload packet. + + If nothing is found, try the decompositions with both. + """ + return COMPLEX_OPS_TABLE.get( + func, + COMPLEX_OPS_TABLE.get( + func.overloadpacket, + DECOMPOSITIONS.get(func, DECOMPOSITIONS.get(func.overloadpacket)), + ), + ) + + +def is_complex(x: Any, /) -> bool: + """Utility to detect if a given object is (known) to be complex.""" + return (isinstance(x, Tensor) and _is_complex(x)) or isinstance(x, complex) + + +@overload +def split_complex_arg( + arg: Tensor | ComplexTensor, +) -> tuple[Tensor, Tensor]: ... + + +@overload +def split_complex_arg( + arg: complex | Number, +) -> tuple[Number, Number]: ... + + +def split_complex_arg( + arg: Tensor | ComplexTensor | complex | Number, +) -> tuple[Tensor, Tensor] | tuple[Number, Number]: + """ + Split a complex argument into a real/imaginary component. + + If real, use zero for the imaginary part. + """ + if isinstance(arg, ComplexTensor): + return split_complex_tensor(arg) + if isinstance(arg, Tensor): + if is_complex(arg): + return arg.real, arg.imag + return arg, torch.zeros_like(arg) + # TODO (hameerabbasi): Should there be a `torch.SymComplex`? + if isinstance(arg, complex): + return arg.real, arg.imag + if isinstance(arg, float | torch.SymFloat): + return arg, 0.0 + if isinstance(arg, int | torch.SymInt): + return arg, 0 + if isinstance(arg, bool | torch.SymBool): + return arg, False + raise TypeError(f"Expected tensor or number got, {type(arg)}") + + +def split_complex_tensor(complex_tensor: ComplexTensor) -> tuple[Tensor, Tensor]: + """Split a ComplexTensor into its real and imaginary parts.""" + return complex_tensor.re, complex_tensor.im + + +def complex_to_real_dtype(dtype: torch.dtype) -> torch.dtype: + """Convert a complex dtype to the dtype of its real part. Return other dtypes as-is.""" + return COMPLEX_TO_REAL.get(dtype, dtype) + + +def _get_op_name(op: OpType) -> str: + """Get the op name from the op.""" + if isinstance(op, OpOverload): + op = op.overloadpacket + return str(op).split(".", 1)[1] + + +def _get_func_name(op: OpType) -> str: + """Get the name of the implementation function from the op.""" + return f"{_get_op_name(op)}_impl" + + +def register_error(op: OpType, exc_type: type[Exception] = NotImplementedError): + msg = f"`aten.{_get_op_name(op)}` not implemented for `{ComplexTensor.__name__}`." + + def ordered_impl(*args, **kwargs): + raise exc_type(msg) + + func_name = _get_func_name(op) + ordered_impl.__name__ = func_name + ordered_impl.__qualname__ = func_name + + return register_force_test(op, ordered_impl) + + +def register_binary_nonlinear(op: OpType) -> Callable: + """Register a "multiplication-style" op, e.g. aten.mul, aten.mm, ...""" + + def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor: + a_r, a_i = split_complex_arg(lhs) + b_r, b_i = split_complex_arg(rhs) + out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i) + real = op(a_r, b_r, *args, **kwargs) - op(a_i, b_i, *args, **kwargs) + imag = op(a_r, b_i, *args, **kwargs) + op(a_i, b_r, *args, **kwargs) + return ComplexTensor(real.to(out_dt), imag.to(out_dt)) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +def register_simple(op: OpType): + """Register an op which can be applied independently to the real and complex parts to get the result.""" + + def impl( + self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs + ) -> ComplexTensor: + x, y = split_complex_tensor(self) + if dtype is not None and dtype not in COMPLEX_TO_REAL: + raise RuntimeError( + "Non-complex `dtype` specified, please write custom impl." + ) + + if dtype in COMPLEX_TO_REAL: + assert dtype is not None + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + + u = op(x, *args, **kwargs) + v = op(y, *args, **kwargs) + + u_flat, u_spec = tree_flatten(u) + v_flat, v_spec = tree_flatten(v) + assert u_spec == v_spec + out_flat = [ + ComplexTensor(ui, vi) for ui, vi in zip(u_flat, v_flat, strict=False) + ] + return tree_unflatten(out_flat, u_spec) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +def _as_complex_tensor(arg: Tensor | Any) -> Tensor | ComplexTensor | Any: + """Convert a Tensor with complex dtypes to a ComplexTensor. Pass along other args as-is.""" + if ( + not isinstance(arg, ComplexTensor) + and isinstance(arg, Tensor) + and arg.dtype in COMPLEX_TO_REAL + ): + return ComplexTensor.from_interleaved(arg) + return arg + + +def _as_interleaved(arg: ComplexTensor | Any) -> Tensor | Any: + """Convert a ComplexTensor to a Tensor with a complex dtype. Pass other arguments as-is.""" + if isinstance(arg, ComplexTensor): + return arg.as_interleaved() + return arg + + +class ComplexTensorMode(TorchDispatchMode): + _compile: bool + + """ A TorchDispatchMode to replace any Tensor that has a complex dtype with a ComplexTensor for the computation. """ + + def __init__(self, _dispatch_key=None, *, _compile: bool = False): + """Initialize a ComplexTensorMode. + + Args: + _dispatch_key: passed on to TorchDispatchMode + _compile: Compile the op before the computation + """ + super().__init__(_dispatch_key) + self._compile = _compile + + def __torch_dispatch__( + self, + func: OpOverload, + types: tuple[type], + args: tuple = (), + kwargs: dict[str, Any] | None = None, + ): + if kwargs is None: + kwargs = {} + + # TODO (hameerabbasi): Test perf with `_compile` set to `True` + if self._compile: + func = torch.compile(func) # type: ignore[bad-assignment] + + args = tree_map(_as_complex_tensor, args) + kwargs = tree_map(_as_complex_tensor, kwargs) + + return tree_map(_as_interleaved, func(*args, **kwargs)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/prims.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/prims.py new file mode 100644 index 0000000000000000000000000000000000000000..9a237b32d99042a649632a432290919ea4db9c46 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/prims.py @@ -0,0 +1,34 @@ +import torch + +from .._core import ComplexTensor +from .common import ( + complex_to_real_dtype, + register_complex, + register_force_test, + split_complex_tensor, +) + + +prims = torch.ops.prims +aten = torch.ops.aten + + +# TODO (hameerabbasi): Not being tested +@register_force_test(prims.convert_element_type) +def convert_element_type_impl(x: ComplexTensor, dtype: torch.dtype) -> ComplexTensor: + dtype = complex_to_real_dtype(dtype) + u, v = split_complex_tensor(x) + u_out = prims.convert_element_type(u, dtype) + v_out = prims.convert_element_type(v, dtype) + + return ComplexTensor(u_out, v_out) + + +@register_complex(prims.conj_physical) +def conj_physical_impl(self: ComplexTensor) -> ComplexTensor: + return aten._conj_physical(self) + + +@register_complex(prims.conj) +def conj_impl(self: ComplexTensor) -> ComplexTensor: + return aten._conj(self) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/compiler/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/compiler/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbf8bce6a236beb5bc138ac3b974a1149e757749 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/compiler/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/compiler/__pycache__/_cache.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/compiler/__pycache__/_cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bf044ec55100ad270c4c7868f0ede8b095dee9a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/compiler/__pycache__/_cache.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/compiler/__pycache__/config.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/compiler/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac88fe1dc31d0b64c8be32cf0edbec9839dc1315 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/compiler/__pycache__/config.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29e841f9715bf057be952d00675351041c381da7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/amp/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/amp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dae673c7b2313480a940a9cc19517dba21d20d3a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/amp/__init__.py @@ -0,0 +1,3 @@ +# pyrefly: ignore [deprecated] +from .autocast_mode import autocast +from .grad_scaler import GradScaler diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/amp/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/amp/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e973e65ddbe641c6e36166ee0efbebedfcaf42d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/amp/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/amp/__pycache__/autocast_mode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/amp/__pycache__/autocast_mode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18dcd5998e7ff0db2e6fab85d8f4bb7eb491da09 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/amp/__pycache__/autocast_mode.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/amp/__pycache__/grad_scaler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/amp/__pycache__/grad_scaler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..216e10a0eee8f8510711cb355137c1e0fad30d7a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/amp/__pycache__/grad_scaler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/amp/autocast_mode.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/amp/autocast_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..f0f81060d4a01fc6857138c49ec8276bee59b90d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/amp/autocast_mode.py @@ -0,0 +1,71 @@ +# mypy: allow-untyped-defs +import sys +from typing import Any +from typing_extensions import deprecated + +import torch + + +__all__ = ["autocast"] + + +@deprecated( + "`torch.cpu.amp.autocast(args...)` is deprecated. " + "Please use `torch.amp.autocast('cpu', args...)` instead.", + category=FutureWarning, +) +class autocast(torch.amp.autocast_mode.autocast): + r""" + See :class:`torch.autocast`. + ``torch.cpu.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cpu", args...)`` instead. + """ + + # TODO: remove this conditional once we stop supporting Python < 3.13 + # Prior to Python 3.13, inspect.signature could not retrieve the correct + # signature information for classes decorated with @deprecated (unless + # the __new__ static method was explicitly defined); + # + # However, this issue has been fixed in Python 3.13 and later versions. + if sys.version_info < (3, 13): + + def __new__( + cls, + enabled: bool = True, + dtype: torch.dtype = torch.bfloat16, + cache_enabled: bool = True, + ): + return super().__new__(cls) + + def __init_subclass__(cls): + pass + + def __init__( + self, + enabled: bool = True, + dtype: torch.dtype = torch.bfloat16, + cache_enabled: bool = True, + ): + if torch._jit_internal.is_scripting(): + self._enabled = enabled + self.device = "cpu" + self.fast_dtype = dtype + return + super().__init__( + "cpu", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled + ) + + def __enter__(self): + if torch._jit_internal.is_scripting(): + return self + return super().__enter__() + + # TODO: discuss a unified TorchScript-friendly API for autocast + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] + if torch._jit_internal.is_scripting(): + return + return super().__exit__(exc_type, exc_val, exc_tb) + + def __call__(self, func): + if torch._jit_internal.is_scripting(): + return func + return super().__call__(func) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/amp/grad_scaler.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/amp/grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..aefaa1c323f5ff9089fc69c7a7aabbb380cc7233 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cpu/amp/grad_scaler.py @@ -0,0 +1,35 @@ +from typing_extensions import deprecated + +import torch + + +__all__ = ["GradScaler"] + + +class GradScaler(torch.amp.GradScaler): + r""" + See :class:`torch.amp.GradScaler`. + ``torch.cpu.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cpu", args...)`` instead. + """ + + @deprecated( + "`torch.cpu.amp.GradScaler(args...)` is deprecated. " + "Please use `torch.amp.GradScaler('cpu', args...)` instead.", + category=FutureWarning, + ) + def __init__( + self, + init_scale: float = 2.0**16, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + enabled: bool = True, + ) -> None: + super().__init__( + "cpu", + init_scale=init_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + enabled=enabled, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bbf9b2d657ffcc92745a93bf0b63f5b0660bde8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_checkpointable.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_checkpointable.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55c9de32ee44aaf8609f32f2f4e4a3e7356d3dd8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_checkpointable.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_composable_state.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_composable_state.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76d9a373c33c308086a8208a17f7cac016b1812a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_composable_state.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_dist2.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_dist2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab580bc9f09134dada8c3d82e2517f5323a510d2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_dist2.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_functional_collectives.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_functional_collectives.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6dc631579e5ca2dca729b9e1170e02998eaff35 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_functional_collectives.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_functional_collectives_impl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_functional_collectives_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..298c8108eb7d57f1bb6ed607b0e5569b0a6134b4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_functional_collectives_impl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_mesh_layout.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_mesh_layout.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7eb31c66be111abdd3663d760bf70aa8b9f7dfdd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_mesh_layout.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_serialization.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_serialization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f4fc6da64f3465736862f7d2fff108365d51631 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_serialization.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_state_dict_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_state_dict_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65a9e6cf6f789902d021c8dc3ba1f32cbb7cb55f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/_state_dict_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/argparse_util.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/argparse_util.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a481722e2b9757a5b08d406079c57b8a030570f8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/argparse_util.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/c10d_logger.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/c10d_logger.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..653ca49beba3a2f5413be1774641c99b72944144 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/c10d_logger.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/collective_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/collective_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe67d3bed7c3804e194ebf4074a37b7a061874c9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/collective_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/constants.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/constants.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60b4fbfe66a5bc8c83cff85b0b9aa74bae39e515 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/constants.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/device_mesh.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/device_mesh.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f60461c6ec44caae9aa72abe2f860c07929a6fdd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/device_mesh.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/launch.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/launch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9418ac78ed890a47aa46e42fab205228b6e19b17 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/launch.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/logging_handlers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/logging_handlers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dffdd9dbac6b1dcc4fd9a2d63a74bb9018b4b668 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/logging_handlers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/remote_device.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/remote_device.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c63a480c2c9b2d7243df3b34bdd990f4c7f48b9c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/remote_device.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/rendezvous.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/rendezvous.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9e2ee43b64c44a0f0b71868ab252fd7f01b0255 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/rendezvous.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/run.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/run.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d9d06118e5a02f41a5b2a141cf17e5867c513b6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/run.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ee0e98be879292414deba587f706eff7d742fa1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3e38281810696814a7eae148eff19b58c10e072b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/__init__.py @@ -0,0 +1,3 @@ +from .checkpoint_activation import checkpoint +from .contract import _get_registry, contract +from .replicate import replicate diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b19f19d25dd314324789bca3a7427d298dc10054 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/checkpoint_activation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/checkpoint_activation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe9b864b320930788a87d203f4c9406c220bfcc4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/checkpoint_activation.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/contract.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/contract.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14bdf60e6b2312c7a3ccd4d89263cd5ee29bf1c4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/contract.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/replicate.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/replicate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f29f1efa3110289b043efa02d4f845bc881b46ca Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/replicate.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/replicate_with_fsdp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/replicate_with_fsdp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1486d9a0f14d2de2a131aa118a9eed34dd185dc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/replicate_with_fsdp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/checkpoint_activation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/checkpoint_activation.py new file mode 100644 index 0000000000000000000000000000000000000000..93ae14110ef79a3b9b065c4ca1e8af613bd90ff5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/checkpoint_activation.py @@ -0,0 +1,134 @@ +# mypy: allow-untyped-defs +from collections.abc import Generator +from contextlib import AbstractContextManager, contextmanager, nullcontext +from typing import Any + +import torch +import torch.nn as nn +from torch.utils.checkpoint import ( + _checkpoint_without_reentrant_generator, + _DEFAULT_DETERMINISM_MODE, +) + +from .contract import _State, contract + + +@contextmanager +def _no_hook(module: nn.Module, user_ctx: AbstractContextManager | None = None): + r""" + Disable hooks installed by checkpoint to avoid unintentional recursion + during backward recomputation. + """ + + with user_ctx if user_ctx else nullcontext(): + orig_enable_hook = checkpoint.state(module).enable_hook + checkpoint.state(module).enable_hook = False + try: + yield + finally: + checkpoint.state(module).enable_hook = orig_enable_hook + + +class _CheckpointState(_State): + enable_hook: bool = False + _ac_generator: Generator[None, None, None] | None + + +@contract(_CheckpointState) +def checkpoint(module: nn.Module, **kwargs) -> nn.Module: + r""" + This is a composable activation checkpointing API. Unlike functional + activation checkpointing APIs, this one does not require changing model + source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs, + this one does not modify model structure or fully-qualified names either. + Under the hood, it registers activation checkpointing logic as pre- and + post-forward hooks. Hence, this API can be easily applied to any model or + sub-modules in the model. + + Args: + module (nn.Module): the target model or sub-module to apply activation + checkpointing. + + Example:: + >>> # xdoctest: +SKIP + >>> import torch.nn as nn + >>> + >>> class MyModel(nn.Module): + >>> def __init__(self) -> None: + >>> super().__init__() + >>> self.l1 = nn.Linear(10, 10) + >>> self.l2 = nn.Linear(10, 10) + >>> + >>> def forward(self, x): + >>> return self.l2(self.l1(x)) + >>> + >>> model = MyModel() + >>> checkpoint(model.l1) # apply activation checkpointing only to l1 + >>> model(torch.zeros(2, 10)).sum().backward() + + """ + torch._C._log_api_usage_once("torch.distributed.checkpoint") + + use_reentrant = kwargs.pop("use_reentrant", False) + if use_reentrant: + raise NotImplementedError( + "use_reentrant=True is not supported in composable checkpoint. " + "Please use torch.utils.checkpoint.checkpoint instead." + ) + preserve_rng_state = kwargs.pop("preserve_rng_state", True) + user_context_fns = kwargs.pop("context_fn", None) + determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE) + debug = kwargs.pop("debug", False) + early_stop = kwargs.pop("early_stop", True) + + if kwargs: + raise ValueError( + "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) + ) + + def forward_pre_hook( + module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> None: + if checkpoint.state(module).enable_hook: + + def context_fns(): + if user_context_fns is not None: + ctx1, ctx2 = user_context_fns() + return ctx1, _no_hook(module, ctx2) + else: + return nullcontext(), _no_hook(module) + + gen = _checkpoint_without_reentrant_generator( + module, + preserve_rng_state, + context_fns, + determinism_check, + debug, + early_stop, + *args, + **kwargs, + ) + checkpoint.state(module)._ac_generator = gen + next(gen) + + def forward_hook(module: nn.Module, inputs: tuple[Any, ...], output: Any) -> Any: + if checkpoint.state(module).enable_hook: + try: + gen = checkpoint.state(module)._ac_generator + assert gen is not None + next(gen) + except StopIteration: + pass + else: + raise RuntimeError( + "Expected non-reentrant activation checkpoint generator to be exhausted, but it was not!" + ) + + # Ensure that we no longer hold on to the generator. always_call=True helps ensure we + # clear this even in the case of exception in fwd pass. + checkpoint.state(module)._ac_generator = None + + checkpoint.state(module).enable_hook = True + module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) + module.register_forward_hook(forward_hook, prepend=True, always_call=True) + return module diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/contract.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/contract.py new file mode 100644 index 0000000000000000000000000000000000000000..c810da8cb583c1199cda7087f7feb45b8ab6c443 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/contract.py @@ -0,0 +1,259 @@ +# mypy: allow-untyped-defs +import uuid +from collections import OrderedDict +from collections.abc import Callable +from functools import wraps +from typing import Concatenate, Generic, Protocol +from typing_extensions import ParamSpec, TypeVar + +import torch +import torch.nn as nn +from torch.distributed._composable_state import _State +from torch.distributed.utils import _get_root_modules + + +_T = TypeVar("_T", covariant=True) +_P = ParamSpec("_P") + + +def generate_state_key(string="__composable_api_state_key"): + return f"{string}_{str(uuid.uuid4())}" + + +STATE_KEY = generate_state_key() +REGISTRY_KEY = generate_state_key() + + +# TODO: we can add additional info to RegistryItem to share across APIs. E.g., +# we can add args and kwargs here, and then we can detect whether fully_shard +# is combined with reentrant activation checkpointing and error out with a clear +# message. +class RegistryItem: + pass + + +_TState = TypeVar("_TState", bound="_State", covariant=True) +_M = TypeVar("_M", nn.Module, list[nn.Module]) + + +class _ContractFn(Protocol, Generic[_P, _T, _TState]): + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ... + + def state(self, module: nn.Module) -> _TState: ... + + +def contract( + state_cls: type[_TState] = _State, # type: ignore[assignment] +) -> Callable[ + [Callable[Concatenate[_M, _P], _M]], + _ContractFn[Concatenate[_M, _P], _M, _TState], +]: + r""" + Decorate a function as a composable distributed API, where the first + argument of the function must be an :class:`nn.Module` instance or sequence + of :class:`nn.Module` instances. + + The decorator verifies that the decorated function does not modify + fully-qualified names (FQNs) for parameters, buffers, or modules. The + decorated function can return different module instances than the input + modules; the FQN invariant will be enforced following the input order. + + When a function ``func`` is decorated by ``@contract()``, a + ``.state(module: nn.Module)`` method will be installed to the decorated + function. Then you can retrieve and modify the state on a module by calling + ``func.state(module)``. + + Example:: + >>> # xdoctest: +SKIP + >>> import torch.nn as nn + >>> + >>> class MyModel(nn.Module): + >>> def __init__(self) -> None: + >>> super().__init__() + >>> self.l1 = nn.Linear(10, 10) + >>> self.l2 = nn.Linear(10, 10) + >>> + >>> def forward(self, x): + >>> return self.l2(self.l1(x)) + >>> + >>> @contract() + >>> def my_feature(module: nn.Module) -> nn.Module: + >>> my_feature.state(module).some_state = "any value" + >>> return module + >>> + >>> model = MyModel() + >>> my_feature(model.l1) + >>> assert my_feature.state(model.l1).some_state == "any value" + >>> my_feature(model.l2) + >>> model(torch.randn(2, 10)).sum().backward() + """ + + # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package + @wraps(state_cls) # type: ignore[arg-type] + def inner( + func: Callable[Concatenate[_M, _P], _M], + ) -> _ContractFn[Concatenate[_M, _P], _M, _TState]: + @wraps(func) + def wrapper( + module: _M, + *args: _P.args, + **kwargs: _P.kwargs, + ) -> _M: + inp_module = module + modules: list[nn.Module] + if isinstance(module, nn.Module): + modules = [module] + else: + # If the user passes a sequence of modules, then we assume that + # we only need to insert the state object on the root modules + # (i.e. those without a parent) among the passed-in modules. + # pyrefly: ignore [no-matching-overload] + modules = _get_root_modules(list(module)) + state = state_cls() # shared across all modules + registry_item = RegistryItem() # shared across all modules + + # `func` is allowed to return different module instances than the + # input modules as long as FQNs are preserved following the input + # module order + all_orig_named_params: list[dict[str, nn.Parameter]] = [] + all_orig_named_buffers: list[dict[str, torch.Tensor]] = [] + all_orig_named_modules: list[dict[str, nn.Module]] = [] + + # pyrefly: ignore [bad-assignment] + for module in modules: + default_all_state: dict[Callable, _State] = OrderedDict() + default_registry: dict[str, RegistryItem] = OrderedDict() + all_state: dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload] + STATE_KEY, default_all_state + ) + if not isinstance(all_state, dict): + raise AssertionError( + f"Distributed composable API states corrupted: {all_state}" + ) + registry: dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload] + REGISTRY_KEY, default_registry + ) + if not isinstance(registry, dict): + raise AssertionError( + f"Distributed composable API registry corrupted: {registry}" + ) + if func in all_state or func.__name__ in registry: + raise AssertionError( + "Each distinct composable distributed API can only be applied to a " + f"module once. {func.__name__} has already been applied to the " + f"following module:\n{module}" + ) + all_state.setdefault(func, state) + registry.setdefault(func.__name__, registry_item) + + # pyrefly: ignore [missing-attribute] + all_orig_named_params.append(OrderedDict(module.named_parameters())) + # pyrefly: ignore [missing-attribute] + all_orig_named_buffers.append(OrderedDict(module.named_buffers())) + # pyrefly: ignore [missing-attribute] + all_orig_named_modules.append(OrderedDict(module.named_modules())) + + updated = func(inp_module, *args, **kwargs) + if updated is None: + updated = inp_module # type: ignore[assignment] + updated_modules: list[nn.Module] + if isinstance(updated, nn.Module): + updated_modules = [updated] + else: + updated_modules = _get_root_modules(list(inp_module)) # type: ignore[arg-type, call-overload] + + all_new_named_params: list[dict[str, nn.Parameter]] = [] + all_new_named_buffers: list[dict[str, torch.Tensor]] = [] + all_new_named_modules: list[dict[str, nn.Module]] = [] + # pyrefly: ignore [bad-assignment] + for module in updated_modules: + # pyrefly: ignore [missing-attribute] + all_new_named_params.append(OrderedDict(module.named_parameters())) + # pyrefly: ignore [missing-attribute] + all_new_named_buffers.append(OrderedDict(module.named_buffers())) + # pyrefly: ignore [missing-attribute] + all_new_named_modules.append(OrderedDict(module.named_modules())) + + num_orig_modules = len(all_orig_named_modules) + num_new_modules = len(all_new_named_modules) + if num_orig_modules != num_new_modules: + raise AssertionError( + f"{func.__name__} should return the same number of modules as input modules" + f"Inputs: {num_orig_modules} modules\n" + f"Outputs: {num_new_modules} modules" + ) + + def check_fqn(orig_fqns: list[str], new_fqns: list[str], check_key: str): + if orig_fqns == new_fqns: + return + + orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns) + orig_only = orig_fqn_set - new_fqn_set + new_only = new_fqn_set - orig_fqn_set + if len(orig_only) or len(new_only): + raise RuntimeError( + f"{check_key}" + "Composable distributed API implementations cannot modify FQNs.\n" + f"FQNs only in original: {orig_only}\n" + f"FQNs only in new: {new_only}" + ) + else: + raise RuntimeError( + f"{check_key}" + "Composable distributed API implementations cannot modify " + "the order of FQNs.\n" + f"Original FQNs: {orig_only}\n" + f"New FQNs: {new_only}" + ) + + for orig_named_params, new_named_params in zip( + all_orig_named_params, all_new_named_params + ): + check_fqn( + list(orig_named_params.keys()), + list(new_named_params.keys()), + "Checking parameters: ", + ) + for orig_named_buffers, new_named_buffers in zip( + all_orig_named_buffers, all_new_named_buffers + ): + check_fqn( + list(orig_named_buffers.keys()), + list(new_named_buffers.keys()), + "Checking buffers: ", + ) + for orig_named_modules, new_named_modules in zip( + all_orig_named_modules, all_new_named_modules + ): + check_fqn( + list(orig_named_modules.keys()), + list(new_named_modules.keys()), + "Checking modules: ", + ) + + # TODO: verify that installed distributed paradigms are compatible with + # each other. + + # pyrefly: ignore [bad-return] + return updated + + def get_state(module: nn.Module) -> _State: + return module.__dict__.setdefault( # type: ignore[call-overload] + STATE_KEY, + {}, # TODO(@yhcharles): this is a temporary fix, need a better way + ).get(func) # type: ignore[call-overload] + + wrapper.state = get_state # type: ignore[attr-defined] + + return wrapper # type: ignore[return-value] + + return inner # type: ignore[return-value] + + +def _get_registry(module: nn.Module) -> dict[str, RegistryItem] | None: + r""" + Get an ``OrderedDict`` of composable APIs that have been applied to the + ``module``, indexed by the API name. If no API has been applied, then this + returns ``None``. + """ + return getattr(module, REGISTRY_KEY, None) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..108c765ba4766bf7d9110aa67e09ac02cab00410 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/__init__.py @@ -0,0 +1,3 @@ +from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy + +from .fully_shard import FSDPModule, fully_shard, register_fsdp_forward_method diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee8c2b7487f3562539e1c15b21ba07ccd5701914 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/__pycache__/fully_shard.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/__pycache__/fully_shard.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7622e5a81a5005cf0e4dd6380574650846e02d9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/__pycache__/fully_shard.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/fully_shard.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/fully_shard.py new file mode 100644 index 0000000000000000000000000000000000000000..9e36c7b430fc89dd58cc5742f299ac607eb4367b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/fully_shard.py @@ -0,0 +1,8 @@ +# TODO: For backward compatibility, we are importing the public objects +# originally from this file. +from torch.distributed.fsdp import ( # noqa: F401 + FSDPModule, + fully_shard, + register_fsdp_forward_method, + UnshardHandle, +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/replicate.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..8cdec49468703e53b0a125a0d3c71a92ec80d00c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/replicate.py @@ -0,0 +1,254 @@ +# mypy: allow-untyped-defs +import weakref +from collections.abc import Iterable +from typing import Any, NoReturn + +import torch +import torch.nn as nn +from torch.distributed._composable_state import _State +from torch.nn.parallel import DistributedDataParallel + +from .contract import _get_registry, contract + + +_ROOT_MODULE_PREFIX = "" + + +class _ReplicateState(_State): + _ddp_weakref: weakref.ref + + def __init__(self) -> None: + super().__init__() + self.module: nn.Module = nn.ParameterList() + self.has_initialized: bool = False + self._param_list: nn.ParameterList = nn.ParameterList() + # TODO(@fegin): this variable is originally create for testing, we + # should remove this if possible. + self._orig_module = self.module + self._param_names: list[str] = [] + self._no_sync: bool = False + self._init_args: tuple[Any, ...] | None = None + self._init_kwargs: dict[str, Any] = {} + self._comm_hook_args: list[Any] = [] + + def _collect_params( + self, + module: nn.Module, + ignored_modules: set[nn.Module], + ignored_params: set[nn.Parameter], + prefix: str = _ROOT_MODULE_PREFIX, + ) -> None: + # skip if managed by fully_sharded API + if _is_fully_sharded(module): + return + + # if a module is ignored, all descendants of the module are ignored. + if module in ignored_modules: + return + + recurse_prefix = ( + f"{prefix}." if prefix != _ROOT_MODULE_PREFIX else _ROOT_MODULE_PREFIX + ) + + for n, p in module.named_parameters(recurse=False): + if p not in ignored_params: + self._param_list.append(p) + self._param_names.append(f"{recurse_prefix}{n}") + + for name, child_module in module.named_children(): + self._collect_params( + child_module, + ignored_modules, + ignored_params, + prefix=f"{recurse_prefix}{name}", + ) + + def lazy_init(self) -> None: + @torch._disable_dynamo(recursive=True) + def _lazy_init(): + assert self._init_args is not None + self.init(*self._init_args, **self._init_kwargs) + self.register_comm_hook() + self._init_args = () + self._init_kwargs = {} + + _lazy_init() + + def init( + self, + module: nn.Module, + ignored_modules: set[nn.Module], + **kwargs, + ) -> None: + if self.has_initialized: + return + + self.has_initialized = True + self.module = module + ignored_params = {p for m in ignored_modules for p in m.parameters()} + for submodule in module.modules(): + if _is_fully_sharded(submodule): + ignored_params.update(submodule.parameters()) + from torch.distributed.tensor.parallel.ddp import _localize_dtensor + + _localize_dtensor(module, ignored_params=ignored_params) + self._collect_params(module, ignored_modules, ignored_params) + + if "device_id" in kwargs: + # replicate() supports a small usability enhancement where + # user can pass in device_id as a Union[int, torch.device] even for + # CPU devices so users don't have to change code for CPU/GPU runs. + # We derive the right device_ids to feed into DDP to support this. + if kwargs["device_id"] is not None: + device_id = kwargs["device_id"] + # Convert to device_ids that DDP expects. + if isinstance(device_id, torch.device) and device_id.type == "cpu": + # CPU modules receive device_ids None + kwargs["device_ids"] = None + else: + # GPU modules expect device_ids=[cuda_device] + kwargs["device_ids"] = [device_id] + else: + kwargs["device_ids"] = None + kwargs.pop("device_id") + + self._ddp = DistributedDataParallel(self._param_list, **kwargs) + # Weakref to the DDP instance is currently only used for testing. + replicate.state(self.module)._ddp_weakref = weakref.ref(self._ddp) + + def register_comm_hook(self) -> None: + for comm_args, comm_kwargs in self._comm_hook_args: + self._ddp.register_comm_hook(*comm_args, **comm_kwargs) + self._comm_hook_args.clear() + + def record_init_args(self, *args, **kwargs) -> None: + self._init_args = args + self._init_kwargs = kwargs + + def forward_pre_hook( + self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> Any: + if self._init_args or self._init_kwargs: + self.lazy_init() + self._ddp.require_backward_grad_sync = not self._no_sync + return self._ddp._pre_forward(*args, **kwargs) + + def forward_post_hook( + self, + module: nn.Module, + input: tuple[torch.Tensor], + output: torch.Tensor, + ) -> torch.Tensor: + return self._ddp._post_forward(output) + + +def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn: + raise AssertionError( + "DDP does not support deepcopy. Please use state dict for serialization." + ) + + +# Follow the same pattern as FSDP/fully_shard +class DDP: + def __new__(cls, *args, **kwargs): + """ + Override ``__new__`` to remove the DDP class and directly construct + the original class for cases like indexing into a container module. + """ + # Use index 2 since 0 is the dynamically constructed `DDP<...>` class + # and index 1 is the `DDP` class itself + orig_cls = cls.__mro__[2] + return orig_cls.__new__(orig_cls, *args, **kwargs) + + def set_requires_gradient_sync(self, requires_gradient_sync: bool) -> None: + """ + Sets if the module should sync gradients. This can be used to implement + gradient accumulation without communication. + + Args: + requires_gradient_sync (bool): Whether to reduce gradients for the + module's parameters. + """ + replicate.state(self)._no_sync = not requires_gradient_sync # type: ignore[arg-type] + + def register_comm_hook(self, *args, **kwargs) -> None: + replicate.state(self)._comm_hook_args.append((args, kwargs)) # type: ignore[arg-type] + + +@contract(state_cls=_ReplicateState) +def replicate( + module: nn.Module, + ignored_modules: Iterable[torch.nn.Module] | None = None, + **kwargs, +) -> nn.Module: + r"""Replicates a module + + Args: + module (torch.nn.Module): module to replicate + + Example:: + >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d) + >>> module = nn.Linear(3, 3) + >>> replicate(module) + """ + torch._C._log_api_usage_once("torch.distributed.replicate") + + # TODO(fegin): using kwargs is not a good idea if we would like to make + # replicate a formal API to replace DDP. + if "device_id" in kwargs: + if not isinstance(kwargs["device_id"], (int, torch.device)): + raise RuntimeError( + "Expected device_id to be int or torch.device, " + f"but got {type(kwargs['device_id'])}" + ) + + if _is_fully_sharded(module): + raise RuntimeError( + "Cannot apply `replicate()` on a Module already managed by `fully_shard`" + ) + + if ignored_modules is None: + ignored_modules = {} + else: + ignored_modules = set(ignored_modules) + + state = replicate.state(module) + module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True) + device_mesh = kwargs.get("device_mesh") + if device_mesh is not None: + root_mesh = device_mesh._get_root_mesh() + # if a root mesh is not the same as device_mesh, + # meaning the device_mesh is sliced out from the root mesh. + if root_mesh != device_mesh: + # TODO: This is a temporary work around to enable DDP + TP. + # We should do the logic in DDP so that the 2D implementation is + # sound and the state_dict works out of the box. + # + # This won't conflict with what is done in DDP class as the module + # replicate is going to pass is NOT the original module. + from torch.distributed.tensor.parallel.ddp import ( + _localize_dtensor, + _reconstruct_dtensor, + ) + + module.register_forward_pre_hook(_reconstruct_dtensor) + module.register_forward_hook(_localize_dtensor) + + module.register_forward_hook(state.forward_post_hook) # type: ignore[arg-type] + + state.record_init_args(module, ignored_modules, **kwargs) + + # Place DDP leftmost for highest priority in the method resolution order + cls = module.__class__ + dct = {"__deepcopy__": unimplemented_deepcopy} + new_cls = type(f"DDP{cls.__name__}", (DDP, cls), dct) + module.__class__ = new_cls + return module + + +def _is_fully_sharded(module: nn.Module) -> bool: + r"""Check if module is marked with fully_shard.""" + registry = _get_registry(module) + if registry is None: + return False + return "fully_shard" in registry diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/replicate_with_fsdp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/replicate_with_fsdp.py new file mode 100644 index 0000000000000000000000000000000000000000..6c242323bcac82a55198f6f768ff5bd60c01595f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_composable/replicate_with_fsdp.py @@ -0,0 +1,408 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import logging +from typing import overload + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._composable_state import _get_module_state, _insert_module_state +from torch.distributed.device_mesh import _get_device_handle +from torch.distributed.fsdp._fully_shard._fsdp_api import ( + MixedPrecisionPolicy, + OffloadPolicy, +) +from torch.distributed.fsdp._fully_shard._fsdp_common import ( + DDPMeshInfo, + detect_compiled_autograd, +) +from torch.distributed.fsdp._fully_shard._fsdp_init import ( + _get_device_from_mesh, + _get_managed_states, + _init_default_fully_shard_mesh, + _move_states_to_device, +) +from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup +from torch.distributed.fsdp._fully_shard._fsdp_state import ( + _register_group_forward_hooks, + FSDPState, +) +from torch.distributed.fsdp._fully_shard._fully_shard import ( + _unimplemented_deepcopy, + FSDPModule, +) +from torch.distributed.tensor import DeviceMesh, init_device_mesh +from torch.distributed.utils import _get_root_modules + +from .contract import _get_registry, contract + + +cls_to_replicate_cls: dict[type, type] = {} + +_ROOT_MODULE_PREFIX = "" + +logger = logging.getLogger("torch.distributed._composable.replicate_with_fsdp") + + +class _ReplicateStateContext: + """This has state shared across Replicate states.""" + + def __init__(self) -> None: + # All Replicate states in the root state's module tree + self.all_states: list[_ReplicateState] = [] + # Iteration's forward root runs the once-per-forward logic; this root + # may not be the overall root set by lazy initialization in cases where + # only a submodule runs forward (e.g. encoder-only for eval) + self.iter_forward_root: _ReplicateState | None = None + # Final callback should only be queued once per backward + self.post_backward_final_callback_queued: bool = False + # Whether to finalize backward in this backward's final callback + self.is_last_backward: bool = True + # Optional user-provided event recorded after optimizer for the + # all-gather streams to wait on in the root pre-forward + self.post_optim_event: torch.Event | None = None + + +def _get_module_replicate_state(module: nn.Module) -> _ReplicateState | None: + """Checks if module state is ReplicateState""" + state = _get_module_state(module) + if isinstance(state, _ReplicateState): + return state + return None + + +class _ReplicateState(FSDPState): + """ + Replicate state functionality is adapted from FSDP state. + In the future, could experiment with inheriting from it instead. + """ + + def __init__(self) -> None: + super().__init__() + self._state_ctx = _ReplicateStateContext() # type: ignore[assignment] + + # Define a separate init since `__init__` is called in the contract + def init( + self, + modules: tuple[nn.Module, ...], + device: torch.device, + mp_policy: MixedPrecisionPolicy, + auto_reshard_after_forward: bool = False, + ) -> None: + for module in modules: + _insert_module_state(module, self) + self._modules = modules + # pyrefly: ignore [read-only] + self._device = device + self._device_handle = _get_device_handle(device.type) + self._mp_policy = mp_policy + self._auto_reshard_after_forward = auto_reshard_after_forward + if len(modules) == 1: + self._pre_forward_hook_handle = modules[0].register_forward_pre_hook( + self._pre_forward, prepend=True, with_kwargs=True + ) + self._post_forward_hook_handle = modules[0].register_forward_hook( + self._post_forward, prepend=False + ) + else: + hook_handle = _register_group_forward_hooks( + modules, + self._pre_forward, + self._post_forward, + self._modules_to_run_forward, + ) + self._pre_forward_hook_handle = hook_handle + self._post_forward_hook_handle = hook_handle + + def _lazy_init(self) -> None: + """ + Lazy initialization represents when all modules' parallelisms have + finalized (e.g. Replicate has been applied to all desired modules). This + means that we can determine which state is the root, and we do so by + the 1st state to run forward. + """ + if self._is_root is not None: + return # no-op: already initialized + self._is_root = True + if len(self._modules) > 1: + raise RuntimeError( + f"Replicate requires a single root module but got {self._modules}" + ) + detect_compiled_autograd() + root_module = self._modules[0] + visited_states: set[_ReplicateState] = set() + for module_name, module in root_module.named_modules(): + if (state := _get_module_replicate_state(module)) is None: + continue + if module is not root_module: + if state not in visited_states and state._is_root is not None: + raise RuntimeError( + "Replicate state has already been lazily initialized for " + f"{module_name}\nReplicate requires running forward through " + "the root module first" + ) + state._is_root = False + self._state_ctx.all_states.append(state) + # pyrefly: ignore [bad-argument-type] + visited_states.add(state) + if self._fsdp_param_group and self._auto_reshard_after_forward: + # For the root, do not reshard after forward since for training, + # the parameters would be freed and all-gathered immediately + self._fsdp_param_group.post_forward_mesh_info = None + self._init_fqns() + self._init_shared_state() + # Run parameter group lazy inits after initializing FQNs for improved + # error messages + for state in self._state_ctx.all_states: # type: ignore[assignment] + if state._fsdp_param_group: # type: ignore[union-attr] + state._fsdp_param_group.lazy_init() # type: ignore[union-attr] + + +def replicate_impl( + module, + mesh: DeviceMesh, + *, + device_id: int | torch.device | None = None, + mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(), + offload_policy: OffloadPolicy = OffloadPolicy(), + ignored_params: set[nn.Parameter] | None = None, +): + torch._C._log_api_usage_once("torch.distributed._composable.replicate_with_fsdp") + if isinstance(module, (nn.ModuleList, nn.ModuleDict)): + raise ValueError( + f"replicate does not support containers that do not implement forward: {module}" + ) + + mesh = mesh or _init_default_fully_shard_mesh() + if mesh.ndim != 1: + raise ValueError(f"replicate expects a 1D DeviceMesh but got {mesh}") + + else: + if mesh.mesh_dim_names is None: + raise AssertionError( + "Please init the 2D mesh for HSDP with mesh_dim_names specified" + ) + mesh_info = DDPMeshInfo(mesh, replicate_mesh_dim=0) + device = _get_device_from_mesh(mesh) + + post_forward_mesh_info = None + + arg_module = module + modules = ( + (module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module)) + ) + state = replicate.state(modules[0]) # type: ignore[attr-defined] # see [1] + state.init(modules, device, mp_policy) + + managed_modules = _get_managed_modules(modules, ignored_params) + params, buffers = _get_managed_states(managed_modules, ignored_params) + + _move_states_to_device(params, buffers, device) + if params: + state._fsdp_param_group = FSDPParamGroup( + params, + modules, + mesh_info, # type: ignore[arg-type] + post_forward_mesh_info, + device, + None, + mp_policy, + offload_policy, + ) + + # Place Replicate leftmost for highest priority in the method resolution order + for module in modules: + cls = module.__class__ + new_cls = cls_to_replicate_cls.get(cls) + if not new_cls: + dct = {"__deepcopy__": _unimplemented_deepcopy} + new_cls = type(f"Replicate{cls.__name__}", (ReplicateModule, cls), dct) + cls_to_replicate_cls[cls] = new_cls + module.__class__ = new_cls + return arg_module + + +@overload +# pyrefly: ignore [inconsistent-overload] +def replicate( + module: nn.Module, + *, + mesh: DeviceMesh | None = ..., + mp_policy: MixedPrecisionPolicy = ..., + offload_policy: OffloadPolicy = ..., + ignored_params: set[nn.Parameter] | None = ..., +) -> ReplicateModule: ... + + +@overload +# pyrefly: ignore [inconsistent-overload] +def replicate( + module: list[nn.Module], + *, + mesh: DeviceMesh | None = ..., + mp_policy: MixedPrecisionPolicy = ..., + offload_policy: OffloadPolicy = ..., + ignored_params: set[nn.Parameter] | None = ..., +) -> list[ReplicateModule]: ... + + +@contract(state_cls=_ReplicateState) # type: ignore[misc] +def replicate( + module: nn.Module, + *, + mesh: DeviceMesh | None = None, + mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(), + offload_policy: OffloadPolicy = OffloadPolicy(), + ignored_params: set[nn.Parameter] | None = None, +): + r"""Replicates a module + + Args: + module (torch.nn.Module): module to replicate + + Example:: + >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d) + >>> module = nn.Linear(3, 3) + >>> replicate(module) + """ + + if not is_composable_with_replicate(module): + raise RuntimeError( + "Cannot apply `replicate()` on a Module already managed by `fully_shard`" + ) + + if mesh is None: + mesh = replicate_mesh() + + return replicate_impl( + module, + mesh, + mp_policy=mp_policy, + offload_policy=offload_policy, + ignored_params=ignored_params, + ) + + +class ReplicateModule(FSDPModule): + def __new__(cls, *args, **kwargs): + """ + Override ``__new__`` to remove the FSDP class and directly construct + the original class for cases like indexing into a container module. + """ + # Use index 2 since 0 is the dynamically constructed `FSDP<...>` class + # and index 1 is the `FSDPModule` class itself + orig_cls = cls.__mro__[3] + self = orig_cls.__new__(orig_cls, *args, **kwargs) + self.__init__(*args, **kwargs) + return self + + +def _get_managed_modules( + root_modules: tuple[nn.Module, ...], + ignored_params: set[nn.Parameter] | None = None, +) -> list[nn.Module]: + modules: list[nn.Module] = [] + root_modules_set = set(root_modules) + # Track visisted modules to avoid visiting shared modules multiple times + visited_modules: set[nn.Module] = set() + + def dfs(module: nn.Module) -> None: + """ + Runs a DFS to collect managed modules, not recursing into modules with + a non-composable API or ``replicate`` already applied. + """ + if not is_composable_with_replicate(module): + return + elif ( + module not in root_modules_set + and _get_module_replicate_state(module) is not None + ): + return # nested `fully_shard` module + visited_modules.add(module) + for submodule in module.children(): + if submodule not in visited_modules: + dfs(submodule) + modules.append(module) + + for root_module in root_modules: + dfs(root_module) + + if ignored_params is None: + return modules + + adjusted_modules = _adjust_managed_modules(modules, ignored_params) + return adjusted_modules + + +def is_composable_with_replicate(module: nn.Module) -> bool: + """Checks if replicate can be applied with module""" + registry = _get_registry(module) + if registry is None: + return True + # Registry keys by function name + return "fully_shard" not in registry + + +def replicate_mesh(): + """Creates a device mesh for replicate if the user doesn't provide one""" + if not dist.distributed_c10d.is_initialized(): + dist.distributed_c10d.init_process_group() + default_pg = dist.distributed_c10d._get_default_group() + device = torch._C._get_accelerator() + mesh = init_device_mesh( + device.type, + mesh_shape=(default_pg.size(),), + mesh_dim_names=("replicate",), + ) + return mesh + + +def _adjust_managed_modules( + modules: list[nn.Module], ignored_params: set[nn.Parameter] +) -> list[nn.Module]: + """ + Adjust the given list of managed modules by removing those with all parameters ignored. + """ + ignore_decision: dict[nn.Module, bool] = {} + new_modules = [] + for module in modules: + ignored = _ignore_module(module, ignored_params, ignore_decision) + if not ignored: + new_modules.append(module) + return new_modules + + +def _ignore_module( + module: nn.Module, + ignored_params: set[nn.Parameter], + ignore_decision: dict[nn.Module, bool], +) -> bool: + """ + Decide if it is safe to ignore a module for applying replicate. + """ + if module in ignore_decision: + return ignore_decision[module] + + if len(list(module.buffers(recurse=False))) > 0: + # Cannot ignore a module with any buffer + ignore_decision[module] = False + return False + + for _, param in module.named_parameters(recurse=False): + if param not in ignored_params: + # at least one param is not ignored. So this module shouldn't be. + ignore_decision[module] = False + return False + + # Need to consider descendants of module + for child in list(module.children()): + ignore_child = _ignore_module(child, ignored_params, ignore_decision) + if not ignore_child: + # Cannot ignore module if one of its children is not ignored + ignore_decision[module] = False + return False + + # Safe to ignore module + ignore_decision[module] = True + return True diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_local_tensor/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_local_tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..632f28224193de697b7eb96608a262f58dd6363d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_local_tensor/__init__.py @@ -0,0 +1,1965 @@ +from ast import Call + +from torch._ops import OpOverload + + +""" +A LocalTensor is a tensor subclass which simulates a tensor that is +distributed across SPMD ranks. A LocalTensor might be size N, but in fact +there are world_size shards/replicas of it stored internally. When you do a +plain PyTorch operation on it, we apply the operation to each shard; when you +do a collective, we do the mathematically equivalent operation on the local +shards. A LocalTensor is associated with a list of ranks which specify +which ranks it holds local tensors for. + +NB, this is NOT a DataParallel like abstraction where you can run operations +on multiple different GPUs. It is intended purely for *debugging* purposes, +the overhead is almost certainly too high to keep eight GPUs (even the C++ +autograd needs multithreading to keep up!) (It might potentially be possible +to trace through this with torch.compile and then compile it with CUDA graphs +but this is currently a non-goal.) + +We do not directly handling MPMD. However in practice even in SPMD you may +encounter divergence in behavior per rank (for example, uneven sharding +across ranks). To support scenarios like this, we provide a helper decorator +that allows you to run a function with no side effects for each LocalTensor +shard and combine results back into LocalTensor or LocalIntNode. + +NB: This is a torch dispatch Tensor subclass, as we want to assume that autograd +is SPMD, so we run it once, and dispatch the inner autograd calls to the individual +local shards. + +NOTE ABOUT MESH: This subclass requires collectives that are issued to it to +respect a DeviceMesh like abstraction. The reason for this is that when +DTensor issues us a collective for a particular rank, you will be asked to do +this on a specific process group which involves some ranks. However, this +will only be for the LOCAL PG that this particular rank is participating in; +there will be a bunch of other PGs for other nodes that you don't get to see. +We need to be able to reverse engineer all of the collectives that don't +involve the current local rank here to actually issue them. This can be done +two ways: (1) looking at the participating local ranks in the PG and computing +the complement which specifies all the other collectives you have to run, or +(2) retrieving the device mesh axis corresponding to the PG for this rank, and +then running all the fibers for this. +""" + +import contextlib +import copy +import functools +import operator +import os +import sys +import threading +from collections import defaultdict +from collections.abc import Callable, Generator, Sequence +from types import TracebackType +from typing import Any, Optional, ParamSpec, TypeVar, Union + + +try: + import numpy as np + + HAS_NUMPY = True +except ModuleNotFoundError: + HAS_NUMPY = False + np = None # type: ignore[assignment] + +import torch +import torch.distributed as dist +from torch import Size, SymBool, SymInt, Tensor +from torch._C import DispatchKey, DispatchKeySet, ScriptObject +from torch._export.wrappers import mark_subclass_constructor_exportable_experimental +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.distributed import DeviceMesh, ProcessGroup +from torch.distributed._functional_collectives import AsyncCollectiveTensor +from torch.distributed.distributed_c10d import _get_default_group +from torch.fx.experimental._constant_symnode import ConstantIntNode +from torch.nested._internal.nested_int import NestedIntNode +from torch.utils import _pytree as pytree +from torch.utils._mode_utils import no_dispatch +from torch.utils._python_dispatch import ( + _get_current_dispatch_mode_stack, + return_and_correct_aliasing, + TorchDispatchMode, +) +from torch.utils.checkpoint import get_device_states, set_device_states + + +_R = TypeVar("_R") +_P = ParamSpec("_P") + +not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") + + +from . import _c10d + + +def _is_in_fake_tensor_mode() -> bool: + return any( + isinstance(mode, FakeTensorMode) for mode in _get_current_dispatch_mode_stack() + ) + + +def _reduce_multidim_lists( + lists_to_reduce: list[Any], reduce_func: Callable[[list[Any]], Any] +) -> Any: + """ + Reduces a list of multi-dimensional lists, assuming they all have + the exact same shape. + + Args: + lists_to_reduce (list): A list where each item is a multi-dimensional + list (e.g., [md_list_1, md_list_2, ...]). + All inner md_lists must have the same shape. + reduce_func (callable): A function that takes an iterable (list) of + values and returns a single reduced value. + For example: sum, max, min, or + lambda x: sum(x) / len(x) for mean. + + Returns: + A single multi-dimensional list of the same shape as the inputs, + where each value is the result of the reduce_func. + + Raises: + ValueError: If the input list is empty or if shapes are inconsistent + (which may also raise IndexError or TypeError). + """ + if not lists_to_reduce: + raise ValueError("Input 'lists_to_reduce' cannot be empty.") + + # Get the first list to inspect its structure (shape) + first_list = lists_to_reduce[0] + + # Check if the first element of this list is *also* a list. + # This determines if we are at the base case or need to recurse. + if isinstance(first_list[0], list): + # --- RECURSIVE STEP --- + # The elements are lists, so we need to go one level deeper. + + # We find the number of sub-lists from the first list. + # (e.g., for [[1,2], [3,4]], this is 2) + num_sublists = len(first_list) + + result = [] + # Iterate by the index of the sub-lists (e.g., i = 0, then i = 1) + for i in range(num_sublists): + # Build a new list to pass to the recursive call. + # This list will contain the i-th sublist from *each* of the + # input lists. + # e.g., if lists_to_reduce = [ L1, L2 ] and i = 0, + # this creates [ L1[0], L2[0] ] + sublists_to_reduce = [l[i] for l in lists_to_reduce] + + # Recurse and append the result + result.append(_reduce_multidim_lists(sublists_to_reduce, reduce_func)) + return result + else: + # --- BASE CASE --- + # The elements are values (int, float, etc.), not lists. + # We are at the innermost dimension. + + # Find the number of values in the innermost list. + # (e.g., for [1, 2], this is 2) + num_values = len(first_list) + + result = [] + # Iterate by the index of the values (e.g., i = 0, then i = 1) + for i in range(num_values): + # Get the values at this specific position (i) from *all* + # input lists. + # e.g., if lists_to_reduce = [ [1,2], [10,20] ] and i = 0, + # this creates [ 1, 10 ] + values_at_pos = [l[i] for l in lists_to_reduce] + + # Apply the user-provided reduction function to this list of values + # and append the single result. + result.append(reduce_func(values_at_pos)) + return result + + +def _is_inplace_op(op: OpOverload | Callable[..., Any]) -> bool: + return ( + isinstance(op, OpOverload) + # Not precise heuristic to detect inplace operation + and op._schema.name[-1] == "_" + # Strengthen the heuristic to check that the first argument and return value are a write + and len(op._schema.arguments) > 0 + and op._schema.arguments[0].is_write + and len(op._schema.returns) > 0 + and op._schema.returns[0].is_write + ) + + +def _int_on_rank(i: "int | LocalIntNode | ConstantIntNode", r: int) -> int: + if isinstance(i, LocalIntNode): + return i._local_ints[r] + elif isinstance(i, ConstantIntNode): + return i.val + elif isinstance(i, int): + return i + else: + raise AssertionError(type(i)) + + +def _check_for_subclass(flat_args: Sequence[object]) -> bool: + return any(_check_for_subclass_arg(x) for x in flat_args) + + +def _check_for_subclass_arg(x: object) -> bool: + return ( + not isinstance(x, LocalTensor) + and isinstance(x, Tensor) + and type(x) + not in ( + Tensor, + FakeTensor, + torch.nn.Parameter, + torch.nn.Buffer, + ) + ) + + +def _map_to_rank_local_val(val: Any, rank: int) -> Any: + if isinstance(val, LocalTensor): + return val._local_tensors[rank] + if isinstance(val, SymInt): + if isinstance(val.node, LocalIntNode): + return val.node._local_ints[rank] + if isinstance(val.node, ConstantIntNode): + return val.node.val + return val + + +def _collect_accelerator_rng_states() -> dict[int, torch.Tensor]: + """ + Collects RNG state from all available acceleator devices. + + Returns: + List of RNG state tensors, one for each accelerator device. + Returns empty list if accelerator is not available. + """ + if not torch.accelerator.is_available(): + return {} + + if torch.accelerator.is_available(): + device_idx = torch.accelerator.current_device_index() + with torch.accelerator.device_index(device_idx): + return {device_idx: torch.get_device_module().get_rng_state()} + + return {} + + +def _set_accelerator_rng_states(rng_states: dict[int, torch.Tensor]) -> None: + """ + Sets RNG state for all accelerator devices from a list of states. + + Args: + rng_states: List of RNG state tensors to restore. + """ + if not torch.accelerator.is_available(): + return + + if torch.accelerator.is_available(): + for device_idx, device_rng_state in rng_states.items(): + with torch.accelerator.device_index(device_idx): + torch.get_device_module().set_rng_state(device_rng_state) + + +def _get_rng_state() -> tuple[torch.Tensor, dict[int, torch.Tensor]]: + """ + Gets CPU and accelerator (e.g., CUDA, XPU device) rng states from all devices. + """ + return (torch.get_rng_state(), _collect_accelerator_rng_states()) + + +def _set_rng_state( + cpu_state: torch.Tensor, accelerator_states: dict[int, torch.Tensor] +) -> None: + """ + Sets CPU and accelerator (e.g., CUDA, XPU device) rng states for all devices. If + the list of accelerator states is shorter than the number of devices only the + first len(accelerator_states) devices will get their rng state set. + """ + torch.set_rng_state(cpu_state) + _set_accelerator_rng_states(accelerator_states) + + +def _combine_int_rank_results(rank_results: dict[int, int]) -> int | torch.SymInt: + any_v = next(iter(rank_results.values())) + + if all(v == any_v for v in rank_results.values()): + return any_v + + return torch.SymInt(LocalIntNode(rank_results)) + + +def _combine_any_rank_results(rank_results: dict[int, Any]) -> Any: + any_v = next(iter(rank_results.values())) + + if isinstance(any_v, Tensor): + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return LocalTensor(rank_results) + + if isinstance(any_v, int): + return _combine_int_rank_results(rank_results) + + if isinstance(any_v, torch.device): + assert all(v.type == any_v.type for v in rank_results.values()), ( + "device type should be the same" + ) + # Just use the first device - the device type is what matters, + # and LocalTensorMode runs on a single physical device anyway + return any_v + + assert all(v == any_v for v in rank_results.values()), ( + "Non Tensor or int rank results must be equal for all ranks" + ) + + return any_v + + +def _combine_rank_results(rank_results: dict[int, Any], default: Any | None) -> Any: + rank_ids = rank_results.keys() + rank_value = rank_results[next(iter(rank_ids))] + + if isinstance(rank_value, (list, tuple)): + max_rank_result_len = max(len(v) for v in rank_results.values()) + ret_list = [] + for i in range(max_rank_result_len): + rank_col_results = { + r: v[i] if i < len(v) else default for r, v in rank_results.items() + } + ret_list.append(_combine_any_rank_results(rank_col_results)) + return type(rank_value)(ret_list) + else: + return _combine_any_rank_results(rank_results) + + +def _zero_sized_like(tensor: torch.Tensor, dim: int) -> torch.Tensor: + tensor_size = list(tensor.size()) + tensor_size[dim] = 0 + empty_tensor = torch.empty(*tensor_size, dtype=tensor.dtype, device=tensor.device) + return empty_tensor + + +def _for_each_rank_run_func( + func: OpOverload | Callable[..., Any], + ranks: frozenset[int], + args: Sequence[Any], + kwargs: dict[str, Any], + *, + alias: bool = True, +) -> Any: + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + flat_args = [ + a.wait() if isinstance(a, AsyncCollectiveTensor) else a for a in flat_args + ] + + lm = enabled_local_tensor_mode() + use_per_rank_rng = lm is not None and len(lm._per_rank_rng_states) > 0 + + global_rng_state = None if use_per_rank_rng else _get_rng_state() + + flat_rank_rets = {} + + default_value: Tensor | None = None + for r in sorted(ranks): + if use_per_rank_rng: + assert lm is not None + if r in lm._per_rank_rng_states: + _set_rng_state(*lm._per_rank_rng_states[r]) + else: + assert global_rng_state is not None + _set_rng_state(*global_rng_state) + + rank_flat_args = [_map_to_rank_local_val(a, r) for a in flat_args] + rank_args, rank_kwargs = pytree.tree_unflatten(rank_flat_args, args_spec) + if func is torch.ops.aten.hash_tensor.default and rank_args[0].numel() == 0: + # Special case for empty tensors, hash_tensor returns an empty tensor + rank_ret = torch.empty(0, dtype=torch.uint64, device=rank_args[0].device) + else: + rank_ret = func(*rank_args, **rank_kwargs) + flat_rank_rets[r] = rank_ret + + if use_per_rank_rng: + assert lm is not None + lm._per_rank_rng_states[r] = _get_rng_state() + + if default_value is None and func is torch.ops.aten.split.Tensor: + # If split happens over the dimension smaller than the number of chunks + # it is possible that some ranks will produce shorter lists of chunks. + # In order to make the result across all ranks of the same length we + # append empty tensors (zero size on the split dimension). + tensor = rank_flat_args[0] + split_dim = 0 if len(rank_flat_args) < 3 else rank_flat_args[2] + default_value = _zero_sized_like(tensor, split_dim) + + if _is_inplace_op(func): + alias = False + # For the in-place ops return self + ret = args[0] + if isinstance(func, OpOverload) and torch.Tag.inplace_view in func.tags: + # Ensure that wrapper tensor size is synchronized with its local tensors + ret._sync_meta() + else: + ret = _combine_rank_results(flat_rank_rets, default_value) + + if alias: + return return_and_correct_aliasing(func, args, kwargs, ret) + else: + return ret + + +def _get_extra_dispatch_keys(t: torch.Tensor) -> DispatchKeySet: + extra_dispatch_keys = torch._C.DispatchKeySet.from_raw_repr(0) + if torch._C._dispatch_keys(t).has(torch._C.DispatchKey.Conjugate): + extra_dispatch_keys = extra_dispatch_keys.add(torch._C.DispatchKey.Conjugate) + if torch._C._dispatch_keys(t).has(torch._C.DispatchKey.Negative): + extra_dispatch_keys = extra_dispatch_keys.add(torch._C.DispatchKey.Negative) + return extra_dispatch_keys + + +class LocalIntNode: + """ + Like a LocalTensor, but for an int. We can't use a 0D tensor to represent this + because often only a SymInt is accepted where we wish to use this. + """ + + def __new__(cls, local_ints: dict[int, int]) -> "ConstantIntNode | LocalIntNode": # type: ignore[misc] + if len(set(local_ints.values())) == 1: + return ConstantIntNode(next(iter(local_ints.values()))) + return super().__new__(cls) + + def __init__(self, local_ints: dict[int, int]): + self._local_ints = local_ints + + def maybe_as_int(self) -> int | None: + return None + + def is_int(self) -> bool: + return True + + def is_float(self) -> bool: + return False + + def is_bool(self) -> bool: + return False + + def is_nested_int(self) -> bool: + return False + + def clone(self) -> "LocalIntNode": + return self + + def _str(self) -> str: + return f"LocalIntNode({self._local_ints})" + + def __str__(self) -> str: + return self._str() + + def __repr__(self) -> str: + return self._str() + + def _graph_repr(self) -> str: + return self._str() + + def is_symbolic(self) -> bool: + return False + + def is_constant(self) -> bool: + return False + + def sym_max( + self, other: "int | LocalIntNode | ConstantIntNode" + ) -> "LocalIntNode | ConstantIntNode": + return LocalIntNode( + { + r: max(self._local_ints[r], _int_on_rank(other, r)) + for r in self._local_ints + } + ) + + def sym_sum(self, other: Any) -> "LocalIntNode | ConstantIntNode": + t = LocalIntNode(dict.fromkeys(self._local_ints, 0)) + for o in other: + t = t.add(o) + return t + + def neg(self) -> "LocalIntNode | ConstantIntNode": + return LocalIntNode({r: -self._local_ints[r] for r in self._local_ints}) + + def add( + self, other: "int | LocalIntNode | ConstantIntNode" + ) -> "LocalIntNode | ConstantIntNode": + return LocalIntNode( + {r: self._local_ints[r] + _int_on_rank(other, r) for r in self._local_ints} + ) + + def sub( + self, other: "int | LocalIntNode | ConstantIntNode" + ) -> "LocalIntNode | ConstantIntNode": + return LocalIntNode( + {r: self._local_ints[r] - _int_on_rank(other, r) for r in self._local_ints} + ) + + def mul( + self, other: "int | LocalIntNode | ConstantIntNode" + ) -> "LocalIntNode | ConstantIntNode": + return LocalIntNode( + {r: self._local_ints[r] * _int_on_rank(other, r) for r in self._local_ints} + ) + + def floordiv( + self, other: "int | LocalIntNode | ConstantIntNode" + ) -> "LocalIntNode | ConstantIntNode": + return LocalIntNode( + {r: self._local_ints[r] // _int_on_rank(other, r) for r in self._local_ints} + ) + + def mod( + self, other: "int | LocalIntNode | ConstantIntNode" + ) -> "LocalIntNode | ConstantIntNode": + return LocalIntNode( + {r: self._local_ints[r] % _int_on_rank(other, r) for r in self._local_ints} + ) + + def int_floordiv( + self, other: "int | LocalIntNode | ConstantIntNode" + ) -> "LocalIntNode | ConstantIntNode": + return LocalIntNode( + {r: self._local_ints[r] // _int_on_rank(other, r) for r in self._local_ints} + ) + + def eq(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool: + r = {self._local_ints[r] == _int_on_rank(other, r) for r in self._local_ints} + return torch._C._get_constant_bool_symnode(len(r) == 1 and next(iter(r))) + + def ne(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool: + r = {self._local_ints[r] != _int_on_rank(other, r) for r in self._local_ints} + return torch._C._get_constant_bool_symnode(len(r) > 1 or next(iter(r))) + + def ge(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool: + r = {self._local_ints[r] >= _int_on_rank(other, r) for r in self._local_ints} + assert len(r) == 1, (self, other) + return torch._C._get_constant_bool_symnode(next(iter(r))) + + def gt(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool: + r = {self._local_ints[r] > _int_on_rank(other, r) for r in self._local_ints} + assert len(r) == 1, (self, other) + return torch._C._get_constant_bool_symnode(next(iter(r))) + + def lt(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool: + r = {self._local_ints[r] < _int_on_rank(other, r) for r in self._local_ints} + assert len(r) == 1, (self, other) + return torch._C._get_constant_bool_symnode(next(iter(r))) + + def wrap_int(self, num: int) -> "LocalIntNode | ConstantIntNode": + return ConstantIntNode(num) + + +class _LocalDeviceHandle: + """ + Wrapper around device module (e.g., torch.cuda) with automatic LocalTensor semantics. + + This class wraps device modules and automatically handles per-rank operations in + LocalTensor mode: + - get_rng_state() returns a LocalTensor with per-rank states + - set_rng_state(LocalTensor) sets per-rank states + + When not in LocalTensor mode, it delegates directly to the underlying device handle. + """ + + def __init__(self, device_handle, device_type: str): + """ + Initialize the local device handle wrapper. + + Args: + device_handle: The underlying device module (e.g., torch.cuda) + device_type: Device type string (e.g., "cuda", "cpu") + """ + self._device_handle = device_handle + self._device_type = device_type + + def get_rng_state(self): + """ + Get RNG state, automatically returning LocalTensor in LocalTensor mode. + + Returns: + LocalTensor in LocalTensor mode, regular Tensor otherwise + """ + lm = enabled_local_tensor_mode() + if not lm: + return self._device_handle.get_rng_state() + + original_state = _get_rng_state() + per_rank_states = {} + + try: + for rank in lm.ranks: + # We need to set-then-get instead of directly copying lm._per_rank_rng_states[rank] + # because they have different structures: + # - lm._per_rank_rng_states[rank] is a tuple: (cpu_state, {device_idx: cuda_state}) + # - self._device_handle.get_rng_state() returns just the device-specific tensor + # So we temporarily restore the full RNG state (CPU + all CUDA devices) for this rank, + # then extract only the specific device's state tensor that we need. + if rank in lm._per_rank_rng_states: + _set_rng_state(*lm._per_rank_rng_states[rank]) + + per_rank_states[rank] = self._device_handle.get_rng_state() + finally: + _set_rng_state(*original_state) + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return LocalTensor(per_rank_states) + + def set_rng_state(self, state): + """ + Set RNG state, automatically handling LocalTensor input. + + Args: + state: Regular Tensor or LocalTensor with per-rank states + """ + if isinstance(state, LocalTensor): + lm = enabled_local_tensor_mode() + assert lm is not None + + # Similar to get_rng_state but in reverse: we need to convert from + # device-specific tensor format to full state tuple format. + # - state._local_tensors[rank] contains just the device-specific RNG state tensor + # - lm._per_rank_rng_states[rank] needs a tuple: (cpu_state, {device_idx: cuda_state}) + # So we set the device's state with the rank-specific tensor, then _get_rng_state() + # captures both CPU and CUDA states into the tuple format that _per_rank_rng_states expects. + for rank, rank_state in state._local_tensors.items(): + self._device_handle.set_rng_state(rank_state.to("cpu")) + lm._per_rank_rng_states[rank] = _get_rng_state() + else: + self._device_handle.set_rng_state(state.to("cpu")) + + def __getattr__(self, name): + """Delegate all other attributes to the underlying device module.""" + return getattr(self._device_handle, name) + + +class _LocalOffsetBasedRNGTracker: + """ + LocalTensor-specific RNG tracker for DTensor random operations. + + This class manages per-rank RNG states when running in LocalTensor mode, + using _LocalPhiloxState to track different offsets for each virtual rank. + It is instantiated and used by OffsetBasedRNGTracker when in LocalTensor mode. + + Much of this is derived from OffsetBasedRNGTracker: + https://github.com/pytorch/pytorch/blob/402c46503002f98ccfc023a733081fb0719223a1/torch/distributed/tensor/_random.py#L182 + """ + + def __init__(self, device_type: str = "cuda"): + """Initialize the LocalTensor RNG tracker.""" + from torch.distributed.device_mesh import _get_device_handle + + self._device_type = device_type + self._device_handle = _LocalDeviceHandle( + _get_device_handle(device_type), device_type + ) + self.distribute_region_enabled = True + self._device_mesh = None + + @property + def _device(self): + return torch.device(self._device_type, torch.cuda.current_device()) + + def _set_pre_op_offset(self, state, spec) -> None: + """Compute and set per-rank offsets before the random operation.""" + from torch.distributed.tensor._ops.utils import prod + from torch.distributed.tensor._utils import ( + _compute_local_shape_and_global_offset, + ) + from torch.distributed.tensor.placement_types import Shard + + lm = enabled_local_tensor_mode() + assert lm is not None + + state._per_rank_offsets = {} + + for rank in lm.ranks: + # compute this rank's coordinate in the mesh + mesh_coords = [] + for mesh_dim_idx in range(spec.mesh.ndim): + mesh_dim_size = spec.mesh.size(mesh_dim_idx) + # calculate rank's coordinate in this mesh dimension + num_chunks_after = 1 + for j in range(mesh_dim_idx + 1, spec.mesh.ndim): + num_chunks_after *= spec.mesh.size(j) + coord = (rank // num_chunks_after) % mesh_dim_size + mesh_coords.append(coord) + + # compute shard offset based on placements + from torch.distributed.tensor._random import ( + _calc_first_shard_size, + _calc_shard_info, + _calc_shard_linear_idx, + ) + + # Compute shard index and total number of shards on each tensor dim + shard_idx_by_dim, total_num_shards_by_dim = _calc_shard_info( + mesh_coords, spec + ) + + # compute shard linear index + shard_linear_idx = _calc_shard_linear_idx( + shard_idx_by_dim, total_num_shards_by_dim + ) + + # get current offset for this rank + current_offset = int( + state._per_rank_states[rank][8:].view(dtype=torch.int64).item() + ) + + local_shape = _calc_first_shard_size(spec) + # compute local size + local_size = prod(local_shape) + + # compute new offset (must be multiple of 4) + offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4 + state._per_rank_offsets[rank] = current_offset + offset_incr + + def _set_post_op_offset(self, state, spec, old_offset) -> None: + """Set per-rank offsets after the random operation.""" + from torch.distributed.tensor._ops.utils import prod + + lm = enabled_local_tensor_mode() + assert lm is not None + + dtensor_shape = spec.shape + numel = prod(dtensor_shape) + # offset must be multiple of 4 + numel = (numel + 3) // 4 * 4 + + if not hasattr(state, "_per_rank_offsets"): + state._per_rank_offsets = {} + + # handle LocalIntNode old_offset (different values per rank) + if isinstance(old_offset, SymInt) and isinstance(old_offset.node, LocalIntNode): + for rank in lm.ranks: + rank_old_offset = old_offset.node._local_ints[rank] + state._per_rank_offsets[rank] = rank_old_offset + numel + else: + # same old_offset for all ranks + old_offset_int = ( + int(old_offset) if isinstance(old_offset, SymInt) else old_offset + ) + for rank in lm.ranks: + state._per_rank_offsets[rank] = old_offset_int + numel + + @contextlib.contextmanager + def _distribute_region(self, spec, generator=None): + """Context manager for LocalTensor mode distribute region.""" + lm = enabled_local_tensor_mode() + assert lm is not None + + # get base state + if generator is not None: + base_state_tensor = generator.get_state() + per_rank_states = {rank: base_state_tensor.clone() for rank in lm.ranks} + # pyrefly: ignore [bad-argument-type, bad-argument-count] + base_state_tensor = LocalTensor(per_rank_states) + else: + base_state_tensor = self._device_handle.get_rng_state() + + state = _LocalPhiloxState(base_state_tensor) + + if self.distribute_region_enabled: + # sync to rank 0's state if no explicit generator + if generator is None: + any_rank_state = lm._any_local_rng_state() + any_rank_cpu, any_rank_cuda = any_rank_state + + if self._device.type == "cuda": + assert self._device.index in any_rank_cuda + any_rank_device_state = any_rank_cuda[self._device.index] + else: + any_rank_device_state = any_rank_cpu + + from torch.distributed.tensor._random import _PhiloxState + + any_rank_philox = _PhiloxState(any_rank_device_state) + state.seed = any_rank_philox.seed + state.offset = any_rank_philox.offset + + old_offset = state.offset + self._set_pre_op_offset(state, spec) + state.apply_to_local_tensor_mode(self._device_handle) + + try: + yield + finally: + self._set_post_op_offset(state, spec, old_offset) + state.apply_to_local_tensor_mode(self._device_handle) + else: + yield + + # maybe reset generator to rank 0's state + if generator is not None: + rank_0_state = state._per_rank_states[0] + generator.set_state(rank_0_state) + + +_LOCAL_TENSOR_ATTR_PREFIX = "_local_tensor_" + + +def _is_local_tensor_attr(attr: str) -> bool: + return attr.startswith(_LOCAL_TENSOR_ATTR_PREFIX) + + +def _to_local_tensor_attr(rank: int) -> str: + return f"{_LOCAL_TENSOR_ATTR_PREFIX}{rank}" + + +def _from_local_tensor_attr(attr: str) -> int: + if not _is_local_tensor_attr(attr): + raise AssertionError(f"Invalid local tensor attr {attr}") + return int(attr[len(_LOCAL_TENSOR_ATTR_PREFIX) :]) + + +def _all_elements_same(values: list[Any]) -> bool: + if not values: + return True + first_value = values[0] + return all(value == first_value for value in values) + + +def _compute_local_tensor_meta( + local_tensors: dict[int, torch.Tensor], +) -> tuple[ + list[torch.SymInt | int], + list[torch.SymInt | int], + torch.device, + torch.dtype, + torch.layout, + DispatchKeySet, +]: + """ + Computes the meta information for a LocalTensor from its local tensors. + """ + it = iter(local_tensors.values()) + first_local_tensor = next(it) + + first_shape = first_local_tensor.shape + first_stride = first_local_tensor.stride() + dtype = first_local_tensor.dtype + device = first_local_tensor.device + layout = first_local_tensor.layout + + extra_dispatch_keys = _get_extra_dispatch_keys(first_local_tensor) + + # Assert that all tensors have the same dtype, layout and dispatch keys. Due + # to uneven sharding, it is possible that tensors will have different shapes. + for local_tensor in it: + assert dtype == local_tensor.dtype, ( + "Tensors representing LocalTensor shards must have the same dtype" + ) + assert layout == local_tensor.layout, ( + "Tensors representing LocalTensor shards must have the same layout" + ) + assert extra_dispatch_keys == _get_extra_dispatch_keys(local_tensor), ( + "Tensors representing LocalTensor shards must have the same set of extra dispatch keys" + ) + + # Compute shape/stride. We allow for non-SPMD'ness here + local_shapes: dict[int, dict[int, int]] = defaultdict(dict) # dim => rank => size + local_strides: dict[int, dict[int, int]] = defaultdict(dict) # dim => rank => size + for r, local_tensor in local_tensors.items(): + for d, size in enumerate(local_tensor.shape): + local_shapes[d][r] = size + local_strides[d][r] = local_tensor.stride(d) + shape = [ + ( + first_shape[d] + if _all_elements_same(list(local_shapes[d].values())) + else torch.SymInt(LocalIntNode(local_shapes[d])) + ) + for d in range(len(first_shape)) + ] + strides = [ + ( + first_stride[d] + if _all_elements_same(list(local_strides[d].values())) + else torch.SymInt(LocalIntNode(local_strides[d])) + ) + for d in range(len(first_shape)) + ] + return shape, strides, device, dtype, layout, extra_dispatch_keys + + +class LocalTensor(torch.Tensor): + """ + LocalTensor is a Tensor subclass that simulates a tensor distributed across multiple SPMD + (Single Program, Multiple Data) ranks. Each LocalTensor instance internally holds a mapping from + global rank ids to their corresponding local Tensor shards.Operations performed on a LocalTensor + are applied independently to each local shard, mimicking distributed computation. Collectives + and other distributed operations are handled by mapping them to the local shards as appropriate. + + Note: + This class is primarily intended for debugging and simulating distributed tensor computations + on a single process. + + """ + + # Map from global rank to the local tensor. + _local_tensors: dict[int, torch.Tensor] + # Precomputed for speed set of keys from the local tensor map. + _ranks: frozenset[int] + _size: list[torch.SymInt | int] + __slots__ = ["_local_tensors", "_ranks", "_size"] + + @staticmethod + @torch._disable_dynamo + def __new__( + cls, + local_tensors: dict[int, torch.Tensor], + requires_grad: bool = False, + ) -> "LocalTensor": + if any(t.requires_grad for t in local_tensors.values()): + raise AssertionError( + "Internal local_tensors require grad, but we will ignore those autograd graph. " + "Make a custom autograd function and make sure you detach the inner tensors." + ) + + (shape, strides, device, dtype, layout, extra_dispatch_keys) = ( + _compute_local_tensor_meta(local_tensors) + ) + + r = torch.Tensor._make_wrapper_subclass( + cls, + shape, + strides=strides, + dtype=dtype, + device=device, + layout=layout, + # In place ops potentially change local tensor sizes (e.g. resize_). While + # executing an in-place op the return value must be the same as "self" input + # otherwise we can introduce errors due to tensor identity changes. Hence we + # need to be able to update wrapper subclass sizes after in-place ops. This + # dispatch policy allows us to do that. + dispatch_sizes_strides_policy="sizes", + requires_grad=requires_grad, + _extra_dispatch_keys=extra_dispatch_keys, + ) + + local_tensors = { + r: v if not isinstance(v, AsyncCollectiveTensor) else v.wait() + for r, v in local_tensors.items() + } + r._local_tensors = local_tensors + r._ranks = frozenset(local_tensors.keys()) + r._size = shape + return r + + @torch._disable_dynamo + @mark_subclass_constructor_exportable_experimental # type: ignore[misc] + def __init__(self, *args: Any, **kwargs: Any): + super().__init__() + + def __deepcopy__(self, memo: dict[Any, Any] | None) -> "LocalTensor": + local_tensors_copy = { + r: copy.deepcopy(t, memo) for r, t in self._local_tensors.items() + } + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return LocalTensor(local_tensors_copy, self.requires_grad) + + def __repr__(self) -> str: # type: ignore[override] + parts = [] + for k, v in self._local_tensors.items(): + # pyrefly: ignore [bad-argument-type] + parts.append(f" {k}: {v}") + tensors_str = ",\n".join(parts) + return f"LocalTensor(\n{tensors_str}\n)" + + def __getattr__(self, name: str) -> Any: + if _is_local_tensor_attr(name): + rank = _from_local_tensor_attr(name) + if rank not in self._ranks: + raise AttributeError(f"Local tensor has no knowledge of rank {rank}") + return self._local_tensors[rank] + return object.__getattribute__(self, name) + + def __tensor_flatten__(self) -> tuple[list[str], tuple[Any, ...]]: + """ + protocol to inform how to flatten a DTensor to local tensor + for PT2 tracing + """ + local_tensor_attrs = [_to_local_tensor_attr(r) for r in self._ranks] + return local_tensor_attrs, () + + @staticmethod + def __tensor_unflatten__( + inner_tensors: dict[str, Any], + flatten_spec: tuple[Any, ...], + outer_size: torch.Size, + outer_stride: tuple[int, ...], + ) -> "LocalTensor": + assert flatten_spec is not None, ( + "Expecting spec to be not None from `__tensor_flatten__` return value!" + ) + local_tensors = { + _from_local_tensor_attr(a): t for a, t in inner_tensors.items() + } + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return LocalTensor(local_tensors) + + @classmethod + @torch._disable_dynamo + def __torch_dispatch__( # type: ignore[override] + cls, + func: Any, + types: tuple[Any, ...], + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] | None = None, + ) -> Any: + if kwargs is None: + kwargs = {} + + # This is horribly inefficient + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + local_tensor = None + for arg in flat_args: + if isinstance(arg, LocalTensor): + local_tensor = arg + break + + assert local_tensor is not None, ( + "At least one of the arguments must be a LocalTensor" + ) + + # Check for unrecognized tensor subclasses (but allow regular tensors and scalars) + has_unrecognized_types = _check_for_subclass(flat_args) + if has_unrecognized_types: + unrecognized_types = [ + type(x) for x in flat_args if _check_for_subclass_arg(x) + ] + not_implemented_log.debug( + "LocalTensor unrecognized subclass(es): %s", unrecognized_types + ) + return NotImplemented + + with LocalTensorMode(local_tensor._ranks): + return func(*args, **kwargs) + + def numpy(self, *, force: bool = False) -> Any: + if HAS_NUMPY: + return self.reconcile().numpy(force=force) + else: + raise RuntimeError("Numpy is not available") + + def contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> torch.Tensor: + # pyrefly: ignore [bad-argument-type] + return LocalTensor( + # pyrefly: ignore [bad-argument-count] + { + r: t.contiguous(memory_format=memory_format) + for r, t in self._local_tensors.items() + } + ) + + def is_contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> bool: + return all( + t.is_contiguous(memory_format=memory_format) + for t in self._local_tensors.values() + ) + + def tolist(self) -> list[Any]: + """ + Try to reconcile, if successful convert to list, otherwise if dtype is integer, + convert to list of local integers. + """ + equal_obj = self._equal_local_tensors() + if isinstance(equal_obj, torch.Tensor): + return equal_obj.tolist() + if isinstance(equal_obj, torch.Size): + if not self.dtype.is_floating_point and not self.dtype.is_complex: + ranks = sorted(self._ranks) + local_lists = [self._local_tensors[r].tolist() for r in ranks] + return _reduce_multidim_lists( + local_lists, + lambda values: torch.SymInt( + LocalIntNode(dict(zip(ranks, values, strict=True))) + ), + ) + + raise RuntimeError("Cannot convert local tensor to list") + + def reconcile(self) -> torch.Tensor: + """ + Reconciles the LocalTensor into a single torch.Tensor by ensuring all local + shards are identical and returning a detached clone of one of them. + + Note: + This method is useful for extracting a representative tensor from a LocalTensor + when all shards are expected to be the same, such as after a collective operation + that synchronizes all ranks. + """ + + # Force all local tensor shards across ranks to be the same + equal_obj = self._equal_local_tensors() + assert isinstance(equal_obj, torch.Tensor), ( + "LocalTensor shards must be the same to reconcile" + ) + cl = equal_obj.clone().detach() + cl.requires_grad_(self.requires_grad) + return cl + + def _equal_local_tensors(self) -> torch.Tensor | torch.Size | None: + it = iter(self._local_tensors.values()) + t1 = next(it) + if all(t2.equal(t1) for t2 in it): + return t1 + if all(t2.shape == t1.shape for t2 in it): + return t1.shape + return None + + def _sync_meta(self) -> None: + with no_dispatch(): + (shape, strides, device, dtype, layout, extra_dispatch_keys) = ( + _compute_local_tensor_meta(self._local_tensors) + ) + self._size = shape + + +# If set to `True` the LocalTensorMode stack will be created for the whole process, +# otherwise it will be created for each thread. +_PROCESS_MODE: bool = True +_PROCESS_LOCAL_TENSOR_MODE: list["LocalTensorMode"] = [] +# When running under local runner each thread must create its own local tensor mode +# so that they do not interfere with each other. +_THREAD_LOCAL_TENSOR_MODE: threading.local = threading.local() + + +def get_local_tensor_mode_list() -> list["LocalTensorMode"]: + global _PROCESS_MODE + if _PROCESS_MODE: + global _PROCESS_LOCAL_TENSOR_MODE + return _PROCESS_LOCAL_TENSOR_MODE + global _THREAD_LOCAL_TENSOR_MODE + if not hasattr(_THREAD_LOCAL_TENSOR_MODE, "value"): + _THREAD_LOCAL_TENSOR_MODE.value = [] + return _THREAD_LOCAL_TENSOR_MODE.value + + +class LocalTensorMode(TorchDispatchMode): + """ + A TorchDispatchMode that simulates SPMD (Single Program, Multiple Data) execution + for LocalTensor objects across a set of ranks. + + LocalTensorMode enables PyTorch operations to be transparently applied to each + local shard of a LocalTensor, as if they were distributed across multiple ranks. + When active, this mode intercepts tensor operations and dispatches them to each + rank's local tensor, collecting and wrapping the results as LocalTensors. It also + handles collective operations by mapping them to local implementations. + + This mode is primarily intended for debugging and simulating distributed tensor + computations on a single process, rather than for high-performance distributed + training. It maintains a stack of active modes, patches DeviceMesh coordinate + resolution, and provides utilities for temporarily disabling the mode or mapping + functions over ranks. + """ + + # What ranks this local tensor mode is operating over + def __init__(self, ranks: int | frozenset[int]): + if isinstance(ranks, int): + # assume is world size + self.ranks = frozenset(range(ranks)) + else: + assert isinstance(ranks, frozenset) + self.ranks = ranks + self._disable = True + self._old_get_coordinate = None + self._old_get_rank = None + self._old_get_local_rank = None + self._old_torch_manual_seed: Any = None + self._old_torch_initial_seed: Any = None + self._per_rank_rng_states: dict[ + int, tuple[torch.Tensor, dict[int, torch.Tensor]] + ] = {} + + self.enable_() + + def __enter__(self) -> "LocalTensorMode": + self.enable_() + get_local_tensor_mode_list().append(self) + + # _distribute_region will compute correct per-shard offsets + # but we want all ranks to start with the same state + if not _is_in_fake_tensor_mode(): + cpu_state, cuda_states = _get_rng_state() + for rank in self.ranks: + self._per_rank_rng_states[rank] = ( + cpu_state.clone(), + {idx: state.clone() for idx, state in cuda_states.items()}, + ) + + return super().__enter__() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.disable_() + get_local_tensor_mode_list().pop() + super().__exit__(exc_type, exc_val, exc_tb) + + def __torch_dispatch__( + self, + func: Any, + types: tuple[Any, ...], + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] | None = None, + ) -> Any: + if kwargs is None: + kwargs = {} + + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + + # Find all LocalTensor arguments to determine ranks + local_tensors = [a for a in flat_args if isinstance(a, LocalTensor)] + + # Check for unrecognized tensor subclasses (but allow regular tensors and scalars) + has_unrecognized_types = _check_for_subclass(flat_args) + if has_unrecognized_types: + unrecognized_types = [ + type(x) for x in flat_args if _check_for_subclass_arg(x) + ] + not_implemented_log.debug( + "LocalTensorMode unrecognized subclass(es): %s", unrecognized_types + ) + return NotImplemented + + # Factory functions convert into LocalTensor, so we don't have to + # transmute a Tensor into a LocalTensor if mutation happens... + # But if you do an operation on a Tensor, do NOT wrap it into a + # LocalTensor. This helps prevent accidents when you're doing Tensor + # operations on the inner non-wrapped tensors. + if not local_tensors: + if self._disable or any(isinstance(a, Tensor) for a in flat_args): + return func(*args, **kwargs) + + # For LocalTensors, verify they have compatible ranks + for a in flat_args: + if isinstance(a, LocalTensor): + assert a._ranks <= self.ranks, ( + f"Input LocalTensor {a} must be configured for a subset of the LocalTensorMode ranks {self.ranks}" + ) + + if func.overloadpacket == torch.ops.aten.dim: + return len(args[0]._size) + if func.overloadpacket == torch.ops.aten.sym_size: + return tuple(args[0]._size) + + if func.namespace == "c10d": + if func is torch.ops.c10d.allreduce_.default: + return _c10d._local_all_reduce_(*args, **kwargs) + elif func is torch.ops.c10d.allreduce_coalesced_.default: + return _c10d._local_allreduce_coalesced_(*args, **kwargs) + elif func is torch.ops.c10d.reduce_scatter_tensor_coalesced_.default: + return _c10d._local_reduce_scatter_tensor_coalesced_(*args, **kwargs) + elif func is torch.ops.c10d.scatter_.default: + return _c10d._local_scatter_(*args, **kwargs) + elif func is torch.ops.c10d.broadcast_.default: + return _c10d._local_broadcast_(*args, **kwargs) + elif func is torch.ops.c10d.allgather_.default: + return _c10d._local_all_gather_(*args, **kwargs) + elif func is torch.ops.c10d.allgather_into_tensor_coalesced_.default: + return _c10d._local_allgather_into_tensor_coalesced_(*args, **kwargs) + elif func is torch.ops.c10d._allgather_base_.default: + return _c10d._local_allgather_base_(*args, **kwargs) + elif func is torch.ops.c10d._reduce_scatter_base_.default: + return _c10d._local_reduce_scatter_base_(*args, **kwargs) + elif func is torch.ops.c10d.gather_.default: + return _c10d._local_gather_(*args, **kwargs) + elif func is torch.ops.c10d.alltoall_.default: + return _c10d._local_alltoall_(*args, **kwargs) + elif func is torch.ops.c10d.alltoall_base_.default: + return _c10d._local_alltoall_base_(*args, **kwargs) + elif func is torch.ops.c10d.barrier.default: + return _c10d._local_barrier(*args, **kwargs) + elif func is torch.ops.c10d.monitored_barrier_.default: + return _c10d._local_monitored_barrier_(*args, **kwargs) + elif func is torch.ops.c10d.send.default: + return _c10d._local_send(*args, **kwargs) + elif func is torch.ops.c10d.recv_.default: + return _c10d._local_recv_(*args, **kwargs) + elif func is torch.ops.c10d.recv_any_source_.default: + return _c10d._local_recv_any_source_(*args, **kwargs) + raise NotImplementedError(f"{func} not implemented") + + if func.namespace == "_c10d_functional" or func.namespace == "_dtensor": + if func is torch.ops._dtensor.shard_dim_alltoall.default: + return _c10d._local_functional_shard_dim_alltoall(*args, **kwargs) + elif func is torch.ops._c10d_functional.all_gather_into_tensor.default: + return _c10d._local_functional_all_gather_into_tensor(*args, **kwargs) + elif func is torch.ops._c10d_functional.reduce_scatter_tensor.default: + return _c10d._local_functional_reduce_scatter_tensor(*args, **kwargs) + elif func is torch.ops._c10d_functional.all_to_all_single.default: + return _c10d._local_functional_all_to_all_single(*args, **kwargs) + else: + with LocalTensorMode(self.ranks): + return func._op_dk( + DispatchKey.CompositeExplicitAutograd, *args, **kwargs + ) + + if func.namespace == "profiler": + return func(*args, **kwargs) + + if func.namespace == "_c10d_functional_autograd": + raise NotImplementedError(f"{func} not implemented") + + if func.namespace == "symm_mem": + raise NotImplementedError(f"{func} not implemented") + + return _for_each_rank_run_func(func, self.ranks, args, kwargs, alias=True) + + def disable_(self): + if self._disable: + return + + self._unpatch_device_mesh() + self._unpatch_random_functions() + self._disable = True + + def enable_(self): + if not self._disable: + return + + self._patch_device_mesh() + self._patch_random_functions() + self._disable = False + + @contextlib.contextmanager + def disable(self) -> Generator[None, None, None]: + """ + Disables LocalTensorMode temporarily. Primarily is intended to be used to perform + rank specific computations and merge results back before enabling LocalTensorMode back. + """ + + # don't unpatch again if already disabled + if self._disable: + try: + yield + finally: + # re-disable if the yield messed + # with the state + self.disable_() + return + + self.disable_() + try: + yield + finally: + self.enable_() + + def rank_map(self, cb: Callable[[int], Tensor]) -> LocalTensor: + """ + Creates a LocalTensor instance by mapping rank id to ids local shard. + """ + + with self.disable(): + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return LocalTensor({r: cb(r) for r in self.ranks}) + + def tensor_map( + self, tensor: LocalTensor, cb: Callable[[int, Tensor], Tensor | None] + ) -> LocalTensor: + """ + Creates a LocalTensor instance by mapping rank id to ids local shard. + """ + + with self.disable(): + results = {} + for r in self.ranks: + if r in tensor._local_tensors: + m = cb(r, tensor._local_tensors[r]) + if m is not None: + results[r] = m + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return LocalTensor(results) + + def _any_local_rng_state(self) -> tuple[torch.Tensor, dict[int, torch.Tensor]]: + return self._per_rank_rng_states[next(iter(self.ranks))] + + def _patch_device_mesh(self) -> None: + assert self._old_get_coordinate is None + assert self._old_get_rank is None + assert self._old_get_local_rank is None + self._old_get_coordinate = DeviceMesh.get_coordinate # type: ignore[assignment] + self._old_get_rank = DeviceMesh.get_rank # type: ignore[assignment] + self._old_get_local_rank = DeviceMesh.get_local_rank # type: ignore[assignment] + DeviceMesh.get_coordinate = _LocalDeviceMesh.get_coordinate # type: ignore[method-assign] + DeviceMesh.get_rank = _LocalDeviceMesh.get_rank # type: ignore[method-assign] + DeviceMesh.get_local_rank = _LocalDeviceMesh.get_local_rank # type: ignore[method-assign] + + def _unpatch_device_mesh(self) -> None: + assert self._old_get_coordinate is not None + assert self._old_get_rank is not None + assert self._old_get_local_rank is not None + DeviceMesh.get_coordinate = self._old_get_coordinate + DeviceMesh.get_rank = self._old_get_rank + DeviceMesh.get_local_rank = self._old_get_local_rank + # pyrefly: ignore [bad-assignment] + self._old_get_coordinate = None + # pyrefly: ignore [bad-assignment] + self._old_get_rank = None + # pyrefly: ignore [bad-assignment] + self._old_get_local_rank = None + + def _patch_random_functions(self) -> None: + import torch.random + from torch.distributed.tensor import _random as dtensor_random + + if self._old_torch_manual_seed is None: + self._old_torch_manual_seed = torch.random.manual_seed + torch.random.manual_seed = _LocalRandom.torch_manual_seed + torch.manual_seed = _LocalRandom.torch_manual_seed + + if self._old_torch_initial_seed is None: + self._old_torch_initial_seed = torch.random.initial_seed + torch.random.initial_seed = _LocalRandom.torch_initial_seed + torch.initial_seed = _LocalRandom.torch_initial_seed + + def _unpatch_random_functions(self) -> None: + import torch.random + from torch.distributed.tensor import _random as dtensor_random + + if self._old_torch_manual_seed is not None: + torch.random.manual_seed = self._old_torch_manual_seed + torch.manual_seed = self._old_torch_manual_seed + self._old_torch_manual_seed = None + + if self._old_torch_initial_seed is not None: + torch.random.initial_seed = self._old_torch_initial_seed + torch.initial_seed = self._old_torch_initial_seed + self._old_torch_initial_seed = None + + +class _LocalRandom: + """ + Holds implementations of random functionality that must be patched while running + under LocalTensorMode. + """ + + @staticmethod + def torch_manual_seed(seed) -> torch._C.Generator: + """LocalTensor-aware version of torch.random.manual_seed.""" + if ( + (lm := enabled_local_tensor_mode()) + and isinstance(seed, torch.SymInt) + and isinstance(seed.node, LocalIntNode) + ): + from torch.random import _manual_seed_impl + + for rank in sorted(lm.ranks): + rank_seed = seed.node._local_ints[rank] + _manual_seed_impl(rank_seed) + lm._per_rank_rng_states[rank] = _get_rng_state() + return torch.random.default_generator + from torch.random import _manual_seed_impl + + result = _manual_seed_impl(seed) + + if lm is not None and len(lm._per_rank_rng_states) > 0: + cpu_state, cuda_states = _get_rng_state() + for rank in lm.ranks: + lm._per_rank_rng_states[rank] = ( + cpu_state.clone(), + {idx: state.clone() for idx, state in cuda_states.items()}, + ) + + return result + + @staticmethod + def torch_initial_seed(): + """LocalTensor-aware version of torch.random.initial_seed.""" + if lm := enabled_local_tensor_mode(): + if len(lm._per_rank_rng_states) == 0: + return torch.random.default_generator.initial_seed() + rank_seeds = {} + + for rank in sorted(lm.ranks): + _set_rng_state(*lm._per_rank_rng_states[rank]) + rank_seeds[rank] = torch.random.default_generator.initial_seed() + + local_int_node = LocalIntNode(rank_seeds) + return torch.SymInt(local_int_node) + + return torch.random.default_generator.initial_seed() + + +# Save the original get_coordinate method before any patching + + +class _LocalDeviceMesh: + """ + Holds implementations of DeviceMesh functionality that must be patched while running + under LocalTensorMode. + """ + + @staticmethod + def get_coordinate(self: DeviceMesh) -> list[int] | None: + # NB: In order to support submeshes the code below recreates for each + # rank submesh with the same mesh dimensions as current mesh. We are + # doing this because when submesh is created it is created for a particular + # rank (therefore below we are patching get_rank method). We are trying to + # limit the invasiveness of local tensor. + lm = enabled_local_tensor_mode() + assert lm is not None, "Unexpectedly not in LocalTensorMode" + + coords: list[dict[int, int]] = [{} for _ in range(self.ndim)] + for r in lm.ranks: + rank_tensor = self._layout.remap_to_tensor(self._rank_map) + rank_coords = (rank_tensor == r).nonzero().tolist() + assert len(rank_coords) == 1 + for d, c in enumerate(rank_coords[0][1:]): + coords[d][r] = c + + out = [torch.SymInt(LocalIntNode(c)) for c in coords] + # The output contains coordinates for each of the ranks with respect to + # their meshes formed from root mesh and selecting the same dimensions + # as the current mesh. + return out # type: ignore[return-value] + + @staticmethod + def get_rank(self) -> int | SymInt: + lm = enabled_local_tensor_mode() + assert lm is not None, "Unexpectedly not in LocalTensorMode" + return torch.SymInt(LocalIntNode(local_ints={r: r for r in lm.ranks})) + + @staticmethod + def get_local_rank(self, mesh_dim: int | str | None = None) -> int | SymInt: + lm = enabled_local_tensor_mode() + assert lm is not None, "Unexpectedly not in LocalTensorMode" + + if self.ndim > 1 and mesh_dim is None: + raise RuntimeError( + f"Found the DeviceMesh have {self.ndim} dimensions", + "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", + ) + elif mesh_dim is None: + mesh_dim = 0 + + if isinstance(mesh_dim, str): + mesh_dim = self._mesh_dim_names.index(mesh_dim) + + # Compute local rank for each global rank + # get_coordinate returns a list of SymInt, one per mesh dimension + # We need to extract the coordinate for the specified mesh_dim + coords = _LocalDeviceMesh.get_coordinate(self) + assert coords is not None + return coords[mesh_dim] + + +def reconcile_args(args: Any, kwargs: dict[str, Any] | None = None) -> Any: + """ + Reconciles arguments by converting any LocalTensor instances in the input + arguments to their underlying torch.Tensor representation. + + This function is typically used to prepare arguments for functions that + expect standard torch.Tensor objects, by flattening the input arguments, + replacing LocalTensor instances with their reconciled (standard tensor) + versions, and then reconstructing the original argument structure. + + Args: + args: Positional arguments, possibly containing LocalTensor instances. + kwargs: Keyword arguments, possibly containing LocalTensor instances. + + Returns: + Any: The arguments with all LocalTensor instances replaced by their reconciled torch.Tensor equivalents, + preserving the original structure. + """ + if kwargs is None: + kwargs = {} + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + reconciled_args = [ + a.reconcile() if isinstance(a, LocalTensor) else a for a in flat_args + ] + return pytree.tree_unflatten(reconciled_args, args_spec) + + +def local_tensor_mode() -> LocalTensorMode | None: + """ + Returns the current active LocalTensorMode if one exists. + + This function checks the global stack of LocalTensorMode instance. If there + is at least one LocalTensorMode active, it returns the most recently entered + (top of the stack) LocalTensorMode. If no LocalTensorMode is active, it returns None. + + Returns: + Optional[LocalTensorMode]: The current LocalTensorMode if active, else None. + """ + local_tensor_mode_list = get_local_tensor_mode_list() + if len(local_tensor_mode_list) > 0: + return local_tensor_mode_list[-1] + return None + + +def enabled_local_tensor_mode() -> LocalTensorMode | None: + """ + Returns the current active LocalTensorMode only if it's enabled. + + This is a convenience function that combines the common pattern of checking + if local_tensor_mode() is not None and not disabled. + + Returns: + Optional[LocalTensorMode]: The current LocalTensorMode if active and enabled, else None. + """ + lm = local_tensor_mode() + if lm is not None and not lm._disable: + return lm + return None + + +def maybe_run_for_local_tensor(func: Callable[_P, _R]) -> Callable[_P, _R]: + """ + Decorator that ensures a function is executed for each local tensor shard + when running under LocalTensorMode. If not in LocalTensorMode, the function + is executed normally. When in LocalTensorMode, the function is run for each + rank, and the results are collected appropriately. + + This decorator is useful for functions that exhibit non-SPMD behavior, such + as those requiring rank specific actions. For example, a function that computes + offset into input tensor based on rank. + + Note that the function being decorated must not have any side effects and + contain operations for a single rank only. For example, wrapping a function + that performs a collective operation will not work. + + Args: + func (Callable[..., Any]): The function to be decorated. + + Returns: + Callable[..., Any]: The wrapped function that handles LocalTensorMode logic. + """ + + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + if not (lm := enabled_local_tensor_mode()): + return func(*args, **kwargs) + ret = None + with lm.disable(): + ret = _for_each_rank_run_func(func, lm.ranks, args, kwargs, alias=False) + + return ret + + return wrapper + + +def maybe_disable_local_tensor_mode() -> contextlib.AbstractContextManager: + """ + Context manager that disables LocalTensorMode for the duration of the context. + """ + lm = local_tensor_mode() + return lm.disable() if lm is not None else contextlib.nullcontext() + + +def maybe_enable_local_tracker( + device_type: str, distribute_region_enabled: bool, spec, generator +): + """ + Returns a context manager for LocalTensor-mode RNG tracking if local tensor mode is enabled. + + Args: + device_type: The device type (e.g., "cuda", "cpu") + distribute_region_enabled: Whether distribute region is enabled + spec: The DTensorSpec + generator: Optional torch.Generator + + Returns: + Context manager from local_tracker._distribute_region if local tensor mode is enabled, + otherwise None. + """ + if enabled_local_tensor_mode(): + local_tracker = _LocalOffsetBasedRNGTracker(device_type) + local_tracker.distribute_region_enabled = distribute_region_enabled + return local_tracker._distribute_region(spec, generator) + + return None + + +def get_generator_seed_for_device_type(device_type: str): + """ + Gets the generator seed for a specific device type, handling LocalTensor mode appropriately. + + Args: + device_type: The device type (e.g., "cuda", "cpu") + + Returns: + If in LocalTensor mode with per-rank RNG states: + - Returns int if all ranks have the same seed + - Returns SymInt(LocalIntNode) if ranks have different seeds + Otherwise: + - Returns int seed from the device's RNG state + """ + if lm := enabled_local_tensor_mode(): + if len(lm._per_rank_rng_states) == 0: + device_module = torch.get_device_module(device_type) + return device_module.get_rng_state()[:8].view(torch.int64).item() + device_module = torch.get_device_module(device_type) + + original_state = _get_rng_state() + + rank_seeds = {} + try: + for rank in sorted(lm.ranks): + _set_rng_state(*lm._per_rank_rng_states[rank]) + rank_seeds[rank] = int( + device_module.get_rng_state()[:8].view(torch.int64).item() + ) + finally: + # restore original state + _set_rng_state(*original_state) + + unique_seeds = set(rank_seeds.values()) + if len(unique_seeds) == 1: + return next(iter(unique_seeds)) + local_int_node = LocalIntNode(rank_seeds) + return torch.SymInt(local_int_node) + else: + device_module = torch.get_device_module(device_type) + return device_module.get_rng_state()[:8].view(torch.int64).item() + + +import threading +from queue import Queue + + +_LOCAL_RUNNER_MODE: "LocalRunnerMode | None" = None + + +class LocalRunnerMode: + """ + A class for running multiple SPMD functions concurrently, however at any point + in time only one function can be running. The main use case for the local runner + mode is to enable SPMD functions to be able to use send and recv to communicate + with each other. Without local runner mode send and recv are not supported. + """ + + runner_context = threading.local() + + def __init__( + self, ranks: frozenset[int] | int, concurrency: int, fn: Callable[[int], None] + ): + if isinstance(ranks, int): + ranks = frozenset(range(ranks)) + self._ranks = ranks + self._fn = fn + self._run_lock = threading.Lock() + self._run_id = -1 + self._run_cond = threading.Condition(self._run_lock) + + self._recv_objects: dict[int, dict[int, Queue]] = { + dst: {src: Queue() for src in ranks} for dst in ranks + } + self._runners = [ + threading.Thread(target=self._run, args=(i,), name="LocalRunnerMode") + for i in range(concurrency) + ] + self._process_mode = True + + def __enter__(self) -> "LocalRunnerMode": + global _LOCAL_RUNNER_MODE + assert _LOCAL_RUNNER_MODE is None, "LocalRunnerMode is already running" + _LOCAL_RUNNER_MODE = self + + global _PROCESS_MODE + self._process_mode = _PROCESS_MODE + _PROCESS_MODE = False + for r in self._runners: + r.start() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + for r in self._runners: + r.join() + global _LOCAL_RUNNER_MODE + _LOCAL_RUNNER_MODE = None + + global _PROCESS_MODE + _PROCESS_MODE = self._process_mode + + def _run(self, id: int) -> None: + LocalRunnerMode.runner_context.id = id + # Only one thread can run at a time, hence must acquire the lock + try: + self._acquire_run_lock() + self._fn(id) + finally: + self._release_run_lock() + + def _acquire_run_lock(self) -> None: + self._run_lock.acquire() + self._run_id = LocalRunnerMode.runner_context.id + + def _release_run_lock(self) -> None: + self._run_id = -1 + self._run_lock.release() + + def _assert_holds_run_lock(self) -> None: + assert self._run_id == LocalRunnerMode.runner_context.id, ( + "Calling thread does not hold the run lock" + ) + + def _get_recv_object(self, src: int, dst: int) -> object | None: + peers = [src] if src != -1 else list(self._ranks) + recv_objects = self._recv_objects[dst] + + for p in peers: + if not recv_objects[p].empty(): + return recv_objects[p].get() + + return None + + def _signal_send(self, src: int, dst: int, obj: object) -> None: + assert obj is not None, "Cannot signal None" + # Only a single thread a time executes so it is safe to mutate + # read objects queue (executing thread is already holding the lock) + self._recv_objects[dst][src].put(obj) + # Signal directly condition variable since the calling thread is already + # holding the lock + self._run_cond.notify_all() + + def _wait_recv(self, src: int, dst: int, post: Callable[[object], None]) -> None: + # Wait for the object to be available + while True: + obj = self._get_recv_object(src, dst) + if obj is not None: + post(obj) + # Note that we are not releasing the lock here, since the thread + # will continue to run and therefore must hold the lock + return + self._run_cond.wait() + + @staticmethod + def current() -> "LocalRunnerMode": + global _LOCAL_RUNNER_MODE + assert _LOCAL_RUNNER_MODE is not None, "LocalRunnerMode is not enabled" + return _LOCAL_RUNNER_MODE + + +class _LocalPhiloxState: + """ + LocalTensor-aware version of _PhiloxState that manages per-rank RNG states. + This class handles the case where the generator state is a LocalTensor, allowing + different offsets and seeds for different virtual ranks. + + Note: This is designed to be used as a drop-in replacement for _PhiloxState + when working with LocalTensors in the DTensor random ops implementation. + """ + + def __init__(self, state: torch.Tensor): + assert isinstance(state, LocalTensor), ( + "_LocalPhiloxState requires a LocalTensor" + ) + self._local_tensor = state + self._per_rank_states = { + rank: local_state.to("cpu") + for rank, local_state in state._local_tensors.items() + } + + @property + def state(self): + return LocalTensor(self._per_rank_states) # type: ignore[name-defined] + + @property + def offset(self) -> int | SymInt: + from torch.distributed.tensor._random import _PhiloxState + + offsets = {} + for rank, state in self._per_rank_states.items(): + rank_philox = _PhiloxState(state) + offsets[rank] = rank_philox.offset + + if len(set(offsets.values())) == 1: + return next(iter(offsets.values())) + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return SymInt(LocalIntNode(offsets)) + + @offset.setter + def offset(self, offset: int | SymInt) -> None: + from torch.distributed.tensor._random import _PhiloxState + + if isinstance(offset, SymInt) and isinstance(offset.node, LocalIntNode): + for rank, state in self._per_rank_states.items(): + rank_offset = offset.node._local_ints[rank] + rank_philox = _PhiloxState(state) + rank_philox.offset = rank_offset + else: + offset_int = int(offset) if isinstance(offset, SymInt) else offset + for state in self._per_rank_states.values(): + rank_philox = _PhiloxState(state) + rank_philox.offset = offset_int + + @property + def seed(self) -> int | SymInt: + from torch.distributed.tensor._random import _PhiloxState + + seeds = {} + for rank, state in self._per_rank_states.items(): + rank_philox = _PhiloxState(state) + seeds[rank] = rank_philox.seed + + if len(set(seeds.values())) == 1: + return next(iter(seeds.values())) + return SymInt(LocalIntNode(seeds)) + + @seed.setter + def seed(self, seed: int | SymInt) -> None: + from torch.distributed.tensor._random import _PhiloxState + + if isinstance(seed, SymInt) and isinstance(seed.node, LocalIntNode): + for rank, state in self._per_rank_states.items(): + rank_seed = seed.node._local_ints[rank] + rank_philox = _PhiloxState(state) + rank_philox.seed = rank_seed + else: + seed_int = int(seed) if isinstance(seed, SymInt) else seed + for state in self._per_rank_states.values(): + rank_philox = _PhiloxState(state) + rank_philox.seed = seed_int + + def apply_to_local_tensor_mode(self, device_handle) -> None: + """ + Apply per-rank RNG states to the LocalTensorMode's tracked states. + This updates both the device RNG state and the LocalTensorMode's _per_rank_rng_states. + + Args: + device_handle: The device handle to use for setting RNG state (_LocalDeviceHandle) + """ + if not enabled_local_tensor_mode(): + return + + assert hasattr(self, "_per_rank_offsets") + + for rank in sorted(self._per_rank_states.keys()): + offset_value = self._per_rank_offsets[rank] + if isinstance(offset_value, SymInt): + if isinstance(offset_value.node, LocalIntNode): + offset_value = offset_value.node._local_ints[rank] + else: + offset_value = int(offset_value) + + offset_tensor = torch.tensor( + [offset_value], dtype=torch.uint64, device="cpu" + ).view(torch.uint8) + self._per_rank_states[rank][8:] = offset_tensor + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + device_handle.set_rng_state(LocalTensor(self._per_rank_states)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_local_tensor/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_local_tensor/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e00080875140aac24836355e599780dc1eba5eb7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_local_tensor/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_local_tensor/__pycache__/_c10d.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_local_tensor/__pycache__/_c10d.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..102f26fdfdc80681bce07c44894281e51039b33a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_local_tensor/__pycache__/_c10d.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_local_tensor/_c10d.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_local_tensor/_c10d.py new file mode 100644 index 0000000000000000000000000000000000000000..b3eca57402c56d8b5e9cdb216245ee2652f5250e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_local_tensor/_c10d.py @@ -0,0 +1,1060 @@ +import functools +import math +import operator +from collections.abc import Callable, Sequence +from datetime import timedelta + +import torch +from torch._C import ScriptObject +from torch._C._distributed_c10d import FakeWork, PythonCallbackWork +from torch.distributed._mesh_layout import _MeshLayout +from torch.distributed.distributed_c10d import ( + _check_op, + _get_default_group, + _resolve_process_group, + GroupName, + ProcessGroup, + ReduceOp, + Work, +) + + +# NOTE: Most of the c10d collectives often take a Tensor[] (or Tensor[][]) +# when you would expect Tensor (or Tensor[]). In fact, there will only ever +# be one Tensor in this case; the old signature was to support dispatching a +# collective on multiple devices (ala DataParallel) but we don't support that +# API anymore. Note that we are not 100% consistent about this; some more +# modern collectives like _allgather_base_ got rid of the unnecessary list. +# When in doubt, consult the code that dispatches to the collective on the PG +# in distributed_c10d.py e.g., work = group.allgather([tensor_list], [tensor], +# opts) indicates its always a list. + + +def _gcd_list(numbers: Sequence[int]) -> int: + return 0 if not numbers else functools.reduce(math.gcd, numbers) + + +def _indices_to_layout(indices: list[int]) -> tuple[tuple[int, ...], tuple[int, ...]]: + # Base case: A single index represents a point, not a dimension. + if len(indices) <= 1: + return (), () + + # The smallest stride is likely the GCD of the differences between consecutive indices. + # For a sorted, unique list, all differences will be positive. + diffs = [indices[i] - indices[i - 1] for i in range(1, len(indices))] + last_stride = _gcd_list(diffs) + + assert last_stride != 0, ( + # This case should not be reached if indices are unique and sorted. + "Cannot determine stride; indices may not be unique." + ) + + # Identify the starting index of each "row" in the last dimension. + # An index starts a new row if the preceding index (index - stride) is not present. + indices_set = set(indices) + higher_dim_indices = [indices[0]] + for index in indices[1:]: + if (index - last_stride) not in indices_set: + higher_dim_indices.append(index) + + # From the number of rows, we can deduce the shape of the last dimension. + assert len(indices) % len(higher_dim_indices) == 0, ( + "Indices do not form a regular grid. " + f"Found {len(higher_dim_indices)} subgroups for {len(indices)} total elements." + ) + last_shape = len(indices) // len(higher_dim_indices) + + # Recurse on the higher-dimensional indices (the start of each row). + higher_shapes, higher_strides = _indices_to_layout(higher_dim_indices) + + # Combine the results from the recursion with the current dimension's results. + final_shapes = higher_shapes + (last_shape,) + final_strides = higher_strides + (last_stride,) + + return final_shapes, final_strides + + +def _prepare_collective_groups( + process_group_so: ScriptObject | ProcessGroup, +) -> tuple[list[int], list[int], int]: + process_group = ( + ProcessGroup.unbox(process_group_so) + if isinstance(process_group_so, ScriptObject) + else process_group_so + ) + + ranks = torch.distributed.get_process_group_ranks(process_group) + assert ranks + # TODO: We can handle permutations but the layout inference algorithm will + # lose the permutation so we will have to reapply it + assert ranks == sorted(ranks), ranks + offset = ranks[0] + ranks = [r - offset for r in ranks] + + shape, strides = _indices_to_layout(ranks) + layout = _MeshLayout(shape, strides) + + global_pg = _get_default_group() + group_offsets = layout.complement(global_pg.size()).all_ranks_from_zero() + + return ranks, group_offsets, offset + + +# NB: There are two flavors of the collectives: regular and functional. Regular collectives +# allocate outputs to write the result to, accept process group and support async ops (return +# work object). Functional collectives expect the implementation to allocate outputs, accept +# process group name that must be resolved and do not support async ops (return output). +def _local_functional_all_gather_into_tensor( + tensor: torch.Tensor, group_size: int, group_name: GroupName +) -> torch.Tensor: + # "all_gather_into_tensor(Tensor input, int group_size, str group_name) -> Tensor" + from . import LocalTensor + + ranks, group_offsets, offset = _prepare_collective_groups( + _resolve_process_group(group_name) + ) + + assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor" + output_local_tensors: dict[int, torch.Tensor] = {} + + for group_offset in group_offsets: + group_ranks = [group_offset + r for r in ranks] + + group_tensors = [] + if not all(rank in tensor._local_tensors for rank in group_ranks): + continue + + for rank in group_ranks: + group_tensors.append(tensor._local_tensors[rank]) + + gathered_tensor = torch.cat(group_tensors, dim=0) + + for rank in group_ranks: + output_local_tensors[rank] = gathered_tensor.clone() + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + output = LocalTensor(output_local_tensors) + + return output + + +def _local_functional_reduce_scatter_tensor( + tensor: torch.Tensor, reduce_op: str, group_size: int, group_name: GroupName +) -> torch.Tensor: + # "reduce_scatter_tensor(Tensor input, str reduce_op, int group_size, str group_name) -> Tensor" + from . import _zero_sized_like, LocalTensor + + ranks, group_offsets, offset = _prepare_collective_groups( + _resolve_process_group(group_name) + ) + + assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor" + output_local_tensors: dict[int, torch.Tensor] = {} + + for group_offset in group_offsets: + group_ranks = [group_offset + r for r in ranks] + + group_tensors = [] + if not all(rank in tensor._local_tensors for rank in group_ranks): + continue + + for rank in group_ranks: + group_tensors.append(tensor._local_tensors[rank]) + + reduced_tensor = _local_reduce(reduce_op, group_tensors) + + scattered_tensor = torch.split( + reduced_tensor, + reduced_tensor.size(0) // len(group_ranks), + dim=0, + ) + + for i, rank in enumerate(group_ranks): + if i < len(scattered_tensor): + output_local_tensors[rank] = scattered_tensor[i].clone() + else: + output_local_tensors[rank] = _zero_sized_like(reduced_tensor, 0) + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + output = LocalTensor(output_local_tensors) + + return output + + +def _local_functional_shard_dim_alltoall( + tensor: torch.Tensor, gather_dim: int, shard_dim: int, group_name: GroupName +) -> torch.Tensor: + # "shard_dim_alltoall(Tensor input, int gather_dim, int shard_dim, str group_name) -> Tensor" + from . import _zero_sized_like, LocalTensor + + ranks, group_offsets, offset = _prepare_collective_groups( + _resolve_process_group(group_name) + ) + + assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor" + output_local_tensors: dict[int, torch.Tensor] = {} + + for group_offset in group_offsets: + group_ranks = [group_offset + r for r in ranks] + + group_tensors = [] + if not all(rank in tensor._local_tensors for rank in group_ranks): + continue + + for rank in group_ranks: + group_tensors.append(tensor._local_tensors[rank]) + + gathered_tensor = torch.cat(group_tensors, dim=gather_dim) + + split_tensor = torch.split( + gathered_tensor, + gathered_tensor.size(shard_dim) // len(group_ranks), + dim=shard_dim, + ) + + for i, rank in enumerate(group_ranks): + if i < len(split_tensor): + output_local_tensors[rank] = split_tensor[i].clone() + else: + output_local_tensors[rank] = _zero_sized_like( + gathered_tensor, shard_dim + ) + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + output = LocalTensor(output_local_tensors) + + return output + + +def _local_functional_all_to_all_single( + tensor: torch.Tensor, + output_split_sizes: list[torch.SymInt], + input_split_sizes: list[torch.SymInt], + group_name: GroupName, +) -> torch.Tensor: + # "all_to_all_single(Tensor input, SymInt[] output_split_sizes, SymInt[] input_split_sizes, str group_name) -> Tensor" + from . import LocalIntNode, LocalTensor + + ranks, group_offsets, offset = _prepare_collective_groups( + _resolve_process_group(group_name) + ) + + assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor" + + split_local_sizes: dict[int, list[int]] = {} + for input_split_size in input_split_sizes: + if isinstance(input_split_size, torch.SymInt) and isinstance( + input_split_size.node, LocalIntNode + ): + local_ints = dict(input_split_size.node._local_ints.items()) + else: + local_ints = {rank: int(input_split_size) for rank in tensor._local_tensors} + for rank, split_size in local_ints.items(): + if rank not in split_local_sizes: + split_local_sizes[rank] = [] + split_local_sizes[rank].append(split_size) + + split_local_tensors: dict[int, list[torch.Tensor]] = {} + + for rank, split_sizes in split_local_sizes.items(): + split_local_tensors[rank] = list( + torch.split(tensor._local_tensors[rank], split_sizes) + ) + + output_local_tensors: dict[int, torch.Tensor] = {} + + for group_offset in group_offsets: + group_ranks = [group_offset + r for r in ranks] + + if not all(rank in split_local_tensors for rank in group_ranks): + continue + + for i, dst in enumerate(group_ranks): + splits = [] + for j, src in enumerate(group_ranks): + splits.append(split_local_tensors[src][i]) + output_local_tensors[dst] = torch.cat(splits) + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + output = LocalTensor(output_local_tensors) + + return output + + +def _local_broadcast_( + tensors: list[torch.Tensor], + process_group_so: ScriptObject, + root_rank: int, + root_tensor: int, + async_op: bool = True, + timeout: int = -1, +) -> tuple[list[torch.Tensor], ScriptObject]: + # "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " + # "int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)" + from . import LocalTensor + + assert len(tensors) == 1 + assert root_tensor == 0 + tensor = tensors[0] + + ranks, group_offsets, offset = _prepare_collective_groups(process_group_so) + + # We're going to assume SPMD where for every rank group the root_rank is + # the same relative to others + relative_root_rank = root_rank - offset + + assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor" + + for group_offset in group_offsets: + # For the tensors in this group [group_offset + r for r in ranks] + # perform the broadcast on them + group_ranks = [group_offset + r for r in ranks] + if not all(rank in tensor._local_tensors for rank in group_ranks): + continue + + source_rank = group_offset + relative_root_rank + source_tensor = tensor._local_tensors[source_rank] + + # Broadcast the source tensor to all ranks in this group + for rank in group_ranks: + if source_rank != rank: + tensor._local_tensors[rank].copy_(source_tensor) + + work = FakeWork() + work_so = Work.boxed(work) + return (tensors, work_so) + + +def _local_reduce( + reduce_op: ReduceOp | str, + tensors: list[torch.Tensor], +) -> torch.Tensor: + if reduce_op == ReduceOp.SUM or reduce_op == "sum": + op = operator.add + elif reduce_op == ReduceOp.AVG or reduce_op == "avg": + op = None + elif reduce_op == ReduceOp.PRODUCT or reduce_op == "product": + op = operator.mul + elif reduce_op == ReduceOp.MIN or reduce_op == "min": + op = torch.minimum + elif reduce_op == ReduceOp.MAX or reduce_op == "max": + op = torch.maximum + elif reduce_op == ReduceOp.BAND or reduce_op == "band": + op = torch.bitwise_and + elif reduce_op == ReduceOp.BOR or reduce_op == "bor": + op = torch.bitwise_or + elif reduce_op == ReduceOp.BXOR or reduce_op == "bxor": + op = torch.bitwise_xor + elif reduce_op == ReduceOp.PREMUL_SUM or reduce_op == "premul_sum": + raise NotImplementedError("PREMUL_SUM: need to add binding for scaling factor") + else: + raise NotImplementedError(f"ReduceOp {reduce_op} not implemented") + + if reduce_op == ReduceOp.AVG or reduce_op == "avg": + return functools.reduce(operator.add, tensors) / len(tensors) + else: + assert op is not None + return functools.reduce(op, tensors) + + +def _local_all_reduce_( + tensors: list[torch.Tensor], + process_group_so: ScriptObject, + reduce_op_so: ScriptObject, + sparse_indices: torch.Tensor | None = None, + async_op: bool = True, + timeout: int = -1, +) -> tuple[list[torch.Tensor], ScriptObject]: + # "allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " + # "__torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, bool async_op=True, " + # "int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); + from . import LocalTensor + + assert len(tensors) == 1 + tensor = tensors[0] + reduce_op = reduce_op_so.op() # type: ignore[attr-defined] + + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor" + + for group_offset in group_offsets: + # For the tensors in this group [group_offset + r for r in ranks] + # perform the allreduce on them + group_ranks = [group_offset + r for r in ranks] + if not all(rank in tensor._local_tensors for rank in group_ranks): + continue + + # Collect tensors from the specified ranks in this group + group_tensors = [] + for rank in group_ranks: + group_tensors.append(tensor._local_tensors[rank]) + + # Perform the reduction operation + reduced_tensor = _local_reduce(reduce_op, group_tensors) + + # Update all tensors in the group with the reduced result + for rank in group_ranks: + tensor._local_tensors[rank].copy_(reduced_tensor) + + work = FakeWork() + work_so = Work.boxed(work) + return (tensors, work_so) + + +def _local_allreduce_coalesced_( + tensors: list[torch.Tensor], + process_group_so: ScriptObject, + reduce_op_so: ScriptObject, + async_op: bool = True, + timeout: int = -1, +) -> ScriptObject: + # "allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " + # "__torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work" + from . import LocalTensor + + reduce_op = reduce_op_so.op() # type: ignore[attr-defined] + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + for group_offset in group_offsets: + # For the tensors in this group [group_offset + r for r in ranks] + # perform the allreduce on all tensors together + group_ranks = [group_offset + r for r in ranks] + + # For each tensor, perform the reduction operation + for tensor in tensors: + assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor" + if not all(rank in tensor._local_tensors for rank in group_ranks): + continue + # Collect tensors from the specified ranks in this group + group_tensors = [] + for rank in group_ranks: + group_tensors.append(tensor._local_tensors[rank]) + + # Perform the reduction operation + reduced_tensor = _local_reduce(reduce_op, group_tensors) + + # Update all tensors in the group with the reduced result + for rank in group_ranks: + tensor._local_tensors[rank].copy_(reduced_tensor) + + work = FakeWork() + work_so = Work.boxed(work) + return work_so + + +def _local_reduce_scatter_tensor_coalesced_( + output_tensors: list[torch.Tensor], + input_tensors: list[torch.Tensor], + process_group_so: ScriptObject, + reduce_op_so: ScriptObject, + async_op: bool = True, + timeout: int = -1, +) -> ScriptObject: + # "reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, " + # "__torch__.torch.classes.c10d.ProcessGroup process_group, " + # "__torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, " + # "int timeout=-1) -> __torch__.torch.classes.c10d.Work" + + from . import LocalTensor + + reduce_op = reduce_op_so.op() # type: ignore[attr-defined] + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + for group_offset in group_offsets: + # For the tensors in this group [group_offset + r for r in ranks] + # perform the allreduce on all tensors together + group_ranks = [group_offset + r for r in ranks] + + # For each tensor, perform the reduction operation + for input_tensor, output_tensor in zip(input_tensors, output_tensors): + assert isinstance(input_tensor, LocalTensor), ( + "Input tensor must be a LocalTensor" + ) + assert isinstance(output_tensor, LocalTensor), ( + "Output tensor must be a LocalTensor" + ) + if not all(rank in input_tensor._local_tensors for rank in group_ranks): + continue + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue + + # Collect tensors from the specified ranks in this group + group_inputs = [] + for rank in group_ranks: + group_inputs.append(input_tensor._local_tensors[rank]) + + # Perform the reduction operation + reduced_input = _local_reduce(reduce_op, group_inputs) + + reduced_input_splits = torch.split( + reduced_input, reduced_input.size(0) // len(group_ranks), dim=0 + ) + + # Update all tensors in the group with the reduced result + for i, rank in enumerate(group_ranks): + output_tensor._local_tensors[rank].copy_(reduced_input_splits[i]) + + work = FakeWork() + work_so = Work.boxed(work) + return work_so + + +def _local_allgather_base_( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + process_group_so: ScriptObject, + async_op: bool = True, + timeout: int = -1, +) -> tuple[torch.Tensor, ScriptObject]: + # "_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup + # process_group, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)"); + from . import LocalTensor + + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor" + assert isinstance(input_tensor, LocalTensor), "Input tensor must be a LocalTensor" + + for group_offset in group_offsets: + group_ranks = [group_offset + r for r in ranks] + + if not all(rank in input_tensor._local_tensors for rank in group_ranks): + continue + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue + + gathered_tensors = [] + for rank_i in group_ranks: + gathered_tensors.append(input_tensor._local_tensors[rank_i]) + + gathered_tensor = torch.cat(gathered_tensors, dim=0) + + for rank_i in group_ranks: + output_tensor._local_tensors[rank_i].copy_(gathered_tensor) + + work = FakeWork() + work_so = Work.boxed(work) + return output_tensor, work_so + + +def _local_reduce_scatter_base_( # type: ignore[no-untyped-def] + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + process_group_so: ScriptObject, + reduce_op_so: ScriptObject, + async_op: bool = True, + timeout: int = -1, +) -> tuple[torch.Tensor, ScriptObject]: + # "_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, + # __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, + # bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)" + + from . import LocalTensor + + reduce_op = reduce_op_so.op() # type: ignore[attr-defined] + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor" + assert isinstance(input_tensor, LocalTensor), "Input tensor must be a LocalTensor" + + for group_offset in group_offsets: + group_ranks = [group_offset + r for r in ranks] + if not all(rank in input_tensor._local_tensors for rank in group_ranks): + continue + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue + + gathered_tensors = [] + for rank_i in group_ranks: + gathered_tensors.append(input_tensor._local_tensors[rank_i]) + + reduced_tensor = _local_reduce(reduce_op, gathered_tensors) + + scattered_tensor = torch.split( + reduced_tensor, + reduced_tensor.size(0) // len(group_ranks), + dim=0, + ) + + for i, rank_i in enumerate(group_ranks): + output_tensor._local_tensors[rank_i].copy_(scattered_tensor[i].clone()) + + work = FakeWork() + work_so = Work.boxed(work) + return output_tensor, work_so + + +def _local_all_gather_( + output_tensors: list[list[torch.Tensor]], + input_tensors: list[torch.Tensor], + process_group_so: ScriptObject, + async_op: bool = True, + timeout: int = -1, +) -> tuple[list[list[torch.Tensor]], ScriptObject]: + # "allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, " + # "__torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, " + # "int timeout=-1) -> (Tensor[][], __torch__.torch.classes.c10d.Work)"); + + from . import LocalTensor + + assert len(output_tensors) == 1 + assert len(input_tensors) == 1 + + input_tensor = input_tensors[0] + # pyrefly: ignore [bad-assignment] + output_tensors = output_tensors[0] + + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + for i in range(len(output_tensors)): + assert isinstance(output_tensors[i], LocalTensor), ( + "Output tensor must be a LocalTensor" + ) + + for group_offset in group_offsets: + # For the tensors in this group [group_offset + r for r in ranks] + # perform the all_gather on them + group_ranks = [group_offset + r for r in ranks] + + # For each rank in the group, gather from their input tensor + for i, rank_i in enumerate(group_ranks): + # allgather object happens to create pure tensor, so we special case it here + source_tensor = input_tensor + if isinstance(input_tensor, LocalTensor): + source_tensor = input_tensor._local_tensors[rank_i] + # pyrefly: ignore [missing-attribute] + output_tensors[i].copy_(source_tensor) + + work = FakeWork() + work_so = Work.boxed(work) + # pyrefly: ignore [bad-return] + return ([output_tensors], work_so) + + +def _local_allgather_into_tensor_coalesced_( + output_tensors: list[torch.Tensor], + input_tensors: list[torch.Tensor], + process_group_so: ScriptObject, + async_op: bool = True, +) -> ScriptObject: + # "allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, " + # "__torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) " + # "-> __torch__.torch.classes.c10d.Work" + from . import LocalTensor + + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + # Each output tensor should be sized to hold all gathered inputs + # outputs[i] will contain all inputs[i] from all ranks + assert len(output_tensors) == len(input_tensors), ( + f"Number of outputs ({len(output_tensors)}) must match number of inputs ({len(input_tensors)})" + ) + + for group_offset in group_offsets: + # For the tensors in this group [group_offset + r for r in ranks] + # perform the allgather_into_tensor on them + group_ranks = [group_offset + r for r in ranks] + + # For each input/output pair + for input_tensor, output_tensor in zip(input_tensors, output_tensors): + assert isinstance(input_tensor, LocalTensor), ( + "Input tensor must be a LocalTensor" + ) + assert isinstance(output_tensor, LocalTensor), ( + "Output tensor must be a LocalTensor" + ) + + if not all(rank in input_tensor._local_tensors for rank in group_ranks): + continue + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue + + # Gather input_tensor from all ranks into output_tensor + # The output should be a concatenation of all inputs along the first dimension + gathered_tensors = [] + for rank in group_ranks: + gathered_tensors.append(input_tensor._local_tensors[rank]) + + # Concatenate along first dimension and copy to output + if gathered_tensors: + concatenated = torch.cat(gathered_tensors, dim=0) + for rank in group_ranks: + output_tensor._local_tensors[rank].copy_(concatenated) + + work = FakeWork() + work_so = Work.boxed(work) + return work_so + + +def _local_gather_( + output_tensors: list[list[torch.Tensor]], + input_tensors: list[torch.Tensor], + process_group_so: ScriptObject, + root_rank: int, + async_op: bool = True, + timeout: int = -1, +) -> ScriptObject: + # "gather_(Tensor[][] output_tensors, Tensor[] input_tensors, " + # "__torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, " + # "bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work" + raise NotImplementedError( + "LocalTensor does not support MPMD operations like gather " + "(only root rank receives data). Use SPMD collective operations like allgather instead." + ) + + +def _local_scatter_( + output_tensors: list[torch.Tensor], + input_tensors: list[list[torch.Tensor]], + process_group_so: ScriptObject, + root_rank: int, + async_op: bool = True, + timeout: int = -1, +) -> tuple[list[torch.Tensor], ScriptObject]: + # "scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, " + # "__torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, " + # "bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); + + from . import LocalTensor + + assert len(output_tensors) == 1 + assert len(input_tensors) == 1 + output_tensor = output_tensors[0] + # pyrefly: ignore [bad-assignment] + input_tensors = input_tensors[0] + + ranks, group_offsets, offset = _prepare_collective_groups(process_group_so) + + # We're going to assume SPMD where for every rank group the root_rank is + # the same relative to others + relative_root_rank = root_rank - offset + + assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor" + assert len(ranks) == len(input_tensors), (ranks, input_tensors) + + for group_offset in group_offsets: + # For the tensors in this group [group_offset + r for r in ranks] + # perform the scatter on them + group_ranks = [group_offset + r for r in ranks] + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue + + # Root rank scatters its input tensors to all ranks in this group + for i, rank in enumerate(group_ranks): + input_tensor = input_tensors[i] + assert isinstance(input_tensor, LocalTensor) + # Each rank i gets the i-th input tensor from the root + source_tensor = input_tensor._local_tensors[ + group_offset + relative_root_rank + ] + output_tensor._local_tensors[rank].copy_(source_tensor) + + work = FakeWork() + work_so = Work.boxed(work) + return (output_tensors, work_so) + + +def _local_alltoall_( + output_tensors: list[torch.Tensor], + input_tensors: list[torch.Tensor], + process_group_so: ScriptObject, + async_op: bool = True, + timeout: int = -1, +) -> tuple[list[torch.Tensor], ScriptObject]: + # "alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, " + # "__torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, " + # "int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"; + + from . import LocalTensor + + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + assert len(input_tensors) == len(output_tensors) == len(ranks), ( + f"Number of input tensors ({len(input_tensors)}), " + f"output tensors ({len(output_tensors)}), and ranks ({len(ranks)}) must match" + ) + + for group_offset in group_offsets: + # For the tensors in this group [group_offset + r for r in ranks] + # perform the alltoall on them + group_ranks = [group_offset + r for r in ranks] + + # In alltoall, rank i sends input_tensors[j] to rank j and receives into output_tensors[i] from rank j + for i, rank_i in enumerate(group_ranks): + output_tensor = output_tensors[i] + assert isinstance(output_tensor, LocalTensor), ( + "Output tensor must be a LocalTensor" + ) + + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue + + for j, rank_j in enumerate(group_ranks): + input_tensor = input_tensors[j] + assert isinstance(input_tensor, LocalTensor), ( + "Input tensor must be a LocalTensor" + ) + + if not all(rank in input_tensor._local_tensors for rank in group_ranks): + continue + + # Rank i's j-th input tensor goes to rank j's i-th output tensor + source_tensor = input_tensor._local_tensors[rank_i] + output_tensor._local_tensors[rank_j].copy_(source_tensor) + + work = FakeWork() + work_so = Work.boxed(work) + return (output_tensors, work_so) + + +def _local_alltoall_base_( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + process_group_so: ScriptObject, + output_split_sizes: list[int], + input_split_sizes: list[int], + async_op: bool = True, + timeout: int = -1, +) -> ScriptObject: + # "alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, " + # "int[] output_split_sizes, int[] input_split_sizes, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"; + + from . import LocalTensor + + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + assert isinstance(input_tensor, LocalTensor), "Input tensor must be a LocalTensor" + assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor" + # Convert split sizes to lists if they aren't already + if output_split_sizes is not None: + output_split_sizes = list(output_split_sizes) + if input_split_sizes is not None: + input_split_sizes = list(input_split_sizes) + + for group_offset in group_offsets: + # For the tensors in this group [group_offset + r for r in ranks] + # perform the alltoall_base on them + group_ranks = [group_offset + r for r in ranks] + + if not all(rank in input_tensor._local_tensors for rank in group_ranks): + continue + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue + + for i, rank_i in enumerate(group_ranks): + # Split input tensor from rank_i according to input_split_sizes + rank_tensor = input_tensor._local_tensors[rank_i] + + if input_split_sizes is not None and len(input_split_sizes) > 0: + # Split the input tensor + input_splits = torch.split(rank_tensor, input_split_sizes, dim=0) + else: + # No split sizes specified, split evenly + split_size = rank_tensor.size(0) // len(group_ranks) + input_splits = torch.split(rank_tensor, split_size, dim=0) + + # Send each split to the corresponding rank + for j, rank_j in enumerate(group_ranks): + if j < len(input_splits): + split_tensor = input_splits[j] + + # Determine where to place this split in the output tensor + if output_split_sizes is not None and len(output_split_sizes) > 0: + # Calculate offset based on output split sizes + output_offset = sum(output_split_sizes[:i]) if i > 0 else 0 + end_offset = ( + output_offset + output_split_sizes[i] + if i < len(output_split_sizes) + else output_tensor._local_tensors[rank_j].size(0) + ) + else: + # No output split sizes, use even splits + split_size = output_tensor._local_tensors[rank_j].size( + 0 + ) // len(group_ranks) + output_offset = i * split_size + end_offset = min( + (i + 1) * split_size, + output_tensor._local_tensors[rank_j].size(0), + ) + + # Copy the split to the appropriate section of the output tensor + output_section = output_tensor._local_tensors[rank_j][ + output_offset:end_offset + ] + if output_section.numel() > 0: + # Reshape split_tensor to match output_section if necessary + if split_tensor.size() != output_section.size(): + split_tensor = split_tensor.view(output_section.size()) + output_section.copy_(split_tensor) + + work = FakeWork() + work_so = Work.boxed(work) + return work_so + + +def _local_barrier( + tensor: torch.Tensor, + process_group_so: ScriptObject, + device_ids: list[int], + async_op: bool = True, + timeout: int = -1, +) -> ScriptObject: + # "barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, " + # "int[] device_ids, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"; + + from . import LocalTensor + + # Barrier is a synchronization primitive - in local simulation, + # we don't need to do any actual work since all "ranks" are in the same process + # Just validate that the tensor is a LocalTensor + assert isinstance(tensor, LocalTensor) + + # In a real distributed setting, barrier would synchronize all processes + # In local simulation, this is essentially a no-op since all ranks are local + work = FakeWork() + work_so = Work.boxed(work) + return work_so + + +def _local_monitored_barrier_( + tensor: torch.Tensor, + process_group_so: ScriptObject, + device_ids: list[int], + timeout: int, + wait_all_ranks: bool, +) -> None: + # "monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, " + # "int[] device_ids, int timeout, bool wait_all_ranks) -> ()"; + + from . import LocalTensor + + # Monitored barrier is a synchronization primitive with monitoring - in local simulation, + # we don't need to do any actual work since all "ranks" are in the same process + # Just validate that the tensor is a LocalTensor + assert isinstance(tensor, LocalTensor) + + # In a real distributed setting, monitored barrier would synchronize all processes + # and provide monitoring capabilities. In local simulation, this is essentially a no-op + # since all ranks are local and no actual synchronization is needed + return + + +def _local_send( + tensors: list[torch.Tensor], + process_group_so: ScriptObject, + dst: int, + tag: int, +) -> ScriptObject: + # "send(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " + # "int dst, int tag) -> __torch__.torch.classes.c10d.Work"; + + from . import LocalRunnerMode, LocalTensor + + assert len(tensors) == 1 + tensor = tensors[0] + + assert isinstance(tensor, LocalTensor), "Input tensor must be a Tensor" + src = int(tensor.__src_rank__) + + LocalRunnerMode.current()._signal_send(src, dst, tensor._local_tensors[src]) + + work = FakeWork() + work_so = Work.boxed(work) + return work_so + + +def _local_recv_( + tensors: list[torch.Tensor], + process_group_so: ScriptObject, + src: int, + tag: int, +) -> ScriptObject: + # "recv_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " + # "int src, int tag) -> __torch__.torch.classes.c10d.Work"; + from . import LocalRunnerMode, LocalTensor + + assert len(tensors) == 1 + tensor = tensors[0] + + assert isinstance(tensor, LocalTensor), "Input tensor must be a Tensor" + dst = int(tensor.__src_rank__) + + def _recv_and_store(timeout: timedelta) -> bool: + def _wait_and_store(obj: object) -> None: + assert isinstance(obj, torch.Tensor), "Expected to receive a Tensor" + assert isinstance(tensor, LocalTensor), "Input tensor must be a Tensor" + tensor._local_tensors[dst] = obj + + LocalRunnerMode.current()._wait_recv(src, dst, _wait_and_store) + return True + + work = PythonCallbackWork(_recv_and_store) + work_so = Work.boxed(work) + return work_so + + +def _local_recv_any_source_( + tensors: list[torch.Tensor], process_group_so: ScriptObject, tag: int +) -> ScriptObject: + # "recv_any_source_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " + # "int tag) -> __torch__.torch.classes.c10d.Work"; + + return _local_recv_(tensors, process_group_so, -1, tag) + + +def _attach_rank(tensor: torch.Tensor, rank: int) -> torch.Tensor: + """ + Attaches rank as an attribute to given tensor so that the send or recv implementation + knows which rank initiates the operation (note under local tensor mode ). + """ + from torch.distributed.tensor import DTensor + + if isinstance(tensor, DTensor): + tensor = tensor._local_tensor + + tensor.__src_rank__ = rank # type: ignore[attr-defined] + return tensor + + +def local_p2p_op( + dst: torch.SymInt, + tensor: torch.Tensor, + op: Callable[[torch.Tensor, int], Work | None], +) -> Work | None | list[Work | None]: + """ + Runs a point-to-point (P2P) operation for all combinations of source and destination ranks. + """ + _check_op(op) + + from . import LocalIntNode + + assert isinstance(dst.node, LocalIntNode), ( + "Expected 'dst' to be a LocalIntNode where the value is the destination rank and key is the source rank" + ) + + w = [] + for s, d in dst.node._local_ints.items(): + tensor = _attach_rank(tensor, s) + w.append(op(tensor, d)) + return w + + +def wait_all(work: Work | None | list[Work | None]) -> None: + """ + Waits for all work objects in the input to complete. + + A single Work object, None, or a list of Work objects (possibly containing None). + If None, does nothing. If a single Work, waits for it to complete. If a list, waits + for each non-None Work in the list to complete. + """ + + if work is None: + return + if isinstance(work, Work): + work = [work] + for w in work: + if w is None: + continue + w.wait() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e13bcc86e5095a0762417cf0c6cfdaa20951ee5d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/__init__.py @@ -0,0 +1,74 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from .int_tuple import ( + as_tuple, + crd2crd, + crd2idx, + elem_scale, + flatten, + has_none, + idx2crd, + inner_product, + IntTuple, + is_int, + is_tuple, + match_structure, + product, + shape_div, + signum, + slice_, + suffix_product, + tuple_max, +) +from .layout import ( + coalesce, + complement, + composition, + cosize, + filter, + is_layout, + Layout, + LayoutBase, + left_inverse, + logical_divide, + logical_product, + make_layout, + right_inverse, + size, + slice_and_offset, + tiled_divide, + tiled_product, + zipped_divide, + zipped_product, +) +from .typing import Integer diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50994747431755daea4b16be4953da3f147a127f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/int_tuple.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/int_tuple.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9e9c58f28915112e898c688a26cd33386a11379 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/int_tuple.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/layout.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/layout.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06fa3f5975994be5340bcb205bceab3123347d6a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/layout.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/typing.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/typing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b05672085d035a78217131dd55c78b48553ce80 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/__pycache__/typing.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/int_tuple.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/int_tuple.py new file mode 100644 index 0000000000000000000000000000000000000000..bb3406a7399b1af1f892c17e2cf34755aeed244c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/int_tuple.py @@ -0,0 +1,269 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Functions for manipulating IntTuples +""" + +from functools import reduce +from itertools import chain +from typing import TypeAlias +from typing_extensions import TypeIs + +from .typing import Integer + + +# Type aliases for better readability +IntTuple: TypeAlias = int | tuple["IntTuple", ...] + + +def is_int(x: object) -> TypeIs[int]: + return isinstance(x, Integer) + + +def is_tuple(x: object) -> TypeIs[tuple]: + return isinstance(x, tuple) + + +def as_tuple(x: IntTuple) -> tuple[IntTuple, ...]: + if is_int(x): + return (x,) + return x + + +def match_structure(a: IntTuple, b: IntTuple) -> bool: + if is_int(a) and is_int(b): + return True + if is_tuple(a) and is_tuple(b): + return len(a) == len(b) and all(match_structure(x, y) for x, y in zip(a, b)) + return False + + +def flatten(t: IntTuple) -> tuple[int, ...]: + if is_tuple(t): + if len(t) == 0: + return () + else: + return tuple(i for a in t for i in flatten(a)) + else: + return (t,) + + +def signum(a: int) -> int: + return bool(a > 0) - bool(a < 0) + + +def product(a: IntTuple) -> int: + if is_tuple(a): + return reduce(lambda val, elem: val * product(elem), a, 1) + else: + return a + + +def inner_product(a: IntTuple, b: IntTuple) -> int: + if is_tuple(a) and is_tuple(b): # tuple tuple + assert len(a) == len(b) + return sum(inner_product(x, y) for x, y in zip(a, b)) + else: # "int" "int" + assert not is_tuple(a) and not is_tuple(b) + return a * b + + +def tuple_max(a: IntTuple) -> int: + if is_tuple(a): + return max(tuple_max(x) for x in a) + else: + return a + + +def elem_scale(a: IntTuple, b: IntTuple) -> IntTuple: + if is_tuple(a): + if is_tuple(b): # tuple tuple + assert len(a) == len(b) + return tuple(elem_scale(x, y) for x, y in zip(a, b)) + else: # tuple "int" + raise AssertionError("Invalid combination: tuple with int") + else: + if is_tuple(b): # "int" tuple + return elem_scale(a, product(b)) + else: # "int" "int" + return a * b + + +# Inclusive prefix ceil div with output congruent to input a +def shape_div(a: IntTuple, b: IntTuple) -> IntTuple: + if is_tuple(a): + if is_tuple(b): # tuple tuple + assert len(a) == len(b) + return tuple(shape_div(x, y) for x, y in zip(a, b)) + else: # tuple "int" + # r = [shape_div(a[0],b)] + [shape_div(a[i],b := shape_div(b, product(a[i-1]))) for i in range(1,len(a))] + r = [] + for v in a: + r.append(shape_div(v, b)) + b = shape_div(b, product(v)) + return tuple(r) + else: + if is_tuple(b): # "int" tuple + return shape_div(a, product(b)) + else: # "int" "int" + assert a % b == 0 or b % a == 0 + return (a + b - 1) // b + + +# Exclusive suffix product with output congruent to input a (lexicographic) +def suffix_product(a: IntTuple, init: IntTuple = 1) -> IntTuple: + # TODO: With all these length asserts, may want to create a zip_strict wrapper. + if is_tuple(a): + if is_tuple(init): # tuple tuple + assert len(a) == len(init) + return tuple(suffix_product(x, i) for x, i in zip(a, init)) + else: # tuple "int" + # Process from right to left for lexicographic ordering + # r = [prefix_product(a[len(a)-1],init)] + + # [prefix_product(a[i],init := init * product(a[i+1])) for i in range(len(a)-1,0)].reverse() + r = [] + + # Calculate products from right to left, appending to list + for i in range(len(a) - 1, -1, -1): + r.append(suffix_product(a[i], init)) + init = init * product(a[i]) + + # Reverse to get correct lexicographic order + r.reverse() + return tuple(r) + else: + if is_tuple(init): # "int" tuple + raise AssertionError("Invalid combination: int with tuple init") + else: # "int" "int" + return init + + +def idx2crd(idx: IntTuple, shape: IntTuple, stride: IntTuple | None = None) -> IntTuple: + if stride is None: + stride = suffix_product(shape) + + if is_tuple(idx): + if is_tuple(shape) and is_tuple(stride): # tuple tuple tuple + assert len(idx) == len(shape) and len(stride) == len(shape) + return tuple(idx2crd(i, s, d) for i, s, d in zip(idx, shape, stride)) + else: # tuple "int" "int" + raise AssertionError("Invalid combination: tuple with int stride") + else: + if is_tuple(shape) and is_tuple(stride): # "int" tuple tuple + assert len(shape) == len(stride) + return tuple(idx2crd(idx, s, d) for s, d in zip(shape, stride)) + else: # "int" "int" "int" + assert not is_tuple(shape) and not is_tuple(stride) + return (idx // stride) % shape # all are ints after type checks + + +def crd2idx( + crd: IntTuple | None, shape: IntTuple, stride: IntTuple | None = None +) -> int: + if stride is None: + stride = suffix_product(shape) + + if is_tuple(crd): + if is_tuple(shape) and is_tuple(stride): # tuple tuple tuple + assert len(crd) == len(shape) and len(stride) == len(shape) + return sum(crd2idx(c, s, d) for c, s, d in zip(crd, shape, stride)) + else: # tuple "int" "int" + raise AssertionError(f"Invalid combination: crd={crd}, shape={shape}") + else: + if crd is None: + crd = 0 + + if is_tuple(shape) and is_tuple(stride): # "int" tuple tuple + assert len(shape) == len(stride) + result = 0 + # Process from right to left for lexicographic ordering + for i in range(len(shape) - 1, 0, -1): + result += crd2idx(crd % product(shape[i]), shape[i], stride[i]) + crd = crd // product(shape[i]) + if len(shape) > 0: + result += crd2idx(crd, shape[0], stride[0]) + return result + else: # "int" "int" "int" + assert not is_tuple(shape) and not is_tuple(stride) + return crd * stride # all are ints after type checks + + +# Transform crd into the dst_shape's iteration space +def crd2crd( + crd: IntTuple, dst_shape: IntTuple, src_shape: IntTuple | None = None +) -> IntTuple: + if is_tuple(crd): + if is_tuple(dst_shape): # tuple tuple + assert len(crd) == len(dst_shape) + return tuple(crd2crd(x, y) for x, y in zip(crd, dst_shape)) + else: # tuple "int" + # Ambiguous unless we have src_shape + assert src_shape is not None + return crd2idx(crd, src_shape) + else: + if is_tuple(dst_shape): # "int" tuple + return idx2crd(crd, dst_shape) + else: # "int" "int" + assert crd < dst_shape + return crd + + +# Filter trg according to crd: keep only elements of trg that are paired with None +def slice_(crd: None | tuple | int, trg: tuple | int) -> tuple | int: + if is_tuple(crd): + if is_tuple(trg): # tuple tuple + assert len(crd) == len(trg) + # match C++ behavior of `filter_tuple` using `tuple_cat(...)` + return tuple( + chain( + *filter( # type: ignore[arg-type] # filter returns Iterator which is compatible + lambda x: x != (), + [slice_(c, s) for c, s in zip(crd, trg)], + ) + ) + ) + else: + raise AssertionError("Invalid combination: tuple crd with int trg") + elif crd is None: + # match C++ behavior `return cute::tuple{b};` + return (trg,) + else: + return () + + +# Determine if None appears at any of an int_tuples' terminals +def has_none(a: None | tuple | int) -> bool: + if is_tuple(a): + return any(has_none(v) for v in a) + else: + return a is None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/layout.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/layout.py new file mode 100644 index 0000000000000000000000000000000000000000..0adf94b5b142b925f7ab35dc82f46c4bf509001a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/layout.py @@ -0,0 +1,470 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Definition of CuTe Layouts and functions to manipulate them which works with the order +of lexicographic instead of co-lexicographic as implemented in the original layout.py +""" + +from itertools import chain +from typing import TypeAlias +from typing_extensions import Self, TypeIs + +from .int_tuple import ( + crd2idx, + flatten, + has_none, + IntTuple, + is_int, + is_tuple, + product, + slice_, + suffix_product, +) + + +# Type aliases +CoordinateType: TypeAlias = ( + int | IntTuple | tuple[object, ...] | None +) # Input for slice_ and crd2idx functions + + +class LayoutBase: + pass + + +def is_layout(x: object) -> TypeIs["Layout"]: + return isinstance(x, LayoutBase) + + +class Layout(LayoutBase): + def __init__(self, _shape: IntTuple, _stride: IntTuple | None = None) -> None: + self.shape = _shape + if _stride is None: + self.stride = suffix_product(self.shape) + else: + self.stride = _stride + + # operator == + def __eq__(self, other: object) -> bool: + if not isinstance(other, Layout): + return False + return self.shape == other.shape and self.stride == other.stride + + # operator len(L) (len [rank] like tuples) + def __len__(self) -> int: + if is_tuple(self.shape): + return len(self.shape) + else: + return 1 + + # operator () (map coord to idx) + def __call__(self, *args: CoordinateType) -> Self | int: + """ + Map a logical coordinate to a linear index (Coord has no Underscore slice operators) + OR + Slice the layout and return the sublayout (Coord has an Underscore slice op) + + Follow the same behavior of `Layout::operator(Coord const&)` in cute C++ + """ + if has_none(args): + if len(args) == 1: + return Layout(slice_(args[0], self.shape), slice_(args[0], self.stride)) + else: + return Layout(slice_(args, self.shape), slice_(args, self.stride)) + else: + if len(args) == 1: + return crd2idx(args[0], self.shape, self.stride) # type: ignore[arg-type] + else: + return crd2idx(args, self.shape, self.stride) # type: ignore[arg-type] + + # operator [] (get-i like tuples) + def __getitem__(self, i: int) -> Self: + if is_tuple(self.shape): + return Layout(self.shape[i], self.stride[i]) # type: ignore[index] + else: + assert i == 0 + return Layout(self.shape, self.stride) + + # size(layout) Size of the domain + def size(self) -> int: + return product(self.shape) + + # cosize(layout) Size of the codomain + def cosize(self) -> int: + return self(self.size() - 1) + 1 # type: ignore[operator] + + # print and str + def __str__(self) -> str: + return f"{self.shape}:{self.stride}" + + # error msgs and representation + def __repr__(self) -> str: + return f"Layout({self.shape},{self.stride})" + + +# Type aliases +LayoutOrIntTuple: TypeAlias = Layout | IntTuple +LayoutProfile: TypeAlias = tuple[object, ...] | Layout | None +LayoutInput: TypeAlias = Layout | IntTuple | tuple[object, ...] | None + + +# Make Layout from a list of layouts (each layout it's own mode in the result) +def make_layout(*layouts: Layout | tuple[Layout, ...]) -> Layout: + if len(layouts) == 1 and not is_layout(layouts[0]): + layouts = layouts[0] + + shape, stride = zip(*((a.shape, a.stride) for a in layouts)) # type: ignore[union-attr] + return Layout(shape, stride) + + +# Size of the domain +def size(layout: LayoutOrIntTuple) -> int: + if is_layout(layout): + return layout.size() + return product(layout) + + +# Size of the codomain +def cosize(layout: Layout) -> int: + return layout.cosize() + + +# Layout coalesce -- flatten and combine as many modes as possible while preserving the int-to-int function +def coalesce(layout: Layout, profile: LayoutProfile = None) -> Layout: + if is_tuple(profile): + assert len(layout) >= len(profile) + return make_layout( + chain( + (coalesce(layout[i], profile[i]) for i in range(len(profile))), # type: ignore[arg-type] + (layout[i] for i in range(len(profile), len(layout))), + ) + ) + + result_shape = [1] + result_stride = [0] + # Since we now follow lexicographic order, we need to process from right to left. + # And to make implementation more efficient, we append to the end of list and reverse it in the end. + for shape, stride in zip( + reversed(flatten(layout.shape)), reversed(flatten(layout.stride)) + ): + # skip their shape-1s + if shape == 1: + continue + # replace our shape-1 with anything + elif result_shape[-1] == 1: + result_shape[-1] = shape + result_stride[-1] = stride + # merge modes if the shape*stride match + elif result_shape[-1] * result_stride[-1] == stride: + result_shape[-1] = result_shape[-1] * shape + # append a new mode + else: + result_shape.append(shape) + result_stride.append(stride) + + if len(result_shape) == 1: + return Layout(result_shape[0], result_stride[0]) + else: + result_shape.reverse() + result_stride.reverse() + return Layout(tuple(result_shape), tuple(result_stride)) + + +# Layout filter -- replace all stride-0 modes with size-1 and then coalesce to remove them +def filter(layout: Layout, profile: LayoutProfile = None) -> Layout: + if is_tuple(profile): + assert len(layout) >= len(profile) + return make_layout( + chain( + (filter(layout[i], profile[i]) for i in range(len(profile))), # type: ignore[arg-type] + (layout[i] for i in range(len(profile), len(layout))), + ) + ) + + result_shape = [] + result_stride = [] + for shape, stride in zip(flatten(layout.shape), flatten(layout.stride)): + # skip their shape-1s and stride-0s + if not (shape == 1 or stride == 0): + result_shape.append(shape) + result_stride.append(stride) + + if len(result_shape) == 0: + return Layout(1, 0) + else: + return coalesce(Layout(tuple(result_shape), tuple(result_stride))) + + +# Layout composition +# Use tuples-of-layouts to perform this operation by-mode and None as no-op +def composition(layoutA: Layout, layoutB: LayoutInput) -> Layout: + if layoutB is None: + return layoutA + elif is_int(layoutB): + return composition(layoutA, Layout(layoutB)) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + return make_layout( + chain( + (composition(layoutA[i], layoutB[i]) for i in range(len(layoutB))), # type: ignore[arg-type] + (layoutA[i] for i in range(len(layoutB), len(layoutA))), + ) + ) + elif is_tuple(layoutB.shape): + return make_layout(composition(layoutA, layoutB_i) for layoutB_i in layoutB) # type: ignore[arg-type, attr-defined] + + if layoutB.stride == 0: + return Layout(layoutB.shape, 0) + else: + result_shape = [] + result_stride = [] + rest_shape = layoutB.shape + rest_stride = layoutB.stride + flat_A = coalesce(layoutA) + # when left layout is multi-dimensional sublayout, aka, self = (a,b,...,c):(x,y,...,z), layout = s:d, + # for integral s and d means that we want: + # (1) “remove” the first d elements from left, starting from rightmost. (This will increase the stride.) + # (2) “keep” the first s of those strided elements. (This does not affect the stride.) + # For example, if self = (6,2):(2,1), layout = (3:2) + # Step 1: remove the first 2 elements from self with stride increase, i.e., (6,2):(2,1) -> (6,1):(2,2) + # Step 2: keep the first 3 of those strided elements, i.e., (6,1):(2,2) -> (3,1):(2,2) + # Because we are going lexicographically, we go through left layout from right to left. + for curr_shape, curr_stride in zip( + reversed(flatten(flat_A.shape)[1:]), reversed(flatten(flat_A.stride)[1:]) + ): + assert curr_shape % rest_stride == 0 or rest_stride % curr_shape == 0 # type: ignore[operator] + new_shape = min(max(1, curr_shape // rest_stride), rest_shape) # type: ignore[operator] + + if new_shape != 1: + result_shape.append(new_shape) # Append to end, will reverse later + result_stride.append(rest_stride * curr_stride) + + rest_shape = rest_shape // new_shape # type: ignore[operator] + rest_stride = -( + -rest_stride // curr_shape # type: ignore[operator] + ) # Python exclusive impl: "//" is always floor div so == ceil_div(abs(rest_stride), curr_shape) * signum(rest_stride) + + # When left has single-size sublayout or reach the last sublayout, aka, left = a:b, layout = s:d, + # the result is rather trivial: left o layout = a:b o s:d = s:(b*d). + # For example, if self = (6:2), layout = (3:2), the result is (3:(2*2)) = (3:4). + if rest_shape != 1 or len(result_shape) == 0: + result_shape.append(rest_shape) # Append to end, will reverse later + result_stride.append(rest_stride * flatten(flat_A.stride)[0]) + + # Reverse the lists because we build lists in reverse order (append to end), this way it is more efficient. + result_shape.reverse() + result_stride.reverse() + + if len(result_shape) == 1: + return Layout(result_shape[0], result_stride[0]) # type: ignore[arg-type] + else: + return Layout(tuple(result_shape), tuple(result_stride)) # type: ignore[arg-type] + + +# Layout complement +def complement(layout: LayoutOrIntTuple, max_idx: int = 1) -> Layout: + if is_int(layout): + return complement(Layout(layout)) + + result_shape = [] + result_stride = [] + current_idx = 1 + + sorted_DS = sorted(zip(flatten(layout.stride), flatten(layout.shape))) # type: ignore[union-attr] + for stride, shape in sorted_DS: + if stride == 0 or shape == 1: + continue + + in_bound = current_idx <= shape * stride + # To support symbolic value which can't be evaluated now + assert (type(in_bound) is not bool) or in_bound + + result_shape.append(stride // current_idx) + result_stride.append(current_idx) + current_idx = shape * stride + + result_shape.append((max_idx + current_idx - 1) // current_idx) # ceil_div + result_stride.append(current_idx) + # This is different from original pycute implementation, because we want to follow the lexicographic order here + # where the right-most dimension is the innermost dimension (smallest stride). + result_shape.reverse() + result_stride.reverse() + + return coalesce(Layout(tuple(result_shape), tuple(result_stride))) + + +# Layout right inverse +def right_inverse(layout: LayoutOrIntTuple | None) -> Layout | None: + if layout is None: + return None + elif is_int(layout): + return Layout(layout) + + result_shape = [] + result_stride = [] + current_idx = 1 + + flat_shape = flatten(layout.shape) # type: ignore[union-attr] + flat_stride = flatten(layout.stride) # type: ignore[union-attr] + sorted_DSA = sorted(zip(flat_stride, flat_shape, suffix_product(flat_shape))) # type: ignore[arg-type] + for stride, shape, rstride in sorted_DSA: + if shape == 1: + continue + if current_idx != stride: + break + + result_shape.append(shape) + result_stride.append(rstride) + current_idx = shape * stride + + result_shape.reverse() + result_stride.reverse() + return coalesce(Layout(tuple(result_shape), tuple(result_stride))) + + +# Layout left inverse +def left_inverse(layout: LayoutOrIntTuple | None) -> Layout | None: + if layout is None: + return None + elif is_int(layout): + return Layout(layout) + return right_inverse(make_layout(complement(layout), layout)) # type: ignore[arg-type] + + +# Split a layout by the composition of B and the "rest" +# Use tuples-of-layouts to perform this operation by-mode and None as no-op +def logical_divide(layoutA: Layout, layoutB: LayoutInput) -> Layout: + if layoutB is None: + return layoutA + elif is_int(layoutB): + return logical_divide(layoutA, Layout(layoutB)) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + return make_layout( + chain( + ( + logical_divide(layoutA[i], layoutB[i]) # type: ignore[arg-type] + for i in range(len(layoutB)) + ), + (layoutA[i] for i in range(len(layoutB), len(layoutA))), + ) + ) + + return composition( + layoutA, + make_layout(layoutB, complement(layoutB, size(layoutA))), + ) + + +# Reproduce a layoutA over a layoutB +# Use tuples-of-layouts to perform this operation by-mode and None as no-op +def logical_product(layoutA: Layout, layoutB: LayoutInput) -> Layout: + if layoutB is None: + return layoutA + elif is_int(layoutB): + return logical_divide(layoutA, Layout(layoutB)) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + return make_layout( + chain( + ( + logical_product(layoutA[i], layoutB[i]) # type: ignore[arg-type] + for i in range(len(layoutB)) + ), + (layoutA[i] for i in range(len(layoutB), len(layoutA))), + ) + ) + + return make_layout( + layoutA, + composition(complement(layoutA, size(layoutA) * cosize(layoutB)), layoutB), + ) + + +# Gather the modes from a hierarchical logical_divide or logical_product +def hier_unzip( + splitter: object, + layoutA: Layout, + layoutB: LayoutInput, +) -> Layout: + if layoutB is None: + return make_layout(Layout(1, 0), layoutA) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + # A layout with shape ((A,a),(B,b),(C,c)) + split = make_layout( + hier_unzip(splitter, layoutA[i], layoutB[i]) # type: ignore[arg-type] + for i in range(len(layoutB)) + ) + # Gather to shape ((A,B,C,...),(a,b,c,...,y,z)) + return make_layout( + make_layout(split[i][0] for i in range(len(layoutB))), # type: ignore[arg-type] + make_layout( + chain( # type: ignore[arg-type] + (split[i][1] for i in range(len(layoutB))), + (layoutA[i] for i in range(len(layoutB), len(layoutA))), + ) + ), + ) + + # splitter must return a rank-2 layout + return splitter(layoutA, layoutB) # type: ignore[operator] + + +# Apply logical divide hierarchically and gather the split modes into two modes +def zipped_divide(layoutA: Layout, layoutB: LayoutInput) -> Layout: + return hier_unzip(logical_divide, layoutA, layoutB) + + +# Perform logical divide hierarchically and gather tiles (B-layouts) into a new mode +def tiled_divide(layoutA: Layout, layoutB: LayoutInput) -> Layout: + result = zipped_divide(layoutA, layoutB) + return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))]) # type: ignore[arg-type] + + +# Apply logical product hierarchically and gather the split modes into two modes +def zipped_product(layoutA: Layout, layoutB: LayoutInput) -> Layout: + return hier_unzip(logical_product, layoutA, layoutB) + + +# Perform logical product hierarchically and gather tiles (B-layouts) into a new mode +def tiled_product(layoutA: Layout, layoutB: LayoutInput) -> Layout: + result = zipped_product(layoutA, layoutB) + return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))]) # type: ignore[arg-type] + + +def slice_and_offset(crd: tuple[object, ...], layout: Layout) -> tuple[Layout, int]: + return ( + Layout(slice_(crd, layout.shape), slice_(crd, layout.stride)), + crd2idx(crd, layout.shape, layout.stride), # type: ignore[arg-type] + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/typing.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..5e6fe0a9c66e800186b4d84ef52cc12d6baeb1f6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_pycute/typing.py @@ -0,0 +1,42 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from abc import ABC + + +class Integer(ABC): # noqa: B024 # Uses __subclasshook__ instead of abstract methods + @classmethod + def __subclasshook__(cls, c: type) -> bool: + if c in [bool, float]: + return False + + return issubclass(c, int) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..85a313c779e7aa87726f425146048fcd37efd261 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__init__.py @@ -0,0 +1 @@ +from .api import _shard_tensor, load_with_process_group, shard_module, shard_parameter diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..016544309bcb5f60f0cb3bf4bf6093f63fee5e50 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae764e8c769c01ad50ca2d772b36d7087c584983 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/api.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..051a72aa04ddeb55af9f6ad20c4dd11429dab69d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/api.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/common_op_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/common_op_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a666db21eccd3c18ed59cba9c6b8fc454ba18ab7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/common_op_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/metadata.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/metadata.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d9ca75dbc5ce40be21a87fe34acd762f5633ba0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/metadata.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/op_registry_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/op_registry_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f24895d79c3828ffc47fea488ec7f8200838e77 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/op_registry_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/sharder.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/sharder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b3b1ac599d2ca794dbb05e83f9cd981c14de194 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/sharder.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6fd641b3f9443faa64b6b54c3ab209f8167a56f7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/_utils.py @@ -0,0 +1,32 @@ +from collections.abc import Sequence + +import torch +from torch.distributed._shard.metadata import ShardMetadata + + +DEPRECATE_MSG = "Please use DTensor instead and we are deprecating ShardedTensor." + + +def narrow_tensor_by_index( + tensor: torch.Tensor, + offsets: Sequence[int], + sizes: Sequence[int], +) -> torch.Tensor: + """ + Narrow the tensor according to ``offsets`` and ``sizes``. + """ + narrowed_tensor = tensor + for idx, (offset, size) in enumerate(zip(offsets, sizes)): + if size < tensor.size(idx): + # Reshape to get shard for this rank and we don't want autograd + # recording here for the narrow op and 'local_shard' should be a + # leaf variable in the autograd graph. + narrowed_tensor = narrowed_tensor.narrow(idx, offset, size) + return narrowed_tensor + + +def narrow_tensor(tensor: torch.Tensor, metadata: ShardMetadata) -> torch.Tensor: + """ + Narrow the tensor according to the metadata + """ + return narrow_tensor_by_index(tensor, metadata.shard_offsets, metadata.shard_sizes) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/api.py new file mode 100644 index 0000000000000000000000000000000000000000..82589119d7afa6086b6b6289954d88676516b620 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/api.py @@ -0,0 +1,305 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import distributed_c10d +from torch.distributed._shard.sharded_tensor import ShardedTensor + +from .sharder import Sharder +from .sharding_plan import ShardingPlan +from .sharding_spec import ChunkShardingSpec, ShardingSpec + + +def _shard_tensor( + tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None +) -> ShardedTensor: + """ + Given a :class:`torch.Tensor`, it shards that tensor according to the provided + ``sharding_spec``. ``src_rank`` denotes the source rank which would be + used as the ground truth of the data which would be scattered as shards + across the rest of the ranks. + + Args: + tensor (:class:`torch.Tensor`): Tensor needs to be sharded. + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + + Keyword args: + src_rank (int, optional): The source rank which is used as the ground truth of + the data for the parameter that would be sharded and scattered + across the rest of the ranks. + Default: 0. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + A :class:`ShardedTensor` sharded from the given tensor. + + .. warning:: + Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is + currently supported as the ``sharding_spec``. + """ + if not tensor.is_contiguous(): + raise ValueError("input tensor is not a contiguous Tensor") + + pg = ( + process_group + if process_group is not None + else distributed_c10d._get_default_group() + ) + world_size = dist.get_world_size(pg) + current_rank = dist.get_rank(pg) + + # Validate src_rank and sharding_spec are same across all ranks. + gathered_list = [None] * world_size + dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg) + + for idx, entry in enumerate(gathered_list): + if src_rank != entry[0]: # type: ignore[index] + raise ValueError( + f"src_rank={src_rank} on rank: {current_rank} does not " # type: ignore[index] + f"match with src_rank={entry[0]} on rank: {idx}" # type: ignore[index] + ) + if sharding_spec != entry[1]: # type: ignore[index] + raise ValueError( + f"sharding_spec={sharding_spec} on rank: {current_rank} does not " # type: ignore[index] + f"match with sharding_spec={entry[1]} on rank: {idx}" # type: ignore[index] + ) + + st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=pg) + + return st + + +def shard_parameter( + module: torch.nn.Module, + param_name: str, + sharding_spec: ShardingSpec, + src_rank=0, + process_group=None, +): + """ + Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that + module, it shards that parameter according to the provided + ``sharding_spec``. ``src_rank`` denotes the source rank which would be + used as the ground truth of the data which would be scattered as shards + across the rest of the ranks. + + This method replaces ``module.param_name`` with a + :class:`torch.distributed._sharded_tensor.ShardedTensor` + + Args: + module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded. + param_name (str): Name of the parameter of ``module`` that needs to be sharded. + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + + Keyword args: + src_rank (int, optional): The source rank which is used as the ground truth of + the data for the parameter that would be sharded and scattered + across the rest of the ranks. + Default: 0. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + .. warning:: + Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is + currently supported as the ``sharding_spec``. + """ + # Perform some validation first. + if not hasattr(module, param_name): + raise AttributeError(f"{module._get_name()} has no attribute `{param_name}`") + + tensor = getattr(module, param_name) + if not isinstance(tensor, torch.Tensor): + raise ValueError( + f"Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}" + ) + + if not tensor.is_contiguous(): + raise ValueError(f"param: {param_name} is not a contiguous Tensor") + + st = _shard_tensor(tensor, sharding_spec, src_rank, process_group) + + # Replace param with ShardedTensor. + module.register_parameter(param_name, nn.Parameter(st)) + + +# Tracks the current process group in the load context manager. +_CURRENT_PROCESS_GROUP: dist.ProcessGroup | None = None + + +@contextmanager +def load_with_process_group(process_group): + """ + Context manager to set the process group with which to load a ShardedTensor. + """ + global _CURRENT_PROCESS_GROUP + if _CURRENT_PROCESS_GROUP is not None: + raise RuntimeError( + 'ProcessGroup already set by previous "load_with_process_group" ' + "context manager" + ) + _CURRENT_PROCESS_GROUP = process_group + try: + yield process_group + finally: + _CURRENT_PROCESS_GROUP = None + + +def _get_current_process_group(): + """ + Retrieves the current process group set by ``load_with_process_group``. + If not set, it just returns the default group. + """ + global _CURRENT_PROCESS_GROUP + if _CURRENT_PROCESS_GROUP is None: + return distributed_c10d._get_default_group() + else: + return _CURRENT_PROCESS_GROUP + + +def _reshard_output( + module: torch.nn.Module, resharding_spec: ShardingSpec +) -> torch.nn.Module: + """ + Hook a module with output resharding in the forward pass according + to the given ``resharding_spec``. + + Args: + module (:class:`torch.nn.Module`): Module whose output needs to be resharded. + resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): + The specification describing how the output of the module will be resharded. + + Returns: + A :class:`torch.nn.Module` object with reshard API hooked. + """ + + def hook_func(_module, _input, output): + if isinstance(output, ShardedTensor): + return output.reshard(resharding_spec) + return output + + module.register_forward_hook(hook_func) + return module + + +def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module: + """ + Hook a module with local shards collection in the forward pass. + + This API is typically used to convert a sharded representation back to data parallel + representation. In particular, it returns the local tensor for this Shard. If the + size along the sharding dimension for the local tensor is 1, this dimension is removed + from the final result. For example a [4, 16] ShardedTensor across 4 ranks is typically + a local Tensor of size [16] across each rank and not [1, 16] across each rank. + + Args: + module (:class:`torch.nn.Module`): Module whose output is ShardedTensor and the + local tensor value needs to be returned. + + Returns: + A :class:`torch.nn.Module` object with collection API hooked. + """ + + def hook_func(_module, _input, output): + if isinstance(output, ShardedTensor): + local_tensor = output.local_tensor() + # Squeeze the # of dimensions manually, only applicable to ChunkShardingSpec + sharding_spec = output._sharding_spec + if ( + isinstance(sharding_spec, ChunkShardingSpec) + and local_tensor.size(sharding_spec.dim) == 1 # type: ignore[attr-defined, arg-type] + ): + local_tensor = local_tensor.squeeze( + output._sharding_spec.dim # type: ignore[attr-defined] + ) + return local_tensor + + module.register_forward_hook(hook_func) + return module + + +def shard_module(module: nn.Module, plan: ShardingPlan, src_rank=0, process_group=None): + """ + Shards a given module according to the provided sharding `plan`. This method + first shards all the parameters according to the given sharding `plan`. Then if + `output_plan` and `return_local_tensor` are specified in the sharding `plan`, it + will tag the output of modules according `output_plan`, convert the module's + output back to data parallel according to `return_local_tensor`. + + Needs to be called on all ranks in an SPMD fashion. + + Args: + module (:class:`torch.nn.Module`): The module to apply sharding to + plan (:class:`torch.distributed._shard.sharding_plan.ShardingPlan`): + The ShardingPlan which specified param name to ShardingSpec to apply to + each parameter. + + Keyword args: + src_rank (int, optional): The source rank which is used as the ground truth of + the data for the module that would be sharded and scattered across the rest + of the ranks. + Default: 0. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + """ + # record Sharder paths for sanity check on the plan to ensure items in the plan + # does not conflict with the submodule tree that the Sharder is working with + sharder_paths = [] + for name, spec in plan.plan.items(): + if isinstance(spec, Sharder): + sharder_paths.append(name) + + # shard the parameter according to the ShardingPlan + for name, spec in plan.plan.items(): + if isinstance(spec, ShardingSpec): + # if found a sharding spec, try to shard the parameter + module_path, _, param_name = name.rpartition(".") + + for sharder_path in sharder_paths: + if module_path.startswith(sharder_path): + raise RuntimeError( + f"ShardingPlan is in-valid, trying to shard a parameter: {name}," + f" but there's already a Sharder entry for module {sharder_path}," + f" parameter sharding should not conflict with the submodule tree" + f" that a Sharder is working with!" + ) + + mod = module.get_submodule(module_path) + shard_parameter( + mod, param_name, spec, src_rank=src_rank, process_group=process_group + ) + elif isinstance(spec, Sharder): + parent_mod_path, _, _mod_name = name.rpartition(".") + if name == "": + raise KeyError("Module path must not be empty for custom sharder!") + mod = module.get_submodule(name) + parent_mod = module.get_submodule(parent_mod_path) + sharded_mod = spec.shard(mod) + # swap this submodule with the sharded module + parent_mod.mod_name = sharded_mod + else: + raise TypeError( + f"Only `ShardingSpec` and `Sharder` are supported to shard '{name}'" + ) + + # reshard output if there's an entry in `reshard_output` for this module + if plan.output_plan is not None: + for module_path, output_spec in plan.output_plan.items(): + if isinstance(output_spec, ShardingSpec): + mod = module.get_submodule(module_path) + _reshard_output(mod, output_spec) + else: + raise TypeError( + f"Only `ShardingSpec` is supported as output_plan for '{module_path}'" + ) + # convert the output back to data parallel for the modules appears in + # `return_local_tensor` of the plan, we will call `_collect_local_shard` + # to collect the local tensor for output of modules + if plan.return_local_tensor is not None: + for module_path in plan.return_local_tensor: + mod = module.get_submodule(module_path) + _collect_local_shard(mod) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/checkpoint/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/checkpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..85915636a014640d8fff5a29db602c4a114f1b1d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/checkpoint/__init__.py @@ -0,0 +1,19 @@ +# Keep old package for BC purposes, this file should be removed once +# everything moves to the `torch.distributed.checkpoint` package. +import sys +import warnings + +import torch +from torch.distributed.checkpoint import * # noqa: F403 + + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`torch.distributed._shard.checkpoint` will be deprecated, " + "use `torch.distributed.checkpoint` instead", + DeprecationWarning, + stacklevel=2, + ) + +sys.modules["torch.distributed._shard.checkpoint"] = torch.distributed.checkpoint diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/checkpoint/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/checkpoint/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edadd4c543314ef0ab3177d92d2be546fccdaab2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/checkpoint/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/common_op_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/common_op_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c98b8c87ca2c7ceb1608a59673738a7e57333035 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/common_op_utils.py @@ -0,0 +1,64 @@ +# mypy: allow-untyped-defs + +import torch +from torch.utils import _pytree as pytree + + +def _basic_validation(op, args=(), kwargs=None): + """ + Common validation across all ops go in here. + """ + from torch.distributed._shard.sharded_tensor import ShardedTensor + + if len(args) == 0 and (kwargs is None or len(kwargs) == 0): + raise ValueError(f" No input for '{op.__name__}'!") + + # Validate types + has_distributed_tensor = False + + def is_distributed_tensor(e): + nonlocal has_distributed_tensor + if isinstance(e, ShardedTensor): + has_distributed_tensor = True + + pytree.tree_map_(is_distributed_tensor, args) + pytree.tree_map_(is_distributed_tensor, kwargs) + + if not has_distributed_tensor: + raise TypeError( + f"torch function '{op.__name__}', with args: {args} and " + f"kwargs: {kwargs} are called without any distributed tensor!" + ) + + # Validate all distributed tensors use the same PG. + cur_pg: torch.distributed.ProcessGroup | None = None + + def validate_pg(e): + nonlocal cur_pg + if isinstance(e, ShardedTensor): + if cur_pg is not None and e._process_group is not cur_pg: + raise RuntimeError( + "All distributed tensors should use the " + "same ProcessGroup if used together in an op." + ) + cur_pg = e._process_group + + pytree.tree_map_(validate_pg, args) + pytree.tree_map_(validate_pg, kwargs) + + +def _register_default_op(op, decorator): + @decorator(op) + def tensor_default_op(types, args=(), kwargs=None, pg=None): + """ + Handles ``__torch_function__`` dispatch for the default tensor ops that + behave the same as ``torch.Tensor`` such as ``torch.Tensor.shape`` or + ``torch.Tensor.dtype``. We simply lower to the real op call with + DisableTorchFunctionSubclass context like ``torch.Tensor.__torch_function__`` + to avoid recursions. + """ + if kwargs is None: + kwargs = {} + + with torch._C.DisableTorchFunctionSubclass(): + return op(*args, **kwargs) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/metadata.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..63ef073b1c494ab450bca79c83f3867548140fd8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/metadata.py @@ -0,0 +1,63 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from functools import reduce + +from torch.distributed.remote_device import _remote_device + + +@dataclass +class ShardMetadata: + """ + Represents a shard of the overall Tensor including its + offsets, lengths and device placement. + + Args: + shard_offsets(List[int]): Offsets in the original tensor indicating + the start offsets for this shard. Should have the same rank as + the original tensor. + shard_sizes(List[int]): Integers indicating the size of each + dimension for this shard. Should have the same rank as the + original tensor. + placement(:class:`torch.distributed._remote_device`): + Specifies the placement of this shard. + """ + + __slots__ = ["shard_offsets", "shard_sizes", "placement"] + + shard_offsets: list[int] + shard_sizes: list[int] + placement: _remote_device | None + + def __init__( + self, + shard_offsets: list[int], + shard_sizes: list[int], + placement: str | _remote_device | None = None, + ): + self.shard_offsets = shard_offsets + self.shard_sizes = shard_sizes + if isinstance(placement, str): + self.placement = _remote_device(placement) + else: + self.placement = placement + if len(self.shard_offsets) != len(self.shard_sizes): + raise ValueError( + f"shard_offsets and shard_sizes should have " + f"the same number of elements, found {len(self.shard_offsets)} " + f"and {self.shard_sizes} respectively" + ) + + for i in range(len(self.shard_offsets)): + if self.shard_offsets[i] < 0: + raise ValueError("shard_offsets should be >=0") + if self.shard_sizes[i] < 0: + raise ValueError("shard_sizes should be >= 0") + + def __hash__(self): + def _hash_reduce(a, b): + return (a << 8) + hash(b) + + res = reduce(_hash_reduce, self.shard_offsets, 37) + res = reduce(_hash_reduce, self.shard_sizes, res) + res = _hash_reduce(res, self.placement) + return res diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/op_registry_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/op_registry_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..12e0b1895e2f053e6c4a975cb6d3c0118baf50bb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/op_registry_utils.py @@ -0,0 +1,41 @@ +# mypy: allow-untyped-defs +import functools +from inspect import signature + +from .common_op_utils import _basic_validation + + +""" +Common utilities to register ops on ShardedTensor +and PartialTensor. +""" + + +def _register_op(op, func, op_table): + """ + Performs basic validation and registers the provided op in the given + op_table. + """ + if len(signature(func).parameters) != 4: + raise TypeError( + f"Custom sharded op function expects signature: " + f"(types, args, kwargs, process_group), but received " + f"signature: {signature(func)}" + ) + + op_table[op] = func + + +def _decorator_func(wrapped_func, op, op_table): + """ + Decorator function to register the given ``op`` in the provided + ``op_table`` + """ + + @functools.wraps(wrapped_func) + def wrapper(types, args, kwargs, process_group): + _basic_validation(op, args, kwargs) + return wrapped_func(types, args, kwargs, process_group) + + _register_op(op, wrapper, op_table) + return wrapper diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..effae2e3cd1b89027cf06bf6e603e6ca84551520 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/__init__.py @@ -0,0 +1,53 @@ +from collections.abc import Iterator +from typing import Union + +import torch.nn as nn +from torch.distributed._shard.sharded_tensor import ShardedTensor + +from .api import ShardedOptimizer + + +def named_params_with_sharded_tensor( + module: nn.Module, + prefix: str = "", + recurse: bool = True, +) -> Iterator[tuple[str, nn.Parameter | ShardedTensor]]: + r"""Returns an iterator over module parameters (together with the + ShardedTensor parameters), yielding both the name of the parameter + as well as the parameter itself. This is typically passed to a + :class:torch.distributed._shard.sharded_optim.ShardedOptimizer + + Args: + prefix (str): prefix to prepend to all parameter names. + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only parameters that + are direct members of this module. + + Yields: + (str, Union[Tensor, ShardedTensor]): Tuple containing + the name and parameter (or ShardedTensor parameter) + + Example:: + + >>> # xdoctest: +SKIP + >>> model = torch.nn.Linear(*linear_size) + >>> shard_parameter(model, "weight", spec) + >>> for name, param in named_params_with_sharded_tensor(model): + >>> if name in ['weight']: + >>> print(param.size()) + + """ + modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)] + + memo = set() + for mod_prefix, mod in modules: + # find all sharded tensor params + for name, val in vars(mod).items(): + if isinstance(val, ShardedTensor) and val not in memo: + memo.add(val) + name = mod_prefix + ("." if mod_prefix else "") + name + yield name, val + + # find all nn.Parameters + for name, val in module.named_parameters(): + yield name, val diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cb79f145c631dac2913d75a2f38209b20dd6efa Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/api.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74b556eaa4b4450879076c96bb1a7f2f6b60e92d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/api.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/api.py new file mode 100644 index 0000000000000000000000000000000000000000..2989e85496090782fcdb39d0e6613b82155ea23c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/api.py @@ -0,0 +1,102 @@ +# mypy: allow-untyped-defs +from collections.abc import Mapping +from typing import Any + +import torch.optim as optim +from torch import Tensor +from torch.distributed._shard.sharded_tensor import ShardedTensor + + +class ShardedOptimizer(optim.Optimizer): + def __init__( + self, + named_params: Mapping[str, Tensor | ShardedTensor], + optimizer_class, + *optimizer_args, + **optimizer_kwargs, + ): + """ + ShardedOptimizer collects all tensors and local shard tensors of + ShardedTensor, then use these tensors as ``params`` for optimizers + + Args: + named_params (Dict[str, Union[Tensor, ShardedTensor]]) : a Dict + of parameters, where key is the parameter key, value is either + Tensor or ShardedTensor parameter. + optimizer_class (torch.optim.Optimizer): the Optimizer to use + locally, i.e. torch.optim.SGD, torch.optim.Adagrad, etc. + *optimizer_args: the arguments to initialize the optimizer. + **optimizer_kwargs: the key-word arguments to initialize the optimizer. + + """ + tensors: list[Tensor] = [] + for value in named_params.values(): + if isinstance(value, ShardedTensor): + tensors.extend( + local_shard.tensor for local_shard in value.local_shards() + ) + else: + tensors.append(value) + + self.named_params = named_params + self._optim = optimizer_class(tensors, *optimizer_args, **optimizer_kwargs) + self.param_groups = self._optim.param_groups + self.state = self._optim.state + + def zero_grad(self, set_to_none: bool = True): # type: ignore[override] + r"""Resets the gradients of all optimized :class:`torch.Tensor` s. + + Args: + set_to_none (bool): instead of setting to zero, set the grads to None. + This will in general have lower memory footprint, and can modestly improve performance. + However, it changes certain behaviors. For example: + 1. When the user tries to access a gradient and perform manual ops on it, + a None attribute or a Tensor full of 0s will behave differently. + 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s + are guaranteed to be None for params that did not receive a gradient. + 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None + (in one case it does the step with a gradient of 0 and in the other it skips + the step altogether). + """ + self._optim.zero_grad(set_to_none) + + def step(self, closure=None): + r"""Performs a single optimization step (parameter update). + + Args: + closure (Callable): A closure that reevaluates the model and + returns the loss. Optional for most optimizers. + + .. note:: + Unless otherwise specified, this function should not modify the + ``.grad`` field of the parameters. + """ + self._optim.step(closure) + + def state_dict(self) -> dict[str, Any]: + """ + Returned state and param_groups will contain parameter keys + instead of parameter indices like torch.optim.Optimizer. + This allows for advanced functionality like optimizer re-sharding to be implemented. + """ + # TODO: implement state_dict + raise NotImplementedError("ShardedOptimizer state_dict not implemented yet!") + + def load_state_dict(self, state_dict: Mapping[str, Any]): + r"""Loads the ShardedOptimizer state. + + Args: + state_dict (dict): ShardedOptimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # TODO: implement load_state_dict + raise NotImplementedError( + "ShardedOptimizer load_state_dict not implemented yet!" + ) + + def add_param_group(self, param_group: Any): + r"""Add a new param group""" + # TODO: implement add_param_group + raise NotImplementedError( + "ShardedOptimizer add_param_group not implemented yet!" + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d3af3ed3595378ca8522384f295ef6ea9930ebf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__init__.py @@ -0,0 +1,490 @@ +# mypy: allow-untyped-defs +import functools +from typing import TYPE_CHECKING + +import torch +from torch.distributed._shard.op_registry_utils import _decorator_func + +from .api import ( + _CUSTOM_SHARDED_OPS, + _SHARDED_OPS, + Shard, + ShardedTensor, + ShardedTensorBase, + ShardedTensorMetadata, + TensorProperties, +) +from .metadata import ShardMetadata # noqa: F401 + + +if TYPE_CHECKING: + from torch.distributed._shard.sharding_spec import ShardingSpec +else: + ShardingSpec = "ShardingSpec" + + +def empty( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Returns a :class:`ShardedTensor` filled with uninitialized data. + Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + return ShardedTensor( + sharding_spec, + *size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + + +def ones( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Returns a :class:`ShardedTensor` with the scalar value 1. + Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + return full( + sharding_spec, + size, + fill_value=1, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + + +def zeros( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Returns a :class:`ShardedTensor` filled with the scalar value 0. + Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + return full( + sharding_spec, + size, + fill_value=0, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + + +def full( + sharding_spec: ShardingSpec, + size, + fill_value, + *, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Creates a :class:`ShardedTensor` filled with fill_value. The tensor's dtype + is inferred from fill_value. If dtype is specified, it will override the + inferred type from fill_value. Needs to be called on all ranks in an SPMD fashion. + Args: + sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the + output tensor. + fill_value (Scalar) - the value to fill the output tensor with. + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + Returns: + A :class:`ShardedTensor` object on each rank + """ + sharded_tensor = ShardedTensor( + sharding_spec, + *size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + torch.nn.init.constant_(sharded_tensor, fill_value) # type: ignore[arg-type] + return sharded_tensor + + +def rand( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)`. The shape of the tensor is defined by the + variable argument `size`. Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the + output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + sharded_tensor = ShardedTensor( + sharding_spec, + *size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + torch.nn.init.uniform_(sharded_tensor, 0, 1) # type: ignore[arg-type] + return sharded_tensor + + +def randn( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution + with mean `0` and variance `1` (also called standard normal distribution). The shape + of the tensor is defined by the variable argument `size`. Needs to be called on all ranks + in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the + output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + sharded_tensor = ShardedTensor( + sharding_spec, + *size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + torch.nn.init.normal_(sharded_tensor, 0, 1) # type: ignore[arg-type] + return sharded_tensor + + +def init_from_local_shards( + local_shards: list[Shard], *global_size, process_group=None, init_rrefs=False +) -> ShardedTensor: + """ + Creates an :class:`ShardedTensor` from local shards and the global metadata. + Needs to be called on all ranks in an SPMD fashion. + + Args: + local_shards (List[:class `torch.distributed._shard.sharded_tensor.Shard`]): A list + of shards that represent the local shards on this rank. + global_size (int...): a list, tuple, or `torch.Size` of integers defining the + shape of the overall sharded tensor. + + Keyword args: + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object handle on this rank + + + Examples: + Suppose we want construct a sharded tensor on two ranks, global size = (10, 5), + each shard have a (5, 5) local tensor, we can do it like below: + + on rank 0: + >>> # xdoctest: +SKIP("not distributed") + >>> local_shard_metadata = ShardMetadata( + >>> shard_offsets=[0, 0], + >>> shard_lengths=[5, 5], + >>> placement="rank:0/cuda:0" + >>> ) + >>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)] + >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5]) + + on rank 1: + >>> # xdoctest: +SKIP("not distributed") + >>> local_shard_metadata = ShardMetadata( + >>> shard_offsets=[5, 0], + >>> shard_lengths=[5, 5], + >>> placement="rank:1/cuda:1" + >>> ) + >>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)] + >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5]) + """ + return ShardedTensor._init_from_local_shards( + local_shards, *global_size, process_group=process_group, init_rrefs=init_rrefs + ) + + +def state_dict_hook(module, destination, prefix, local_metadata): + """ + Hook to add ShardedTensor to Module's ``state_dict``. Needs to be + registered to the Module using + :meth:`torch.nn.Module._register_state_dict_hook`. + """ + for submodule_name, submodule in module.named_modules(): + for attr_name, attr in submodule.__dict__.items(): + if isinstance(attr, ShardedTensor): + mod_prefix = prefix + submodule_name + key = mod_prefix + ("." if mod_prefix else "") + attr_name + destination[key] = attr + + +def pre_load_state_dict_hook( + module, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + """ + Pre-load state dict hook to add ShardedTensor to the module. + """ + for submodule_name, submodule in module.named_modules(): + for attr_name in submodule.__dict__: + mod_prefix = prefix + submodule_name + key = mod_prefix + ("." if mod_prefix else "") + attr_name + if key in state_dict: + if isinstance(state_dict[key], ShardedTensor): + setattr(submodule, attr_name, state_dict[key]) + + +def custom_sharded_op_impl(func): + """ + Provides a way for users to write their own custom sharded operator. This + can be used to override existing ShardedTensor operators or write a new + one not supported by ShardedTensor. If the operator in question is covered + by ``__torch_function__`` dispatch and has a ShardedTensor as any of its + parameters, the function provided will be invoked for that operator. + + Example:: + >>> # xdoctest: +SKIP + >>> @custom_sharded_op_impl(torch.nn.functional.linear) + >>> def my_custom_sharded_linear(types, args, kwargs, process_group): + >>> ... + >>> # xdoctest: +SKIP("Undefined variables") + >>> input = torch.rand(10, 32) + >>> weight = sharded_tensor.rand(32, 16) + >>> bias = torch.rand(16) + >>> # This will call 'my_custom_sharded_linear' + >>> torch.nn.functional.linear(input, weight, bias) + + The types, args and kwargs parameters are the same parameters that are + passed to ``__torch_function__`` dispatch API + (https://pytorch.org/docs/stable/notes/extending.html#extending-torch). + There is an additional ``process_group`` parameter which is the + process_group used for the ShardedTensor and can be used by + implementations for communications within a sharded implementation. + + Args: + func(Callable): Torch function for which we want to provide a sharded + implementation (ex: torch.nn.functional.linear) + """ + return functools.partial(_decorator_func, op=func, op_table=_CUSTOM_SHARDED_OPS) + + +def _sharded_op_impl(func): + """ + Decorator to register a default sharded op. + """ + return functools.partial(_decorator_func, op=func, op_table=_SHARDED_OPS) + + +# Import all builtin sharded ops +from ._ops import * # noqa: F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e44eb5c94bb2018254eb370dc613ce5c10b55e18 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/api.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13cc29b12f53219bc412efc39ac3a98110de16c0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/api.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logger.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logger.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02f4a39e394d8c8155601f381e84264b893119a8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logger.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logging_handlers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logging_handlers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdc7d1852fd9debc64314ac8cde07ff6a1ee7ab0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logging_handlers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/metadata.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/metadata.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f1038226e7a334195ba9dc03362f043c9bacdb2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/metadata.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/reshard.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/reshard.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4ae142ff287c17467440b5856dfecee2984e28e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/reshard.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/shard.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/shard.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9719c5ce3c1dc3702c8f38deeb86ad213c77a488 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/shard.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba0e7f464bb8b99e65bb79f4bd8c1d4e6f2c8086 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..be6d01fc8e54ee214fafa847c9261db375d8b87e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__init__.py @@ -0,0 +1,13 @@ +import torch.distributed._shard.sharded_tensor._ops.misc_ops +import torch.distributed._shard.sharded_tensor._ops.tensor_ops + +# Import all ChunkShardingSpec ops +from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding import ( + sharded_embedding, +) +from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding_bag import ( + sharded_embedding_bag, +) + +from .binary_cmp import allclose, equal +from .init import constant_, kaiming_uniform_, normal_, uniform_ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d7c323271d96ff0ef2a4c3cfca857457f154e88 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/_common.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/_common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97d65c2f08b15c98b3a785052de72186bbc101cd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/_common.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/binary_cmp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/binary_cmp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98a36dc022014db087d445cf8386db3b141fa46d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/binary_cmp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/init.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/init.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8be9fa0d9562aff0bff318ce9ae86f3b64e9e59b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/init.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/misc_ops.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/misc_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f153ebbbb6dde3be3b8702e81daf2c3af938d0c0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/misc_ops.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/tensor_ops.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/tensor_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..760723cbb0e9a396cea9bc6a85b7dd680c67c157 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/tensor_ops.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/_common.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/_common.py new file mode 100644 index 0000000000000000000000000000000000000000..0a356e524a47a6f1e73022a707f19d7ddb8c935d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/_common.py @@ -0,0 +1,115 @@ +# mypy: allow-untyped-defs +import functools + +from torch.distributed._shard.common_op_utils import _basic_validation +from torch.distributed._shard.sharded_tensor import ( + _sharded_op_impl, + Shard, + ShardedTensor, +) + + +def _sharded_op_common(op, early_stop_func, extra_check): + """ + Inject sharded tensor op registration with common logics executed before + different behaviors are done on either local shards or a local tensor. + + Example:: + >>> # xdoctest: +SKIP("Undefined variables") + >>> op = torch.transpose + >>> @_sharded_op_impl(op) + >>> @_sharded_op_common(op, early_stop_func, extra_check) + >>> def sharded_tensor_op(types, args, kwargs, process_group): + >>> ... + >>> + >>> st = sharded_tensor.rand(32, 16) + >>> st.transpose(1, 2) + >>> # This will call '_sharded_op_common' + + Args: + op: The op to be registered and applied to all shards of the st. + early_stop_func (Callable, optional): the func for early stop. + Default: if ``None``, no early stop. + extra_check (Callable, optional): the func for extra condition check. + Default: if ``None``, no extra check. + + Return: + func (Callable): Torch function for which we want to provide a sharded + implementation (ex: torch.transpose) + """ + + def decorator_sharded_func(wrapped_func): + @functools.wraps(wrapped_func) + def wrapper(types, args=(), kwargs=None, pg=None): + _basic_validation(op, args, kwargs) + + # pyrefly: ignore [index-error] + st = args[0] + if kwargs is None: + kwargs = {} + if extra_check: + extra_check(*args, **kwargs) + if early_stop_func: + early_stop = early_stop_func(*args, **kwargs) + if early_stop: + return st + return wrapped_func(types, args, kwargs, pg) + + return wrapper + + return decorator_sharded_func + + +def _register_sharded_op_on_local_shards( + op, early_stop_func=None, extra_check=None, customized_func=None +): + """ + Handles ``__torch_function__`` dispatch for ops which are performed on + each shard of the sharded tensor such as elementwise op like + ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``. + + For more complicated ops, a customized func can be used to generate + the new shards and sharded tensor size. + + This function expects that the original ShardingSpec for the ShardedTensor + is preserved irrespective of whether or not a customized function is used. + + Args: + op: The op to be registered and applied to all shards of the st. + early_stop_func (Callable, optional): the func for early stop. + Default: if ``None``, no early stop. + extra_check (Callable, optional): the func for extra condition check. + Default: if ``None``, no extra check. + customized_func (Callable, optional): the func for customized logic + to generate new shards and sharded tensor size. + Default: if ``None``, we simply lower to the real op call with + all local shards of the st. + + Return: + func (Callable): registered implementation for sharded op for + ``__torch_function__`` dispatch. + """ + + @_sharded_op_impl(op) + @_sharded_op_common(op, early_stop_func, extra_check) + def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None): + # pyrefly: ignore [index-error] + st = args[0] + st_metadata = st.metadata() + local_shards = st.local_shards() + local_shards_new = [] + if customized_func: + local_shards_new, st_metadata = customized_func(args, kwargs, pg) + else: + for local_shard in local_shards: + args = (local_shard.tensor, *args[1:]) + local_shards_new.append( + Shard(op(*args, **kwargs), local_shard.metadata) + ) + return ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards_new, + st_metadata, + process_group=pg, + init_rrefs=st._init_rrefs, + sharding_spec=st.sharding_spec(), + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py new file mode 100644 index 0000000000000000000000000000000000000000..0548b81fb90af087593d05695418664c6d109f2d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py @@ -0,0 +1,78 @@ +# mypy: allow-untyped-defs +import torch +import torch.distributed as dist +import torch.distributed.distributed_c10d as distributed_c10d +from torch.distributed._shard.sharded_tensor import _sharded_op_impl, ShardedTensor + + +def _communicate_result(result, pg): + # Gather results from all ranks. + if result: + result_tensor = torch.ones(1, device=torch.device(torch.cuda.current_device())) + else: + result_tensor = torch.zeros(1, device=torch.device(torch.cuda.current_device())) + + dist.all_reduce(result_tensor, group=pg) + + expected_result = torch.ones( + 1, device=torch.device(torch.cuda.current_device()) + ) * dist.get_world_size(pg) + + return torch.equal(result_tensor, expected_result) + + +def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None): + if len(args) != 2: + raise ValueError(f"Expected two arguments for torch.{cmp_fun.__name__}") + + st1 = args[0] + st2 = args[1] + if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)): + raise TypeError( + f"Both arguments to torch.{cmp_fun.__name__} need to be of type ShardedTensor" + ) + + # Verify same PG + if st1._process_group != st2._process_group: + return False + + if distributed_c10d._rank_not_in_group( + st1._process_group + ) or distributed_c10d._rank_not_in_group(st2._process_group): + return distributed_c10d._rank_not_in_group( + st1._process_group + ) == distributed_c10d._rank_not_in_group(st2._process_group) + + # Verify metadata + if st1.metadata() != st2.metadata(): + return _communicate_result(False, st1._process_group) + + # Verify number of local shards + st1_local_shards = st1.local_shards() + st2_local_shards = st2.local_shards() + if len(st1_local_shards) != len(st2_local_shards): + return _communicate_result(False, st1._process_group) + + # kwargs must be dict-like + if kwargs is None: + kwargs = {} + # Verify each local shard + for idx in range(len(st1_local_shards)): + if st1_local_shards[idx].metadata != st2_local_shards[idx].metadata: + return _communicate_result(False, st1._process_group) + if not cmp_fun( + st1_local_shards[idx].tensor, st2_local_shards[idx].tensor, **kwargs + ): + return _communicate_result(False, st1._process_group) + + return _communicate_result(True, st1._process_group) + + +@_sharded_op_impl(torch.equal) +def equal(types, args, kwargs, process_group): + return binary_cmp(torch.equal, types, args, kwargs, process_group) + + +@_sharded_op_impl(torch.allclose) +def allclose(types, args, kwargs, process_group): + return binary_cmp(torch.allclose, types, args, kwargs, process_group) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/init.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/init.py new file mode 100644 index 0000000000000000000000000000000000000000..d0e576b45ebeeda7661e0011b6a100cd60d0f5f4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/init.py @@ -0,0 +1,164 @@ +# mypy: allow-untyped-defs +import torch +import torch.distributed._shard.sharded_tensor as sharded_tensor +from torch.distributed._shard.sharded_tensor import _sharded_op_impl + + +def validate_param(param, param_name): + if param is None: + raise ValueError(f"param: {param_name} shouldn't be None!") + + +@_sharded_op_impl(torch.nn.init.uniform_) +def uniform_(types, args=(), kwargs=None, pg=None): + r""" + Fills the Tensor in tensor.local_shards with values drawn from the uniform + distribution :math:`\mathcal{U}(a, b)`. + Args: + tensor: tensor sharded across devices + a: the lower bound of the uniform distribution + b: the upper bound of the uniform distribution + """ + validate_param(kwargs, "kwargs") + # pyrefly: ignore [unsupported-operation] + sharded_tensor = kwargs["tensor"] + validate_param(sharded_tensor, "tensor") + # pyrefly: ignore [unsupported-operation] + a = kwargs["a"] + validate_param(a, "a") + # pyrefly: ignore [unsupported-operation] + b = kwargs["b"] + validate_param(b, "b") + + for shard in sharded_tensor.local_shards(): + torch.nn.init.uniform_(shard.tensor, a=a, b=b) + return sharded_tensor + + +@_sharded_op_impl(torch.nn.init.normal_) +def normal_(types, args=(), kwargs=None, pg=None): + r""" + Fills the Tensors in tensor.local_shards with values drawn from the normal + distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. + Args: + tensor: tensor sharded across devices + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + """ + validate_param(kwargs, "kwargs") + # pyrefly: ignore [unsupported-operation] + sharded_tensor = kwargs["tensor"] + validate_param(sharded_tensor, "tensor") + # pyrefly: ignore [unsupported-operation] + mean = kwargs["mean"] + validate_param(mean, "mean") + # pyrefly: ignore [unsupported-operation] + std = kwargs["std"] + validate_param(std, "std") + + for shard in sharded_tensor.local_shards(): + torch.nn.init.normal_(shard.tensor, mean=mean, std=std) + return sharded_tensor + + +@_sharded_op_impl(torch.nn.init.kaiming_uniform_) +def kaiming_uniform_(types, args=(), kwargs=None, pg=None): + r""" + Fills the Tensors in tensor.local_shards with values according to the method + described in `Delving deep into rectifiers: Surpassing human-level + performance on ImageNet classification` - He, K. et al. (2015), using a + uniform distribution. The resulting tensor will have values sampled from + :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} + Also known as He initialization. + Args: + tensor: tensor sharded across devices + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + """ + validate_param(kwargs, "kwargs") + # pyrefly: ignore [unsupported-operation] + sharded_tensor = kwargs["tensor"] + validate_param(sharded_tensor, "tensor") + # pyrefly: ignore [unsupported-operation] + a = kwargs["a"] + validate_param(a, "a") + # pyrefly: ignore [unsupported-operation] + mode = kwargs["mode"] + validate_param(mode, "mode") + # pyrefly: ignore [unsupported-operation] + nonlinearity = kwargs["nonlinearity"] + validate_param(nonlinearity, "nonlinearity") + + for shard in sharded_tensor.local_shards(): + torch.nn.init.kaiming_uniform_( + shard.tensor, a=a, mode=mode, nonlinearity=nonlinearity + ) + return sharded_tensor + + +@_sharded_op_impl(torch.nn.init.constant_) +def constant_(types, args=(), kwargs=None, pg=None): + r""" + Fills the input ShardedTensor with the value \text{val}val. + Args: + tensor: tensor sharded across devices + val: the value to fill the tensor with + """ + validate_param(kwargs, "kwargs") + # pyrefly: ignore [unsupported-operation] + sharded_tensor = kwargs["tensor"] + validate_param(sharded_tensor, "tensor") + # pyrefly: ignore [unsupported-operation] + val = kwargs["val"] + validate_param(val, "val") + for shard in sharded_tensor.local_shards(): + torch.nn.init.constant_(shard.tensor, val=val) + return sharded_tensor + + +tensor_like_creation_op_map = { + torch.full_like: sharded_tensor.full, + torch.empty_like: sharded_tensor.empty, + torch.zeros_like: sharded_tensor.zeros, + torch.ones_like: sharded_tensor.ones, + torch.rand_like: sharded_tensor.rand, + torch.randn_like: sharded_tensor.randn, +} + + +# tensor ops that behave the same as the default tensor +def register_tensor_creation_op(op): + @_sharded_op_impl(op) + def tensor_creation_op(types, args=(), kwargs=None, pg=None): + """ + Handles ``__torch_function__`` dispatch for tensor creation ops that + takes a ShardedTensor as argument, such as ``torch.zeros_like`` or + ``torch.full_like``. + """ + creation_op = tensor_like_creation_op_map.get(op) + if creation_op is None: + raise RuntimeError(f"Tensor creation {op} not supported!") + if kwargs is None: + kwargs = {} + + # pyrefly: ignore [index-error] + st = args[0] + + new_st = creation_op(st.sharding_spec(), st.size(), *args[1:], **kwargs) # type: ignore[operator] + return new_st + + +register_tensor_creation_op(torch.full_like) +register_tensor_creation_op(torch.empty_like) +register_tensor_creation_op(torch.zeros_like) +register_tensor_creation_op(torch.ones_like) +register_tensor_creation_op(torch.rand_like) +register_tensor_creation_op(torch.randn_like) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..8b84c1684c32456989e3998b3d4c30c34cb5dbf4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py @@ -0,0 +1,12 @@ +# mypy: allow-untyped-defs +import torch +from torch.distributed._shard.sharded_tensor import _sharded_op_impl + + +# This is used by `_apply()` within module.py to set new +# parameters after apply a certain method, we should follow +# the future behavior of overwriting the existing tensor +# instead of doing in-place change using `.data = `. +@_sharded_op_impl(torch._has_compatible_shallow_copy_type) +def tensor_has_compatible_shallow_copy_type(types, args=(), kwargs=None, pg=None): + return False diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d5b7ad7c77b1b7948f5464cde0bee0f703d738fb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py @@ -0,0 +1,222 @@ +# mypy: allow-untyped-defs +import copy + +import torch +from torch.distributed._shard.common_op_utils import _register_default_op +from torch.distributed._shard.sharded_tensor import ( + _sharded_op_impl, + Shard, + ShardedTensor, +) + +from ._common import _register_sharded_op_on_local_shards + + +# Tensor properties access +_register_default_op(torch.Tensor.shape.__get__, _sharded_op_impl) # type: ignore[attr-defined] +_register_default_op(torch.Tensor.dtype.__get__, _sharded_op_impl) # type: ignore[attr-defined] +_register_default_op(torch.Tensor.layout.__get__, _sharded_op_impl) # type: ignore[attr-defined] +_register_default_op(torch.Tensor.size, _sharded_op_impl) +_register_default_op(torch.Tensor.dim, _sharded_op_impl) +_register_default_op(torch.Tensor.ndim.__get__, _sharded_op_impl) # type: ignore[attr-defined] +_register_default_op(torch.Tensor.is_contiguous, _sharded_op_impl) +_register_default_op(torch.Tensor.contiguous, _sharded_op_impl) +_register_default_op(torch.Tensor.is_floating_point, _sharded_op_impl) + +# __reduce_ex__ to dispatch to get_state/set_state +_register_default_op(torch.Tensor.__reduce_ex__, _sharded_op_impl) + +# autograd related properties +_register_default_op(torch.Tensor.requires_grad.__get__, _sharded_op_impl) # type: ignore[attr-defined] +# TODO: set grad with a ShardedTensor that consists of all local grads +_register_default_op(torch.Tensor.grad.__get__, _sharded_op_impl) # type: ignore[union-attr] +_register_default_op(torch.Tensor.grad_fn.__get__, _sharded_op_impl) # type: ignore[union-attr] +_register_default_op(torch.Tensor.is_leaf.__get__, _sharded_op_impl) # type: ignore[attr-defined] + + +# device property is ambiguous as from a global prospective, +# ShardedTensor.device consists of multiple devices (might even across hosts) +# We choose to return the current device of the local tensor to represent +# the device property on each rank +@_sharded_op_impl(torch.Tensor.device.__get__) +def tensor_device(types, args=(), kwargs=None, pg=None): + # pyrefly: ignore [index-error] + self_st = args[0] + # Validate types + if not isinstance(self_st, ShardedTensor): + raise TypeError("input needs to be a ShardedTensor") + dev: torch.device + if self_st._local_shards: + dev = self_st._local_shards[0].tensor.device + elif pg and pg._get_backend_name() == "gloo": + dev = torch.device("cpu") + else: + dev = torch.device(torch.cuda.current_device()) + return dev + + +@_sharded_op_impl(torch.Tensor.is_meta.__get__) # type: ignore[attr-defined] +def st_is_meta(types, args=(), kwargs=None, pg=None): + # pyrefly: ignore [index-error] + return args[0].local_tensor().is_meta + + +def sharded_type_as_check(*args, **kwargs): + """ + Perform extra checks for the sharded_type_as op such as the input needs to + be either a Tensor or ShardedTensor. + + Args: same as ``torch.Tensor.type_as``. + + Return: None + """ + if len(args) < 2: + raise ValueError("Needs to give a tensor to cast type as!") + if not isinstance(args[1], torch.Tensor) and not isinstance(args[1], ShardedTensor): + raise ValueError("Needs to give a Tensor or ShardedTensor to cast type as!") + + +def same_dtype(*args, **kwargs): + """ + When the dtype is the same, return the original ShardedTensor. + + Args: same as ``torch.Tensor.type_as``. + + Return (bool): Whether to return early or not. + """ + return args[0].dtype == args[1].dtype + + +def sharded_type_as(args, kwargs, pg): + """ + Handles ``__torch_function__`` dispatch for the ``torch.Tensor.type_as`` op. + + Args: same as ``torch.Tensor.type_as``. + + Return: + new_local_shards (List[Shard]): Local shards for the new sharded tensor. + st_meta (ShardedTensorMetadata): Metadata of the new sharded tensor. + """ + st = args[0] + tensor = args[1] + if isinstance(tensor, ShardedTensor): + tensor = tensor.local_tensor() + new_local_shards = [ + Shard(shard.tensor.type_as(tensor), shard.metadata) + for shard in st.local_shards() + ] + st_meta = copy.deepcopy(st._metadata) + st_meta.tensor_properties.dtype = tensor.dtype + return new_local_shards, st_meta + + +_register_sharded_op_on_local_shards( + torch.Tensor.type_as, + early_stop_func=same_dtype, + extra_check=sharded_type_as_check, + customized_func=sharded_type_as, +) + + +def sharded_deepcopy(args, kwargs, pg): + # NOTE: we directly implement deepcopy magic method + # instead of using the default tensor.__deepcopy__ + # and implement clone(). This is because the default + # tensor deepcopy copies every attribute, but the + # process_group in ShardedTensor cannot be deep copied. + self_st = args[0] + new_local_shards = copy.deepcopy(self_st.local_shards()) + new_metadata = copy.deepcopy(self_st.metadata()) + return new_local_shards, new_metadata + + +_register_sharded_op_on_local_shards( + torch.Tensor.__deepcopy__, + customized_func=sharded_deepcopy, +) + + +@_sharded_op_impl(torch.Tensor.copy_) +def sharded_inplace_copy(types, args, kwargs, pg): + # NOTE: inplace op don't need to rewrap + kwargs = {} if kwargs is None else kwargs + self_st = args[0] + new_st = args[1] + nonblocking = kwargs.get("non_blocking", False) + for local_shard, new_shard in zip(self_st.local_shards(), new_st.local_shards()): + if local_shard.metadata != new_shard.metadata: + raise RuntimeError( + "inplace copy can only happen between two ShardedTensor with same metadata!" + ) + for local_shard, new_shard in zip(self_st.local_shards(), new_st.local_shards()): + local_shard.tensor.copy_(new_shard.tensor, nonblocking) + + return self_st + + +def sharded_clone(args, kwargs, pg): + self_st = args[0] + desire_memory_format = kwargs.get("memory_format", None) + if desire_memory_format and desire_memory_format != torch.preserve_format: + raise RuntimeError("Only support torch.preserve_format for ShardedTensor!") + cloned_local_shards = [ + Shard( + local_shard.tensor.clone(memory_format=desire_memory_format), + metadata=copy.deepcopy(local_shard.metadata), + ) + for local_shard in self_st.local_shards() + ] + new_metadata = copy.deepcopy(self_st.metadata()) + return cloned_local_shards, new_metadata + + +_register_sharded_op_on_local_shards( + torch.Tensor.clone, + customized_func=sharded_clone, +) + + +def sharded_detach(args, kwargs, pg): + self_st = args[0] + detached_local_shards = [ + Shard( + local_shard.tensor.detach(), + metadata=copy.deepcopy(local_shard.metadata), + ) + for local_shard in self_st.local_shards() + ] + new_metadata = copy.deepcopy(self_st.metadata()) + new_metadata.tensor_properties.requires_grad = False + return detached_local_shards, new_metadata + + +_register_sharded_op_on_local_shards( + torch.Tensor.detach, + customized_func=sharded_detach, +) + + +@_sharded_op_impl(torch.Tensor.requires_grad_) +def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None): + # pyrefly: ignore [index-error] + self_st = args[0] + # Validate types + if not isinstance(self_st, ShardedTensor): + raise TypeError("input needs to be a ShardedTensor") + + if kwargs is None: + kwargs = {} + + requires_grad = args[1] if len(args) > 1 else kwargs.get("requires_grad", True) + if requires_grad == self_st.requires_grad: + return self_st + + for local_shard in self_st.local_shards(): + local_shard.tensor.requires_grad_(requires_grad) + + # update the wrapper class property + with torch._C.DisableTorchFunctionSubclass(): + self_st.requires_grad_(requires_grad) + # update the metadata in the meanwhile + self_st._metadata.tensor_properties.requires_grad = requires_grad + return self_st diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/api.py new file mode 100644 index 0000000000000000000000000000000000000000..a8e8677d6ae7c91cf8d871ff697e057b554b794c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/api.py @@ -0,0 +1,1368 @@ +# mypy: allow-untyped-defs +from __future__ import annotations # type: ignore[attr-defined] + +import copy +import operator +import threading +import warnings +import weakref +from dataclasses import dataclass +from functools import reduce +from typing import cast, TYPE_CHECKING +from typing_extensions import deprecated + +import torch +import torch.distributed as dist +import torch.distributed._shard.sharding_spec as shard_spec +from torch._utils import _get_device_module +from torch.distributed import distributed_c10d, rpc +from torch.distributed._shard._utils import DEPRECATE_MSG +from torch.distributed._shard.sharding_spec._internals import ( + check_tensor, + validate_non_overlapping_shards_metadata, +) +from torch.distributed._shard.sharding_spec.api import ( + _dispatch_custom_op, + _has_custom_op, +) +from torch.distributed.remote_device import _remote_device +from torch.utils import _pytree as pytree + +from .metadata import ShardedTensorMetadata, TensorProperties +from .reshard import reshard_local_shard, reshuffle_local_shard +from .shard import Shard +from .utils import ( + _flatten_tensor_size, + _parse_and_validate_remote_device, + _validate_output_tensor_for_gather, + build_global_metadata, + build_metadata_from_local_shards, +) + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from torch.distributed._shard.metadata import ShardMetadata + + +# Tracking for sharded tensor objects. +_sharded_tensor_lock = threading.Lock() +_sharded_tensor_current_id = 0 +_sharded_tensor_map: dict[int, weakref.ReferenceType[ShardedTensor]] = {} + +# Default sharded ops +_SHARDED_OPS: dict[Callable, Callable] = {} + +# Customized user ops +_CUSTOM_SHARDED_OPS: dict[Callable, Callable] = {} + + +def _register_remote_shards( + sharded_tensor_id: int, rrefs: list[rpc.RRef[Shard]], rpc_rank: int +): + with _sharded_tensor_lock: + if sharded_tensor_id not in _sharded_tensor_map: + raise RuntimeError( + f"Could not find sharded_tensor_id: {sharded_tensor_id} in map: {_sharded_tensor_map.keys()}" + ) + + sharded_tensor = _sharded_tensor_map[sharded_tensor_id]() + if sharded_tensor is None: + raise RuntimeError("ShardedTensor weakref has been deallocated") + else: + sharded_tensor._register_remote_shards(rrefs, rpc_rank) + + +class ShardedTensorBase(torch.Tensor): + _sharding_spec: shard_spec.ShardingSpec + _metadata: ShardedTensorMetadata + _local_shards: list[Shard] + + def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs): + # Use __new__ to construct a wrapper tensor, for recording tensor + # properties and logging purposes. + torch._C._log_api_usage_once("torch.distributed._shard.sharded_tensor") + + # check sharding spec and build sharded tensor metadata + if not isinstance(sharding_spec, shard_spec.ShardingSpec): + raise ValueError(f"Expecting ShardingSpec but got: {type(sharding_spec)}") + + sizes = _flatten_tensor_size(size) + dtype = kwargs["dtype"] + layout = kwargs["layout"] + pin_memory = kwargs["pin_memory"] + requires_grad = kwargs["requires_grad"] + + if dtype is None: + dtype = torch.get_default_dtype() + + tensor_properties = TensorProperties( + dtype, layout, requires_grad, pin_memory=pin_memory + ) + sharded_tensor_metadata = sharding_spec.build_metadata( + sizes, tensor_properties=tensor_properties + ) + + r = torch.Tensor._make_wrapper_subclass( + cls, + sizes, + dtype=dtype, + layout=layout, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + # set sharding spec + r._sharding_spec = sharding_spec + # set metadata + r._metadata = sharded_tensor_metadata + # set local shards + r._local_shards = [] + return r + + def metadata(self) -> ShardedTensorMetadata: + """ + Returns a :class:`ShardedTensorMetadata` object corresponding to the + metadata for the entire tensor. + """ + return self._metadata + + def local_shards(self) -> list[Shard]: + """ + Returns a list of :class:`Shard' corresponding to the + local shards for this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return self._local_shards + + @classmethod + def _init_from_local_shards_and_global_metadata( + cls, + local_shards: list[Shard], + sharded_tensor_metadata: ShardedTensorMetadata, + sharding_spec=None, + ) -> ShardedTensorBase: + """ + Initialize a ShardedTensorBase with local shards and a global + ShardedTensorMetadata built on each rank. + Warning: This API is experimental and subject to change. It does + not do cross rank validations, and fully rely on the user + for the correctness of sharded_tensor_metadata on each rank + """ + shards_metadata = sharded_tensor_metadata.shards_metadata + tensor_properties = sharded_tensor_metadata.tensor_properties + + if len(shards_metadata) == 0: + raise ValueError("shards_metadata must not be empty!") + + if tensor_properties.layout != torch.strided: + raise ValueError("Only torch.strided layout is currently supported") + + if sharding_spec is None: + spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata) + else: + spec = sharding_spec + + sharded_tensor_base = ShardedTensorBase.__new__( + ShardedTensor, + spec, + sharded_tensor_metadata.size, + dtype=tensor_properties.dtype, + layout=tensor_properties.layout, + pin_memory=tensor_properties.pin_memory, + requires_grad=tensor_properties.requires_grad, + ) + + # check if shards_metadata have overlap shards + validate_non_overlapping_shards_metadata(shards_metadata) + + # check if the shards_metadata is compatible with overall size of the sharded tensor. + check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) + + # done validation, add local_shards + sharded_tensor_base._local_shards = local_shards + return sharded_tensor_base + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] + raise RuntimeError( + f"A {cls.__name__} object is being used from c++ while calling {func.__module__}.{func.__name__} " + "but the there is no custom __torch_dispatch__ implementation for it." + ) + + +class ShardedTensor(ShardedTensorBase): + """ + ShardedTensor is an torch.Tensor subclass to represent Tensors that are sharded + across multiple devices and multiple processes. + + ShardedTensor is initialized in an SPMD like fashion where each rank + initializes the ShardedTensor. The ShardedTensor object on each rank + then only stores the local shard for the Tensor and provides global + metadata for all the shards. + + ShardedTensor doesn't provide any Tensor like operations but is a wrapper + providing the Tensor representing the local shard and the global metadata. + Using these, users can build their custom distributed._sharded computations + on top of this primitive. The local shards are all initialized using the + create_op specified by tensor_init_params.create_op, e.g., torch.ones, or + torch.empty + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + .. note:: ShardedTensor uses collectives to do various operations, i.e. it + uses all_gather to do cross rank validations. For NCCL-based process + groups, internal tensor representations of objects must be moved to the + GPU device before communication takes place. In this case, the device + used is given by ``torch.cuda.current_device()`` and it is the user's + responsibility to ensure that this is set so that each rank has an + individual GPU, via ``torch.cuda.set_device()`` + + """ + + def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs): + self = super().__new__(cls, sharding_spec, *size, **kwargs) + return self + + def __init__( + self, + sharding_spec: shard_spec.ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, + ): + # prepare initialization, initialize fields like + # _process_group, _local_shards, etc. + self._prepare_init(process_group=process_group, init_rrefs=init_rrefs) + + if layout != torch.strided: + raise ValueError("Only torch.strided layout is currently supported") + + if memory_format != torch.contiguous_format: + raise ValueError( + "Only torch.contiguous_format memory_format is currently supported" + ) + + self._metadata.tensor_properties.memory_format = memory_format + + current_rank = dist.get_rank() # global rank + + for shard_metadata in self._metadata.shards_metadata: + rank, device = _parse_and_validate_remote_device( + self._process_group, shard_metadata.placement + ) + if rank == current_rank: + local_tensor = _create_tensor_from_params( + shard_metadata.shard_sizes, + local_device=device, + tensor_properties=self._metadata.tensor_properties, + ) + self._local_shards.append(Shard(local_tensor, shard_metadata)) + + # do post initialization (i.e. register sharded_tensor_id, initialize_rpc) + self._post_init() + + def _prepare_init(self, process_group=None, init_rrefs=False): + self._init_rrefs = init_rrefs + self._sharded_tensor_id = None + + self._process_group = self._normalize_pg(process_group) + self._remote_shards: dict[int, list[rpc.RRef[Shard]]] = {} + + def _post_init(self): + # Initialize RPC if available. + if self._init_rrefs: + with _sharded_tensor_lock: + global _sharded_tensor_current_id, _sharded_tensor_map + # pyrefly: ignore [bad-assignment] + self._sharded_tensor_id = _sharded_tensor_current_id + # pyrefly: ignore [unsupported-operation] + _sharded_tensor_map[self._sharded_tensor_id] = weakref.ref(self) + _sharded_tensor_current_id += 1 + + if not rpc._is_current_rpc_agent_set(): + raise RuntimeError( + "RPC Framework needs to be initialized using" + " torch.distributed.rpc.init_rpc if init_rrefs is set to True" + ) + self._init_rpc() + + def __del__(self): + # Clean up the global map. + with _sharded_tensor_lock: + global _sharded_tensor_current_id, _sharded_tensor_map + if ( + hasattr(self, "_sharded_tensor_id") + and self._sharded_tensor_id in _sharded_tensor_map + ): + _sharded_tensor_map.pop(self._sharded_tensor_id) # type: ignore[call-overload] + + def _init_rpc(self): + # Validate PG and RPC ranks match. + pg_rank = dist.get_rank() + rpc_rank = rpc.get_worker_info().id + if pg_rank != rpc_rank: + raise ValueError( + f"Default ProcessGroup and RPC ranks must be " + f"the same for ShardedTensor, found process group rank: " + f"{pg_rank} and RPC rank: {rpc_rank}" + ) + + self._remote_shards = {} + + # Gather all the sharded tensor ids. + worker_infos = rpc._get_current_rpc_agent().get_worker_infos() + rank_to_name = {} + name_to_rank = {} + + for worker_info in worker_infos: + rank_to_name[worker_info.id] = worker_info.name + name_to_rank[worker_info.name] = worker_info.id + + all_tensor_ids = rpc.api._all_gather(self._sharded_tensor_id) + + # Share the local shards to the entire world. + futs = [] + rpc_rank = rpc.get_worker_info().id + for rank in range(dist.get_world_size()): + # Skip self. + if rank == dist.get_rank(): + continue + + if len(self.local_shards()) != 0: + rrefs: list[rpc.RRef[Shard]] = [ + rpc.RRef(shard) for shard in self.local_shards() + ] + fut = rpc.rpc_async( + rank, + _register_remote_shards, + args=(all_tensor_ids[rank_to_name[rank]], rrefs, rpc_rank), + ) + futs.append(fut) + + torch.futures.wait_all(futs) + + # Barrier for all RPCs to finish on all ranks. + rpc.api._all_gather(None) + + def _get_preferred_device(self) -> torch.device: + """ + Return the preferred device to be used when creating tensors for collectives. + This method takes into account the associated process group + """ + backend = dist.get_backend(self._process_group) + if backend == dist.Backend.NCCL: + return torch.device(torch.cuda.current_device()) + elif backend == dist.Backend.GLOO: + return torch.device("cpu") + else: + backend_config = dist.BackendConfig(backend) + for device, backend_str in backend_config.get_device_backend_map().items(): + if backend_str == backend and device != "cpu": + return torch.device( + device, _get_device_module(device).current_device() + ) + return torch.device("cpu") + + def gather( # type: ignore[override] + self, + dst: int = 0, + out: torch.Tensor | None = None, + enforce_dtype: bool = False, + dtype: torch.dtype | None = None, + ) -> None: + """ + Creates a full :class:`Tensor` on rank ``dst`` by gathering all shards of the + sharded tensor. + + The API needs to be called on all ranks in SPMD fashion. All ranks should have + the same ``dst``. ``out`` should be a tensor of the same size as the overall + size of the sharded tensor on ``dst`` and ``None`` on all other ranks. + + Args: + dst(int): The rank where full tensor is constructed. + Default: 0 + out (:class `torch.Tensor`, optional): The output full tensor. + Must to be provided ONLY on ``dst`` rank. + Default: ``None`` + enforce_dtype (bool): Deprecated, please use dtype instead. Force the + gathered tensors to be the same type as input and output. + dtype (torch.dtype): Force the gathered tensors to be this dtype. + Default: ``None`` + """ + + def shard_size(shard_md): + return reduce(operator.mul, shard_md.shard_sizes) # type: ignore[attr-defined] + + if enforce_dtype: + warnings.warn( + "`enforce_dtype` is deprecated. Please use `dtype` instead.", + FutureWarning, + stacklevel=2, + ) + + rank = dist.get_rank(self._process_group) + full_size = self.metadata().size + _validate_output_tensor_for_gather(rank, dst, full_size, out) + + local_shards = self.local_shards() + world_size = dist.get_world_size(self._process_group) + rank_sizes = [0 for _ in range(world_size)] + max_rank_size = 0 + shard_placement: dict[ShardMetadata, tuple[int, int]] = {} + # collect sizes + for shard_md in self.metadata().shards_metadata: + shard_rank = cast(_remote_device, shard_md.placement).rank() + assert shard_rank is not None + + shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank]) + rank_sizes[shard_rank] += shard_size(shard_md) + max_rank_size = max(max_rank_size, rank_sizes[shard_rank]) + + gather_list: list[torch.Tensor] | None + if rank == dst: + assert out is not None + if enforce_dtype: + # enforce_dtype is deprecated. Do it for backward compatibility. + dtype = out.dtype + # TODO make it as a view of out tensor + gather_list = [ + torch.empty((max_rank_size,), device=out.device, dtype=dtype) + for _ in range(world_size) + ] + else: + gather_list = None + + with torch.no_grad(): + if enforce_dtype and len(local_shards) > 0: + # enforce_dtype is deprecated. Do it for backward compatibility. + dtype = local_shards[0].tensor.dtype + data = torch.empty( + max_rank_size, device=self._get_preferred_device(), dtype=dtype + ) + + for shard in local_shards: + src = shard.tensor.flatten() + if src.nelement() == 0: + warnings.warn( + "Gathering a tensor with zero elements on rank " + str(rank), + stacklevel=2, + ) + continue + shard_offset = shard_placement[shard.metadata][1] + data[shard_offset : shard_offset + src.numel()].copy_(src) + + dist.gather( + tensor=data, + gather_list=gather_list, + dst=dst, + group=self._process_group, + ) + if rank != dst: + return + # In _validate_output_tensor_for_gather, we raise if out == None and rank == dst + out = cast(torch.Tensor, out) + assert gather_list is not None + + full_size = self.metadata().size + dims = len(full_size) + for shard_md in self.metadata().shards_metadata: + rank, rank_offset = shard_placement[shard_md] + tensor = gather_list[rank] + tensor = tensor[rank_offset : rank_offset + shard_size(shard_md)] + tensor = tensor.view(shard_md.shard_sizes) + + out_narrow_view = out + for dim in range(dims): + out_narrow_view = out_narrow_view.narrow( + dim, + shard_md.shard_offsets[dim], + shard_md.shard_sizes[dim], + ) + + out_narrow_view.copy_(tensor) + + def cpu( + self, memory_format=torch.preserve_format, process_group=None + ) -> ShardedTensor: + """ + Returns a copy of this object in CPU memory. + + If this ShardedTensor is already on CPU memory, then no copy is + performed and original object is returned. + + .. note:: When moving a ShardedTensor from GPU to CPU, the ShardedTensor might + need to be managed by a different type of ProcessGroup(i.e. ProcessGroupGloo), + it is the user's responsibility to explicitly pass in a new process_group that + is compatible with CPU. + """ + # TODO: make this a __torch_function__ op once ShardedTensor becomes a + # torch.Tensor subclass, see https://github.com/pytorch/pytorch/issues/75402 + if ( + memory_format != torch.preserve_format + and memory_format != torch.contiguous_format + ): + raise RuntimeError( + "Only `torch.contiguous_format` or " + "`torch.preserve_format` is supported!" + ) + all_on_cpu = True + for meta in self.metadata().shards_metadata: + all_on_cpu &= meta.placement.device().type == "cpu" # type: ignore[union-attr] + + # if every shard is already on CPU, return the original object + if all_on_cpu: + return self + + # if not, returns a copy of this object on CPU + list_shards: list[Shard] = [] + # move all local shards to cpu, and change metadata + for shard in self._local_shards: + cpu_tensor = shard.tensor.cpu(memory_format=memory_format) # type: ignore[call-arg] + metadata = copy.deepcopy(shard.metadata) + metadata.placement._device = torch.device("cpu") # type: ignore[union-attr] + list_shards.append(Shard(cpu_tensor, metadata)) + + st_meta = copy.deepcopy(self.metadata()) + for meta in st_meta.shards_metadata: + if meta.placement.device().type != "cpu": # type: ignore[union-attr] + meta.placement._device = torch.device("cpu") # type: ignore[union-attr] + + pg = self._process_group if process_group is None else process_group + st_cpu = ShardedTensor._init_from_local_shards_and_global_metadata( + list_shards, + sharded_tensor_metadata=st_meta, + process_group=pg, + init_rrefs=self._init_rrefs, + ) + return st_cpu + + def cuda( + self, + device=None, + non_blocking=False, + memory_format=torch.preserve_format, + process_group=None, + ) -> ShardedTensor: + """ + Returns a copy of this object in CUDA memory, if the original ShardedTensor + is on CPU, we will move the local shard to the current GPU device of each + process in a SPMD fashion. + If this ShardedTensor is already on CUDA memory and local shards on each rank are + already on current device, we still returns a new ShardedTensor object with new + metadata, but no underlying data movements are performed. + .. note:: When moving a ShardedTensor from CPU to GPU, the ShardedTensor might + need to be managed by a different type of ProcessGroup(i.e. ProcessGroupNCCL), + it is the user's responsibility to explicitly pass in a new process_group that + is compatible with GPU. + """ + if ( + memory_format != torch.preserve_format + and memory_format != torch.contiguous_format + ): + raise RuntimeError( + "Only `torch.contiguous_format` or " + "`torch.preserve_format` is supported!" + ) + + if device is not None: + device = torch.device(device) if isinstance(device, str) else device + assert ( + isinstance(device, torch.device) + and device.index == torch.cuda.current_device() + ), ( + """Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!""" + ) + + current_device = torch.device(torch.cuda.current_device()) + # returns a copy of ShardedTensor on CUDA current device + list_shards: list[Shard] = [] + # move all local shards to current device, and change metadata + # if local shards already on the current device, there's no + # real data movement, only the metadata are copied. + for shard in self._local_shards: + cuda_tensor = shard.tensor.cuda( + device=current_device, + non_blocking=non_blocking, + memory_format=memory_format, + ) # type: ignore[call-arg] + metadata = copy.deepcopy(shard.metadata) + metadata.placement._device = current_device # type: ignore[union-attr] + + list_shards.append(Shard(cuda_tensor, metadata)) + + st_meta = copy.deepcopy(self.metadata()) + for meta in st_meta.shards_metadata: + if meta.placement.device().type != "cuda": # type: ignore[union-attr] + meta.placement._device = current_device # type: ignore[union-attr] + + pg = self._process_group if process_group is None else process_group + # we need to use `init_from_local_shards` to communicate between ranks + # and update the sharding spec/shards metadata. + st_cuda = ShardedTensor._init_from_local_shards_and_global_metadata( + list_shards, + sharded_tensor_metadata=st_meta, + process_group=pg, + init_rrefs=self._init_rrefs, + ) + return st_cuda + + def to(self, *args, **kwargs) -> ShardedTensor: + current_device: torch.device + if self._local_shards: + current_device = self._local_shards[0].tensor.device + elif self._process_group._get_backend_name() == "gloo": + current_device = torch.device("cpu") + else: + current_device = torch.device(torch.cuda.current_device()) + current_dtype = self.dtype + device_to = current_device + dtype_to = current_dtype + if len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype_to = args[0] + elif isinstance(args[0], torch.device): + device_to = args[0] + elif isinstance(args[0], (str, int)): + device_to = torch.device(args[0]) + elif isinstance(args[0], torch.Tensor): + dtype_to = args[0].dtype + device_to = args[0].device + else: + raise RuntimeError(f"ShardedTensor.to() have wrong arguments: {args}") + elif len(args) == 2: + device_to, dtype_to = args + else: + dtype_to = kwargs.get("dtype", current_dtype) + device_to = kwargs.get("device", current_device) + + device_to = ( + torch.device(device_to) if isinstance(device_to, (str, int)) else device_to + ) + + if device_to.type == "cuda": + # if device_to set to cuda, set to current device even + # if user specify the device index. + current_idx = torch.cuda.current_device() + if device_to.index != current_idx: + warnings.warn( + "ShardedTensor.to only move tensor to its current device" + "If you want to put to different device, use `reshard` instead.", + stacklevel=2, + ) + device_to = torch.device(current_idx) + + copy_tensor = kwargs.get("copy", False) + non_blocking = kwargs.get("non_blocking", False) + memory_format = kwargs.get("memory_format", torch.preserve_format) + process_group = kwargs.get("process_group") + + if ( + not copy_tensor + and dtype_to == current_dtype + and device_to == current_device + ): + # already have correct dtype and device, return itself + return self + + # returns a copy of ShardedTensor on CUDA current device + list_shards: list[Shard] = [] + + for shard in self._local_shards: + new_tensor = shard.tensor.to( # type: ignore[call-overload] + device=device_to, + dtype=dtype_to, + non_blocking=non_blocking, + copy=copy_tensor, + memory_format=memory_format, + ) + metadata = copy.deepcopy(shard.metadata) + if metadata.placement is not None: + metadata.placement._device = device_to + list_shards.append(Shard(new_tensor, metadata)) + + # update metadata + st_meta = copy.deepcopy(self.metadata()) + st_meta.tensor_properties.dtype = dtype_to + for meta in st_meta.shards_metadata: + meta.placement._device = device_to # type: ignore[union-attr] + + pg = self._process_group if process_group is None else process_group + # we need to use `init_from_local_shards` to communicate between ranks + # and update the sharding spec/shards metadata. + st_to = ShardedTensor._init_from_local_shards_and_global_metadata( + list_shards, + sharded_tensor_metadata=st_meta, + process_group=pg, + init_rrefs=self._init_rrefs, + ) + return st_to + + @classmethod + def _normalize_pg( + cls, process_group: dist.ProcessGroup | None + ) -> dist.ProcessGroup: + if process_group is not None: + return process_group + return distributed_c10d._get_default_group() + + @classmethod + def _init_from_local_shards( + cls, + local_shards: list[Shard], + *global_size, + process_group=None, + init_rrefs=False, + ): + # recalc metadata handles special ST creation cases like each rank only has tensor available + # caller need to provide None on the unknown dimension of the global size + # We will change None into zeros and go through the same amount of checks as before to create ST + # and use all_gather to calculate the offsets and global size for metadata + # It is compatible with the current use case since, conventionally we don't pass None as global size + # Therefore the old path won't trigger the new feature + recalc_metadata = False + for dim in global_size: + if dim is None: + recalc_metadata = True + if recalc_metadata: + global_size = tuple( + 0 if dim_size is None else dim_size for dim_size in global_size + ) + # STEP 1: Validate the Shardmetadatas locally + process_group = cls._normalize_pg(process_group) + current_rank = dist.get_rank() # intentional to get global rank + world_size = dist.get_world_size(process_group) + + local_sharded_tensor_metadata: ShardedTensorMetadata | None = None + global_tensor_size = _flatten_tensor_size(global_size) + + if len(local_shards) > 0: + local_sharded_tensor_metadata = build_metadata_from_local_shards( + local_shards, global_tensor_size, current_rank, process_group + ) + + # STEP 2. Validate metadata across ranks, and build a global sharded tensor + # metadata by gathering local ShardedTensorMetadata + gathered_metadatas: list[ShardedTensorMetadata | None] = [] + if world_size > 1: + gathered_metadatas = [None for _ in range(world_size)] + + dist.all_gather_object( + gathered_metadatas, local_sharded_tensor_metadata, group=process_group + ) + else: + gathered_metadatas = [local_sharded_tensor_metadata] + + global_sharded_tensor_metadata = build_global_metadata( + gathered_metadatas, recalc_metadata=recalc_metadata + ) + if recalc_metadata: + # for recalc use cases, we only support rw for now, limit the blast radius + # will modify here once we support more sharding type + assert ( + len(local_shards) > 0 + and len(global_sharded_tensor_metadata.shards_metadata) > current_rank + ), ( + f"# for metadata recalculation, local_shards must be larger than 0 " + f"actual:{len(local_shards)}, # glb metadata must be greater than any rank id, " + f"# metadata:{len(global_sharded_tensor_metadata.shards_metadata)}, rank id:{current_rank}" + ) + local_md = [ + shard_md + for shard_md in global_sharded_tensor_metadata.shards_metadata + if shard_md.placement.rank() == current_rank + ] + assert len(local_md) == 1, ( + f"should has and only has one metadata for local rank, actual:{local_md}" + ) + local_shards[0].metadata = local_md[0] + tensor_properties = global_sharded_tensor_metadata.tensor_properties + + # STEP 3: Validation done, create the actual ShardedTensor and populate fields + # prepare initialization + spec = shard_spec._infer_sharding_spec_from_shards_metadata( + global_sharded_tensor_metadata.shards_metadata + ) + sharded_tensor = cls.__new__( + cls, + spec, + global_sharded_tensor_metadata.size, + dtype=tensor_properties.dtype, + layout=tensor_properties.layout, + pin_memory=tensor_properties.pin_memory, + requires_grad=tensor_properties.requires_grad, + ) + sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) + + # attach local_shards to the ShardedTensor created + sharded_tensor._local_shards = local_shards + + # run post initialization, i.e. map registration, rpc initialization + sharded_tensor._post_init() + return sharded_tensor + + @classmethod + @deprecated(DEPRECATE_MSG, category=FutureWarning) + def _init_from_local_tensor( + cls, + local_tensor: torch.Tensor, + sharding_spec: shard_spec.ShardingSpec, + *global_size: Sequence[int], + process_group: dist.ProcessGroup | None = None, + init_rrefs=False, + ) -> ShardedTensor: + """ + Initialize a ShardedTensor given only one local tensor, global sharded tensor + size and sharding spec on each rank. + + Args: + local_tensor (Tensor): Single tensor of local shard stored in each rank. + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): + The specification describing how to shard the Tensor. + global_size (Sequence[int]): Size of the sharded tensor. + process_group (ProcessGroup, optional): The process group to aggregate on. + Default: None + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` sharded based on the given sharding_spec with local + tensor stored in the current rank. + + Examples: + >>> # xdoctest: +SKIP + >>> # All tensors below are of torch.int64 type. + >>> # We have 2 process groups, 2 ranks. + >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank + >>> local_tensor = torch.unsqueeze(torch.cat([tensor, tensor + 2])) + >>> local_tensor + tensor([[1, 2, 3, 4]]) # Rank 0 + tensor([[3, 4, 5, 6]]) # Rank 1 + >>> sharding_dim = 0 + >>> sharding_spec = ChunkShardingSpec( + dim=sharding_dim, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + ], + ) + >>> st = ShardedTensor._init_from_local_tensor( + ... local_tensor, sharding_spec, [2, 4] + ... ) + >>> st + ShardedTensor( + ShardedTensorMetadata( + shards_metadata=[ + ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1, 4], placement=rank:0/cuda:0), + ShardMetadata(shard_offsets=[1, 0], shard_sizes=[1, 4], placement=rank:1/cuda:1), + ], + size=torch.Size([2, 4]) + ) + >>> st.local_tensor() + tensor([1, 2, 3, 4]) # Rank 0 + tensor([3, 4, 5, 6]) # Rank 1 + + Warning: This API is experimental and subject to change. It lacks of a fully across + rank validations, and we only validate the local shard on the current rank. + We fully rely on the user to ensure local tensor is sharded based on the + sharding spec. + """ + if not local_tensor.is_contiguous(): + raise ValueError("local_tensor is not a contiguous Tensor.") + + global_tensor_size = _flatten_tensor_size(global_size) + tensor_properties = TensorProperties( + dtype=local_tensor.dtype, + layout=local_tensor.layout, + requires_grad=local_tensor.requires_grad, + memory_format=torch.contiguous_format, + pin_memory=local_tensor.is_pinned(), + ) + sharded_tensor_metadata = sharding_spec.build_metadata( + global_tensor_size, tensor_properties + ) + + process_group = cls._normalize_pg(process_group) + current_rank = dist.get_rank() # intentional to get global rank + + local_shards: list[Shard] = [] + for shard_metadata in sharded_tensor_metadata.shards_metadata: + rank, _device = _parse_and_validate_remote_device( + process_group, shard_metadata.placement + ) + if rank == current_rank: + local_shards.append(Shard(local_tensor, shard_metadata)) + + # TODO: figure out what the API should behave when some rank have no shard + # see https://github.com/pytorch/pytorch/issues/7313 + return ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards, + sharded_tensor_metadata, + process_group=process_group, + init_rrefs=init_rrefs, + sharding_spec=sharding_spec, + ) + + @classmethod + def _init_from_local_shards_and_global_metadata( # type: ignore[override] + cls, + local_shards: list[Shard], + sharded_tensor_metadata: ShardedTensorMetadata, + process_group=None, + init_rrefs=False, + sharding_spec=None, + ) -> ShardedTensor: + """ + Initialize a ShardedTensor with local shards and a global + ShardedTensorMetadata built on each rank. + + Warning: This API is experimental and subject to change. It does + not do cross rank validations, and fully rely on the user + for the correctness of sharded_tensor_metadata on each rank + """ + process_group = cls._normalize_pg(process_group) + current_rank = dist.get_rank() # intentional to get global rank + + shards_metadata = sharded_tensor_metadata.shards_metadata + + local_shard_metadatas = [] + + # collect local shard metadatas from the global sharded_tensor_metadata + for shard_metadata in shards_metadata: # type: ignore[attr-defined] + rank, local_device = _parse_and_validate_remote_device( + process_group, shard_metadata.placement + ) + + if current_rank == rank: + local_shard_metadatas.append(shard_metadata) + + if len(local_shards) != len(local_shard_metadatas): + raise RuntimeError( + f"Number of local shards ({len(local_shards)}) does not match number of local " + f"shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) " + f"on rank ({current_rank}) " + ) + + shards_metadata = sharded_tensor_metadata.shards_metadata + tensor_properties = sharded_tensor_metadata.tensor_properties + + if len(shards_metadata) == 0: + raise ValueError("shards_metadata must not be empty!") + + if tensor_properties.layout != torch.strided: + raise ValueError("Only torch.strided layout is currently supported") + + if sharding_spec is None: + spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata) + else: + spec = sharding_spec + + sharded_tensor = ShardedTensor.__new__( + ShardedTensor, + spec, + sharded_tensor_metadata.size, + dtype=tensor_properties.dtype, + layout=tensor_properties.layout, + pin_memory=tensor_properties.pin_memory, + requires_grad=tensor_properties.requires_grad, + ) + + def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False): + tensor_property_or_metadata = ( + "tensor property" if is_property else "local ShardMetadata" + ) + if expected != actual: + raise ValueError( + f"Local shards' tensor {prop_name} property is incompatible with " + f"{tensor_property_or_metadata} on rank {rank}: " + f"{tensor_property_or_metadata} {prop_name}={expected}, " + f"local shard tensor {prop_name}={actual}." + ) + + for shard in local_shards: + shard_meta = shard.metadata + local_shard_tensor = shard.tensor + placement = shard_meta.placement + assert placement is not None, "Must specify placement for `Shard`!" + rank = placement.rank() + local_device = placement.device() + + _raise_if_mismatch( + tensor_properties.layout, + local_shard_tensor.layout, + "layout", + rank, + True, + ) + if not local_shard_tensor.is_contiguous(): + raise ValueError( + "Only torch.contiguous_format memory_format is currently supported" + ) + + _raise_if_mismatch( + shard_meta.shard_sizes, + list(local_shard_tensor.size()), + "size", + rank, + ) + _raise_if_mismatch( + tensor_properties.pin_memory, + local_shard_tensor.is_pinned(), + "pin_memory", + rank, + True, + ) + _raise_if_mismatch(local_device, local_shard_tensor.device, "device", rank) + _raise_if_mismatch( + tensor_properties.dtype, + local_shard_tensor.dtype, + "dtype", + rank, + True, + ) + _raise_if_mismatch( + tensor_properties.requires_grad, + local_shard_tensor.requires_grad, + "requires_grad", + rank, + True, + ) + + # check if shards_metadata have overlap shards + validate_non_overlapping_shards_metadata(shards_metadata) + + # check if the shards_metadata is compatible with overall size of the sharded tensor. + check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) + + # done validation, add local_shards + sharded_tensor._local_shards = local_shards + sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) + + # run post initialization, i.e. map registration, rpc initialization + sharded_tensor._post_init() + return sharded_tensor + + def sharding_spec(self) -> shard_spec.ShardingSpec: + """ + Returns the ShardingSpec for the tensor. + """ + return self._sharding_spec + + @deprecated(DEPRECATE_MSG, category=FutureWarning) + def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> ShardedTensor: + """ + Reshard a sharded tensor given the ``resharding_spec``. For now, we only support + single local shard. + + If ``resharding_spec`` is same as the original one, this becomes a no-op. + If only ``resharding_spec`` shares the same sharding dim with the original one, + we swap local shards directly. + For more generic cases, we merge different shards across different ranks and split + the local shards based on the ``resharding_spec`` via `all_to_all` collective API. + + Args: + resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor is sharded. + + Returns: + A :class:`ShardedTensor` object whose local shards are resharded. + + Examples: + >>> # xdoctest: +SKIP + >>> # We have 2 process groups, 2 ranks. + >>> tensor = torch.arange(4, dtype=torch.int64) + 1 + 2 * rank + >>> tensor = torch.stack([tensor, tensor]) + >>> tensor + tensor([[1, 2, 3, 4], [1, 2, 3, 4]]) # Rank 0 + tensor([[3, 4, 5, 6], [3, 4, 5, 6]]) # Rank 1 + tensor([[5, 6, 7, 8], [5, 6, 7, 8]]) # Rank 2 + tensor([[7, 8, 9, 10], [7, 8, 9, 10]]) # Rank 3 + >>> sharding_dim = 0 + >>> spec = ChunkShardingSpec( + dim=sharding_dim, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + >>> current_offsets = [0] * 2 + >>> current_offsets[0] = rank * 2 + >>> shard_metadata = ShardMetadata( + shard_offsets=copy.deepcopy(current_offsets), + shard_sizes=tensor.size(), + placement=spec.placements[rank], + ) + >>> local_shards = [ + Shard( + tensor=tensor, + metadata=shard_metadata, + ) + ] + >>> st = ShardedTensor._init_from_local_shards(local_shards, tensor.size()) + >>> sharding_dim = 1 + >>> resharding_spec = ChunkShardingSpec( + dim=sharding_dim, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + >>> st.reshard(resharding_spec) + >>> tensor = st.local_shards()[0].tensor + >>> tensor + tensor([[1], [1], [3], [3], [5], [5], [7], [7]]) # Rank 0 + tensor([[2], [2], [4], [4], [6], [6], [8], [8]]) # Rank 1 + tensor([[3], [3], [5], [5], [7], [7], [9], [9]]) # Rank 2 + tensor([[4], [4], [6], [6], [8], [8], [10], [10]]) # Rank 3 + """ + if not isinstance( + resharding_spec, shard_spec.ChunkShardingSpec + ) or not isinstance(self._sharding_spec, shard_spec.ChunkShardingSpec): + raise NotImplementedError("Only ChunkShardingSpec supported for reshard.") + + num_local_shards = len(self.local_shards()) + if num_local_shards != 1: + raise NotImplementedError( + f"Only single local shard supported for reshard. Number of shards: {num_local_shards}" + ) + + if self._sharding_spec.dim == resharding_spec.dim: # type: ignore[attr-defined] + if self._sharding_spec.placements == resharding_spec.placements: # type: ignore[attr-defined] + return self + else: + local_shards, shards_metadata = reshuffle_local_shard( + self.local_tensor(), + self.size(), # type: ignore[arg-type] + self._sharding_spec, + resharding_spec, + self._process_group, + ) + else: + local_shards, shards_metadata = reshard_local_shard( + self.local_tensor(), + self.size(), # type: ignore[arg-type] + self._sharding_spec, + resharding_spec, + self._process_group, + ) + self._local_shards = local_shards + self._metadata.shards_metadata = shards_metadata + self._sharding_spec = resharding_spec + return self + + def local_tensor(self) -> torch.Tensor: + """ + Return local tensor for a sharded_tensor. For now we only support single local shard. + + Returns: + A :class:`torch.Tensor` of the local shard. + """ + num_local_shards = len(self.local_shards()) + if num_local_shards != 1: + raise NotImplementedError( + f"Only single local shard is supported. Number of shards: {num_local_shards}" + ) + return self.local_shards()[0].tensor + + @classmethod + @deprecated(DEPRECATE_MSG, category=FutureWarning) + def __torch_function__(cls, func, types, args=(), kwargs=None): + def dispatch(st: ShardedTensor, func: Callable): + # Dispatch to custom user provided op first if it exists. + if func in _CUSTOM_SHARDED_OPS: + return _CUSTOM_SHARDED_OPS[func](types, args, kwargs, st._process_group) + + # Dispatch to custom sharding spec op if it has one. + if _has_custom_op(st._sharding_spec, func): + return _dispatch_custom_op( + st._sharding_spec, func, types, args, kwargs, st._process_group + ) + + if func in _SHARDED_OPS: + return _SHARDED_OPS[func](types, args, kwargs, st._process_group) + + raise RuntimeError( + f"torch function '{func.__name__}', with args: {args} and " + f"kwargs: {kwargs} not supported for ShardedTensor!" + ) + + # Find ShardedTensor instance to get process_group and sharding_spec. + st_instance = None + + def find_sharded_tensor(e): + nonlocal st_instance + if st_instance is None and isinstance(e, ShardedTensor): + st_instance = e + + pytree.tree_map_(find_sharded_tensor, args) + pytree.tree_map_(find_sharded_tensor, kwargs) + + if st_instance is not None: + return dispatch(st_instance, func) + + raise RuntimeError( + f"torch function '{func.__name__}', with args: {args} and " + f"kwargs: {kwargs} not supported for ShardedTensor!" + ) + + def is_pinned(self) -> bool: # type: ignore[override] + """ + Returns True if the sharded tensor (each local shard) resides in pinned memory. + """ + return self._metadata.tensor_properties.pin_memory + + def _register_remote_shards( + self, remote_shards: list[rpc.RRef[Shard]], rpc_rank: int + ): + self._remote_shards[rpc_rank] = remote_shards + + def remote_shards(self) -> dict[int, list[rpc.RRef[Shard]]]: + """ + Returns a Dict[int, RRef] with keys being the RPC rank and values + being RRefs to shards on that rank. Need to initialize the + RPC framework for this functionality. + + Raises an exception if ShardedTensor was created with ``init_rrefs=False`` + """ + if not self._init_rrefs: + raise RuntimeError( + "ShardedTensor created with init_rrefs=False, no RRefs to remote shards available" + ) + return self._remote_shards + + def __hash__(self): + return id(self) + + def __repr__(self) -> str: # type: ignore[override] + return f"ShardedTensor({self._metadata})" + + @dataclass + class ProcessGroupState: + """ + State for ser-de of process group + """ + + local_rank: int + global_rank: int + local_world_size: int + global_world_size: int + + def __getstate__(self): + pg_state = ShardedTensor.ProcessGroupState( + distributed_c10d.get_rank(self._process_group), + distributed_c10d.get_rank(), + distributed_c10d.get_world_size(self._process_group), + distributed_c10d.get_world_size(), + ) + + return ( + self._local_shards, + self._metadata, + pg_state, + self._sharding_spec, + self._init_rrefs, + ) + + def __setstate__(self, state): + self._sharded_tensor_id = None + if not distributed_c10d.is_initialized(): + raise RuntimeError( + "Need to initialize default process group using " + '"init_process_group" before loading ShardedTensor' + ) + + ( + self._local_shards, + self._metadata, + pg_state, + self._sharding_spec, + self._init_rrefs, + ) = state + + # Setup process group + from torch.distributed._shard.api import _get_current_process_group + + self._process_group = _get_current_process_group() + + # Validate process group. + local_rank = distributed_c10d.get_rank(self._process_group) + if pg_state.local_rank != local_rank: + raise RuntimeError( + f"Local rank at save time was {pg_state.local_rank}, but at " + f"load time was {local_rank}" + ) + + global_rank = distributed_c10d.get_rank() + if pg_state.global_rank != global_rank: + raise RuntimeError( + f"Global rank at save time was {pg_state.global_rank}, but at " + f"load time was {global_rank}" + ) + + local_world_size = distributed_c10d.get_world_size(self._process_group) + if pg_state.local_world_size != local_world_size: + raise RuntimeError( + f"Local world size at save time was {pg_state.local_world_size}, " + f"but at load time was {local_world_size}" + ) + + global_world_size = distributed_c10d.get_world_size() + if pg_state.global_world_size != global_world_size: + raise RuntimeError( + f"Global world size at save time was {pg_state.global_world_size}, " + f"but at load time was {global_world_size}" + ) + + self._post_init() + + +def _create_tensor_from_params( + *size, local_device, tensor_properties: TensorProperties +): + """Helper to construct tensor from size, device and common params.""" + dtype = tensor_properties.dtype + layout = tensor_properties.layout + requires_grad = tensor_properties.requires_grad + memory_format = tensor_properties.memory_format + pin_memory = tensor_properties.pin_memory + + return torch.empty( + *size, + dtype=dtype, + layout=layout, + device=local_device, + requires_grad=requires_grad, + memory_format=memory_format, + pin_memory=pin_memory, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/logger.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..ff8cb4d18fb180ea620dd8daad60b5771a9688be --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/logger.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from torch.distributed._shard.sharded_tensor.logging_handlers import _log_handlers + + +__all__: list[str] = [] + + +def _get_or_create_logger() -> logging.Logger: + logging_handler, log_handler_name = _get_logging_handler() + logger = logging.getLogger(f"sharding-spec-{log_handler_name}") + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s" + ) + logging_handler.setFormatter(formatter) + logger.propagate = False + logger.addHandler(logging_handler) + return logger + + +def _get_logging_handler( + destination: str = "default", +) -> tuple[logging.Handler, str]: + log_handler = _log_handlers[destination] + log_handler_name = type(log_handler).__name__ + return (log_handler, log_handler_name) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/logging_handlers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/logging_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..ed6832fd1ae834b6365a6b005b07bbbfffe90726 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/logging_handlers.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + + +__all__: list[str] = [] + +_log_handlers: dict[str, logging.Handler] = { + "default": logging.NullHandler(), +} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/metadata.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..466ca1a0c519ce4cc4ee24fae98ff4ddfbee300a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/metadata.py @@ -0,0 +1,94 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass, field +from enum import Enum + +import torch +from torch.distributed._shard.metadata import ShardMetadata + + +class MEM_FORMAT_ENCODING(Enum): + TORCH_CONTIGUOUS_FORMAT = 0 + TORCH_CHANNELS_LAST = 1 + TORCH_PRESERVE_FORMAT = 2 + + +@dataclass +class TensorProperties: + """Properties used to create :class:`Tensor`""" + + # Regular tensor fields + dtype: torch.dtype = field(default=torch.get_default_dtype()) + layout: torch.layout = field(default=torch.strided) + requires_grad: bool = False + memory_format: torch.memory_format = field(default=torch.contiguous_format) + pin_memory: bool = False + + def __getstate__(self): + # Since torch.memory_format cannot be pickled! + memory_format = self.memory_format + if memory_format == torch.contiguous_format: + mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT + elif memory_format == torch.channels_last: + mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST + elif memory_format == torch.preserve_format: + mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT + else: + raise RuntimeError(f"Invalid torch.memory_format: {memory_format}") + + return ( + self.dtype, + self.layout, + self.requires_grad, + mem_format_encoding, + self.pin_memory, + ) + + def __setstate__( + self, + state, + ): + ( + self.dtype, + self.layout, + self.requires_grad, + mem_format_encoding, + self.pin_memory, + ) = state + + if mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT: + memory_format = torch.contiguous_format + elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST: + memory_format = torch.channels_last + elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT: + memory_format = torch.preserve_format + else: + raise RuntimeError( + f"Invalid torch.memory_format encoding: {mem_format_encoding}" + ) + + self.memory_format = memory_format + + @staticmethod + def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties": + return TensorProperties( + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + memory_format=torch.contiguous_format, + pin_memory=tensor.is_pinned(), + ) + + +@dataclass +class ShardedTensorMetadata: + """ + Represents metadata for :class:`ShardedTensor` + """ + + # Metadata about each shard of the Tensor + shards_metadata: list[ShardMetadata] = field(default_factory=list) + + # Size of each dim of the overall Tensor. + size: torch.Size = field(default=torch.Size([])) + + tensor_properties: TensorProperties = field(default_factory=TensorProperties) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/reshard.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/reshard.py new file mode 100644 index 0000000000000000000000000000000000000000..daef9c3586184e4e62b4a141ec2e43f5025bf454 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/reshard.py @@ -0,0 +1,243 @@ +# mypy: allow-untyped-defs +import copy + +import torch +import torch.distributed as dist +import torch.distributed._shard.sharding_spec as shard_spec +from torch._C._distributed_c10d import ProcessGroup +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed._shard.sharding_spec._internals import ( + get_chunked_dim_size, + get_split_size, +) +from torch.distributed.nn.functional import all_to_all, all_to_all_single + +from .shard import Shard + + +def get_idx_from_placements(placements, current_rank) -> int: + """ + Return the position of the current rank in the given placements. + + Args: + placements(List[Union[_remote_device, str]]): + Specifies the placement of each shard of the Tensor. The size of + the list represents the number of shards to be created. This could + be a list of + :class:`torch.distributed._remote_device`'s. This list + could also contain a string which represents remote + device as accepted by + :class:`torch.distributed._remote_device` + current_rank (int): number of current device. + + Returns: + A int which contains the position of current device in the placement list. + """ + for idx, placement in enumerate(placements): # type: ignore[attr-defined] + if current_rank == placement.rank(): # type: ignore[union-attr] + return idx + raise RuntimeError("current_rank not in the placement.") + + +def build_reshard_metadata( + st_size: torch.Size, + sharding_spec: shard_spec.ShardingSpec, + world_size: int, +) -> tuple[list[ShardMetadata], list[int]]: + """ + Based the given sharding spec, we calculate the offset and local shard size. + We then build a ShardMetadata on top of the calculation result. + + Args: + st_size (torch.Size): The size of the sharded tensor. + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor is sharded. + world_size (int): number of ranks. + + Returns: + A Tuple of the followings: + A List[`ShardMetadata`] which contains the metadata for the shard, including + offsets, lengths and device placement. + A List[int] which contains the ranks in the order of placement. + """ + shard_dim = int(sharding_spec.dim) # type: ignore[attr-defined] + shards_metadata = [None] * world_size + ranks = [] + offsets = [0] * len(st_size) + split_size = get_split_size(st_size[shard_dim], world_size) + for idx, placement in enumerate(sharding_spec.placements): # type: ignore[attr-defined] + ranks.append(placement.rank()) + sharded_dim_size = get_chunked_dim_size(st_size[shard_dim], split_size, idx) + local_tensor_size = list(st_size) + local_tensor_size[shard_dim] = sharded_dim_size + shards_metadata[placement.rank()] = ShardMetadata( # type: ignore[call-overload] + shard_offsets=copy.deepcopy(offsets), + shard_sizes=local_tensor_size, + placement=placement, + ) + offsets[shard_dim] += sharded_dim_size + return shards_metadata, ranks # type: ignore[return-value] + + +def reshuffle_local_shard( + local_shard: torch.Tensor, + st_size: torch.Size, + sharding_spec: shard_spec.ShardingSpec, + resharding_spec: shard_spec.ShardingSpec, + pg: ProcessGroup, +) -> tuple[list[Shard], list[ShardMetadata]]: + """ + Reshuffle the local shard directly when the reshard dim is same as the original + sharding dim. Logically we do this in two step: + 1. To collect all shards based on original sharding spec. + 2. Reshard the tensor based on the given resharding spec. + + In reality, we consolidate the two steps into one by sending the local tensor to + the new shard directly based on the resharding spec. + + Args: + local_shard (Tensor): Local tensor stored in the current rank. + st_size (torch.Size): The size of the sharded tensor. + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor is sharded originally. + resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor will be resharded. + pg (ProcessGroup): The process group to aggregate on. + + Returns: + A Tuple of the followings: + A List[`Shard`] which contains the local tensor and its metadata. + A List[`ShardMetadata`] which contains the metadata for the shard, including + offsets, lengths and device placement. + """ + current_rank = dist.get_rank(pg) + world_size = dist.get_world_size(pg) + # Build shards_metadata first. + shards_metadata, ranks = build_reshard_metadata( + st_size, resharding_spec, world_size + ) + # Get input split size for all2all. + reshard_dim = int(resharding_spec.dim) # type: ignore[attr-defined] + split_size = get_split_size(st_size[reshard_dim], world_size) + input_split_sizes = [0] * world_size + idx = get_idx_from_placements(sharding_spec.placements, current_rank) # type: ignore[attr-defined] + new_rank = resharding_spec.placements[idx].rank() # type: ignore[union-attr, attr-defined] + input_split_sizes[new_rank] = local_shard.size(reshard_dim) + # Get output split size for all2all. + output_split_sizes = [0] * world_size + new_idx = ranks.index(current_rank) + sharded_dim_size = get_chunked_dim_size(st_size[reshard_dim], split_size, new_idx) + output_split_sizes[new_rank] = sharded_dim_size + # Get gathered_input for all2all. + local_shard = local_shard.transpose(0, reshard_dim).contiguous() + gathered_input_size = list(local_shard.size()) + gathered_input_size[0] = sharded_dim_size + gathered_input = torch.empty( + gathered_input_size, device=local_shard.device, dtype=local_shard.dtype + ) + # all2all. + local_shard = all_to_all_single( + gathered_input, + local_shard, + input_split_sizes=input_split_sizes, + output_split_sizes=output_split_sizes, + group=pg, + ) + local_tensor = local_shard.transpose(0, reshard_dim).contiguous() + local_shards = [Shard(local_tensor, shards_metadata[current_rank])] + return local_shards, shards_metadata + + +def reshard_local_shard( + local_tensor: torch.Tensor, + st_size: torch.Size, + sharding_spec: shard_spec.ShardingSpec, + resharding_spec: shard_spec.ShardingSpec, + pg: ProcessGroup, +) -> tuple[list[Shard], list[ShardMetadata]]: + """ + Reshard a sharded tensor given the ``resharding_spec``. When the reshard dim is + different from the original sharding dim, we need to do two steps logically: + 1. To collect all shards based on original sharding spec. + 2. Reshard the tensor based on the given resharding spec. + + In reality, we consolidate the two steps into one by sending each rank the new + shard based on the resharding spec. + + Args: + local_tensor (Tensor): Local tensor stored in the current rank. + st_size (torch.Size): The size of the sharded tensor. + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor is sharded originally. + resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor will be resharded. + pg (ProcessGroup): The process group to aggregate on. + + Returns: + A Tuple of the followings: + A List[`Shard`] which contains the local tensor and its metadata. + A List[`ShardMetadata`] which contains the metadata for the shard, including + offsets, lengths and device placement. + """ + current_rank = dist.get_rank(pg) + world_size = dist.get_world_size(pg) + current_sharding_dim = int(sharding_spec.dim) # type: ignore[attr-defined] + reshard_dim = int(resharding_spec.dim) # type: ignore[attr-defined] + + # Build shards_metadata first. + shards_metadata, ranks = build_reshard_metadata( + st_size, resharding_spec, world_size + ) + + # Compute expected size + input_split_sizes = [ + metadata.shard_sizes[reshard_dim] for metadata in shards_metadata + ] + rearrange_input = any(ranks[i] > ranks[i + 1] for i in range(len(ranks) - 1)) + + if rearrange_input: + # Need to re-arrange reshard_dim of local_tensor before all2all. + indices: list[int] = [] + for metadata in shards_metadata: + offset_start_idx = metadata.shard_offsets[reshard_dim] + split_size = metadata.shard_sizes[reshard_dim] + indices += range(offset_start_idx, offset_start_idx + split_size) + local_tensor = local_tensor.index_select( + reshard_dim, torch.tensor(indices, device=local_tensor.device) + ) + + # Because reshard_dim != original shard_dim. We need to compute the + # size of tensor from each rank. + output_tensor_list = [torch.tensor(1)] * world_size + split_size = get_split_size(st_size[current_sharding_dim], world_size) + rearrange_output_list = False + indices = [] + for idx, placement in enumerate(sharding_spec.placements): # type: ignore[attr-defined] + sharded_dim_size = get_chunked_dim_size( + st_size[current_sharding_dim], split_size, idx + ) + output_tensor_size = list(st_size) + output_tensor_size[current_sharding_dim] = sharded_dim_size + output_tensor_size[reshard_dim] = input_split_sizes[current_rank] + output_tensor_list[placement.rank()] = torch.empty( # type: ignore[union-attr, index] + output_tensor_size, device=local_tensor.device, dtype=local_tensor.dtype + ) + indices.append(placement.rank()) # type: ignore[union-attr, index, arg-type] + if idx != placement.rank(): # type: ignore[union-attr] + rearrange_output_list = True + + # Perform autograd enabled all2all. + input_tensor_tuple = torch.split(local_tensor, input_split_sizes, dim=reshard_dim) + input_tensor_list = [tensor.contiguous() for tensor in input_tensor_tuple] + output_tensor_list = all_to_all( + output_tensor_list, + input_tensor_list, + group=pg, + ) + + if rearrange_output_list: + # Need to re-arrange original shard_dim of output_tensor_list. + output_tensor_list = [output_tensor_list[idx] for idx in indices] # type: ignore[call-overload] + local_tensor = torch.cat(output_tensor_list, dim=current_sharding_dim) + local_shards = [Shard(local_tensor, shards_metadata[current_rank])] + return local_shards, shards_metadata diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/shard.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/shard.py new file mode 100644 index 0000000000000000000000000000000000000000..2d9d4357436a6c15f590a4db486d9d54b6d6ca57 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/shard.py @@ -0,0 +1,61 @@ +from dataclasses import dataclass + +import torch +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed.remote_device import _remote_device + + +@dataclass +class Shard: + """ + Container which holds the data for a shard as a Tensor and also + the associated metadata for that shard. + + Args: + tensor(torch.Tensor): Local tensor for the shard. + metadata(:class `torch.distributed._shard.sharded_tensor.ShardMetadata`): + The metadata for the shard, including offsets, lengths and device placement. + """ + + __slots__ = ["tensor", "metadata"] + tensor: torch.Tensor + metadata: ShardMetadata + + def __post_init__(self) -> None: + # verification between local tensor and metadata + if list(self.tensor.size()) != self.metadata.shard_sizes: + raise ValueError( + "Shard tensor size does not match with metadata.shard_lengths! " + f"Found shard tensor size: {list(self.tensor.size())}, " + f"metadata.shard_lengths: {self.metadata.shard_sizes}, " + ) + placement_device = self.metadata.placement + if ( + placement_device is not None + and placement_device.device() != self.tensor.device + ): + raise ValueError( + f"Local shard tensor device does not match with local Shard's placement! " + f"Found local shard tensor device: {self.tensor.device}, " + f"local shard metadata placement device: {placement_device.device()}" + ) + + @classmethod + def from_tensor_and_offsets( + cls, tensor: torch.Tensor, shard_offsets: list[int], rank: int + ) -> "Shard": + """ + Creates a Shard of a ShardedTensor from a local torch.Tensor, shard_offsets and rank. + + Args: + tensor(torch.Tensor): Local tensor for the shard. + shard_offsets(List[int]): List of integers specify the offset + of the shard on each dimension. + rank(int): Specify the rank for the shard. + """ + shard_sizes = list(tensor.size()) + placement = _remote_device(f"rank:{rank}/{str(tensor.device)}") + shard_meta = ShardMetadata( + shard_offsets=shard_offsets, shard_sizes=shard_sizes, placement=placement + ) + return Shard(tensor, shard_meta) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b323da4ecbfa3adcea51367dc42a6e54d2cd1624 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/utils.py @@ -0,0 +1,325 @@ +# mypy: allow-untyped-defs +import collections.abc +import copy +import itertools +from collections.abc import Sequence +from typing import TYPE_CHECKING + +import torch +from torch.distributed import distributed_c10d as c10d, rpc +from torch.distributed._shard.sharding_spec._internals import ( + check_tensor, + validate_non_overlapping_shards_metadata, +) + +from .metadata import ShardedTensorMetadata, TensorProperties +from .shard import Shard + + +if TYPE_CHECKING: + from torch.distributed._shard.metadata import ShardMetadata + + +def _parse_and_validate_remote_device(pg, remote_device): + if remote_device is None: + raise ValueError("remote device is None") + + worker_name = remote_device.worker_name() + rank = remote_device.rank() + device = remote_device.device() + + # Validate rank, skip validation if rank is not part of process group. + if rank is not None and not c10d._rank_not_in_group(pg): + pg_global_ranks = c10d.get_process_group_ranks(pg) + if rank not in pg_global_ranks: + raise ValueError( + f"Global rank {rank} does not exist in input process group: {pg_global_ranks}" + ) + + if worker_name is not None: + if not rpc._is_current_rpc_agent_set(): + raise RuntimeError( + f"RPC framework needs to be initialized for using worker names: {worker_name}" + ) + + workers = rpc._get_current_rpc_agent().get_worker_infos() + for worker in workers: + if worker.name == worker_name: + return worker.id, device + + raise ValueError(f"Invalid worker name: {worker_name}") + + return rank, device + + +def _validate_output_tensor_for_gather( + my_rank: int, + dst_rank: int, + size: torch.Size, + dst_tensor: torch.Tensor | None, +) -> None: + if dst_rank == my_rank: + if dst_tensor is None: + raise ValueError( + f"Argument ``dst_tensor`` must be specified on destination rank {dst_rank}" + ) + if tuple(size) != (dst_tensor.size()): + raise ValueError( + f"Argument ``dst_tensor`` have size {tuple(dst_tensor.size())}," + f"but should be {tuple(size)}" + ) + elif dst_tensor: + raise ValueError( + "Argument ``dst_tensor`` must NOT be specified on non-destination ranks." + ) + + +def _flatten_tensor_size(size) -> torch.Size: + """ + Checks if tensor size is valid, then flatten/return a torch.Size object. + """ + if len(size) == 1 and isinstance(size[0], collections.abc.Sequence): + # pyrefly: ignore [not-iterable] + dims = list(*size) + else: + dims = list(size) + + for dim in dims: + if not isinstance(dim, int): + raise TypeError(f"size has to be a sequence of ints, found: {dims}") + + return torch.Size(dims) + + +def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True): + if is_local: + assert isinstance(ranks, int) + if expected != actual: + raise ValueError( + f"Local shards' tensor {prop_name} property need to be the same on rank:{ranks}! " + f"Found one local shard tensor {prop_name}={expected}, " + f"the other local shard tensor {prop_name}={actual}." + ) + else: + # compare failure check across ranks, ranks list should have two rank + assert len(ranks) == 2 + if expected != actual: + raise ValueError( + f"ShardedTensor {prop_name} property does not match from different ranks! " + f"Found {prop_name}={expected} on rank:{ranks[0]}, " + f"and {prop_name}={actual} on rank:{ranks[1]}." + ) + + +def build_metadata_from_local_shards( + local_shards: list[Shard], + global_size: torch.Size, + current_rank: int, + pg: c10d.ProcessGroup, +) -> ShardedTensorMetadata: + assert len(local_shards) > 0, "must have local shards!" + local_shard_metadatas: list[ShardMetadata] = [] + + first_shard_dtype = local_shards[0].tensor.dtype + first_shard_layout = local_shards[0].tensor.layout + first_shard_requires_grad = local_shards[0].tensor.requires_grad + first_shard_is_pinned = local_shards[0].tensor.is_pinned() + + # 1). Validate local tensors and associated metadatas + for local_shard in local_shards: + local_shard_tensor = local_shard.tensor + local_shard_meta = local_shard.metadata + local_shard_metadatas.append(local_shard_meta) + rank, local_device = _parse_and_validate_remote_device( + pg, local_shard_meta.placement + ) + + if ( + local_shard_tensor.layout != torch.strided + or local_shard_tensor.layout != first_shard_layout + ): + raise ValueError( + f"Only torch.strided layout is currently supported, but found " + f"{local_shard_tensor.layout} on rank:{current_rank}!" + ) + + if not local_shard_tensor.is_contiguous(): + raise ValueError( + "Only torch.contiguous_format memory_format is currently supported!" + ) + + if rank != current_rank: + raise ValueError( + f"Local shard metadata's rank does not match with the rank in its process group! " + f"Found current rank in the process group: {current_rank}, " + f"local ShardMetadata placement's rank: {rank}" + ) + if local_shard_tensor.device != local_device: + raise ValueError( + f"Local shard tensor device does not match with local Shard's placement! " + f"Found local shard tensor device: {local_shard_tensor.device}, " + f"local shard metadata placement device: {local_device}" + ) + + _raise_if_mismatch( + local_shard_meta.shard_sizes, + list(local_shard_tensor.size()), + "size", + current_rank, + ) + _raise_if_mismatch( + local_shard_tensor.is_pinned(), + first_shard_is_pinned, + "pin_memory", + current_rank, + ) + _raise_if_mismatch( + local_shard_tensor.dtype, first_shard_dtype, "dtype", current_rank + ) + _raise_if_mismatch( + local_shard_tensor.requires_grad, + first_shard_requires_grad, + "requires_grad", + current_rank, + ) + + # 2). Build a "local" ShardedTensorMetadata with all local shards on this rank, then + # do all_gather to collect local_sharded_tensor_metadata from all ranks + local_tensor_properties = TensorProperties( + dtype=first_shard_dtype, + layout=first_shard_layout, + requires_grad=first_shard_requires_grad, + memory_format=torch.contiguous_format, + pin_memory=first_shard_is_pinned, + ) + + local_sharded_tensor_metadata = ShardedTensorMetadata( + shards_metadata=local_shard_metadatas, + size=global_size, + tensor_properties=local_tensor_properties, + ) + + return local_sharded_tensor_metadata + + +def build_global_metadata( + gathered_metadatas: Sequence[ShardedTensorMetadata | None], + recalc_metadata: bool = False, +): + global_sharded_tensor_metadata = None + global_metadata_rank = 0 + + # pyrefly: ignore [bad-assignment] + for rank, rank_metadata in enumerate(gathered_metadatas): + if rank_metadata is None: + continue + + if global_sharded_tensor_metadata is None: + global_sharded_tensor_metadata = copy.deepcopy(rank_metadata) + global_metadata_rank = rank + else: + _raise_if_mismatch( + global_sharded_tensor_metadata.size, + rank_metadata.size, + "global_size", + [global_metadata_rank, rank], + is_local=False, + ) + + # don't need to check layout and memory format as we already checked in local shards validation stage + _raise_if_mismatch( + global_sharded_tensor_metadata.tensor_properties.dtype, + rank_metadata.tensor_properties.dtype, + "dtype", + [global_metadata_rank, rank], + is_local=False, + ) + + _raise_if_mismatch( + global_sharded_tensor_metadata.tensor_properties.requires_grad, + rank_metadata.tensor_properties.requires_grad, + "requires_grad", + [global_metadata_rank, rank], + is_local=False, + ) + + _raise_if_mismatch( + global_sharded_tensor_metadata.tensor_properties.pin_memory, + rank_metadata.tensor_properties.pin_memory, + "pin_memory", + [global_metadata_rank, rank], + is_local=False, + ) + # pass all validations, extend shards metadata + global_sharded_tensor_metadata.shards_metadata.extend( + rank_metadata.shards_metadata + ) + + if global_sharded_tensor_metadata is not None: + if recalc_metadata: + recalc_global_sharded_tensor_metadata( + global_sharded_tensor_metadata, + 0, # sharded on 0th dim + ) + + # check if shards_metadata have overlap shards + validate_non_overlapping_shards_metadata( + global_sharded_tensor_metadata.shards_metadata + ) + + # check if the shards_metadata is compatible with global size of the sharded tensor. + check_tensor( + global_sharded_tensor_metadata.shards_metadata, + global_sharded_tensor_metadata.size, + ) + else: + raise ValueError("ShardedTensor have no local shards on all ranks!") + + return global_sharded_tensor_metadata + + +def recalc_global_sharded_tensor_metadata( + global_sharded_tensor_metadata: ShardedTensorMetadata, sharded_dim: int +) -> None: + # recalculate global ShardedTensorMetadata + + # reorder here in case shard metadata is not sorted on sharded_dim + placement_idx_pairs = [] + for i, shard_metadata in enumerate(global_sharded_tensor_metadata.shards_metadata): + if shard_metadata.placement: + placement_idx_pairs.append((shard_metadata.placement.rank(), i)) + else: + raise AssertionError( + "currently only support rw, it should always have valid rank info" + ) + sorted_idx = sorted(placement_idx_pairs) + shard_sizes = [ + global_sharded_tensor_metadata.shards_metadata[idx].shard_sizes[sharded_dim] + for _, idx in sorted_idx + ] + cum_sum = [0] + list(itertools.accumulate(shard_sizes)) + + for shard_id, shard_metadata in enumerate( + global_sharded_tensor_metadata.shards_metadata + ): + # update shard offset for each shard on the sharded dimension + shard_metadata.shard_offsets[sharded_dim] = cum_sum[shard_id] + for other_dim in range( + len(global_sharded_tensor_metadata.shards_metadata[0].shard_sizes) + ): + if other_dim != sharded_dim: + # shard offset for each shard on the unsharded dimension + shard_metadata.shard_offsets[other_dim] = 0 + + # update global size for ShardedTensorMetadata + global_size_list = [] + for other_dim in range( + len(global_sharded_tensor_metadata.shards_metadata[0].shard_sizes) + ): + if other_dim != sharded_dim: + global_size_list.append( + global_sharded_tensor_metadata.shards_metadata[0].shard_sizes[other_dim] + ) + else: + global_size_list.append(cum_sum[-1]) + global_sharded_tensor_metadata.size = torch.Size(global_size_list) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharder.py new file mode 100644 index 0000000000000000000000000000000000000000..5d91ec15775bea870b81c4b10fb1443a3fba0977 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharder.py @@ -0,0 +1,29 @@ +import abc + +import torch.nn as nn + + +class Sharder(abc.ABC): + """ + This is an interface which allows user to create more advanced + sharding strategies that are not easily be composed by the + `ShardingSpec`. + + :class:`torch.distributed._shard.sharding_plan.ShardingPlan` could + take an object of the `Sharder` and call `shard` to shard the module, + then replace the original module with sharded module returned. + """ + + @abc.abstractmethod + def shard(self, module: nn.Module) -> nn.Module: + """ + Shard a module base on the implementation of this method, and + return the sharded version of the module. + + Args: + module (:class:`torch.nn.Module`): + The module to apply sharding to. + Returns: + A :class:`torch.nn.Module` object that represents a module + that's already been sharded. + """ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..325f7d7eb47b96a79fdc10cc2d1f072cdec9b4ce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/__init__.py @@ -0,0 +1 @@ +from .api import ShardingPlan, ShardingPlanner diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96e0b80aeed82955ae6b290ec154f39552a9c2e5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/api.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccb21dfc8997fd5af983b357dba5cf2bdbdb7e40 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/api.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/api.py new file mode 100644 index 0000000000000000000000000000000000000000..a94f4b54edf2b6c29fd9331ec5e662a793510102 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/api.py @@ -0,0 +1,86 @@ +import abc +from dataclasses import dataclass + +import torch.nn as nn +from torch.distributed._shard.sharder import Sharder +from torch.distributed._shard.sharding_spec import ShardingSpec + + +@dataclass +class ShardingPlan: + """ + Representation of a sharding plan, describes how to shard a module + across hosts. `plan` is used to shard module parameters according to the spec provided, + `output_plan` and `return_local_tensor` are optional, they are used to specify the output + layout of a module with a spec, and when to convert back to data parallel fashion. + + Args: + plan (Dict[str, Union[:class:`torch.distributed._shard.sharding_spec.ShardingSpec`, + :class:`torch.distributed._shard.sharder.Sharder`]): + a dict describes how to shard a module, there're currently two ways to shard a module: + 1. directly shard a module parameter by a `ShardingSpec`, keyed by the name of + a parameter to a `ShardingSpec`. + 2. shard a submodule by applying a `Sharder` on it, keyed by the name of a module + to a `Sharder` object. + output_plan (Dict[str, :class:`torch.distributed._shard.sharding_spec.ShardingSpec`), optional): + a dict specifies the layout of a module's output which produces a ShardedTensor, + keyed by the name of module to ShardingSpec("" in key means the root module). + Default: `None` + return_local_tensor (List[str], optional): a list of string, each element enables + a module's sharded output to be returned as a Tensor from its local shards to + ensure further processing in a data parallel fashion. ("" in list means the + root module). + Default: None + Example: + Suppose we want to shard a module with two linear layers and then run it with DDP, we also + want to convert the output of the second linear layer back to DDP, we can do it as follows: + + >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d) + >>> class MyModule(nn.Module): + >>> def __init__(self) -> None: + >>> super().__init__() + >>> self.fc1 = nn.Linear() + >>> self.gelu = nn.GELU() + >>> self.fc2 = nn.Linear() + >>> self.relu = nn.Linear() + >>> + >>> def forward(self, input): + >>> return self.relu(self.fc2(self.gelu(self.fc1(input)))) + + + >>> # xdoctest: +SKIP("Undefined spec1, spec2) + >>> sharding_plan = ShardingPlan( + >>> plan={ + >>> "fc1.weight": spec1, + >>> "fc2.weight": spec2 + >>> }, + >>> output_plan={ + >>> "fc2": output_spec + >>> }, + >>> return_local_tensor=["fc2"] + >>> ) + """ + + plan: dict[str, ShardingSpec | Sharder] + output_plan: dict[str, ShardingSpec] | None = None + return_local_tensor: list[str] | None = None + + +class ShardingPlanner(abc.ABC): + """ + Default ShardingPlanner interface, can be extended and + implement advanced sharding strategies. + """ + + @abc.abstractmethod + def build_plan(self, module: nn.Module) -> ShardingPlan: + """ + Given a nn.Module, define how to shard the module across + ranks, return a ShardingPlan + Args: + module (:class:`torch.nn.Module`): + The module to apply sharding to. + Returns: + A :class:`torch.distributed._shard.sharding_plan.ShardingPlan` object that + represents how to shard the module. + """ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bfd3f0a7581e8c4352eba843af6d3751bee7f387 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__init__.py @@ -0,0 +1,10 @@ +from torch.distributed._shard.metadata import ShardMetadata + +from .api import ( + _infer_sharding_spec_from_shards_metadata, + DevicePlacementSpec, + EnumerableShardingSpec, + PlacementSpec, + ShardingSpec, +) +from .chunk_sharding_spec import ChunkShardingSpec as ChunkShardingSpec diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a8d92654ad3f04789123e706ad51a048458a11d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/_internals.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/_internals.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cbe080793571bd01399fbd94dbe267205fc7a91 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/_internals.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/api.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3fa39590d319751efb381e141b23e4e1c5821c7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/api.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/chunk_sharding_spec.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/chunk_sharding_spec.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..520283f68e80ed6ff3605485429a2c302829874d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/chunk_sharding_spec.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/_internals.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/_internals.py new file mode 100644 index 0000000000000000000000000000000000000000..486c62a18cd7b91e30ad21891fb0c735e28d443f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/_internals.py @@ -0,0 +1,244 @@ +# mypy: allow-untyped-defs +import math +import sys +from bisect import bisect_right, insort + +from torch.distributed._shard.metadata import ShardMetadata + + +def _check_shard_metadata_pair_overlap(shard1: ShardMetadata, shard2: ShardMetadata): + """ + Checks if two shards overlap. + """ + + # For each dim of each shard, check if one shard resides on the other + # end of second shard with respect to that dim. As an example for a 2D + # shard, we would check if one shard is above or on the left of the + # other shard. + ndims = len(shard1.shard_offsets) + for i in range(ndims): + if shard1.shard_offsets[i] >= shard2.shard_offsets[i] + shard2.shard_sizes[i]: + return False + if shard2.shard_offsets[i] >= shard1.shard_offsets[i] + shard1.shard_sizes[i]: + return False + + return True + + +def _find_nd_overlapping_shards( + shards: list[ShardMetadata], sharded_dims: list[int] +) -> tuple[int, int] | None: + """Find overlapping shards using sweep-line algorithm.""" + if len(shards) <= 1: + return None + + dims = len(sharded_dims) + if dims == 0: + return None + + sweep_dim_idx = 0 + if dims > 1: + max_size = 0 + for i, dim in enumerate(sharded_dims): + dim_size = shards[0].shard_offsets[dim] + shards[0].shard_sizes[dim] + if dim_size > max_size: + max_size = dim_size + sweep_dim_idx = i + sweep_dim = sharded_dims[sweep_dim_idx] + + sorted_indices = sorted( + range(len(shards)), + key=lambda idx: ( + shards[idx].shard_offsets[sweep_dim], + *(shards[idx].shard_offsets[d] for d in sharded_dims if d != sweep_dim), + ), + ) + active: list[tuple[int, int]] = [] + + for idx in sorted_indices: + current = shards[idx] + start = current.shard_offsets[sweep_dim] + end = start + current.shard_sizes[sweep_dim] + + cutoff = bisect_right(active, (start, sys.maxsize)) + if cutoff: + del active[:cutoff] + + for _, other_idx in active: + other = shards[other_idx] + + if _check_shard_metadata_pair_overlap(current, other): + return (other_idx, idx) + insort(active, (end, idx)) + return None + + +def _find_1d_overlapping_shards( + shards: list[ShardMetadata], dim: int +) -> tuple[int, int] | None: + # (begin, end, index_in_shards). Begin and end are inclusive. + intervals = [ + (s.shard_offsets[dim], s.shard_offsets[dim] + s.shard_sizes[dim] - 1, i) + for i, s in enumerate(shards) + ] + intervals.sort() + for i in range(len(shards) - 1): + if intervals[i][1] >= intervals[i + 1][0]: + return (intervals[i][2], intervals[i + 1][2]) + return None + + +def validate_non_overlapping_shards_metadata(shards: list[ShardMetadata]): + """ + Ensures none of the shards overlap with each other. + + Args: + shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing + each shard. + Raises: + ``ValueError`` if there's overlap in any two shards. + """ + if not shards or len(shards) == 1: + return + + sharded_dims: list[int] = [] + for dim in range(len(shards[0].shard_offsets)): + for i in range(1, len(shards)): + if ( + shards[i].shard_offsets[dim] != shards[0].shard_offsets[dim] + or shards[i].shard_sizes[dim] != shards[0].shard_sizes[dim] + ): + sharded_dims.append(dim) + break + + pair: tuple[int, int] | None = None + if len(sharded_dims) == 0: + # if shard is all zeros, we should consider as pass + all_zeros: bool = all( + # strictly limited all offsets to be 0 to pass + # could loose it later on + shard.shard_offsets == [0] * len(shards[0].shard_offsets) + and math.prod(shard.shard_sizes) == 0 # one dimension is 0 + for shard in shards + ) + if all_zeros: + return + # All shards are the same, all dims are not partitioned. Choose any 2. + pair = (0, 1) + elif len(sharded_dims) == 1: + # Shards are partitioned over only one dimension. Overlap can be found + # using a O(nlogn) overlapping interval algorithm. + pair = _find_1d_overlapping_shards(shards, sharded_dims[0]) + else: + # Shards are partitioned over more than one dimension. + # Use sweep-line algorithm for O(n log n) complexity. + pair = _find_nd_overlapping_shards(shards, sharded_dims) + + if pair: + raise ValueError(f"Shards {shards[pair[0]]} and {shards[pair[1]]} overlap") + + +def check_tensor(shards_metadata, tensor_dims) -> None: + """ + Checks if the shards_metadata is compatible with the provided tensor dims. + + Args: + shards_metadata(List[ShardMetadata]): List of :class:`ShardMetadata` + objects representing each shard of the tensor. + tensor_dims(Sequence of int): Dimensions of tensor to verify + Raises: + ``ValueError`` if not compatible. + """ + + # If the tensor's volume matches the total volume of all shards and + # all shard boundaries are within tensor dims, we have a compatible + # sharding spec for this tensor. Note that we have already verified + # we don't have overlapping shards. + tensor_rank = len(tensor_dims) + shards_rank = len(shards_metadata[0].shard_offsets) + if tensor_rank != shards_rank: + raise ValueError( + f"Rank of tensor is {tensor_rank}, but shards rank is {shards_rank}" + ) + + total_shard_volume = 0 + for shard in shards_metadata: + shard_volume = 1 + for i, shard_length in enumerate(shard.shard_sizes): + shard_volume *= shard_length + if shard.shard_offsets[i] + shard.shard_sizes[i] > tensor_dims[i]: + raise ValueError( + f"Shard offset {shard.shard_offsets[i]} and length " + f"{shard.shard_sizes[i]} exceeds tensor dim: {tensor_dims[i]} for shard {shard}" + ) + total_shard_volume += shard_volume + + tensor_volume = 1 + for size in tensor_dims: + tensor_volume *= size + + if total_shard_volume != tensor_volume: + # TODO: Can we improve this error message to point out the gaps? + raise ValueError( + f"Total volume of shards: {total_shard_volume} " + f"does not match tensor volume: {tensor_volume}, in other words " + f"all the individual shards do not cover the entire tensor" + ) + + +def get_split_size(dim_size, chunks): + """ + Computes the split size inline with ``torch.chunk`` + + Args: + dim_size(int): Size of the dimension being chunked. + chunks(int): Number of chunks to create for ``dim_size``. + + Returns: + An int indicating the split size to use. + """ + return (dim_size + chunks - 1) // chunks + + +def get_chunked_dim_size(dim_size, split_size, idx): + """ + Computes the dim size of the chunk for provided ``idx`` given ``dim_size`` + and ``split_size``. + + Args: + dim_size(int): Size of the dimension being chunked. + split_size(int): The chunk size for each chunk of ``dim_size``. + idx(int): The index of chunk whose dim size is being requested. + + Returns: + An int indicating the dim size of the chunk. + """ + return max(min(dim_size, split_size * (idx + 1)) - split_size * idx, 0) + + +def get_chunk_sharding_params(sharding_dim_size, world_size, spec, rank): + """ + Generate the start pos and offset length for the current rank for + chunk sharding. + + Args: + sharding_dim_size(int): The dimension length which we shard on. + world_size(int): number of ranks. + spec (:class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec`): + sharding spec. + rank(int): # of cuda process. + + Returns: + start_pos(int): start position of sharded tensor on the given rank. + chunk_size(int): chunk size of sharded tensor on the given rank. + """ + split_size = get_split_size(sharding_dim_size, world_size) + current_offsets = 0 + start_pos = current_offsets + for idx, placement in enumerate(spec.placements): + chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) + if rank == placement.rank(): + start_pos = current_offsets + break + current_offsets += chunk_size + return start_pos, chunk_size # type: ignore[possibly-undefined] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/api.py new file mode 100644 index 0000000000000000000000000000000000000000..87a49abdb5c05dcfe3db1fdf734dc9f3bef3b4bf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/api.py @@ -0,0 +1,264 @@ +# mypy: allow-untyped-defs +import functools +import operator +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch +import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed._shard.op_registry_utils import _decorator_func + +from ._internals import ( + check_tensor, + get_chunked_dim_size, + get_split_size, + validate_non_overlapping_shards_metadata, +) + + +if TYPE_CHECKING: + # Only include ShardedTensor when do type checking, exclude it + # from run-time to resolve circular dependency. + from torch.distributed._shard.sharded_tensor import ShardedTensor + + +class PlacementSpec(ABC): # noqa: B024 + """ + Base class representing the placement of an entity. Subclasses of this + class can be used to specify customized placements which might not be + covered by existing APIs. + """ + + +@dataclass +class DevicePlacementSpec(PlacementSpec): + """ + Associates placement of an entity with a single device. + + Args: + device(:class:`torch.distributed._remote_device`): The device to place the entity on. + """ + + device: torch.distributed._remote_device + + def __post_init__(self): + if not isinstance(self.device, torch.distributed._remote_device): + self.device = torch.distributed._remote_device(self.device) + + +class ShardingSpec(ABC): + """ + Base class representing sharding specifications. + """ + + @abstractmethod + def build_metadata( + self, + tensor_sizes: torch.Size, + tensor_properties: sharded_tensor_meta.TensorProperties, + ) -> sharded_tensor_meta.ShardedTensorMetadata: + """ + Given a global tensor size, define how to shard a tensor like this shape + across ranks, return ShardedTensorMetadata + Args: + tensor_sizes (:class:`torch.Size`): + The tensor shape to shard on, a `torch.Size` object that represents the + tensor shape to be sharded according to the ShardingSpec. + tensor_properties(:class:`torch.distributed._shard.sharded_tensor.TensorProperties): + Tensor properties used to create a ShardedTensor. + Returns: + A :class:`ShardedTensorMetadata` object that encodes the information about + the layout of the ShardedTensor and its properties. + """ + + @abstractmethod + def shard( + self, tensor: torch.Tensor, src_rank: int = 0, process_group=None + ) -> "ShardedTensor": + """ + Given a global tensor on src_rank, shard this tensor + across ranks within the process group, return a ShardedTensor. + Args: + tensor (:class:`torch.Tensor`): Tensor needs to be sharded. + Keyword args: + src_rank (int, optional): The source rank which is used as the ground truth of + the data for the parameter that would be sharded and scattered + across the rest of the ranks. + Default: 0. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + Returns: + A :class:`ShardedTensor` sharded from the given tensor. + """ + + +# Ops customized for a particular ShardingSpec. +_CUSTOM_SHARDING_SPEC_OPS: dict[str, dict[Callable, Callable]] = {} + + +def _has_custom_op(sharding_spec, op): + """ + Returns whether or not the ShardingSpec has a custom op implementation. + """ + class_name = type(sharding_spec).__qualname__ + return ( + class_name in _CUSTOM_SHARDING_SPEC_OPS + and op in _CUSTOM_SHARDING_SPEC_OPS[class_name] + ) + + +def _dispatch_custom_op( + sharding_spec, op: Callable, types, args, kwargs, process_group +): + """ + Calls the custom op for this ShardingSpec if it exists. + """ + class_name = type(sharding_spec).__qualname__ + if not _has_custom_op(sharding_spec, op): + raise RuntimeError(f"Custom op: {op} not registered for {class_name}") + func = _CUSTOM_SHARDING_SPEC_OPS[class_name][op] + return func(types, args, kwargs, process_group) + + +def custom_sharding_spec_op(sharding_spec_class, func): + """ + Decorator to allow custom registration of ops. + Args: + sharding_spec_class(type): The ShardingSpec for which we need to add this custom op. + func(Callable): The op to override (ex: torch.bmm) + """ + class_name = sharding_spec_class.__qualname__ + if class_name not in _CUSTOM_SHARDING_SPEC_OPS: + _CUSTOM_SHARDING_SPEC_OPS[class_name] = {} + return functools.partial( + _decorator_func, op=func, op_table=_CUSTOM_SHARDING_SPEC_OPS[class_name] + ) + + +@dataclass +class EnumerableShardingSpec(ShardingSpec): + """ + This is a type of PlacementSpec that allows users to specify a generic + sharding scheme by enumerating exactly how each shard is laid out. + + Args: + shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing + each shard. Note that none of the shards should overlap. + """ + + shards: list[ShardMetadata] + + def __post_init__(self): + if len(self.shards) == 0: + raise ValueError(f"Empty shard list provided: {self.shards}") + + # Validate each shard has same rank. + rank = -1 + for shard in self.shards: + if rank != -1 and rank != len(shard.shard_offsets): + raise ValueError( + f"Found inconsistent ranks for shards: {rank} and {len(shard.shard_offsets)}" + ) + rank = len(shard.shard_offsets) + + validate_non_overlapping_shards_metadata(self.shards) + + def build_metadata( + self, + tensor_sizes: torch.Size, + tensor_properties: sharded_tensor_meta.TensorProperties, + ) -> sharded_tensor_meta.ShardedTensorMetadata: + # check if shards form a valid tensor + check_tensor(self.shards, tensor_sizes) + return sharded_tensor_meta.ShardedTensorMetadata( + self.shards, tensor_sizes, tensor_properties + ) + + def shard( + self, tensor: torch.Tensor, src_rank: int = 0, process_group=None + ) -> "ShardedTensor": + # TODO: figure out a generic and efficient way to scatter the shards for EnumerableShardingSpec + raise NotImplementedError("EnumerableShardingSpec.shard not implemented yet!") + + +def _infer_sharding_spec_from_shards_metadata(shards_metadata): + """ + Infer the sharding spec from the metadata of each shard of a ShardedTensor. + If the tensor is sharded only on one dimension, we can then verify whether it's + a ChunkShardingSpec or not. The way to verify it is to first get the total length + and perform a chunk sharding with the given placements to see if we can have the + same chunk size as the given shards_metadata. If not, we assume it's enum sharded. + + Args: + shards_metadata (List[ShardMetadata]): List of Metadata of local shards. + + Returns: + A :class:`torch.distributed._shard.sharding_spec.ShardingSpec` object of sharding + spec for one sharded tensor. + """ + placements = [] + chunk_sharding_dim = None + chunk_offset_list = [] + shard_size_list = [] + shard_offset_list = [] + # collect local shard metadatas from the global sharded_tensor_metadata + for shard_metadata in shards_metadata: # type: ignore[attr-defined] + placements.append(shard_metadata.placement) + local_offsets = shard_metadata.shard_offsets + chunk_offset_list.append(sum(local_offsets)) + shard_size_list.append(shard_metadata.shard_sizes) + shard_offset_list.append(shard_metadata.shard_offsets) + shard_dims = [idx for idx, e in enumerate(local_offsets) if e != 0] + # If the offset is [0, 0, ..., 0] (all zeros), + # we cannot decide whether how the tensor is sharded. + if len(shard_dims) == 0: + continue + # If the offset is [0, N, .,0, M, 0, .., 0], + # we are sure it's sharded by more than one dimension. + if len(shard_dims) != 1: + chunk_sharding_dim = None + break + # If the offset is [0, 0, .,0, M, 0, .., 0], aka, it's sharded by just + # one dimension, we need to make sure all ranks share the same dimension. + if not chunk_sharding_dim: + chunk_sharding_dim = shard_dims[0] + elif chunk_sharding_dim != shard_dims[0]: + chunk_sharding_dim = None + break + + if chunk_sharding_dim is not None: + # Ensure we infer the correct placement order from offsets + placements = [ + x + for _, x in sorted( + zip(chunk_offset_list, placements), key=operator.itemgetter(0) + ) + ] + + from .chunk_sharding_spec import ChunkShardingSpec + + chunk_spec = ChunkShardingSpec( + dim=chunk_sharding_dim, + placements=placements, + ) + + shard_sizes = sorted([x[chunk_sharding_dim] for x in shard_size_list]) + shard_total_length = sum(shard_sizes) + shard_offsets = sorted([x[chunk_sharding_dim] for x in shard_offset_list]) + + chunks = len(placements) + split_size = get_split_size(shard_total_length, chunks) + chunk_shard_sizes = sorted( + [ + get_chunked_dim_size(shard_total_length, split_size, idx) + for idx in range(chunks) + ] + ) + # Should match ChunkShardingSpec offsets calculation + chunk_shard_offsets = [split_size * idx for idx in range(chunks)] + if shard_sizes == chunk_shard_sizes and shard_offsets == chunk_shard_offsets: + return chunk_spec + return EnumerableShardingSpec(shards_metadata) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..4d7b11b7c16c567b0b71fc6a0858dc58b7977ebf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py @@ -0,0 +1,229 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from typing import cast, TYPE_CHECKING + +import torch +import torch.distributed as dist +import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta +import torch.distributed.distributed_c10d as distributed_c10d +from torch.distributed._shard._utils import narrow_tensor +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed._shard.sharded_tensor.shard import Shard +from torch.distributed._shard.sharded_tensor.utils import ( + _parse_and_validate_remote_device, +) + +from ._internals import get_chunked_dim_size, get_split_size +from .api import ShardingSpec + + +if TYPE_CHECKING: + # Only include ShardedTensor when do type checking, exclude it + # from run-time to resolve circular dependency. + from torch.distributed._shard.sharded_tensor import ShardedTensor + + +@dataclass +class ChunkShardingSpec(ShardingSpec): + """ + This is a type of PlacementSpec that defines the placement as being sharded + across multiple devices. In particular, it represents sharding a Tensor + along a single dimension into equal chunks (similar to :meth:`torch.chunk`). + + The semantics of how a tensor is partitioned is inline with + :meth:`torch.chunk`, where ``dim`` in torch.chunk corresponds to the + specified ``dim`` and ``chunks`` in torch.chunk is the number of elements + in the placement specified. + + Args: + dim (int or str): + The dimension to shard on, could be an integer representing the + dimension or a string in case of named tensors where dimensions are + named. Note that named tensor support is not added yet. + placement(List[Union[_remote_device, str]]): + Specifies the placement of each shard of the Tensor. The size of + the list represents the number of shards to be created. This could + be a list of + :class:`torch.distributed._remote_device`'s. This list + could also contain a string which represents remote + device as accepted by + :class:`torch.distributed._remote_device` + """ + + ShardingDim = int | str + + dim: ShardingDim + placements: list[torch.distributed._remote_device | str] + + def __post_init__(self): + self._verify_dim(self.dim) + for i, remote_device in enumerate(self.placements): + if not isinstance(remote_device, torch.distributed._remote_device): + self.placements[i] = torch.distributed._remote_device(remote_device) + + @staticmethod + def _verify_dim(dim): + # Validate the sharding spec. + # TODO: support named dimension + if isinstance(dim, str): + raise NotImplementedError( + "ChunkShardingSpec does not support named dimension yet!" + ) + + if not isinstance(dim, int): + raise ValueError(f"Sharding dim needs to be an integer, found: {dim}") + + def build_metadata( + self, + tensor_sizes: torch.Size, + tensor_properties: sharded_tensor_meta.TensorProperties, + ) -> sharded_tensor_meta.ShardedTensorMetadata: + tensor_num_dim = len(tensor_sizes) + + self._verify_dim(self.dim) + if self.dim >= tensor_num_dim or self.dim < -tensor_num_dim: # type: ignore[operator] + raise ValueError(f"Invalid sharding dim: {self.dim}") + + shards_metadata = [] + sharding_dim_size = tensor_sizes[self.dim] # type: ignore[index] + chunks = len(self.placements) + split_size = get_split_size(sharding_dim_size, chunks) + for idx, placement in enumerate(self.placements): + # generate ShardMetadata for each placement device + chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) + shard_size = list(tensor_sizes) + current_offsets = [0] * tensor_num_dim + current_offsets[self.dim] = split_size * idx # type: ignore[index] + shard_size[self.dim] = chunked_dim_size # type: ignore[index] + + shard_metadata = ShardMetadata( + shard_offsets=current_offsets, + shard_sizes=shard_size, + placement=placement, + ) + shards_metadata.append(shard_metadata) + + return sharded_tensor_meta.ShardedTensorMetadata( + shards_metadata, tensor_sizes, tensor_properties + ) + + def shard( + self, tensor: torch.Tensor, src_rank: int = 0, process_group=None + ) -> "ShardedTensor": + """ + Args: + src_rank: group rank relative to ``process_group`` + + N.B. If ``process_group`` is None, ``src_rank`` is a global rank. + """ + # relative imports to avoid circular dependency + from torch.distributed._shard.sharded_tensor import ShardedTensor + + tensor_properties = sharded_tensor_meta.TensorProperties( + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + memory_format=torch.contiguous_format, + pin_memory=tensor.is_pinned(), + ) + current_rank = dist.get_rank(process_group) + current_global_rank = dist.get_rank() + tensor_meta = self.build_metadata(tensor.size(), tensor_properties) + local_shards = [] + local_tensor = None + local_metadata = None + + tensors_to_scatter = cast( + list[torch.Tensor | None], + [None] * dist.get_world_size(process_group), + ) + + sharding_dim_size = tensor.size()[self.dim] # type: ignore[index] + chunks = len(self.placements) + split_size = get_split_size(sharding_dim_size, chunks) + scatter_shape = list(tensor.size()) + scatter_shape[self.dim] = split_size # type: ignore[index] + + for shard_meta in tensor_meta.shards_metadata: + remote_global_rank, device = _parse_and_validate_remote_device( + process_group, shard_meta.placement + ) + if current_rank == src_rank: + # Reshape to get shard for this rank and we don't want autograd + # recording here for the narrow op and 'local_shard' should be a + # leaf variable in the autograd graph. + narrowed_tensor = narrow_tensor(tensor, shard_meta) + if shard_meta.shard_sizes[self.dim] < split_size: # type: ignore[index] + # for the last shard that might be smaller to other shards + # resize the narrowed tensor to the same size and use it for + # the scatter collective as dist.scatter requires same size + # inputs on every rank + tensor_to_scatter = ( + narrowed_tensor.detach().clone().resize_(scatter_shape) + ) + else: + tensor_to_scatter = narrowed_tensor.detach().clone( + memory_format=torch.contiguous_format + ) + + tensors_to_scatter[ + # pyrefly: ignore [bad-argument-type] + dist.get_group_rank(process_group, remote_global_rank) + ] = tensor_to_scatter + + if current_global_rank == remote_global_rank: + local_tensor = torch.empty( + scatter_shape, + dtype=tensor.dtype, + layout=tensor.layout, + device=device, + ) + local_metadata = shard_meta + + # each rank should have local_tensor and local_metadata initialized if we build + # the metadata list in a correct way. + assert local_tensor is not None + assert local_metadata is not None + + # Scatter the shards to all ranks in the pg + # scatter takes the global rank as ``src`` + src_for_scatter = src_rank + if ( + process_group is not None + and process_group is not distributed_c10d._get_default_group() + ): + src_for_scatter = distributed_c10d.get_global_rank( + process_group, src_for_scatter + ) + + tensors_to_scatter_: list[torch.Tensor] | None = None + if current_rank == src_rank: + tensors_to_scatter_ = [] + for t in tensors_to_scatter: + assert isinstance(t, torch.Tensor) + tensors_to_scatter_.append(t) + + dist.scatter( + local_tensor, + scatter_list=tensors_to_scatter_, + src=src_for_scatter, + group=process_group, + ) + + if list(local_tensor.size()) != local_metadata.shard_sizes: + # detach again after receiving to ensure local shards remain a leaf node + local_tensor = local_tensor.resize_(local_metadata.shard_sizes).detach() + + # Sync requires_grad to local_shard. + local_tensor.requires_grad = tensor.requires_grad + + local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata)) + + st = ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards, tensor_meta, process_group=process_group + ) + + # Manually set sharding_spec + st._sharding_spec = self + + return st diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..702085218f37a15b25180f9d55e43aad2ff6bdf1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/_common.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/_common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99cfa4b63dae05095caa03d574ef172661b84da4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/_common.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a14279f15b3157bd5ca286456700eb7da5dce10a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding_bag.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding_bag.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4f1129e9620516d51e69666a841ed4909468935 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding_bag.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py new file mode 100644 index 0000000000000000000000000000000000000000..3a8a05fe79d19d2dc67e6ff535ae419e255192c1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py @@ -0,0 +1,350 @@ +# mypy: allow-untyped-defs + +import torch +import torch.distributed as dist +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed._shard.sharded_tensor._ops._common import _sharded_op_common +from torch.distributed._shard.sharding_spec import ChunkShardingSpec +from torch.distributed._shard.sharding_spec._internals import ( + get_chunk_sharding_params, + get_chunked_dim_size, + get_split_size, +) +from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op +from torch.distributed.nn.functional import ( + _all_gather_base, + all_reduce, + all_to_all_single, +) + + +def _chunk_sharding_spec_check(spec, op): + """ + For the given op implementation check if the sharding spec is ChunkShardingSpec. + """ + if not isinstance(spec, ChunkShardingSpec): + raise NotImplementedError( + f"Only ChunkShardingSpec supported for '{op.__name__}'." + ) + + +def _register_sharded_op_on_local_tensor( + op, early_stop_func=None, extra_check=None, customized_func=None +): + """ + Handles ``__torch_function__`` dispatch for ops which are performed on + the single local tensor of the sharded tensor such as op like + ``torch.nn.functional.softmax`` or ``torch.Tensor.view``. + + For more complicated ops, a customized func can be used to generate + the new local tensor, sharding spec and sharded tensor size. + + Args: + op: The op to be registered and applied to all shards of the st. + early_stop_func (Callable, optional): the func for early stop. + Default: if ``None``, no early stop. + extra_check (Callable, optional): the func for extra condition check. + Default: if ``None``, no extra check. + customized_func (Callable, optional): the func for customized logic + to generate the new local tensor, sharding spec and sharded tensor size. + Default: if ``None``, we simply lower to the real op call with + the single local tensor of the st. + + Return: + func (Callable): registered implementation for sharded op for + ``__torch_function__`` dispatch. + """ + + @custom_sharding_spec_op(ChunkShardingSpec, op) + @_sharded_op_common(op, early_stop_func, extra_check) + def sharded_tensor_op_on_local_tensor(types, args=(), kwargs=None, pg=None): + # pyrefly: ignore [index-error] + st = args[0] + sharding_spec = st.sharding_spec() + if len(st.local_shards()) != 1: + raise TypeError( + f"torch function '{op.__name__}', with args: {args} and " + f"kwargs: {kwargs} only supported for single local tensor!" + ) + st_size = st.size() + if customized_func: + local_tensor, sharding_spec, st_size = customized_func(args, kwargs, pg) + else: + args = (st.local_tensor(), *args[1:]) + local_tensor = op(*args, **kwargs) + return ShardedTensor._init_from_local_tensor( + local_tensor.contiguous(), + sharding_spec, + st_size, # type: ignore[arg-type] + process_group=pg, + init_rrefs=st._init_rrefs, + ) + + +def _handle_col_wise_sharding_base( + op_func, + col_dim, + input, + world_size, + weight, + local_shard, + pg, + gathered_inputs, + mode=None, + gathered_per_sample_weights=None, + gathered_offsets=None, + padding_idx=None, +): + """ + For col-wise sharding of weight, lots of logic are common. + So we extract the common logic and put in this function: + Step 1. To get input from each rank and + Step 2. To perform the op on the concatenated tensor. + Step 3. To distribute results to each rank with col rearrangement. + Step 4. To concatenate all results from all ranks. + + Args: + op_func: operator which is applied to the input tensor. + col_dim: dim of result tensor after the operation. + input: tensor to be applied op on. + world_size: number of ranks. + weight: sharded weight tensor. + local_shard: col-wise sharded weight tensor. + pg: process group. + gathered_inputs: list of inputs from all ranks. If specified, we + don't need to communicate with each rank any more. + mode: aggregation mode of EmbeddingBag. + gathered_per_sample_weights: per_sample_weights across all ranks. + gathered_offsets: offsets across all ranks. + padding_idx: If specified, the entries at padding_idx do + not contribute to the gradient; therefore, the embedding + vector at padding_idx is not updated during training, + i.e. it remains as a fixed "pad". + Note that the embedding vector at padding_idx is + excluded from the reduction. + + Return: final result of input being applied with the op. + """ + # run the operator's function for all the inputs. + results = [] + for i, inp in enumerate(gathered_inputs): + if op_func is torch.nn.functional.embedding_bag: + result = op_func( + inp, + local_shard, + offsets=gathered_offsets[i] if gathered_offsets is not None else None, + # pyrefly: ignore [bad-argument-type] + mode=mode, + per_sample_weights=gathered_per_sample_weights[i] + if gathered_per_sample_weights is not None + else None, + padding_idx=padding_idx, + ) + elif op_func is torch.nn.functional.embedding: + result = op_func( + inp, + local_shard, + padding_idx=padding_idx, + ) + else: + result = op_func(inp, local_shard) + results.append(torch.transpose(result, 0, col_dim)) + + # Distribute results to each rank with col rearrangement. + output = _result_distribute_with_col_rearrange( + results, input, world_size, weight, pg + ) + + # transpose the output and return result. + return torch.transpose(output, 0, col_dim) + + +def _result_distribute_with_col_rearrange(results, input, world_size, weight, pg): + """ + For col-wise sharding of weight, we need to distribute + results to each rank. We do them in this function. + Note that, if the index in the Sharding Spec is not equal to + the rank number, we need to do the rearrangement based on the + order given by the Sharding Spec (placement). + + Args: + results: results from ops applied to inputs from all ranks. + We need to distribute them back to their original ranks. + input: tensor to be applied op to. + world_size: number of ranks. + weight: sharded weight tensor. + pg: process group. + + Return: column rearranged result. + """ + # Process results and outputs for all2all. + sharding_dim = weight._sharding_spec.dim + sharding_dim_size = weight.size(sharding_dim) + dims = list(results[0].size()) + dims[0] = sharding_dim_size + combined_results = torch.cat(results) + output = torch.empty( + *dims, device=combined_results.device, dtype=combined_results.dtype + ) + + # Compute output splits + split_size = get_split_size(sharding_dim_size, world_size) + output_split_sizes = [0] * world_size + for idx, placement in enumerate(weight._sharding_spec.placements): + output_split_sizes[placement.rank()] = get_chunked_dim_size( + sharding_dim_size, split_size, idx + ) + + # distribute the outputs using all2all. + output = all_to_all_single( + output, combined_results, output_split_sizes=output_split_sizes, group=pg + ) + + # Check if we need to rearrange columns appropriately for output. + rearrange_columns = any( + idx != placement.rank() + for idx, placement in enumerate(weight._sharding_spec.placements) + ) + if not rearrange_columns: + return output + + indices = [] + for placement in weight._sharding_spec.placements: + dim_size = output_split_sizes[placement.rank()] + start = sum( + split_size if i < placement.rank() else 0 + for i, split_size in enumerate(output_split_sizes) + ) + indices += list(range(start, start + dim_size)) + + return output.index_select(0, torch.tensor(indices, device=output.device)) + + +def _handle_max_norm_col_wise( + max_norm, + norm_type, + local_shard, + input, + world_size, + gathered_inputs, + pg, +): + """ + For col-wise sharding of weight, we need to aggregate the + norm across all ranks before we can perform the proper re-norm. + Note that, the max_norm logic is only applied to the embedding + indices that are looked up and not the whole shard. + + Args: + max_norm: If given, each embedding vector with norm larger + than max_norm is renormalized to have norm max_norm. + Note: this will modify weight in-place. + norm_type: The p in the p-norm to compute for the max_norm option. + local_shard: col-wise shared local weight used for lookup. + input: tensor to be applied op to. + world_size: number of ranks. + gathered_inputs: list of inputs from all ranks. + pg: process group. + + Return: + local_shard_norm_renormed: local_shard re-normed to max_norm if the norm is larger + than it. + + """ + norm_type = norm_type if norm_type is not None else 2.0 + unique_inp = torch.unique(torch.cat(gathered_inputs)) + local_shard_sum = torch.sum( + torch.pow(torch.abs(local_shard), norm_type), dim=1, dtype=local_shard.dtype + ) + # For col-wise sharding, we need to first aggregate the powered sum + # from each rank first and then calculate the norm. + local_shard_sum = all_reduce(local_shard_sum, group=pg) + local_shard_norm = torch.pow(local_shard_sum, 1.0 / norm_type) + max_norm_tensor = torch.full( + (local_shard.size(0),), + float("inf"), + dtype=local_shard.dtype, + device=input.device, + ) + max_norm_tensor[unique_inp] = max_norm + local_shard_t = local_shard.t().contiguous() + normalized_tensor = torch.where( + local_shard_norm > max_norm_tensor, max_norm_tensor, local_shard_norm + ) + # Make sure divisor is not zero. + local_shard_norm[local_shard_norm == 0.0] = 1.0 + local_shard_norm_renormed = ( + torch.div(torch.mul(local_shard_t, normalized_tensor), local_shard_norm) + .t() + .contiguous() + ) + return local_shard_norm_renormed + + +def _all_gather_base_input(input, pg): + """ + Use _all_gather_base to get a concatenated input from each rank. + + Args: + input: tensor to be applied op on. + pg: process group. + + Returns: + gathered_inputs: input gathered from each rank and concat by dim 0. + """ + # allgather the inputs first. + gather_inp_size = list(input.size()) + gather_inp_size[0] = input.size(0) * dist.get_world_size(pg) + gather_inp = torch.empty(gather_inp_size, device=input.device, dtype=input.dtype) + return _all_gather_base(gather_inp, input, group=pg) + + +def _handle_row_wise_mask(gather_inp, padding_idx, weight, world_size, rank): + """ + Mask the input for embedding look-up for IDs which are not stored + on the current rank. This function also adjust the ``padding_idx`` + so that it is only used on the rank where the corresponding row is + stored. + + Note that, with ``max_norm`` flag on, only weights of rows being + looked up will be re-normed. So we need an extra row for masked ID + so that it does not affect the final result and ``max_norm``. + + Args: + gather_inp: tensor to be applied op on gathered from all ranks. + padding_idx: If specified, the entries at padding_idx do + not contribute to the gradient; therefore, the embedding + vector at padding_idx is not updated during training, + i.e. it remains as a fixed "pad". + Note that the embedding vector at padding_idx is + excluded from the reduction. + weight: weight tensor of Embedding look-up table. + world_size: number of ranks. + rank: # of cuda process. + + Returns: + lookup_input: Tensor of masked input. + padding_idx: adjusted padding_idx. + padding_row: The extra row we used during lookup so that + looking up does not affect ``max_norm``. + """ + (start_pos, chunk_size) = get_chunk_sharding_params( + weight.size(0), world_size, weight._sharding_spec, rank + ) + mask = (gather_inp < start_pos) | (gather_inp >= start_pos + chunk_size) + lookup_input = gather_inp.clone() - start_pos + lookup_input[mask] = chunk_size + if ( + padding_idx is not None + and padding_idx >= start_pos + and padding_idx < (start_pos + chunk_size) + ): + padding_idx = padding_idx - start_pos + else: + padding_idx = None + + # When max_norm is set, it will only re-norm the row being looked up. + padding_row = torch.zeros( + 1, weight.size(1), device=gather_inp.device, dtype=weight.dtype + ) + return lookup_input, padding_idx, padding_row diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..117aed79520d9ad78c10bdd2310fb6b032c2a024 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py @@ -0,0 +1,294 @@ +# mypy: allow-untyped-defs + +import torch +import torch.distributed as dist +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed._shard.sharding_spec import ChunkShardingSpec +from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op +from torch.distributed.nn.functional import all_gather, reduce_scatter + +from ._common import ( + _all_gather_base_input, + _handle_col_wise_sharding_base, + _handle_max_norm_col_wise, + _handle_row_wise_mask, +) + + +@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.embedding) +def sharded_embedding(types, args, kwargs, pg): + """ + Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``. + This method computes a sharded embedding lookup and has the following limitations: + + 1. Supports only sharding of ``weight``. + 2. Supports only ``ChunkShardingSpec``. + 3. Supports only a single local shard per rank. + 4. Supports all specs except for scale_grad_by_freq, sparse, etc. + + Based on the dimension that the weight is sharded on, there are two + algorithms: + + ROWWISE SHARDING + ================ + For row-wise sharding the weight is sharded on dimension 0. + + The overall algorithm can be best explained with an example. Let's assume + the dims for input are (4 x 6) and W are (10 x 17) and W is sharded across + 4 GPUs creating 3 shard of (3 x 17) and 1 shard of (1 x 17). + The algorithm is as follows: + + 1. First the input is all gathered to all ranks, since this is SPMD and + input is actually sharded across all ranks. The inputs then become a + 4 (4 x 6) tensor on each rank. For example if the given input is + tensor([[6, 5, 2, 9, 6, 3], + [3, 1, 2, 4, 7, 6], + [4, 0, 4, 9, 8, 9], + [8, 6, 6, 4, 6, 1]]) + on rank 0. + Then on every rank, we will have this tensor. + If input itself is already replicated, no all-gather will be done. + 2. Next, we mask the ID which are not stored on that rank. + For example on rank 0, we store ID [0, 1, 2]. We only keep the ID + inside the set of numbers. The rest of them will be masked to an extra row. + The masked matrix will be used for embedding look up and is like: + tensor([[4, 4, 2, 4, 4, 4], + [4, 1, 2, 4, 4, 4], + [4, 0, 4, 4, 4, 4], + [4, 4, 4, 4, 4, 1]]) + The reason of having an extra row (aka, number 4 in the example) is + because when max_norm is specified only weight which has looked will + be re-normed so mask IDs whose embeddings are not stored in current + rank will to an extra row will ensure max_norm still works as expected. + 3. If max_norm is specified, the extra row guarantees that the mask ID will + not affect the behavior of weigh re-norm. + + COLWISE SHARDING + ================ + For col-wise sharding the weight is sharded on dimension 1. + + The overall algorithm can be best explained with an example. Let's assume + the dims for input are (4 x 6) and W are (16 x 17) and W is sharded across + 4 GPUs creating 3 shards of (16 x 5) and 1 shard of (16 x 2). + The algorithm is as follows: + + 1. First the input is broadcasted to all ranks, since this is SPMD we + actually do an all_gather for all the inputs resulting in 4 (4 x 6) + inputs on each rank. + 2. Next we perform local embedding lookup operation by apply each + input (4 x 6) with the local shard (16 x 5) ((16 x 2) for the last). + This results in 4 (5 x 6 x 4) ((2 x 6 x 4) for the last) matrices + on each rank. We transpose dim 0 and dim 2. + 3. Next, we concat these 4 matrices and perform an all2all to share the + appropriate (5 x 6 x 4) or (2 x 6 x 4) matrices to each rank. + 4. Now, each rank receives a (17 x 6 x 4) matrix which is basically the + size of the result we need. + 5. If placements are not in order any appropriate rearrangement of columns + are done for the (17 x 6 x 4) matrix and finally we transpose the + dim 0 and dim 2 again. + 6. If max_norm is specified, we manually sum up the norm and renorm. Because + the renorm must be in place, we need to override the local_shard to mimic + this behavior. + """ + # Validate input params + _validate_embedding_param(args, kwargs) + + input = args[0] + weight = args[1] + max_norm = kwargs.get("max_norm") + norm_type = kwargs.get("norm_type") + padding_idx = kwargs.get("padding_idx") + + local_shard = weight.local_tensor().contiguous() + sharding_dim = weight._sharding_spec.dim + world_size = dist.get_world_size(pg) + rank = dist.get_rank(pg) + + if sharding_dim == 1: + output, local_shard = _handle_col_wise_sharding( + input, world_size, weight, local_shard, max_norm, norm_type, padding_idx, pg + ) + weight.local_shards()[0].tensor = local_shard + return output + elif sharding_dim == 0: + return _handle_row_wise_sharding( + input, + world_size, + weight, + local_shard, + max_norm, + norm_type, + padding_idx, + rank, + pg, + ) + else: + raise RuntimeError( + f"nn.Embedding weight sharded on dim {sharding_dim} not supported!" + ) + + +def _validate_embedding_param(args, kwargs): + """ + Validate input params of sharded embedding op. + + Args: + input: list of ID used for lookup. + weight: sharded weight tensor. + kwargs: same as normal Embedding. + + Return: None. + """ + + input = args[0] + weight = args[1] + max_norm = kwargs.get("max_norm") + scale_grad_by_freq = kwargs.get("scale_grad_by_freq") + sparse = kwargs.get("sparse") + + # Validate types + if not isinstance(input, torch.Tensor): + raise TypeError("input need to be torch.Tensor") + if not isinstance(weight, ShardedTensor): + raise TypeError("weight needs to be ShardedTensor") + weight_size = weight.size() + if len(weight_size) != 2: + raise ValueError("Weight needs to have exactly 2 dims") + if int(torch.min(input).item()) < 0: + raise ValueError( + "Index out of range in Input %d %d", + int(torch.min(input).item()), + weight_size[1], + ) + if int(torch.max(input).item()) >= weight_size[0]: + raise ValueError( + "Index out of range in Input %d %d", + int(torch.max(input).item()), + weight_size[1], + ) + if scale_grad_by_freq: + raise RuntimeError( + 'nn.Embedding weight sharded with flag on "scale_grad_by_freq" not supported!' + ) + if sparse: + raise RuntimeError( + 'nn.Embedding weight sharded with flag on "sparse" not supported!' + ) + if max_norm and max_norm <= 0.0: + raise ValueError('"max_norm" must be larger than zero!') + + if not isinstance(weight._sharding_spec, ChunkShardingSpec): + raise ValueError("Only ChunkShardingSpec supported for ShardedTensor ops!") + if len(weight.local_shards()) != 1: + raise ValueError("Only one local shard supported!") + + +def _handle_col_wise_sharding( + input, world_size, weight, local_shard, max_norm, norm_type, padding_idx, pg +): + """ + Entry-point function to handle the logic of col-wise sharding of weight + for embedding. (Detailed explanations of the logic can be found in + the comment for sharded_embedding.) + + Args: + input: list of ID used for lookup and aggregation. + world_size: number of ranks. + weight: sharded weight tensor. + local_shard: col-wise shared local weight used for lookup. + max_norm: If given, each embedding vector with norm larger + than max_norm is renormalized to have norm max_norm. + Note: this will modify weight in-place. + norm_type: The p in the p-norm to compute for the max_norm option. + padding_idx: If specified, the entries at padding_idx do + not contribute to the gradient; therefore, the embedding + vector at padding_idx is not updated during training, + i.e. it remains as a fixed "pad". + pg: process group. + + Returns: final result of lookup. + """ + # allgather the inputs first for non Replicated Tensor. + gathered_inputs = all_gather(input, group=pg) + + if max_norm is not None: + # max_norm changes the weight in-place + local_shard = _handle_max_norm_col_wise( + max_norm, norm_type, local_shard, input, world_size, gathered_inputs, pg + ) + + output = _handle_col_wise_sharding_base( + torch.nn.functional.embedding, + len(input.size()), + input, + world_size, + weight, + local_shard, + pg, + gathered_inputs, + padding_idx=padding_idx, + ) + return (output, local_shard) + + +def _handle_row_wise_sharding( + input, world_size, weight, local_shard, max_norm, norm_type, padding_idx, rank, pg +): + """ + Entry-point function to handle the logic of row-wise sharding of weight + for embedding. (Detailed explanations of the logic can be found in + the comment for sharded_embedding.) + + Args: + input: list of ID used for lookup and aggregation. + world_size: number of ranks. + weight: sharded weight tensor. + local_shard: row-wise shared local weight used for lookup. + max_norm: If given, each embedding vector with norm larger + than max_norm is renormalized to have norm max_norm. + Note: this will modify weight in-place. + norm_type: The p in the p-norm to compute for the max_norm option. + padding_idx: If specified, the entries at padding_idx do + not contribute to the gradient; therefore, the embedding + vector at padding_idx is not updated during training, + i.e. it remains as a fixed "pad". + rank: # of cuda process. + pg: process group. + + Returns: final result of lookup. + """ + # allgather the inputs first for non Replicated Tensor. + gather_inp = _all_gather_base_input(input, pg) + + # Mask the input according to sharding spec. + lookup_input, padding_idx, padding_row = _handle_row_wise_mask( + gather_inp, padding_idx, weight, world_size, rank + ) + + # When input is a large tensor, the value of weight is changed. + # This is a walk-around for now. GH issue: #81717 + if max_norm is not None: + torch.nn.functional.embedding( + torch.unique(lookup_input)[:-1], + local_shard, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + ) + max_norm = None + + local_input_embeddings = torch.nn.functional.embedding( + lookup_input, + torch.cat([local_shard, padding_row]), + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + ) + + # TODO: Make the result a PartialTensor. + local_shards = local_input_embeddings.chunk(pg.size()) + return reduce_scatter( + torch.empty_like(local_shards[0]), + list(local_shards), + group=pg, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py new file mode 100644 index 0000000000000000000000000000000000000000..f1581575f5f47058325af51129fd0d9d4497b1d9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py @@ -0,0 +1,479 @@ +# mypy: allow-untyped-defs + +from typing import cast + +import torch +import torch.distributed as dist +from torch._C._distributed_c10d import ReduceOp +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed._shard.sharding_spec import ChunkShardingSpec +from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op +from torch.distributed.nn.functional import all_gather, reduce_scatter + +from ._common import ( + _all_gather_base_input, + _handle_col_wise_sharding_base, + _handle_max_norm_col_wise, + _handle_row_wise_mask, +) + + +@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.embedding_bag) +def sharded_embedding_bag(types, args, kwargs, pg): + """ + Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding_bag``. + This method computes a sharded embedding bag aggregation and has the following limitations: + + 1. Supports only sharding of ``weight``. + 2. Supports only ``ChunkShardingSpec``. + 3. Supports only a single local shard per rank. + 4. Supports all specs except for scale_grad_by_freq, sparse, etc. + + Based on the dimension that the weight is sharded on, there are two + algorithms: + + ROWWISE SHARDING + ================ + For row-wise sharding the weight is sharded on dimension 0. + + The overall algorithm can be best explained with an example. Let's assume + the dims for input are (4 x 6) and W are (16 x 17) and W is sharded across + 4 GPUs creating 4 shard of (4 x 17). + The algorithm is as follows: + + 1. First the input is all gathered to all ranks, since this is SPMD and + input is actually sharded across all ranks. The inputs then become a + 4 (4 x 6) tensor on each rank. For example if the given input is + tensor([[6, 5, 2, 9, 6, 3], + [3, 1, 2, 4, 7, 6], + [4, 0, 4, 9, 8, 9], + [8, 6, 6, 4, 6, 1]]) + on rank 0. + Then on every rank, we will have this tensor. + If input itself is already replicated, no all-gather will be done. + 2. Next, we mask the ID which are not stored on that rank. + For example on rank 0, we store ID [0, 1, 2]. We only keep the ID + inside the set of numbers. The rest of them will be masked to an extra row. + The masked matrix will be used for embedding look up and is like: + tensor([[4, 4, 2, 4, 4, 4], + [4, 1, 2, 4, 4, 4], + [4, 0, 4, 4, 4, 4], + [4, 4, 4, 4, 4, 1]]) + 3. If ``max_norm`` is specified, the extra row guarantees that the mask ID will + not affect the behavior of weigh re-norm. + 4. The example above only happens in one rank and each rank does a very similar thing. + For "Mean" mode we need to divide by either column size (2D) or the interval length + defined by the offset (excluding the row specified in ``padding_idx``). + We also need to mask the unexisting row to neg Inf so that negative value does not + gets wiped out in the "Max" mode. + + COLWISE SHARDING + ================ + For col-wise sharding the weight is sharded on dimension 1. + + The overall algorithm can be best explained with an example. Let's assume + the dims for input are (4 x 6) and W are (16 x 17) and W is sharded across + 4 GPUs creating 3 shards of (16 x 5) and 1 shard of (16 x 2). + The algorithm is as follows: + + 1. First the input is broadcasted to all ranks, since this is SPMD we + actually do an all_gather for all the inputs resulting in 4 (4 x 6) + inputs on each rank. + 2. Next we perform local embedding bag operation under the given mode by + apply each input (4 x 6) with the local shard (16 x 5) ((16 x 2) for the last). + This results in 4 (5 x 4) ((2 x 4) for the last) matrices on each rank. + We transpose the aggregation result. + 3. Next, we concatenate these 4 matrices and perform an all2all to share the + appropriate (5 x 4) or (2 x 4) matrices to each rank. + 4. Now, each rank receives a (17 x 4) matrix which is basically the + size of the result we need. + 5. If placements are not in order any appropriate rearrangement of columns + are done for the (17 x 4) matrix and finally we transpose the output again. + 6. If max_norm is specified, we manually sum up the norm and renorm. Because + the renorm must be in place, we need to override the local_shard to mimic + this behavior. + """ + # Validate input params + _validate_embedding_bag_param(args, kwargs) + + input = args[0] + weight = args[1] + offsets = kwargs.get("offsets") + per_sample_weights = kwargs.get("per_sample_weights") + mode = kwargs.get("mode") + max_norm = kwargs.get("max_norm") + norm_type = kwargs.get("norm_type") + include_last_offset = kwargs.get("include_last_offset") + padding_idx = kwargs.get("padding_idx") + + local_shard = weight.local_tensor().contiguous() + sharding_dim = weight._sharding_spec.dim + world_size = dist.get_world_size(pg) + rank = dist.get_rank(pg) + if include_last_offset: + offsets = offsets[:-1] + + if sharding_dim == 1: + output, local_shard = _handle_col_wise_sharding( + input, + world_size, + weight, + local_shard, + offsets, + per_sample_weights, + mode, + max_norm, + norm_type, + padding_idx, + pg, + ) + weight.local_shards()[0].tensor = local_shard + return output + elif sharding_dim == 0: + return _handle_row_wise_sharding( + input, + world_size, + weight, + local_shard, + offsets, + per_sample_weights, + mode, + max_norm, + norm_type, + padding_idx, + rank, + pg, + ) + else: + raise RuntimeError( + f"nn.EmbeddingBag weight sharded on dim {sharding_dim} not supported!" + ) + + +def _validate_embedding_bag_param(args, kwargs): + """ + Validate input params of sharded embeddingBag op. + + Args: + input: list of ID used for lookup and aggregation. + weight: sharded weight tensor. + kwargs: same as normal EmbeddingBag. + + Return: None. + """ + + input = args[0] + weight = args[1] + offsets = kwargs.get("offsets") + per_sample_weights = kwargs.get("per_sample_weights") + mode = kwargs.get("mode") + max_norm = kwargs.get("max_norm") + scale_grad_by_freq = kwargs.get("scale_grad_by_freq") + sparse = kwargs.get("sparse") + include_last_offset = kwargs.get("include_last_offset") + + # Validate types + if not isinstance(input, torch.Tensor): + raise TypeError("input need to be torch.Tensor") + if offsets is not None and not isinstance(offsets, torch.Tensor): + raise TypeError("offsets need to be torch.Tensor") + if per_sample_weights is not None and not isinstance( + per_sample_weights, torch.Tensor + ): + raise TypeError("per_sample_weights need to be torch.Tensor") + if not isinstance(weight, ShardedTensor): + raise TypeError("weight needs to be ShardedTensor") + if len(input.size()) > 2: + raise ValueError("Input more than 2 dims not supported") + weight_size = weight.size() + if len(weight_size) != 2: + raise ValueError("Weight needs to have exactly 2 dims") + if int(torch.min(input).item()) < 0: + raise ValueError( + "Index out of range in Input %d %d", + int(torch.min(input).item()), + weight_size[1], + ) + if int(torch.max(input).item()) >= weight_size[0]: + raise ValueError( + "Index out of range in Input %d %d", + int(torch.max(input).item()), + weight_size[1], + ) + if offsets is not None and len(input.size()) != 1: + raise ValueError("Input dimension needs to be exactly 1 dim") + if len(input.size()) == 1 and offsets is None: + raise ValueError("offsets is required for 1D input") + if per_sample_weights is not None and per_sample_weights.size() != input.size(): + raise ValueError( + f"per_sample_weights size {per_sample_weights.size()} not equal to input size {input.size()}" + ) + if mode is None: + mode = "mean" + if mode not in ["sum", "mean", "max"]: + raise ValueError(f"mode '{mode}' is not supported") + if scale_grad_by_freq: + raise RuntimeError( + 'nn.Embedding weight sharded with flag on "scale_grad_by_freq" not supported!' + ) + if sparse: + raise RuntimeError( + 'nn.Embedding weight sharded with flag on "sparse" not supported!' + ) + if include_last_offset and offsets is None: + raise ValueError('offsets is required for flag "include_last_offset"!') + if include_last_offset and cast(list[int], offsets)[-1] != input.size(0): + raise ValueError( + 'offsets need to have the input size in the end when the flag "include_last_offset" is on!' + ) + + if max_norm and max_norm <= 0.0: + raise ValueError('"max_norm" must be larger than zero!') + + if not isinstance(weight._sharding_spec, ChunkShardingSpec): + raise ValueError("Only ChunkShardingSpec supported for ShardedTensor ops!") + if len(weight.local_shards()) != 1: + raise ValueError("Only one local shard supported!") + + +def _handle_col_wise_sharding( + input, + world_size, + weight, + local_shard, + offsets, + per_sample_weights, + mode, + max_norm, + norm_type, + padding_idx, + pg, +): + """ + Entry-point function to handle the logic of col-wise sharding of weight + for embeddingBag. (Detailed explanations of the logic can be found in + the comment for sharded_embedding_bag.) + + Args: + input: list of ID used for lookup and aggregation. + world_size: number of ranks. + weight: sharded weight tensor. + local_shard: col-wise shared local weight used for lookup. + offsets: list of start positions of each bag for 1D input. + per_sample_weights: weights for weighted sum mode. + mode: aggregation method of each bag. + max_norm: If given, each embedding vector with norm larger + than max_norm is renormalized to have norm max_norm. + Note: this will modify weight in-place. + norm_type: The p in the p-norm to compute for the max_norm option. + padding_idx: If specified, the entries at padding_idx do + not contribute to the gradient; therefore, the embedding + vector at padding_idx is not updated during training, + i.e. it remains as a fixed "pad". + Note that the embedding vector at padding_idx is + excluded from the reduction. + pg: process group. + + Return: + output: final result of lookup and aggregation. + local_shard: col-wise shared local weight used for lookup. + If max_norm, this will be the renormed weight. + """ + # allgather the special input of embedding bag first. + ( + gathered_inputs, + gathered_per_sample_weights, + gathered_offsets, + ) = _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg) + + if max_norm is not None: + # max_norm changes the weight in-place + local_shard = _handle_max_norm_col_wise( + max_norm, norm_type, local_shard, input, world_size, gathered_inputs, pg + ) + + output = _handle_col_wise_sharding_base( + torch.nn.functional.embedding_bag, + 1, + input, + world_size, + weight, + local_shard, + pg, + gathered_inputs, + mode=mode, + gathered_per_sample_weights=gathered_per_sample_weights, + gathered_offsets=gathered_offsets, + padding_idx=padding_idx, + ) + return (output, local_shard) + + +def _handle_row_wise_sharding( + input, + world_size, + weight, + local_shard, + offsets, + per_sample_weights, + mode, + max_norm, + norm_type, + padding_idx, + rank, + pg, +): + """ + Entry-point function to handle the logic of row-wise sharding of weight + for embeddingBag. (Detailed explanations of the logic can be found in + the comment for sharded_embedding_bag.) + + Args: + input: list of ID used for lookup and aggregation. + world_size: number of ranks. + weight: sharded weight tensor. + local_shard: row-wise shared local weight used for lookup. + offsets: list of start positions of each bag for 1D input. + per_sample_weights: weights for weighted sum mode. + mode: aggregation method of each bag. + max_norm: If given, each embedding vector with norm larger + than max_norm is renormalized to have norm max_norm. + Note: this will modify weight in-place. + norm_type: The p in the p-norm to compute for the max_norm option. + padding_idx: If specified, the entries at padding_idx do + not contribute to the gradient; therefore, the embedding + vector at padding_idx is not updated during training, + i.e. it remains as a fixed "pad". + Note that the embedding vector at padding_idx is + excluded from the reduction. + rank: # of cuda process. + pg: process group. + + Returns: + gathered_output: final result of lookup and aggregation. + """ + if input.dim() > 1 and per_sample_weights is None: + # allgather the inputs first for non Replicated Tensor. + gather_inp = _all_gather_base_input(input, pg) + else: + ( + gathered_inputs, + gathered_per_sample_weights, + gathered_offsets, + ) = _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg) + cat_dim = 0 if input.dim() != 1 else -1 + gather_inp = torch.cat(gathered_inputs, dim=cat_dim) + if per_sample_weights is not None: + per_sample_weights = torch.cat(gathered_per_sample_weights, dim=cat_dim) + offset_add = 0 if input.dim() > 1 else input.size(0) + if offsets is not None: + offsets_list = torch.cat( + [gathered_offsets[i] + (offset_add * i) for i in range(pg.size())], + dim=cat_dim, + ) + + # Mask the input according to sharding spec. + lookup_input, padding_local, padding_row = _handle_row_wise_mask( + gather_inp, padding_idx, weight, world_size, rank + ) + if mode == "max": + padding_row[:] = -float("Inf") + + # When input is a large tensor, the value of weight is changed. + # This is a walk-around for now. GH issue: #81717. + if max_norm is not None: + torch.nn.functional.embedding_bag( + torch.unique(lookup_input)[:-1], + local_shard, + offsets=torch.tensor([0], device=local_shard.device, dtype=torch.long), + mode=mode, + per_sample_weights=None, + max_norm=max_norm, + norm_type=norm_type, + padding_idx=padding_local, + ) + max_norm = None + result = torch.nn.functional.embedding_bag( + lookup_input, + torch.cat([local_shard, padding_row]), + offsets=offsets_list if offsets is not None else offsets, # type: ignore[possibly-undefined] + mode=mode if mode != "mean" else "sum", + per_sample_weights=per_sample_weights, + max_norm=max_norm, + norm_type=norm_type, + padding_idx=padding_local, + ) + + op = ReduceOp.SUM if mode != "max" else ReduceOp.MAX + # TODO: Make the result a PartialTensor and move the logic below there. + local_shards = result.chunk(pg.size()) + result = reduce_scatter( + torch.empty_like(local_shards[0]), + list(local_shards), + op=op, + group=pg, + ) + + # For Mean, we cannot do the division until very end because the sum of means + # not equal to the mean of sum. (Divisor is different) + if mode == "mean": + if input.dim() > 1: + padding_idx = padding_idx if padding_idx is not None else -1 + split_sizes = torch.sum( + torch.ne(input, padding_idx), dim=-1, dtype=local_shard.dtype + ) + else: + split_sizes = torch.cat( + ( + # pyrefly: ignore [unsupported-operation] + offsets[1 : offsets.size(0)] - offsets[0:-1], + # pyrefly: ignore [unsupported-operation] + (input.size(0) - offsets[-1]).unsqueeze(0), + ), + dim=-1, + ) + return torch.div(result, split_sizes.unsqueeze(1)) + + # Return the appropriate local result. + return result + + +def _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg): + """ + In case we need to gather input and all other parameters of embeddingBag + ops, we need to stack all input together to perform ``all_gather`` + collective communication just once. + + Note that since offsets does not share the same size as input and + is always smaller than input, we resize it during the communication. + + Args: + input: tensor to be applied op on. + per_sample_weights: weights for weighted sum mode. + offsets: when input is 1D. offsets determines the starting + index position of each bag (sequence) in input. + pg: process group. + + Returns: + gathered_inputs: list of input tensor gathered from each rank. + gathered_per_sample_weights: list of per_sample_weights from each rank. + gathered_offsets: list of offsets from each rank. + """ + input_to_gather = [input] + if per_sample_weights is not None: + input_to_gather.append(per_sample_weights) + if offsets is not None: + input_to_gather.append(offsets.clone().resize_(input.size())) + gathered_inputs = all_gather(torch.stack(input_to_gather), group=pg) + + gathered_per_sample_weights = None + if per_sample_weights is not None: + gathered_per_sample_weights = [t[1] for t in gathered_inputs] + gathered_offsets = None + if offsets is not None: + idx = 2 if per_sample_weights is not None else 1 + gathered_offsets = [ + t[idx].resize_(offsets.size()).to(offsets.dtype) for t in gathered_inputs + ] + gathered_inputs = [t[0].to(input.dtype) for t in gathered_inputs] + return gathered_inputs, gathered_per_sample_weights, gathered_offsets diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_sharded_tensor/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_sharded_tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..24de2628c0ab9ceb89fa28b52753a421b58b56c2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_sharded_tensor/__init__.py @@ -0,0 +1,21 @@ +# Keep old package for BC purposes, this file should be removed once +# everything moves to the `torch.distributed._shard` package. +import sys +import warnings + +import torch +from torch.distributed._shard.sharded_tensor import * # noqa: F403 + + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`torch.distributed._sharded_tensor` will be deprecated, " + "use `torch.distributed._shard.sharded_tensor` instead", + DeprecationWarning, + stacklevel=2, + ) + +sys.modules["torch.distributed._sharded_tensor"] = ( + torch.distributed._shard.sharded_tensor +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7958ee238ab1299bdf29d0e2b541bdbc46e790e7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_sharding_spec/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_sharding_spec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c74dd3633e0f5e8436b844fd2d14f3bdb00635b7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_sharding_spec/__init__.py @@ -0,0 +1,22 @@ +# Keep old package for BC purposes, this file should be removed once +# everything moves to the `torch.distributed._shard` package. +import sys +import warnings + +import torch +from torch.distributed._shard.sharding_spec import * # noqa: F403 + + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`torch.distributed._sharding_spec` will be deprecated, " + "use `torch.distributed._shard.sharding_spec` instead", + DeprecationWarning, + stacklevel=2, + ) + +import torch.distributed._shard.sharding_spec as _sharding_spec + + +sys.modules["torch.distributed._sharding_spec"] = _sharding_spec diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_sharding_spec/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_sharding_spec/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd404cf430c4f0a37d0565a3b96f3caf83b70857 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_sharding_spec/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_symmetric_memory/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_symmetric_memory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0ee29ea452143fce950421ade8803bf53776d397 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_symmetric_memory/__init__.py @@ -0,0 +1,2121 @@ +from __future__ import annotations + +import math +import os +import socket +import uuid +from collections.abc import Callable, Generator +from contextlib import contextmanager +from datetime import timedelta +from enum import Enum +from functools import partial +from typing import Any, Literal + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.distributed_c10d as c10d +from torch._C._autograd import DeviceType +from torch._C._distributed_c10d import _SymmetricMemory, Work as _Work + + +_group_name_to_store: dict[str, c10d.Store] = {} + + +def enable_symm_mem_for_group(group_name: c10d.GroupName) -> None: + """ + Enables symmetric memory for a process group. + + Args: + group_name (str): the name of the process group. + """ + if group_name in _group_name_to_store: + return + + group = c10d._resolve_process_group(group_name) + global_ranks = sorted(c10d._world.pg_group_ranks[group].keys()) + # Different subgroups with the same name should use different stores + global_ranks_str = "_".join(map(str, global_ranks)) + store = c10d.PrefixStore( + f"symmetric_memory-{global_ranks_str}", + c10d._get_process_group_store(group), + ) + _group_name_to_store[group_name] = store + _SymmetricMemory.set_group_info( + group_name, + group.rank(), + group.size(), + store, + ) + + +_is_test_mode: bool = False +_mocked_group_names: set[str] | None = None + + +@contextmanager +def _test_mode(group_names: set[str] | None = None) -> Generator[None, None, None]: + """ + Forces ``is_symm_mem_enabled_for_group()`` to return ``True`` and the ops + defined in the ``symm_mem`` namespace to use fallback implementations. + + The context manager is not thread safe. + """ + global _is_test_mode + global _mocked_group_names + prev = _is_test_mode + prev_group_names = _mocked_group_names + try: + _is_test_mode = True + _mocked_group_names = group_names + yield + finally: + _is_test_mode = prev + _mocked_group_names = prev_group_names + + +def is_symm_mem_enabled_for_group(group_name: c10d.GroupName) -> bool: + """ + Check if symmetric memory is enabled for a process group. + + Args: + group_name (str): the name of the process group. + """ + if _is_test_mode: + return _mocked_group_names is None or group_name in _mocked_group_names + return group_name in _group_name_to_store + + +_group_name_to_workspace_tensor: dict[str, torch.Tensor | None] = {} + + +def get_symm_mem_workspace( + group_name: c10d.GroupName, min_size: int +) -> _SymmetricMemory: + """ + Get the symmetric memory workspace associated with the process group. If + ``min_size`` is greater than the workspace associated with ``group_name``, + the workspace will be re-allocated and re-rendezvous'd. + + Args: + group_name (str): the name of the process group. + min_size (int): the size requirement for the workspace in bytes. + + Returns: + _SymmetricMemory: the symmetric memory workspace associated with the + group. + """ + enable_symm_mem_for_group(group_name) + + tensor = _group_name_to_workspace_tensor.get(group_name) + size = tensor.numel() * tensor.element_size() if tensor is not None else 0 + if tensor is None or size < min_size: + if torch.cuda.is_current_stream_capturing(): + curr_size = 0 if tensor is None else tensor.numel() * tensor.element_size() + raise RuntimeError( + f"get_symm_mem_workspace(): the requested size ({min_size} bytes) " + "is greater than the size of the currently allocated workspace " + f"({curr_size} bytes). It's currently not possible to expand the " + "workspace size during graph capture. Please invoke " + f'`get_symm_mem_workspace(group_name="{group_name}", ' + f'min_size="{min_size}")` before initiating the graph capture ' + "and try again." + ) + tensor = _SymmetricMemory.empty_strided_p2p( + (max(size, min_size),), + [1], + torch.uint8, + torch.device(f"cuda:{torch.cuda.current_device()}"), + group_name, + ) + _group_name_to_workspace_tensor[group_name] = tensor + return _SymmetricMemory.rendezvous(tensor) + + +_backend_streams: dict[int, torch.cuda.Stream] = {} + + +def _get_backend_stream(priority: int = 0) -> torch.cuda.Stream: + if priority not in _backend_streams: + _backend_streams[priority] = torch.cuda.Stream(priority=priority) + return _backend_streams[priority] + + +def _pipelined_multi_all_gather_and_consume( + shard: list[torch.Tensor], + shard_consumer: Callable[[list[torch.Tensor], int], None], + ag_out: list[torch.Tensor], + group_name: c10d.GroupName, + ag_out_needed: bool = True, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + gathered = [ + all_gather_tensor(x, gather_dim=0, group=group) + for x in shard + ] + + shards = [[] for _ in range(group_size)] + for x in ag_out: + for i, y in enumerate(x.chunk(group_size)): + shards[i].append(y) + + for src_rank, shard in enumerate(shards): + shard_consumer(shard, src_rank) + """ + p2p_workspace_size_req = 0 + for x in shard: + p2p_workspace_size_req += x.numel() * x.element_size() + symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req) + group_size = symm_mem.world_size + rank = symm_mem.rank + + symm_mem.barrier(channel=0) + backend_stream = _get_backend_stream() + backend_stream.wait_stream(torch.cuda.current_stream()) + + for x, y in zip(shard, ag_out): + assert x.is_contiguous(), ( + "_pipelined_all_gather_and_consume: all tensors " + "in `shard` must be contiguous" + ) + assert y.is_contiguous(), ( + "_pipelined_all_gather_and_consume: all tensors " + "in `ag_out` must be contiguous" + ) + assert x.shape[0] * group_size == y.shape[0] + assert x.shape[1:] == y.shape[1:] + + def copy_shard(dst: list[torch.Tensor], src: list[torch.Tensor]) -> None: + for d, s in zip(dst, src): + d.copy_(s) + + def get_p2p_bufs(remote_rank: int) -> list[torch.Tensor]: + offset_bytes = 0 + bufs = [] + for x in shard: + buf = symm_mem.get_buffer( + remote_rank, + x.shape, + x.dtype, + storage_offset=offset_bytes // x.element_size(), + ) + bufs.append(buf) + offset_bytes += buf.numel() * buf.element_size() + return bufs + + local_p2p_bufs = get_p2p_bufs(rank) + + # shards[i] => shard from rank i + shards: list[list[torch.Tensor]] = [[] for _ in range(group_size)] + for x in ag_out: + for i, y in enumerate(x.chunk(group_size)): + shards[i].append(y) + + # Parallelization strategy: after each rank copies its shard into its local + # p2p buffer, every rank issues independent p2p copy -> shard_consumer + # sequences to two streams. In addition to computation/communication + # overlapping, the strategy allows for computation/computation overlapping, + # greatly reducing quantization inefficiency. + # + # Notation: + # - "mv" for the copy to local buffer + # - "cp" for p2p copies + # - "b" for barriers + # + # Constraints: + # - The GPU scheduler may or may not overlap "mv" with the first shard_consumer. + # - "cp" from different streams cannot overlap. + # + # Ideal scenario 0 - "mv" overlaps with the first shard_consumer: + # + # stream 0: [ shard_consumer ][ cp ][ shard_consumer ] + # stream 1: [ mv ][b][ cp ][ shard_consumer ] + # + # Ideal scenario 1 - "mv" is scheduled before the first shard_consumer: + # + # stream 0: [ shard_consumer ][ cp ][ shard_consumer ] + # stream 1: [ mv ][b][ cp ][ shard_consumer ] + # + # Suboptimal scenario 0 - "mv" is scheduled after the first shard_consumer: + # + # stream 0: [ shard_consumer ] [ cp ][ shard_consumer ] + # stream 1: [ mv ][b][ cp ][ shard_consumer ] + # + # Suboptimal scenario 0 - "b" is scheduled after the first shard_consumer: + # + # stream 0: [ shard_consumer ] [ cp ][ shard_consumer ] + # stream 1: [ mv ] [b][ cp ][ shard_consumer ] + # + # We haven't yet figured out a way to ensure "mv" and "b" are either + # overlapped with or scheduled before the first shard_consumer. Thus, to + # prevent suboptimal scenarios, we are giving up the chance to overlap "mv" + # and "b" with the first shard_consumer for now. + copy_shard(dst=local_p2p_bufs, src=shard) + symm_mem.barrier(channel=1) + backend_stream.wait_stream(torch.cuda.current_stream()) + + # At this point, all ranks have copied their local shard to + # their local p2p buffer. Each rank can now copy and consume + # remote shards. + shard_consumer(shard, rank) + + for step in range(1, group_size): + if step % 2 == 0: + stream = torch.cuda.current_stream() + else: + stream = backend_stream + remote_rank = (step + rank) % group_size + remote_p2p_bufs = get_p2p_bufs(remote_rank) + with stream: + copy_shard(dst=shards[remote_rank], src=remote_p2p_bufs) + shard_consumer(shards[remote_rank], remote_rank) + + if ag_out_needed: + # Copy from input to the all-gather output. Opportunistically overlap + # it with the last shard_consumer. + if group_size % 2 == 0: + stream = torch.cuda.current_stream() + else: + stream = backend_stream + with stream: + copy_shard(dst=shards[rank], src=shard) + + torch.cuda.current_stream().wait_stream(backend_stream) + symm_mem.barrier(channel=0) + + +def _pipelined_all_gather_and_consume( + shard: torch.Tensor, + shard_consumer: Callable[[torch.Tensor, int], None], + ag_out: torch.Tensor, + group_name: c10d.GroupName, + ag_out_needed: bool = True, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + ag_out = all_gather_tensor(shard, gather_dim=0, group=group) + shards = ag_out.chunk(group.size()) + for src_rank, shard in enumerate(shards): + shard_consumer(shard, src_rank) + """ + + def adapter(shard: list[torch.Tensor], rank: int) -> None: + shard_consumer(shard[0], rank) + + _pipelined_multi_all_gather_and_consume( + [shard], + adapter, + [ag_out], + group_name, + ag_out_needed, + ) + + +def _pipelined_produce_and_all2all( + chunk_producer: Callable[[int, torch.Tensor], None], + output: torch.Tensor, + group_name: c10d.GroupName, + out_chunk_dim: int = 0, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + chunks = [ + chunk_producer(dst_rank, chunks[dst_rank]) + for dst_rank in range(group_size): + ] + dist.all_to_all_single(output=output, input=torch.cat(chunks)) + """ + out_chunks = output.chunk( + c10d._get_group_size_by_name(group_name), dim=out_chunk_dim + ) + p2p_workspace_size_req = out_chunks[0].numel() * out_chunks[0].element_size() * 2 + symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req) + group_size = symm_mem.world_size + rank = symm_mem.rank + + symm_mem.barrier(channel=0) + backend_stream = _get_backend_stream() + backend_stream.wait_stream(torch.cuda.current_stream()) + + def get_p2p_buf(rank: int, idx: int) -> torch.Tensor: + assert idx in (0, 1) + offset = 0 if idx == 0 else out_chunks[0].numel() + return symm_mem.get_buffer( + rank, out_chunks[0].shape, out_chunks[0].dtype, offset + ) + + # Prepare two local p2p buffers, so that a remote rank can pull the result + # of step [i] in one p2p buffer while the local rank can compute the + # result of step [i+1] and write it directly the other p2p buffer. + local_p2p_buf_0 = get_p2p_buf(rank, 0) + local_p2p_buf_1 = get_p2p_buf(rank, 1) + + for step in range(1, group_size): + remote_rank = (rank - step) % group_size + if step % 2 == 0: + stream = torch.cuda.current_stream() + p2p_buf = local_p2p_buf_1 + remote_p2p_buf = get_p2p_buf(remote_rank, 1) + else: + stream = backend_stream + p2p_buf = local_p2p_buf_0 + remote_p2p_buf = get_p2p_buf(remote_rank, 0) + with stream: + # Parallelization strategy: every rank issues independent compute + # -> barrier -> p2p copy sequences on two streams. In addition to + # computation/communication overlapping, the strategy allows for + # computation/computation overlapping, greatly reducing + # quantization inefficiency. + # + # Ideally, stream activities would look like this ("b" for + # barriers, "cp" for p2p copies): + # + # [rank 0] + # stream 0: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # stream 1: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # + # [rank 1] + # stream 0: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # stream 1: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # + # Note that the barriers synchronize streams with the same ID + # across ranks. They don't synchronize streams on the same rank. + # + # Since the work on both streams is independent, there's no + # guarantee that the chunk_producer from stream 0 or stream 1 will + # be scheduled first. If there is a scheduling mismatch across + # ranks, the barrier forces all ranks to wait for the slowest. + # + # When scheduling mismatches occur among ranks, the stream + # activities might look like this (note that p2p copies from + # different streams cannot overlap with each other): + # + # [rank 0] + # stream 0: [ chunk_producer ][b ][ cp ][ chunk_producer ][b ][ cp ] + # stream 1: [ chunk_producer ][b] [ cp ][ chunk_producer ][b] [ cp ] + # + # [rank 1] + # stream 0: [ chunk_producer ][b] [ cp ][ chunk_producer ][b] [ cp ] + # stream 1: [ chunk_producer ][b ][ cp ][ chunk_producer ][b ][ cp ] + # + # To prevent this, we need to ensure that the chunk_producer on + # stream 1 gets scheduled first on every rank. Without access to + # the underlying kernels, CUDA offers no API to control the + # scheduling order of two independent, overlapping kernels. Our + # solution is to issue a small sleep kernel in stream 0. The sleep + # duration is insignificant, but having an extra task in stream 0 + # will almost guarantee that the chunk_producer on stream 1 gets + # scheduled first. Once the first chunk_producer is scheduled in + # the correct order, there's very little room for the scheduling + # order of subsequent kernels to be inconsistent across ranks. + if step == 2: + torch.cuda._sleep(100) + chunk_producer((rank + step) % group_size, p2p_buf) + symm_mem.barrier(channel=step % 2) + out_chunks[remote_rank].copy_(remote_p2p_buf) + # The local P2P buffer can only be overwritten by the next + # chunk_producer after all peers have finished reading from it. + symm_mem.barrier(channel=step % 2) + + # If the sleep wasn't issued in the above loop, do it now. + if group_size == 2: + torch.cuda._sleep(100) + + chunk_producer(rank, out_chunks[rank]) + torch.cuda.current_stream().wait_stream(backend_stream) + symm_mem.barrier(channel=0) + + +lib = torch.library.Library("symm_mem", "DEF") # noqa: TOR901 +lib.define( + "fused_all_gather_matmul(" + "Tensor A, Tensor[] Bs, int gather_dim, str group_name, *, bool return_A = True) -> (Tensor?, Tensor[])", + tags=[torch._C.Tag.needs_fixed_stride_order], +) +lib.define( + "fused_all_gather_scaled_matmul(" + "Tensor A, Tensor[] Bs, Tensor A_scale, Tensor[] B_scales, " + "int gather_dim, str group_name, " + "Tensor?[] biases, " + "Tensor?[] result_scales, " + "ScalarType?[] out_dtypes, " + "bool[] use_fast_accum) -> (Tensor, Tensor[])", + tags=[torch._C.Tag.needs_fixed_stride_order], +) +lib.define( + "fused_matmul_reduce_scatter(Tensor A, Tensor B, str reduce_op, int scatter_dim, str group_name) -> Tensor", + tags=[torch._C.Tag.needs_fixed_stride_order], +) +lib.define( + "fused_scaled_matmul_reduce_scatter(" + "Tensor A, Tensor B, Tensor A_scale, Tensor B_scale, " + "str reduce_op, int orig_scatter_dim, int scatter_dim_after_maybe_reshape, str group_name, SymInt[]? output_shape, " + "Tensor? bias = None, " + "Tensor? result_scale = None, " + "ScalarType? out_dtype = None, " + "bool use_fast_accum = False) -> Tensor", + tags=[torch._C.Tag.needs_fixed_stride_order], +) +lib.define("_low_contention_all_gather(Tensor tensor, str group_name) -> Tensor") +lib.define( + "_low_contention_reduce_scatter(Tensor tensor, str reduce_op, str group_name) -> Tensor" +) + +lib.define("get_remote_tensors(Tensor x, str group_name) -> Tensor[]") +""" +Given a local tensor and a group name, return a tuple of tensors that are +symmetric on other devices. The returned tensors are ordered by rank IDs. The +length of the tuple equals to the size of the group. + +Note: this API works only when `world_within_direct_access()` returns True, i.e. +only when the group is within NVLink domain or similar. It does not work across +network interfaces. +""" + + +@torch.library.impl(lib, "get_remote_tensors", "CUDA") +def _get_remote_tensors_default( + local: torch.Tensor, group_name: c10d.GroupName +) -> tuple[torch.Tensor, ...]: + hdl = rendezvous(local, group_name) + if hdl is None: + raise ValueError("Tensor is not allocated from Symmetric Memory") + + return tuple( + hdl.get_remote_tensor(peer, local.size(), local.dtype) + for peer in range(hdl.world_size) + ) + + +@torch.library.impl(lib, "get_remote_tensors", "Meta") +def _get_remote_tensors_meta( + local: torch.Tensor, group_name: c10d.GroupName +) -> tuple[torch.Tensor, ...]: + group = c10d._resolve_process_group(group_name) + return tuple(torch.empty_like(local) for _ in range(group.size())) + + +class _ScaleMode(Enum): + UNSCALED = "unscaled" + TENSOR_WISE = "tensor-wise" + ROW_WISE_SHARDED = "row-wise-sharded" + ROW_WISE_REPLICATED = "row-wise-replicated" + + +def _check_and_verify_fp8_all_gather_scale_mode( + shard: torch.Tensor, scale: torch.Tensor | None, gather_dim: int, group_size: int +) -> _ScaleMode: + full_shape = list(shard.shape) + full_shape[gather_dim] *= group_size + + if scale is None: + return _ScaleMode.UNSCALED + elif scale.shape[:-1] == shard.shape[:-1] and scale.shape[-1] == 1: + # Row-wise scaling + # + # NOTE: when the last dim of both A_shard and A_scale is one, we can't + # tell if A_scale is replicated tensor-wise scale or sharded row-wise + # scale. Treating it as row-wise scaling for safety. + return _ScaleMode.ROW_WISE_SHARDED + elif scale.numel() == 1: + return _ScaleMode.TENSOR_WISE + elif list(scale.shape[:-1]) == full_shape[:-1]: + return _ScaleMode.ROW_WISE_REPLICATED + else: + raise ValueError( + "Invalid scale shape for fp8 all-gather " + f"(shard shape: {shard.shape}, scale shape: {scale.shape})" + ) + + +def _fused_all_gather_matmul_impl( + mm_out_op: torch._ops.OpOverload, + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + A_scale: torch.Tensor | None, + kwargs_list: list[dict[str, Any]], + out_dtypes: list[torch.dtype | None], + gather_dim: int, + group_name: c10d.GroupName, + return_A: bool, +) -> tuple[torch.Tensor | None, list[torch.Tensor]]: + if A_shard.dim() < 2: + raise ValueError("A_shard must be a matrix") + for B in Bs: + if B.dim() != 2: + raise ValueError("B must be a matrix") + if len(out_dtypes) != len(Bs): + raise ValueError("len(out_types) must be the same as len(Bs)") + if len(kwargs_list) != len(Bs): + raise ValueError("len(kwargs_list) must be the same as len(Bs)") + if gather_dim < 0 or gather_dim >= A_shard.dim(): + raise ValueError("Invalid gather_dim") + + group = c10d._resolve_process_group(group_name) + + if gather_dim == A_shard.ndim - 1 or gather_dim == -1: + return _fused_all_gather_matmul_last_gather_dim_impl( + mm_out_op, + A_shard, + Bs, + A_scale, + kwargs_list, + out_dtypes, + gather_dim, + group_name, + return_A, + ) + + # Move the gather_dim to the front and flatten the tensor into a 2D matrix. + # The flattened tensor doesn't need to be contiguous (for computation + # efficiency), as _pipelined_all_gather_and_consume guarantees that shards + # passed to shard_consumer are contiguous. + A_shard_flat = A_shard.movedim(gather_dim, 0) + leading_dims = [group.size()] + list(A_shard_flat.shape[:-1]) + A_shard_flat = A_shard_flat.flatten(0, -2) + + # Helper function for reverting the above transformation + def unflatten(t: torch.Tensor) -> torch.Tensor: + return t.view(*leading_dims, -1).flatten(0, 1).movedim(0, gather_dim) + + A_flat = A_shard_flat.new_empty( + A_shard_flat.shape[0] * group.size(), + A_shard_flat.shape[1], + ) + + outputs = [ + A_flat.new_empty(A_flat.shape[0], B.shape[1], dtype=out_dtype or B.dtype) + for B, out_dtype in zip(Bs, out_dtypes) + ] + output_shards = [output.chunk(group.size()) for output in outputs] + + scale_mode = _check_and_verify_fp8_all_gather_scale_mode( + shard=A_shard, scale=A_scale, gather_dim=gather_dim, group_size=group.size() + ) + + # Computing block-wise matmul along the first dim of A + if scale_mode == _ScaleMode.ROW_WISE_SHARDED: + assert A_scale is not None + A_scale_shard = A_scale.movedim(gather_dim, 0).flatten(0, -2) + A_scale_flat = A_scale_shard.new_empty( + A_scale_shard.shape[0] * group.size(), + A_scale_shard.shape[1], + ) + + def row_wise_sharded_consumer(shard: list[torch.Tensor], rank: int) -> None: + for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): + mm_out_op( + shard[0], + B, + scale_a=shard[1], + **kwargs, + out=output_shards[idx][rank], + ) + + _pipelined_multi_all_gather_and_consume( + [A_shard_flat, A_scale_shard], + row_wise_sharded_consumer, + [A_flat, A_scale_flat], + group_name, + return_A, + ) + elif scale_mode == _ScaleMode.ROW_WISE_REPLICATED: + assert A_scale is not None + A_scale_shards = ( + A_scale.movedim(gather_dim, 0).flatten(0, -2).chunk(group.size()) + ) + + def row_wise_replicated_consumer(shard: torch.Tensor, rank: int) -> None: + for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): + mm_out_op( + shard, + B, + scale_a=A_scale_shards[rank], + **kwargs, + out=output_shards[idx][rank], + ) + + _pipelined_all_gather_and_consume( + A_shard_flat, + row_wise_replicated_consumer, + A_flat, + group_name, + return_A, + ) + else: + if scale_mode == _ScaleMode.TENSOR_WISE: + assert A_scale is not None + for kwargs in kwargs_list: + kwargs["scale_a"] = A_scale + else: + assert scale_mode == _ScaleMode.UNSCALED + + def default_consumer(shard: torch.Tensor, rank: int) -> None: + for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): + mm_out_op(shard, B, **kwargs, out=output_shards[idx][rank]) + + _pipelined_all_gather_and_consume( + A_shard_flat, + default_consumer, + A_flat, + group_name, + return_A, + ) + + A = unflatten(A_flat) if return_A else None + return A, [unflatten(output) for output in outputs] + + +def _pipelined_all_gather_and_consume_last_dim( + shard: torch.Tensor, + shard_consumer: Callable[[torch.Tensor, int], None], + ag_out: torch.Tensor, + group_name: c10d.GroupName, + ag_out_needed: bool = True, +) -> None: + p2p_workspace_size_req = 0 + p2p_workspace_size_req = shard.numel() * shard.element_size() + symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req) + group_size = symm_mem.world_size + rank = symm_mem.rank + + symm_mem.barrier(channel=0) + backend_stream = _get_backend_stream() + backend_stream.wait_stream(torch.cuda.current_stream()) + + def copy_shard(dst: torch.Tensor, src: torch.Tensor) -> None: + dst.copy_(src) + + def get_p2p_buf(remote_rank: int) -> torch.Tensor: + buf = symm_mem.get_buffer( + remote_rank, + shard.shape, + shard.dtype, + ) + return buf + + local_p2p_buf = get_p2p_buf(rank) + + shards = ag_out.chunk(group_size) + + copy_shard(dst=local_p2p_buf, src=shard) + symm_mem.barrier(channel=1) + backend_stream.wait_stream(torch.cuda.current_stream()) + + # At this point, all ranks have copied their local shard to + # their local p2p buffer. Each rank can now copy and consume + # remote shards. + shard_consumer(shard, rank) + + for step in range(1, group_size): + if step % 2 == 0: + stream = torch.cuda.current_stream() + else: + stream = backend_stream + remote_rank = (step + rank) % group_size + remote_p2p_buf = get_p2p_buf(remote_rank) + with stream: + copy_shard(dst=shards[remote_rank], src=remote_p2p_buf) + shard_consumer(shards[remote_rank], remote_rank) + + if ag_out_needed: + # Copy from input to the all-gather output. Opportunistically overlap + # it with the last shard_consumer. + if group_size % 2 == 0: + stream = torch.cuda.current_stream() + else: + stream = backend_stream + with stream: + copy_shard(dst=shards[rank], src=shard) + + torch.cuda.current_stream().wait_stream(backend_stream) + symm_mem.barrier(channel=0) + + +def _fused_all_gather_matmul_last_gather_dim_impl( + mm_out_op: torch._ops.OpOverload, + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + A_scale: torch.Tensor | None, + kwargs_list: list[dict[str, Any]], + out_dtypes: list[torch.dtype | None], + gather_dim: int, + group_name: c10d.GroupName, + return_A: bool, +) -> tuple[torch.Tensor | None, list[torch.Tensor]]: + group = c10d._resolve_process_group(group_name) + group_size = group.size() + + B_shards = [B.chunk(group.size()) for B in Bs] + + leading_dims = list(A_shard.shape[:-1]) + A_shard_flat = A_shard.flatten(0, -2) + + def unflatten(t: torch.Tensor) -> torch.Tensor: + return t.view(*leading_dims, -1) + + A_flat_out = A_shard_flat.new_empty( + A_shard_flat.shape[0] * group.size(), + A_shard_flat.shape[1], + ) + + outputs = [ + torch.empty( + (A_shard_flat.shape[0], B.shape[1]), + dtype=out_dtype or B.dtype, + device=A_shard.device, + ) + for B, out_dtype in zip(Bs, out_dtypes) + ] + + first = True + events = [torch.cuda.Event() for _ in outputs] + + def default_consumer(shard: torch.Tensor, rank: int) -> None: + nonlocal first + for out, event, B_shard, kwargs in zip(outputs, events, B_shards, kwargs_list): + event.wait() + if first: + torch.ops.aten.mm.out(shard, B_shard[rank], **kwargs, out=out) + else: + out.addmm_(shard, B_shard[rank]) + event.record() + + first = False + + _pipelined_all_gather_and_consume_last_dim( + A_shard_flat, + default_consumer, + A_flat_out, + group_name, + return_A, + ) + ret_A = None + if return_A: + # This path is inefficient and will be filtered out at passes stage + # Added only for completeness. + A_split_cat_out_flat = torch.cat(A_flat_out.chunk(group_size), dim=-1) + ret_A = unflatten(A_split_cat_out_flat) + + return ret_A, [unflatten(output) for output in outputs] + + +@torch.library.impl(lib, "fused_all_gather_matmul", "Meta") +def _fused_all_gather_matmul_fallback( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + gather_dim: int, + group_name: c10d.GroupName, + *, + return_A: bool = True, +) -> tuple[torch.Tensor | None, list[torch.Tensor]]: + group_size = c10d._get_group_size_by_name(group_name) + A = torch.ops._c10d_functional.all_gather_into_tensor( + A_shard.contiguous(), group_size, group_name + ) + A = torch.ops._c10d_functional.wait_tensor(A) + if gather_dim == A.ndim - 1 or gather_dim == -1: + A_splits = A.chunk(group_size) + A_mm = torch.cat(A_splits, dim=-1) + res = [torch.matmul(A_mm, B) for B in Bs] + if return_A: + return A_mm, res + else: + return None, res + + A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1) + res = [torch.matmul(A, B).movedim(0, gather_dim) for B in Bs] + if return_A: + return A.movedim(0, gather_dim), res + else: + return None, res + + +@torch.library.impl(lib, "fused_all_gather_matmul", "CUDA") +def _fused_all_gather_matmul( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + gather_dim: int, + group_name: c10d.GroupName, + *, + return_A: bool = True, +) -> tuple[torch.Tensor | None, list[torch.Tensor]]: + """ + Perform the following logic with micro-pipelined computation and + communication: + + all_gather_tensor(A_shard, gather_dim, group_name) @ B + + Optimal stride order for A_shard - if A_shard.movedim(gather_dim, 0) is + contiguous, no extra copy is required for input layout transformation. + Otherwise A_shard needs to be copied once. + """ + if _is_test_mode: + return _fused_all_gather_matmul_fallback( + A_shard, Bs, gather_dim, group_name, return_A=return_A + ) + + if _should_use_fused_all_gather_matmul_native(A_shard, Bs, gather_dim, group_name): + group = c10d._resolve_process_group(group_name) + leading_dims = list(A_shard.shape[:-1]) + leading_dims[0] *= group.size() + A, out = _fused_all_gather_matmul_native( + A_shard.flatten(0, -2), Bs[0], group_name + ) + return A.view(*leading_dims, -1), [out.view(*leading_dims, -1)] + + if _should_use_multimem_all_gather_matmul( + A_shard, gather_dim, group_name, return_A + ): + return None, _multimem_all_gather_matmul(A_shard, Bs, group_name) + + with torch.profiler.record_function("fused_all_gather_matmul"): + return _fused_all_gather_matmul_impl( + torch.ops.aten.mm.out, + A_shard, + Bs, + None, + [{} for B in Bs], + [B.dtype for B in Bs], + gather_dim, + group_name, + return_A, + ) + + +def _should_use_fused_all_gather_matmul_native( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + gather_dim: int, + group_name: c10d.GroupName, +) -> bool: + group = c10d._resolve_process_group(group_name) + local_M = math.prod(A_shard.shape[:-1]) + + return ( + "TORCH_SYMM_MEM_ENABLE_NATIVE_ASYNC_TP" in os.environ + and A_shard.is_contiguous() + and gather_dim == 0 + # _async_input_mm requires local_M to be divisible by world_size. + and local_M % group.size() == 0 + # _async_input_mm outperforms the decomposition-based approach when the + # global M is small. + and 2048 < local_M * group.size() <= 4096 + # _async_input_mm only supports a single B. + and len(Bs) == 1 + ) + + +def _fused_all_gather_matmul_native( + A_shard: torch.Tensor, + B: torch.Tensor, + group_name: c10d.GroupName, +) -> tuple[torch.Tensor, torch.Tensor]: + symm_mem = rendezvous(A_shard, group_name) + if symm_mem is None: + symm_mem = get_symm_mem_workspace( + group_name, A_shard.numel() * A_shard.element_size() + ) + symm_mem.barrier() + buf = symm_mem.get_buffer(symm_mem.rank, A_shard.shape, A_shard.dtype) + buf.copy_(A_shard) + A_shard = buf + + rank = symm_mem.rank + world_size = symm_mem.world_size + + current_stream = torch.cuda.current_stream() + backend_stream = _get_backend_stream(priority=-1) + + symm_mem.barrier() + backend_stream.wait_stream(current_stream) + current_stream.wait_stream(backend_stream) + + A = A_shard.new_empty(A_shard.shape[0] * world_size, A_shard.shape[1]) + A_signals = torch.zeros(world_size, dtype=torch.uint32, device=A_shard.device) + A_shards = A.chunk(world_size) + + A_shards[rank].copy_(A_shard) + if not torch.cuda.is_current_stream_capturing(): + _SymmetricMemory.stream_write_value32(A_signals, rank, 1) + else: + _SymmetricMemory.memset32(A_signals, offset=rank, val=1, count=1) + + out = torch.ops.symm_mem._async_input_mm(A, B, A_signals, rank) + for step in range(1, world_size): + src_rank = (rank + step) % world_size + src_buf = symm_mem.get_buffer(src_rank, A_shard.shape, A_shard.dtype) + with backend_stream: + A_shards[src_rank].copy_(src_buf) + if not torch.cuda.is_current_stream_capturing(): + # cuStreamWriteValue32 issues a system level fence before the write + _SymmetricMemory.stream_write_value32(A_signals, src_rank, 1) + else: + _SymmetricMemory.memset32(A_signals, offset=src_rank, val=1, count=1) + + current_stream.wait_stream(backend_stream) + backend_stream.wait_stream(current_stream) + + symm_mem.barrier() + return A, out + + +def _should_use_multimem_all_gather_matmul( + A_shard: torch.Tensor, + gather_dim: int, + group_name: c10d.GroupName, + return_A: bool, +) -> bool: + group = c10d._resolve_process_group(group_name) + local_M = math.prod(A_shard.shape[:-1]) + has_multicast_support = ( + A_shard.device.type == "cuda" + and _SymmetricMemory.has_multicast_support( + DeviceType.CUDA, A_shard.device.index + ) + ) + + return ( + has_multicast_support + and not return_A + and A_shard.is_contiguous() + and gather_dim == 0 + # The heuristic is empirical. We could refine it with a more + # sophisticated perf model. + and local_M * group.size() <= 2048 + ) + + +def _multimem_all_gather_matmul( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + group_name: c10d.GroupName, +) -> list[torch.Tensor]: + group = c10d._resolve_process_group(group_name) + A_shape = torch.Size((A_shard.shape[0] * group.size(), *A_shard.shape[1:])) + symm_mem = get_symm_mem_workspace( + group_name, A_shape.numel() * A_shard.element_size() + ) + A = symm_mem.get_buffer(symm_mem.rank, A_shape, A_shard.dtype) + torch.ops.symm_mem.multimem_all_gather_out(A_shard, group_name, A) + return [torch.matmul(A, B) for B in Bs] + + +@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "Meta") +def _fused_all_gather_scaled_matmul_fallback( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + A_scale: torch.Tensor, + B_scales: list[torch.Tensor], + gather_dim: int, + group_name: c10d.GroupName, + biases: list[torch.Tensor | None], + result_scales: list[torch.Tensor | None], + out_dtypes: list[torch.dtype | None], + use_fast_accum: list[bool], +) -> tuple[torch.Tensor, list[torch.Tensor]]: + out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes) + + group_size = c10d._get_group_size_by_name(group_name) + A = torch.ops._c10d_functional.all_gather_into_tensor( + A_shard.contiguous(), group_size, group_name + ) + A = torch.ops._c10d_functional.wait_tensor(A) + A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1) + + scale_mode = _check_and_verify_fp8_all_gather_scale_mode( + shard=A_shard, scale=A_scale, gather_dim=gather_dim, group_size=group_size + ) + if scale_mode == _ScaleMode.ROW_WISE_SHARDED: + A_scale_shard = A_scale + A_scale = torch.ops._c10d_functional.all_gather_into_tensor( + A_scale.contiguous(), group_size, group_name + ) + A_scale = torch.ops._c10d_functional.wait_tensor(A_scale) + A_scale = ( + A_scale.view(group_size, *A_scale_shard.shape) + .movedim(gather_dim + 1, 1) + .flatten(0, -2) + ) + elif scale_mode == _ScaleMode.ROW_WISE_REPLICATED: + A_scale = A_scale.movedim(gather_dim, 0).flatten(0, -2) + else: + assert scale_mode == _ScaleMode.TENSOR_WISE + + def scaled_matmul( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + bias: torch.Tensor | None, + result_scale: torch.Tensor | None, + out_dtype: torch.dtype | None, + use_fast_accum: bool, + ) -> torch.Tensor: + leading_dims = A.shape[:-1] + res = torch.ops.aten._scaled_mm( + A.flatten(0, -2), + B, + A_scale, + B_scale, + bias, + result_scale, + out_dtype=out_dtype, + use_fast_accum=use_fast_accum, + ) + return res.unflatten(0, leading_dims) + + return A.movedim(0, gather_dim), [ + scaled_matmul( + A, B, A_scale, B_scale, bias, result_scale, out_dtype, fast_accum + ).movedim(0, gather_dim) + for B, B_scale, bias, result_scale, out_dtype, fast_accum in zip( + Bs, B_scales, biases, result_scales, out_dtypes, use_fast_accum + ) + ] + + +@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "CUDA") +def _fused_all_gather_scaled_matmul( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + A_scale: torch.Tensor, + B_scales: list[torch.Tensor], + gather_dim: int, + group_name: c10d.GroupName, + biases: list[torch.Tensor | None], + result_scales: list[torch.Tensor | None], + out_dtypes: list[torch.dtype | None], + use_fast_accum: list[bool], +) -> tuple[torch.Tensor, list[torch.Tensor]]: + """ + Perform the following logic with micro-pipelined computation and + communication: + + A = all_gather_tensor(A_shard, gather_dim, group_name) + leading_dims = A.shape[:-1] + res = torch.ops.aten._scaled_mm(A.flatten(0, -2), B, A_scale, B_scale) + res = res.unflatten(0, leading_dims) + + The input `A_scale` can be tensor-wise, row-wise-sharded or + row-wise-replicated. + + Optimal stride order for `A_shard` - if `A_shard.movedim(gather_dim, 0)` is + contiguous, no extra copy is required for input layout transformation. + Otherwise A_shard needs to be copied once. + """ + out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes) + + if len(biases) != len(Bs): + raise ValueError("len(biases) must be the same as len(Bs)") + if len(result_scales) != len(Bs): + raise ValueError("len(result_scales) must be the same as len(Bs)") + if len(out_dtypes) != len(Bs): + raise ValueError("len(out_dtypes) must be the same as len(Bs)") + if len(use_fast_accum) != len(Bs): + raise ValueError("len(use_gast_accum_list) must be the same as len(Bs)") + + if _is_test_mode: + return _fused_all_gather_scaled_matmul_fallback( + A_shard, + Bs, + A_scale, + B_scales, + gather_dim, + group_name, + biases, + result_scales, + out_dtypes, + use_fast_accum, + ) + + with torch.profiler.record_function("fused_all_gather_scaled_matmul"): + A, res = _fused_all_gather_matmul_impl( + torch.ops.aten._scaled_mm.out, + A_shard, + Bs, + A_scale, + [ + { + "scale_b": B_scale, + "bias": bias, + "scale_result": result_scale, + "out_dtype": out_dtype, + "use_fast_accum": fast_accum, + } + for B_scale, bias, result_scale, out_dtype, fast_accum in zip( + B_scales, biases, result_scales, out_dtypes, use_fast_accum + ) + ], + out_dtypes, + gather_dim, + group_name, + True, + ) + assert A is not None + return A, res + + +def make_contiguous_for_perm( + t: torch.Tensor, + perm: list[int], +) -> torch.Tensor: + """ + Restride `t` such that `t.permute(perm)` is contiguous. + """ + inv_perm = [0] * len(perm) + for i, p in enumerate(perm): + inv_perm[p] = i + return t.permute(perm).contiguous().permute(inv_perm) + + +def restride_A_shard_for_fused_all_gather_matmul( + t: torch.Tensor, + gather_dim: int, +) -> torch.Tensor: + """ + Restride the `A_shard` arg of `fused_all_gather_matmul` for optimal perf. + See the doc for `fused_all_gather_matmul` for detail. + """ + perm = list(range(len(t.shape))) + perm.insert(0, perm.pop(gather_dim)) + return make_contiguous_for_perm(t, perm) + + +@torch.library.impl(lib, "fused_matmul_reduce_scatter", "CUDA") +def _fused_matmul_reduce_scatter( + A: torch.Tensor, + B: torch.Tensor, + reduce_op: str, + scatter_dim: int, + group_name: c10d.GroupName, +) -> torch.Tensor: + """ + Perform the following logic with micro-pipelined computation and + communication: + + reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name) + + Optimal stride order for A - if A.movedim(scatter_dim, 0) is contiguous, no + extra copy is required for input layout transformation. Otherwise A needs + to be copied once. + """ + if _is_test_mode: + return _fused_matmul_reduce_scatter_fallback( + A, B, reduce_op, scatter_dim, group_name + ) + + with torch.profiler.record_function("fused_matmul_reduce_scatter"): + return _fused_matmul_reduce_scatter_impl( + mm_out_op=torch.ops.aten.mm.out, + A=A, + B=B, + kwargs={}, + out_dtype=A.dtype, + reduce_op=reduce_op, + scatter_dim=scatter_dim, + group_name=group_name, + ) + + +@torch.library.impl(lib, "fused_matmul_reduce_scatter", "Meta") +def _fused_matmul_reduce_scatter_fallback( + A: torch.Tensor, + B: torch.Tensor, + reduce_op: str, + scatter_dim: int, + group_name: c10d.GroupName, +) -> torch.Tensor: + res = funcol.reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name) + res = funcol.wait_tensor(res) + return res + + +def _fused_matmul_reduce_scatter_impl( + mm_out_op: torch._ops.OpOverload, + A: torch.Tensor, + B: torch.Tensor, + kwargs: dict[str, Any], + out_dtype: torch.dtype | None, + reduce_op: str, + scatter_dim: int, + group_name: c10d.GroupName, +) -> torch.Tensor: + if A.dim() < 2: + raise ValueError("A_shard must be a matrix") + if scatter_dim < 0 or scatter_dim >= A.dim(): + raise ValueError("Invalid gather_dim") + if B.dim() != 2: + raise ValueError("B must be a matrix") + if reduce_op == "sum": + reduce_fn = partial(torch.sum, dim=0) + elif reduce_op == "avg": + reduce_fn = partial(torch.mean, dim=0) + else: + raise ValueError("reduce_op must be sum or avg") + group = c10d._resolve_process_group(group_name) + out_shape = [*A.shape[:-1], B.shape[1]] + out_shape[scatter_dim] //= group.size() + + if scatter_dim == A.ndim - 1: + B_shards = B.chunk(group.size(), dim=B.ndim - 1) + A_flat = A.flatten(0, -2) + + def _chunk_producer(rank: int, out: torch.Tensor) -> None: + mm_out_op(A_flat, B_shards[rank], **kwargs, out=out) + + leading_dims = list(A.shape[:-1]) + + stacked_partials = torch.empty( + (A_flat.shape[0], B.shape[1]), + dtype=out_dtype or A.dtype, + device=A.device, + ) + + _pipelined_produce_and_all2all( + _chunk_producer, + stacked_partials, + group_name, + out_chunk_dim=1, + ) + + stacked_partials_view = stacked_partials.reshape( + *leading_dims, group.size(), -1 + ) + return reduce_fn( + stacked_partials_view, + dim=-2, + ) + + # Move the scatter_dim to the front and flatten the tensor into a 2D matrix + x = A.movedim(scatter_dim, 0) + leading_dims = [group.size()] + list(x.shape[:-1]) + leading_dims[1] //= group.size() + x = x.flatten(0, -2) + A_shards = x.chunk(group.size()) + + # Computing block-wise matmul along the first dim of A + def chunk_producer(rank: int, out: torch.Tensor) -> None: + mm_out_op(A_shards[rank], B, **kwargs, out=out) + + stacked_partials = x.new_empty(x.shape[0], B.shape[1], dtype=out_dtype or A.dtype) + + _pipelined_produce_and_all2all( + chunk_producer, + stacked_partials, + group_name, + ) + + # Ensures that the transpose and reduction produce contiguous result + # in a single reduction kernel. + return reduce_fn( + stacked_partials.view(*leading_dims, -1) + .movedim(1, scatter_dim + 1) + .movedim(0, scatter_dim), + dim=scatter_dim, + ) + + +@torch.library.impl(lib, "fused_scaled_matmul_reduce_scatter", "CUDA") +def _fused_scaled_matmul_reduce_scatter( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: c10d.GroupName, + output_shape: list[int], + bias: torch.Tensor | None = None, + result_scale: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> torch.Tensor: + if _is_test_mode: + return _fused_scaled_matmul_reduce_scatter_fallback( + A, + B, + A_scale, + B_scale, + reduce_op, + orig_scatter_dim, + scatter_dim_after_maybe_reshape, + group_name, + output_shape, + bias, + result_scale, + out_dtype, + use_fast_accum, + ) + with torch.profiler.record_function("fused_scaled_matmul_reduce_scatter"): + return _fused_scaled_matmul_reduce_scatter_impl( + mm_out_op=torch.ops.aten._scaled_mm.out, + A=A, + B=B, + A_scale=A_scale, + kwargs={ + "scale_b": B_scale, + "bias": bias, + "scale_result": result_scale, + "out_dtype": out_dtype, + "use_fast_accum": use_fast_accum, + }, + out_dtype=out_dtype, + reduce_op=reduce_op, + orig_scatter_dim=orig_scatter_dim, + scatter_dim_after_maybe_reshape=scatter_dim_after_maybe_reshape, + group_name=group_name, + output_shape=output_shape, + ) + + +@torch.library.impl(lib, "fused_scaled_matmul_reduce_scatter", "Meta") +def _fused_scaled_matmul_reduce_scatter_fallback( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: c10d.GroupName, + output_shape: list[int], + bias: torch.Tensor | None = None, + result_scale: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> torch.Tensor: + if A_scale.numel() > 1: + if A_scale.shape[:-1] != A.shape[:-1]: + raise ValueError( + "For row-wise scaling, the leading dims of A_scale " + "must match the leading dims of A " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + A_scale = A_scale.flatten(0, -2).contiguous() + elif A_scale.numel() != 1: + raise ValueError( + "Invalid A_scale shape " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + + C = torch._scaled_mm( + A.flatten(0, -2).contiguous(), + B, + A_scale, + B_scale, + bias, + result_scale, + out_dtype, + use_fast_accum, + ) + C = C.view(*output_shape[:-1], B.shape[1]) + res = funcol.reduce_scatter_tensor( + C, + reduce_op, + orig_scatter_dim, # need original scatter dim for 3D+ output tensor here + group_name, + ) + res = funcol.wait_tensor(res) + return res + + +def _fused_scaled_matmul_reduce_scatter_impl( + mm_out_op: torch._ops.OpOverload, + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + kwargs: dict[str, Any], + out_dtype: torch.dtype | None, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: c10d.GroupName, + output_shape: list[int], +) -> torch.Tensor: + if A.dim() < 2: + raise ValueError("A_shard must be a matrix") + if ( + scatter_dim_after_maybe_reshape < 0 + or scatter_dim_after_maybe_reshape >= A.dim() + ): + raise ValueError("Invalid scatter dim for 2D tensor input to scaled_mm") + if orig_scatter_dim < 0 or orig_scatter_dim >= len(output_shape): + raise ValueError("Invalid scatter dim for 3D+ output tensor") + if B.dim() != 2: + raise ValueError("B must be a matrix") + if reduce_op == "sum": + reduce_fn = partial(torch.sum, dim=0) + elif reduce_op == "avg": + reduce_fn = partial(torch.mean, dim=0) + else: + raise ValueError("reduce_op must be sum or avg") + + group = c10d._resolve_process_group(group_name) + + # Move scatter to first dim, then shard the tensor along the first dim, so the chunk producer + # can perform matmuls along the first dim. + A_with_scatter_dim_0 = A.movedim(scatter_dim_after_maybe_reshape, 0) + + # To handle case where A is 3D+, reshape to 2D to prepare for mm which requires 2D inputs. + A_2D_with_scatter_dim_0 = A_with_scatter_dim_0.flatten(0, -2) + + # Partition A along the first dim to prepare for sharding across TP process group. + A_shards = A_2D_with_scatter_dim_0.chunk(group.size()) + + # Now that 'A' is sharded along the first dim, we need to update its scale(s) accordingly. + # How we do this depends on if we are using tensorwise scaling, rowwise scaling, or no scaling. + tensorwise_scaling = A_scale is not None and A_scale.numel() == 1 + rowwise_scaling = A_scale is not None and A_scale.numel() > 1 + + # For tensorwise scaling, the scale should be replicated so each shard has a copy. + if tensorwise_scaling: + A_scale_shards = [A_scale] * group.size() + + # For rowwise scaling, we need to move the scatter dim to the first dim to match the + # dim swap of the 'A' tensor. Then we can shard the scales along the first dim, just like + # the 'A' tensor. + elif rowwise_scaling: + if A_scale.shape[:-1] != A.shape[:-1]: + raise ValueError( + "For row-wise scaling, the leading dims of A_scale " + "must match the leading dims of A " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + A_scale = ( + A_scale.movedim(scatter_dim_after_maybe_reshape, 0) + .contiguous() + .flatten(0, -2) + ) + A_scale_shards = list(A_scale.chunk(group.size())) + # cuBLAS's row-wise kernel requires scales to be aligned to 16 bytes. + # When we slice them we might break this and need to reallocate them. + A_scale_shards = [ + t if t.data_ptr() % 16 == 0 else t.clone() for t in A_scale_shards + ] + else: + raise ValueError("A_scale cannot be none for scaled_mm") + + # Computing block-wise matmul along the first dim of A + def chunk_producer(rank: int, out: torch.Tensor) -> None: + mm_out_op(A_shards[rank], B, scale_a=A_scale_shards[rank], **kwargs, out=out) + + # Stacked partials will be the 2D outputs of the pipelined scaled mm, and will + # have the shape (A_with_scatter_dim_0_tensor.shape[0], B.shape[1]) to align with the formula: + # (a*b,c) @ (c,d) = (a*b,d) + stacked_partials = A_with_scatter_dim_0.new_empty( + A_2D_with_scatter_dim_0.shape[0], B.shape[1], dtype=out_dtype or A.dtype + ) + + # Execute the pipelined mm/scaled_mm. + _pipelined_produce_and_all2all( + chunk_producer, + stacked_partials, + group_name, + ) + + # We now need to transform the *unreduced* stacked 2D partial mm outputs to an *unreduced* 3D+ output, + # then reduce-scatter. To do this, we first need to determine the shape of the unreduced 3D+ output, + # to reshape our stacked partials so we can apply the reduce-scatter. + # + # The *unreduced* 3D+ tensor will have dim 0 = `group_size`, as we have `group_size` instances of + # stacked partial outputs. The next dims will be A's leading dims (sharded along the original scatter dim), + # as it was the left operand of the mm op. We can use -1 as the final dim of the view to populate the rest. + stacked_partials_3D_leading_dims = [group.size()] + list( + # We use A from after the dim swap 0<=>scatter_dim, but before the flatten, + # to get the leading dims of the 3D+ view of stacked partials. + A_with_scatter_dim_0.shape[:-1] + ) + + # The `group_size` leading dim has been prepended to `stacked_partials_3D_leading_dims`, + # to capture the partial output from each rank. We need to divide the sharding/scatter dim + # by the group size. If the original scatter dim was 0, then it is now dim 1 in this + # tensor, since this new `group_size` dim was prepended. + stacked_partial_scatter_dim = orig_scatter_dim if orig_scatter_dim > 0 else 1 + stacked_partials_3D_leading_dims[stacked_partial_scatter_dim] //= group.size() + + # Ensures that the transpose and reduction produce contiguous result + # in a single reduction kernel. + reduced_out = reduce_fn( + # View 2D stacked partials as 3D+ tensor of shape (`group_size`, ...) + stacked_partials.view(*stacked_partials_3D_leading_dims, -1) + # We originally swapped 0<=>scatter_dim_after_maybe_reshape. Now after + # prepending the `group_size` dim, to undo this original swap, we + # must swap 1<=>scatter_dim_after_maybe_reshape+1. + .movedim(1, scatter_dim_after_maybe_reshape + 1), + # Reduce along the `group_size` dim (0). + dim=0, + ) + + # Output shape must be scattered along original scatter dim as well. + output_shape[orig_scatter_dim] //= group.size() + out = reduced_out.view(*output_shape) + return out + + +def restride_A_for_fused_matmul_reduce_scatter( + t: torch.Tensor, + scatter_dim: int, +) -> torch.Tensor: + """ + Restride the `A_shard` arg of `fused_matmul_reduce_scatter` for optimal + perf. See the doc for `fused_matmul_reduce_scatter` for detail. + """ + perm = list(range(len(t.shape))) + perm.insert(0, perm.pop(scatter_dim)) + return make_contiguous_for_perm(t, perm) + + +def _maybe_convert_scalar_types_to_dtypes( + scalar_types: list[Any], +) -> list[torch.dtype | None]: + """ + When a list of `torch.dtype`s is passed through the dispatcher as + `ScalarType[]`, it is converted to a list of scalar type enum values. This + function converts it back to a list of `torch.dtype`s. + """ + # Order defined in https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h + _SCALAR_TYPE_TO_DTYPE = { + 0: torch.uint8, + 1: torch.int8, + 2: torch.short, + 3: torch.int, + 4: torch.int64, + 5: torch.half, + 6: torch.float, + 7: torch.double, + 8: torch.complex32, + 9: torch.complex64, + 10: torch.complex128, + 11: torch.bool, + 12: torch.qint8, + 13: torch.quint8, + 14: torch.qint32, + 15: torch.bfloat16, + 16: torch.float8_e5m2, + 17: torch.float8_e4m3fn, + 18: torch.float8_e5m2fnuz, + 19: torch.float8_e4m3fnuz, + } + if any(not isinstance(x, (type(None), int)) for x in scalar_types): + return scalar_types + + dtypes: list[torch.dtype | None] = [] + for scalar_type in scalar_types: + if scalar_type is None: + dtypes.append(scalar_type) + elif scalar_type not in _SCALAR_TYPE_TO_DTYPE: + raise ValueError(f"Unrecognized scalar type {scalar_type}") + else: + dtypes.append(_SCALAR_TYPE_TO_DTYPE[scalar_type]) + return dtypes + + +class Work(_Work): + def __init__(self) -> None: + super().__init__() + self.event = torch.cuda.Event() + self.event.record() + + def wait(self, timeout: timedelta = timedelta(seconds=0)) -> bool: + self.event.wait() + return True + + +""" +NOTE [low-contention collectives] +When a collective is overlapped with abundant compute, it makes sense to +prioritize reducing the contention between the collective and the overlapped +compute, even at the cost of a slightly slower collective. + +Common collective implementations (e.g., NCCL without user buffer +registration) optimize for throughput with no ambient compute. However, such +implementations may not be optimal when they are overlapped with compute: +- These implementations typically fuse the entire collective into a single +kernel and reserve SM resources based on the most demanding portion of the +collective, even when a large portion of the collective does not require this +much resource. +- These implementations often use SM-based P2P copy as opposed to copy +engine-based P2P copy. Copy engine-based P2P copy may not have a significant +advantage when there's no ambient compute. However, it may significantly +improve overall resource utilization in the presence of ambient compute. + +When overlapped with intensive compute (e.g., persistent matmul kernels), the +SM-usage of a collective can lead to inefficient overlapping. + +Low-contention collectives achieve their goals with the following strategies: +- Use copy engine-based copy whenever possible. +- Break down portions of a collective with different resource requirements +into multiple kernels. This improves the overlapping efficiency at the cost +of additional launching overhead. +""" + + +@torch.library.impl(lib, "_low_contention_all_gather", "Meta") +def _low_contention_all_gather_meta( + tensor: torch.Tensor, + group_name: c10d.GroupName, +) -> torch.Tensor: + group_size = c10d._get_group_size_by_name(group_name) + return tensor.new_empty(tensor.shape[0] * group_size, *tensor.shape[1:]) + + +@torch.library.impl(lib, "_low_contention_all_gather", "CUDA") +def _low_contention_all_gather( + tensor: torch.Tensor, + group_name: c10d.GroupName, +) -> torch.Tensor: + """ + Performs all-gather with symmetric memory in a low-contention fashion. + + When `tensor` is already in symmetric memory: + - The collective is carried out without using SMs. + - No symmetric memory workspace is required. + + When `tensor` is not in symmetric memory: + - An extra SM-based copy is performed to copy the input data into the + symmetric memory workspace. + - Symmetric memory workspace size requirement: the size of `tensor`. + """ + symm_mem = rendezvous(tensor, group_name) + if symm_mem is not None: + input_is_symm_mem = True + else: + symm_mem = get_symm_mem_workspace( + group_name, tensor.numel() * tensor.element_size() + ) + input_is_symm_mem = False + + rank = symm_mem.rank + world_size = symm_mem.world_size + + output = tensor.new_empty(tensor.shape[0] * world_size, *tensor.shape[1:]) + chunks = output.chunk(world_size) + + _get_backend_stream().wait_stream(torch.cuda.current_stream()) + with _get_backend_stream(): + if not input_is_symm_mem: + local_buf = symm_mem.get_buffer(rank, tensor.shape, tensor.dtype) + local_buf.copy_(tensor) + # pull + symm_mem.barrier() + for step in range(world_size): + remote_rank = (rank - step) % world_size + src_buf = symm_mem.get_buffer(remote_rank, tensor.shape, tensor.dtype) + chunks[remote_rank].copy_(src_buf) + symm_mem.barrier() + torch._C._distributed_c10d._register_work(output, Work()) + return output + + +@torch.library.impl(lib, "_low_contention_reduce_scatter", "Meta") +def _low_contention_reduce_scatter_meta( + tensor: torch.Tensor, + reduce_op: str, + group_name: c10d.GroupName, +) -> torch.Tensor: + group_size = c10d._get_group_size_by_name(group_name) + return tensor.unflatten(0, (group_size, -1)).mean(dim=0) + + +def _low_contention_reduce_scatter_with_symm_mem_input( + tensor: torch.Tensor, + reduce_op: str, + symm_mem: _SymmetricMemory, +) -> torch.Tensor: + rank = symm_mem.rank + world_size = symm_mem.world_size + + assert tensor.shape[0] % world_size == 0 + a2a_res = torch.empty_like(tensor) + chunks = a2a_res.chunk(world_size) + + _get_backend_stream().wait_stream(torch.cuda.current_stream()) + with _get_backend_stream(): + # pull + offline reduction + symm_mem.barrier() + for step in range(world_size): + remote_rank = (rank - step) % world_size + src_buf = symm_mem.get_buffer( + remote_rank, + chunks[0].shape, + chunks[0].dtype, + chunks[0].numel() * rank, + ) + chunks[remote_rank].copy_(src_buf) + symm_mem.barrier() + + ret = a2a_res.unflatten(0, (world_size, -1)) + if reduce_op == "sum": + ret = ret.sum(dim=0) + elif reduce_op == "avg": + ret = ret.mean(dim=0) + else: + raise ValueError(f"reduce_op ({reduce_op}) is not supported") + torch._C._distributed_c10d._register_work(ret, Work()) + return ret + + +def _low_contention_reduce_scatter_with_workspace( + tensor: torch.Tensor, + reduce_op: str, + workspace: _SymmetricMemory, +) -> torch.Tensor: + rank = workspace.rank + world_size = workspace.world_size + + assert tensor.shape[0] % world_size == 0 + chunks = tensor.chunk(world_size) + + _get_backend_stream().wait_stream(torch.cuda.current_stream()) + with _get_backend_stream(): + # push + offline reduction + workspace.barrier() + for step in range(world_size): + remote_rank = (rank - step) % world_size + dst_buf = workspace.get_buffer( + remote_rank, chunks[0].shape, chunks[0].dtype, chunks[0].numel() * rank + ) + dst_buf.copy_(chunks[remote_rank]) + workspace.barrier() + + buf = workspace.get_buffer(rank, tensor.shape, tensor.dtype) + ret = buf.unflatten(0, (world_size, -1)) + if reduce_op == "sum": + ret = ret.sum(dim=0) + elif reduce_op == "avg": + ret = ret.mean(dim=0) + else: + raise ValueError(f"reduce_op ({reduce_op}) is not supported") + torch._C._distributed_c10d._register_work(ret, Work()) + return ret + + +@torch.library.impl(lib, "_low_contention_reduce_scatter", "CUDA") +def _low_contention_reduce_scatter( + tensor: torch.Tensor, + reduce_op: str, + group_name: c10d.GroupName, +) -> torch.Tensor: + """ + Performs reduce-scatter with symmetric memory in a low-contention fashion. + + This implementation performs a P2P-based all-to-all followed by an offline + reduction. + + When `tensor` is already in symmetric memory: + - Pull-based all-to-all is used. + - No symmetric memory workspace is required. + + When `tensor` is not in symmetric memory: + - Push-based all-to-all is used. + - Symmetric memory workspace size requirement: the size of `tensor`. + + SM-usage: + - SM-based copy of the rank's own chunk for the all-to-all. + - Reduction on the all-to-all result. + + TODO(yifu): the SM-based copy can be avoided with a list-based reduction + kernel. + """ + symm_mem = rendezvous(tensor, group_name) + if symm_mem is not None: + return _low_contention_reduce_scatter_with_symm_mem_input( + tensor, reduce_op, symm_mem + ) + else: + workspace = get_symm_mem_workspace( + group_name, tensor.numel() * tensor.element_size() + ) + return _low_contention_reduce_scatter_with_workspace( + tensor, reduce_op, workspace + ) + + +@torch.library.impl(lib, "all_to_all_vdev_2d", "Meta") +def _all_to_all_vdev_2d_meta( + input: torch.Tensor, + out: torch.Tensor, + in_splits: torch.Tensor, + out_splits_offsets: torch.Tensor, + group_name: c10d.GroupName, + major_align: int | None = None, +) -> None: + return None + + +@torch.library.impl(lib, "all_to_all_vdev_2d_offset", "Meta") +def _all_to_all_vdev_2d_offset_meta( + input: torch.Tensor, + out: torch.Tensor, + in_splits_offsets: torch.Tensor, + out_splits_offsets: torch.Tensor, + group_name: c10d.GroupName, +) -> None: + return None + + +# ============================================================================= +# User-facing APIs +# ============================================================================= + + +from collections.abc import Sequence +from typing import overload, TYPE_CHECKING, Union + + +if TYPE_CHECKING: + from torch._C._distributed_c10d import ProcessGroup + from torch.types import _device, _dtype, _int + + +@overload +def empty( + *size: _int, dtype: _dtype | None = None, device: _device | None = None +) -> torch.Tensor: ... + + +@overload +# pyrefly: ignore [inconsistent-overload] +def empty( + size: Sequence[_int], + *, + dtype: _dtype | None = None, + device: _device | None = None, +) -> torch.Tensor: ... + + +def empty( # type: ignore[misc] + *size: Any, + dtype: _dtype | None = None, + device: _device | None = None, +) -> torch.Tensor: + r""" + Similar to :func:`torch.empty()`. The returned tensor can be used by + :func:`torch._distributed._symmetric_memory.rendezvous()` to establish a + symmetric memory tensor among participating processes. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + """ + if len(size) == 1 and isinstance(size[0], Sequence): + size = tuple(size[0]) + else: + size = tuple(size) + + if dtype is None: + dtype = torch.get_default_dtype() + + if device is None: + device = torch.get_default_device() + + return _SymmetricMemory.empty_strided_p2p( + size=size, + stride=torch._prims_common.make_contiguous_strides_for(size), + dtype=dtype, + device=torch.device(device), + ) + + +def rendezvous( + tensor: torch.Tensor, group: Union[c10d.GroupName, ProcessGroup] +) -> _SymmetricMemory: + r""" + rendezvous(tensor, group) -> _SymmetricMemory + + Establish a symmetric memory tensor among participating processes. This is + a collective operation. + + Args: + tensor (:class:`torch.Tensor`): the local tensor used to establish the symmetric memory tensor. + It must be allocated via :func:`torch._distributed._symmetric_memory.empty()`. The shape, + dtype, and device type must be identical across all participating processes. + group (Union[str, :class:`torch.distributed.ProcessGroup`]): The group identifying the + participating processes. This can be either a group name or a process group object. + """ + from torch._C._distributed_c10d import ProcessGroup + + if isinstance(group, str): + group_name = c10d.GroupName(group) + elif isinstance(group, ProcessGroup): + group_name = group.group_name + else: + raise TypeError(f"rendezvous: unsupported group type: {type(group)}") + + enable_symm_mem_for_group(group_name) + return _SymmetricMemory.rendezvous(tensor, group_name) + + +def is_nvshmem_available() -> bool: + r""" + is_nvshmem_available() -> bool + + Check if NVSHMEM is available in current build and on current system. + """ + try: + from torch._C._distributed_c10d import _is_nvshmem_available + except ImportError: + # Not all builds have NVSHMEM support. + return False + + # Check if NVSHMEM is available on current system. + return _is_nvshmem_available() + + +def set_backend(name: Literal["NVSHMEM", "CUDA", "NCCL"]) -> None: + r""" + Set the backend for symmetric memory allocation. This is a global setting + and affects all subsequent calls to + :func:`torch._distributed._symmetric_memory.empty()`. Note that the backend + cannot be changed once a symmetric memory tensor has been allocated. + + Args: + backend (str): the backend for symmetric memory allocation. Currently, + only `"NVSHMEM"`, `"CUDA"`, `"NCCL"` are supported. + """ + _SymmetricMemory.set_backend(name) + + +def get_backend(device: _device) -> str | None: + r""" + Get the backend for symmetric memory allocation for a given device. If not + found, return None. + + Args: + device (`torch.device` or str): the device for which to get the backend. + """ + return _SymmetricMemory.get_backend(torch.device(device)) + + +def get_mempool_allocator(device: _device): # type: ignore[no-untyped-def] + r""" + Get the MemPool allocator for symmetric memory for a given device. + + Args: + device (`torch.device` or str): the device for which to get the MemPool + allocator. + """ + return _SymmetricMemory.get_mempool_allocator(torch.device(device)) + + +def set_signal_pad_size(size: int) -> None: + r""" + Set the signal pad size for future symmetric memory allocations. + + Signal pads are P2P-accessible memory regions used for synchronization in + symmetric memory. This function allows users to configure + the signal pad size to be proportional to their workload requirements. + + .. warning:: + This must be called before any symmetric memory allocations are made. + The size cannot be changed after allocations have been performed. + + Args: + size (int): the signal pad size in bytes. The size should be + proportional to the number of blocks launched and the world size. + + Example:: + + >>> # doctest: +SKIP + >>> # Set a larger signal pad size before any allocations + >>> torch.distributed._symmetric_memory.set_signal_pad_size(1024 * 1024) # 1MB + """ + _SymmetricMemory.signal_pad_size = size + + +def get_signal_pad_size() -> int: + r""" + Get the current signal pad size for symmetric memory allocations. + + Returns the user-configured size if set via :func:`set_signal_pad_size`, + otherwise returns the default size. + + Returns: + int: the signal pad size in bytes. + + Example:: + + >>> # doctest: +SKIP + >>> size = torch.distributed._symmetric_memory.get_signal_pad_size() + >>> print(f"Signal pad size: {size} bytes") + """ + return _SymmetricMemory.signal_pad_size + + +# An internal map from device to the symmetric memory pool for that device. +_symm_mem_pools: dict[_device, torch.cuda.MemPool] = {} + + +def get_mem_pool(device: _device) -> torch.cuda.MemPool: + """ + Get the symmetric memory pool for a given device. If not found, create a new + pool. + + The tensor allocations with this pool must be symmetric across ranks. The + allocated tensors can be used with symmetric operations, for example, + operations defined under `torch.ops.symm_mem`. + + Args: + device (`torch.device` or str): the device for which to get the symmetric memory pool. + + Returns: + `torch.cuda.MemPool`: the symmetric memory pool for the given device. + + Example:: + + >>> # doctest: +SKIP + >>> pool = torch.distributed._symmetric_memory.get_mem_pool("cuda:0") + >>> with torch.cuda.use_mem_pool(pool): + >>> tensor = torch.randn(1000, device="cuda:0") + >>> tensor = torch.ops.symm_mem.one_shot_all_reduce(tensor, "sum", group_name) + + """ + # This function is a wrapper around the `torch.cuda.MemPool` constructor. + # Due to special requirements of SymmetricMemory, we preset certain options for the pool. + # - use_on_oom=False: we don't want to lend the space of the pool for + # non-symmetric allocations because this could desync the allocation state + # across ranks. + # - no_split=True: we don't want to split segments, because today a segment + # is associated with a signal pad, if two allocated tensors share a segment + # and their kernels concurrently use (the same) signal pad, this could cause + # undefined behaviors. We could consider relaxing this in the future if we + # establish stream tracking and implicit synchronization around an + # allocation. + if device not in _symm_mem_pools: + allocator = get_mempool_allocator(device) + # Create a new pool with the given allocator and the preset options. + _symm_mem_pools[device] = torch.cuda.MemPool( + allocator, + use_on_oom=False, + no_split=True, + ) + + return _symm_mem_pools[device] + + +__all__ = [ + "empty", + "rendezvous", + "is_nvshmem_available", + "set_backend", + "get_backend", + "set_signal_pad_size", + "get_signal_pad_size", + "get_mem_pool", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_symmetric_memory/_nvshmem_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..3ca8bc95eae39ebefdd97f8805b97099a82e9a92 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -0,0 +1,1220 @@ +import logging +import os +import subprocess +import sysconfig +from typing import Any + +import torch.distributed as dist +from torch.utils._triton import has_triton + + +logger = logging.getLogger(__name__) + + +class NvshmemLibFinder: + """ + A class to find path to the NVSHMEM device library. + + Environment variable: + + `NVSHMEM_LIB_DIR` (Optional[str]): The directory where the NVSHMEM device + library is located. If not provided, it will use the default path where + NVSHMEM wheel is installed, or search for the library in common system + paths. + """ + + # Class variable to store the found library path for reuse + found_device_lib_path: str | None = None + + @classmethod + def find_device_library(cls) -> str: + """ + Find the path to the NVSHMEM device library. + + Returns: + str: The path to libnvshmem_device.bc (included). + """ + if cls.found_device_lib_path is not None: + # Return the cached path if it exists + return cls.found_device_lib_path + + # First, check if the user has specified a custom library path + user_lib_dir = os.environ.get("NVSHMEM_LIB_DIR", None) + if user_lib_dir is not None: + lib_path = os.path.join(user_lib_dir, "libnvshmem_device.bc") + if not os.path.exists(lib_path): + raise RuntimeError( + f"NVSHMEM device library not found at specified path: {user_lib_dir}" + ) + cls.found_device_lib_path = lib_path + return lib_path + + # Otherwise, search for the library in the default installation paths + paths = [ + os.path.join(sysconfig.get_path("purelib"), "nvidia", "nvshmem", "lib") + ] + + # Add common system installation paths + common_paths = [ + "/usr/local/lib", + "/usr/lib", + "/opt/nvidia/nvshmem/lib", + ] + paths.extend(common_paths) + + try: + import torch + + torch_lib = os.path.join(os.path.dirname(torch.__file__), "lib") + so_path = os.path.join(torch_lib, "libtorch_nvshmem.so") + + if os.path.exists(so_path): + try: + result = subprocess.run( + ["readelf", "-d", so_path], + capture_output=True, + text=True, + check=True, + ) + + for line in result.stdout.splitlines(): + if ("RPATH" in line or "RUNPATH" in line) and "[" in line: + rpath = line.split("[", 1)[1].split("]", 1)[0] + for p in rpath.split(":"): + p = p.strip().replace("$ORIGIN", torch_lib) + if p and p not in paths: + paths.append(p) + except subprocess.CalledProcessError: + pass + + except ImportError: + pass + + for path in paths: + device_lib = os.path.join(path, "libnvshmem_device.bc") + if os.path.exists(device_lib): + cls.found_device_lib_path = device_lib + return device_lib + + raise RuntimeError(f"NVSHMEM device library not found. Searched: {paths}") + + +def enable_triton(lib_dir: str | None = None) -> dict[str, str]: + raise NotImplementedError( + "`enable_triton` is deprecated. " + "If you need NVSHMEM device function support for Triton, " + "please use `@requires_nvshmem` to decorate your Triton kernel. ", + ) + + +class NvshmemKernelRegistry: + """ + A class to register kernel functions that ** require NVSHMEM initialization ** + """ + + # Class variable to store the functions to be initialized + _to_init: dict[str, Any] = {} + + @classmethod + def register(cls, name: str) -> None: + """ + Register a kernel function with the given name. + + Args: + name (str): The name of the kernel function. + """ + cls._to_init.setdefault(name) + + @classmethod + def deregister(cls, name: str) -> None: + """ + Deregister a kernel function with the given name. + + Args: + name (str): The name of the kernel function. + """ + cls._to_init.pop(name, None) + + @classmethod + def has(cls, name: str) -> bool: + """ + Check if a kernel function with the given name is registered. + + Args: + name (str): The name of the kernel function. + + Returns: + bool: True if the kernel function is registered, False otherwise. + """ + return name in cls._to_init + + +def _nvshmem_init_hook(*args, **kwargs) -> None: # type: ignore[no-untyped-def] + """ + A hook function to initialize the CUModule created by `triton.jit` with + NVSHMEM device context + """ + from torch._C._distributed_c10d import _nvshmemx_cumodule_init + + jit_function = kwargs["fn"].jit_function + fn_name = jit_function.fn.__name__ + + # Only initialize NVSHMEM module for kernels registered via @requires_nvshmem + if NvshmemKernelRegistry.has(fn_name): + key = kwargs["key"] + device = kwargs["compile"]["device"] + jit_function = kwargs["fn"].jit_function + kernel_cache = jit_function.device_caches[device][0] + kernel = kernel_cache.get(key, None) + if kernel is not None: + kernel.run + # Initialize NVSHMEM for the CU module + _nvshmemx_cumodule_init(kernel.module) + else: + logger.warning( + f"It seems Triton hasn't created a kernel for function {fn_name}. " # noqa: G004 + "Please report this issue to Triton." + ) + + +if has_triton(): + from triton.runtime.jit import JITFunction, KernelInterface + + # Create a new Callable class that follows the KernelInterface protocol so + # that the Callable works with the subscript operator, e.g. `foo[(1, 1)]` + class GridCallableWithExtern(KernelInterface): + """ + `KernelInterface` invokes `self.run` in `__getitem__`, i.e. []. We + implement a `run` method by directing the call to `JITFunction.run`, + with added extern_libs kwarg, so that users don't have to pass it + """ + + def __init__(self, jit_func: JITFunction, extern_libs: dict[str, str]) -> None: + self.jit_func = jit_func + self.extern_libs = extern_libs + + def run(self, *args, **kwargs): # type: ignore[no-untyped-def] + # Call the JITFunction.run with added extern_libs kwarg + return self.jit_func.run(*args, **kwargs, extern_libs=self.extern_libs) + + +def requires_nvshmem( # type: ignore[no-untyped-def] + jit_func, # JITFunction created by triton.jit +): + """ + A decorator to register a Triton kernel function that requires NVSHMEM initialization. + + Example usage: + ``` + @requires_nvshmem + @triton.jit + def foo(...): + ... + ``` + + If you would like to specify a path to the NVSHMEM device library other + than standard search locations, you can use the following environment + variable: + ``` + export NVSHMEM_LIB_DIR=/path/to/nvshmem/lib + ``` + """ + + import triton + from triton.runtime.jit import JITFunction + + if not isinstance(jit_func, JITFunction): + raise TypeError(f"Expected a JITFunction, but got {type(jit_func)}") + + # Find the NVSHMEM device library + lib_path = NvshmemLibFinder.find_device_library() + extern_libs = {"libnvshmem_device": lib_path} + + # Register the JITFunction with the kernel registry as "to be initialized" + NvshmemKernelRegistry.register(jit_func.fn.__name__) + + # Register the NVSHMEM init function as a post-compile hook. + # [Note] This is a global setting (due to lack of Triton API exposure). To + # avoid initializing Triton kernels that do not require NVSHMEM, filtering + # is performed in the hook function itself by checking against + # NvshmemKernelRegistry. + triton.knobs.runtime.jit_post_compile_hook = _nvshmem_init_hook + + return GridCallableWithExtern(jit_func, extern_libs) + + +if has_triton(): + import triton + import triton.language as tl + from triton.language import core + + @triton.jit # type: ignore[misc] + def put(dest, source, nelems, pe): # type: ignore[no-untyped-def] + """ + Put tensor data from local PE to a remote PE. + + This high-level function provides a tensor-aware interface for NVSHMEM put + operations. It automatically handles type checking and size calculations, making + the API more ergonomic and type-safe. + + Args: + dest: Destination tensor on the remote PE. Type must match source. + source: Source tensor on the local PE containing data to be copied. + nelems: Number of elements to transfer. + pe: PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). + + Notes: + - Performs compile-time type checking between dest and source tensors. + - Automatically calculates byte size from tensor type and element count. + - This is a blocking operation that returns after data has been copied out + of the source array on the local PE. + - The operation does not guarantee delivery to the destination PE. + Use nvshmem_fence() for ordering or nvshmem_quiet() for completion. + + Example: + ``` + # Transfer 100 elements to PE 1 + nvshmem.put(dest_tensor, src_tensor, 100, 1) + ``` + """ + tl.static_assert(dest.type == source.type) + nbytes = nelems * dest.type.element_ty.itemsize + return putmem_block_extern_wrapper( + dest.to(tl.int64), source.to(tl.int64), nbytes.to(tl.int64), pe + ) + + @core.extern + def putmem_block_extern_wrapper(dest, source, size_bytes, pe, _semantic=None): # type: ignore[no-untyped-def] + """Low-level extern wrapper for NVSHMEM put""" + return core.extern_elementwise( + "", + "", + [dest, source, size_bytes, pe], + { + ( + core.dtype("int64"), # dest ptr + core.dtype("int64"), # source ptr + core.dtype("int64"), # size in bytes + core.dtype("int32"), # pe number + ): ("nvshmemx_putmem_block", core.dtype("int32")) + }, + is_pure=False, + _semantic=_semantic, + ) + + @triton.jit # type: ignore[misc] + def get(dest, source, nelems, pe): # type: ignore[no-untyped-def] + """ + Get tensor data from a remote PE to local PE. + + This high-level function provides a tensor-aware interface for NVSHMEM get + operations. It automatically handles type checking and size calculations, making + the API more ergonomic and type-safe. + + Args: + dest: Destination tensor on the local PE. Type must match source. + source: Source tensor on the remote PE containing data to be copied. + nelems: Number of elements to transfer. + pe: PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). + + Notes: + - Performs compile-time type checking between dest and source tensors. + - Automatically calculates byte size from tensor type and element count. + - This is a blocking operation that returns after data has been delivered + to the destination array on the local PE. + - The destination data is guaranteed to be available for use after the call returns. + + Example: + ``` + # Get 100 elements from PE 0 + nvshmem.get(dest_tensor, src_tensor, 100, 0) + ``` + """ + tl.static_assert(dest.type == source.type) + nbytes = nelems * dest.type.element_ty.itemsize + return getmem_block_extern_wrapper( + dest.to(tl.int64), source.to(tl.int64), nbytes.to(tl.int64), pe + ) + + @core.extern + def getmem_block_extern_wrapper(dest, source, size_bytes, pe, _semantic=None): # type: ignore[no-untyped-def] + """Low-level extern wrapper for NVSHMEM get""" + return core.extern_elementwise( + "", + "", + [dest, source, size_bytes, pe], + { + ( + core.dtype("int64"), # dest ptr + core.dtype("int64"), # source ptr + core.dtype("int64"), # size in bytes + core.dtype("int32"), # pe number + ): ("nvshmemx_getmem_block", core.dtype("int32")) + }, + is_pure=False, + _semantic=_semantic, + ) + + @triton.jit # type: ignore[misc] + def get_nbi(dest, source, nelems, pe): # type: ignore[no-untyped-def] + """ + Get tensor data from a remote PE to local PE, non-blocking. + + Different from the `get` function, this function returns after + initiating the operation. The operation is considered complete after a + subsequent call to `quiet`. + + Args: + dest: Destination tensor on the local PE. Type must match source. + source: Source tensor on the remote PE containing data to be copied. + nelems: Number of elements to transfer. + pe: PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). + + Notes: + - Performs compile-time type checking between dest and source tensors. + - Automatically calculates byte size from tensor type and element count. + + Example: + ``` + # Get 100 elements from PE 0 + nvshmem.get_nbi(dest, src, 100, 0) + # Some independent computation which overlaps with the get operation + ... + # Wait for completion of the get operation + nvshmem.quiet() + ``` + """ + tl.static_assert(dest.type == source.type) + nbytes = nelems * dest.type.element_ty.itemsize + return getmem_block_extern_wrapper( + dest.to(tl.int64), source.to(tl.int64), nbytes.to(tl.int64), pe + ) + + @core.extern + def getmem_nbi_block_extern_wrapper(dest, source, size_bytes, pe, _semantic=None): # type: ignore[no-untyped-def] + """Low-level extern wrapper for NVSHMEM get""" + return core.extern_elementwise( + "", + "", + [dest, source, size_bytes, pe], + { + ( + core.dtype("int64"), # dest ptr + core.dtype("int64"), # source ptr + core.dtype("int64"), # size in bytes + core.dtype("int32"), # pe number + ): ("nvshmemx_getmem_nbi_block", core.dtype("int32")) + }, + is_pure=False, + _semantic=_semantic, + ) + + @triton.jit # type: ignore[misc] + def putmem_signal_block( # type: ignore[no-untyped-def] + dst, + src, + size_bytes, + signal, + sig_val, + sig_op, + pe, + ): # type: ignore[no-untyped-def] + """ + Put data to remote PE with atomic signal operation using block-scoped operation. + + This function copies data from the local PE to the remote PE and then + atomically updates a signal variable on the remote PE to indicate completion. + This enables efficient point-to-point synchronization between PEs. + + Args: + dst (tensor): A tensor on calling PE symmetric to the destination tensor on remote PE. + src (tensor): Local tensor containing the source data. + size_bytes (int64): Number of bytes to transfer. Must be positive. + signal (tensor): Symmetric signal pad with remote PE. + Must be 8-byte aligned symmetric memory. + signal (int64): Value to be used in the signal operation. + sig_op (int32): Signal operation type. Common values: + - NVSHMEM_SIGNAL_SET (0): Atomic set operation + - NVSHMEM_SIGNAL_ADD (5): Atomic add operation + pe (int32): PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a blocking operation that returns after data has been copied out + of the source array and the signal has been updated on the remote PE. + - The signal update is performed atomically with respect to other signal + operations and synchronization routines. + - The signal variable must be of type uint64_t in symmetric memory. + - Use with nvshmem_signal_wait_until() for synchronization. + + Example: + ``` + # Transfer data and set completion flag to 1 + NVSHMEM_SIGNAL_SET = 0 + nvshmem.putmem_signal_block( + dst_ptr, src_ptr, 1024, sig_ptr, 1, NVSHMEM_SIGNAL_SET, target_pe + ) + ``` + """ + # Ensure sig_val is 64 bits + sig_val = 0 << 32 | sig_val + return putmem_signal_block_extern_wrapper( + dst.to(tl.int64), + src.to(tl.int64), + size_bytes.to(tl.int64), + signal.to(tl.int64), + sig_val.to(tl.uint64), + sig_op, + pe, + ) + + @core.extern + def putmem_signal_block_extern_wrapper( # type: ignore[no-untyped-def] + dst, + src, + size_bytes, + signal, + sig_val, + sig_op, + pe, + _semantic=None, + ): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [dst, src, size_bytes, signal, sig_val, sig_op, pe], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("uint64"), + core.dtype("int32"), + core.dtype("int32"), + ): ("nvshmemx_putmem_signal_block", core.dtype("int32")) + }, + is_pure=False, + _semantic=_semantic, + ) + + # Wait and Signal Operations + + @triton.jit # type: ignore[misc] + def wait_until(ivar, cmp_op, cmp_val): # type: ignore[no-untyped-def] + """ + Wait until a tensor variable meets a specified condition. + + This high-level function provides a tensor-aware interface for NVSHMEM wait_until + operations. It automatically handles tensor address extraction, making + the API more ergonomic and type-safe. + + Args: + ivar_tensor: Tensor to monitor (typically int64/uint64) in symmetric memory. + cmp: Comparison operator. Common values: + - NVSHMEM_CMP_EQ (0): Wait until ivar == cmp_val + - NVSHMEM_CMP_NE (1): Wait until ivar != cmp_val + - NVSHMEM_CMP_GT (2): Wait until ivar > cmp_val + - NVSHMEM_CMP_GE (3): Wait until ivar >= cmp_val + - NVSHMEM_CMP_LT (4): Wait until ivar < cmp_val + - NVSHMEM_CMP_LE (5): Wait until ivar <= cmp_val + cmp_val: Value to compare against. + + Notes: + - This is a blocking operation that will wait indefinitely until the + condition is satisfied. + - The tensor must be in symmetric memory and accessible from other PEs. + + Example: + ``` + # Wait until flag tensor becomes 1 (set by another PE) + NVSHMEM_CMP_EQ = 0 + nvshmem.wait_until_tensor(flag_tensor, NVSHMEM_CMP_EQ, 1) + ``` + """ + tl.static_assert( + ivar.type.element_ty.itemsize == 4, + "wait_until expects a 32-bit type for the synchronization variable", + ) + return wait_until_extern_wrapper(ivar.to(tl.int64), cmp_op, cmp_val) + + @core.extern + def wait_until_extern_wrapper(ivar, cmp, cmp_val, _semantic=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [ivar, cmp, cmp_val], + { + ( + core.dtype("int64"), + core.dtype("int32"), + core.dtype("int32"), + ): ("nvshmem_int_wait_until", core.dtype("int32")) + }, + is_pure=False, + _semantic=_semantic, + ) + + @triton.jit # type: ignore[misc] + def signal_wait_until(signal, cmp, cmp_val): # type: ignore[no-untyped-def] + """ + Wait until a signal variable meets a specified condition. + + This function blocks the calling thread until the value at the specified + signal variable satisfies the given comparison condition. Signal variables + are special uint64_t symmetric objects used for efficient synchronization + with signal operations. + + Args: + signal (tensor): Symmetric signal tensor with remote PE. + Must be 8-byte aligned symmetric memory. + cmp (int32): Comparison operator. Common values: + - NVSHMEM_CMP_EQ (0): Wait until signal == cmp_val + - NVSHMEM_CMP_NE (1): Wait until signal != cmp_val + - NVSHMEM_CMP_GT (2): Wait until signal > cmp_val + - NVSHMEM_CMP_GE (3): Wait until signal >= cmp_val + - NVSHMEM_CMP_LT (4): Wait until signal < cmp_val + - NVSHMEM_CMP_LE (5): Wait until signal <= cmp_val + cmp_val (int64): Value to compare against. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a blocking operation designed specifically for signal variables. + - Signal variables are updated atomically by putmem_signal operations. + - More efficient than wait_until for signal-based synchronization patterns. + - Ensures the signal update is fully complete before returning. + - Commonly used with putmem_signal_block for producer-consumer patterns. + + Example: + ``` + # Wait for signal to be set to completion value + NVSHMEM_CMP_EQ = 0 + nvshmem.signal_wait_until(signal_ptr, NVSHMEM_CMP_EQ, 42) + ``` + """ + cmp_val = 0 << 32 | cmp_val + return signal_wait_until_extern_wrapper( + signal.to(tl.int64), cmp, cmp_val.to(tl.uint64) + ) + + @core.extern + def signal_wait_until_extern_wrapper(signal, cmp, cmp_val, _semantic=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [signal, cmp, cmp_val], + { + ( + core.dtype("int64"), + core.dtype("int32"), + core.dtype("uint64"), + ): ("nvshmem_signal_wait_until", core.dtype("int32")) + }, + is_pure=False, + _semantic=_semantic, + ) + + @core.extern + def signal_op(sig_addr, signal, sig_op, pe, _semantic=None): # type: ignore[no-untyped-def] + """ + Perform an atomic signal operation on a remote PE. + + This function atomically updates a signal variable on the specified remote PE + using the given operation and value. This enables efficient point-to-point + synchronization and notification between PEs. + + Args: + sig_addr (int64): Symmetric address of the signal variable (uint64_t) on the remote PE. + Must be 8-byte aligned symmetric memory. + signal (int64): Value to be used in the signal operation. + sig_op (int32): Signal operation type. Common values: + - NVSHMEM_SIGNAL_SET (0): Atomically set sig_addr = signal + - NVSHMEM_SIGNAL_ADD (5): Atomically set sig_addr += signal + pe (int32): PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a one-sided operation - the remote PE does not need to participate. + - The signal operation is performed atomically on the remote PE. + - Can be used with signal_wait_until() on the remote PE for synchronization. + - Provides low-overhead notification mechanism between PEs. + - The signal variable must be of type uint64_t in symmetric memory. + + Example: + ```python + # Atomically set remote signal to 1 to notify completion + NVSHMEM_SIGNAL_SET = 0 + nvshmem.signal_op(remote_signal_ptr, 1, NVSHMEM_SIGNAL_SET, target_pe) + ``` + """ + return core.extern_elementwise( + "", + "", + [sig_addr, signal, sig_op, pe], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int32"), + core.dtype("int32"), + ): ("nvshmemx_signal_op", core.dtype("int32")) + }, + is_pure=False, + _semantic=_semantic, + ) + + # Memory Ordering Operations + @core.extern + def fence(_semantic=None): # type: ignore[no-untyped-def] + """ + Ensure ordering of put operations to each remote PE. + + This function provides a memory fence that ensures point-to-point ordering + of remote memory operations. Put operations issued before the fence are + guaranteed to be ordered before put operations issued after the fence, + when targeting the same remote PE. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This provides weaker ordering guarantees than quiet(). + - Operations to each PE are ordered, but operations to different PEs + may still be reordered relative to each other. + - Does not guarantee completion of operations, only ordering. + - Non-blocking operations are not ordered by fence - use quiet() instead. + - Essential for ensuring correct ordering in communication patterns. + + Memory Ordering Guarantees: + - Put operations before fence() → ordered before → Put operations after fence() + - Ordering is maintained per-destination-PE basis + - Remote PEs can observe the enforced ordering + + Example: + ``` + # Ensure first put completes before second put to same PE + nvshmem.put(dst, src, nelems, target_pe) + nvshmem.fence() # Enforce ordering + nvshmem.put(dst2, src2, nelems, target_pe) + ``` + """ + return core.extern_elementwise( + "", + "", + [], + { + (): ("nvshmem_fence", core.dtype("int32")), + }, + is_pure=False, + _semantic=_semantic, + ) + + @core.extern + def quiet(_semantic=None): # type: ignore[no-untyped-def] + """ + Wait for completion of all outstanding put operations. + + This function blocks until all outstanding remote memory operations issued + by the calling PE have completed. It provides stronger guarantees than + fence() by ensuring both ordering and completion of all operations. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a blocking operation that waits for completion. + - Ensures all previous put operations have been delivered to their destinations. + - Provides global ordering - operations to ALL PEs are ordered. + - Required to complete non-blocking operations. + - More expensive than fence() but provides stronger guarantees. + + Memory Ordering Guarantees: + - All put operations before quiet() are completed before any operations after quiet() + - Operations are visible to all PEs as having occurred before subsequent operations + - Both blocking and non-blocking operations are completed + + Example: + ``` + # Ensure all data transfers complete before setting completion flag + nvshmem.putmem_block(data_ptr, src_ptr, data_size, target_pe) + nvshmem.quiet() # Wait for data transfer completion + nvshmem.putmem_block( + flag_ptr, flag_src_ptr, 8, target_pe + ) # Signal completion + ``` + """ + return core.extern_elementwise( + "", + "", + [], + { + (): ("nvshmem_quiet", core.dtype("int32")), + }, + is_pure=False, + _semantic=_semantic, + ) + + # PE Information Operations + @core.extern + def my_pe(_semantic=None): # type: ignore[no-untyped-def] + """ + Get the PE number of the calling PE. + + This function returns the unique identifier (PE number) of the current + processing element within the NVSHMEM job. PE numbers range from 0 to + nvshmem_n_pes() - 1. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: PE number of the calling PE (0 ≤ pe < nvshmem_n_pes()). + + Notes: + - This is a pure function that returns the same value throughout execution. + - PE numbering starts from 0 and is contiguous. + - Each PE has a unique identifier within the NVSHMEM job. + - Can be called from both host and device code. + - Essential for implementing PE-specific logic and communication patterns. + + Example: + ``` + # Get current PE number for conditional logic + pe = nvshmem.my_pe() + if pe == 0: + # Root PE logic + pass + else: + # Non-root PE logic + pass + ``` + """ + return core.extern_elementwise( + "", + "", + [], + {(): ("nvshmem_my_pe", core.dtype("int32"))}, + is_pure=True, + _semantic=_semantic, + ) + + @core.extern + def n_pes(_semantic=None): # type: ignore[no-untyped-def] + """ + Get the total number of PEs in the NVSHMEM job. + + This function returns the total count of processing elements (PEs) + participating in the current NVSHMEM job. This value remains constant + throughout the execution of the program. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Total number of PEs in the job (always ≥ 1). + + Notes: + - This is a pure function that returns the same value throughout execution. + - The value is determined at NVSHMEM initialization and never changes. + - Valid PE numbers range from 0 to n_pes() - 1. + - Can be called from both host and device code. + - Essential for implementing collective operations and communication patterns. + + Example: + ``` + # Broadcast from root to all other PEs + total_pes = nvshmem.n_pes() + my_rank = nvshmem.my_pe() + + if my_rank == 0: + # Send to all other PEs + for peer in range(1, total_pes): + nvshmem.putmem_block(dst_ptr, src_ptr, size, peer) + ``` + """ + return core.extern_elementwise( + "", + "", + [], + {(): ("nvshmem_n_pes", core.dtype("int32"))}, + is_pure=True, + _semantic=_semantic, + ) + + # Synchronization Operations + @core.extern + def barrier_all(_semantic=None): # type: ignore[no-untyped-def] + """ + Synchronize all PEs with completion guarantee. + + This function creates a barrier across all PEs in the NVSHMEM job. It ensures + that all local and remote memory updates issued before the barrier by any PE + are completed before any PE exits the barrier. This provides both + synchronization and memory consistency. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a collective operation - all PEs must participate. + - Stronger guarantee than sync_all() - ensures completion of remote operations. + - Blocks until all PEs reach the barrier AND all memory operations complete. + - Must be called from kernels launched with cooperative launch. + - Provides full memory consistency across all PEs. + - More expensive than sync_all() due to completion guarantees. + + Memory Consistency Guarantees: + - All memory updates before barrier_all() are visible to all PEs + - All remote memory operations are completed before any PE continues + - Provides a global synchronization point with memory ordering + + Example: + ``` + # Ensure all PEs complete their work before proceeding + # All PEs execute this - it's a collective operation + nvshmem.barrier_all() + # At this point, all previous operations are complete on all PEs + ``` + """ + return core.extern_elementwise( + "", + "", + [], + {(): ("nvshmem_barrier_all", core.dtype("int32"))}, + is_pure=False, + _semantic=_semantic, + ) + + @core.extern + def sync_all(_semantic=None): # type: ignore[no-untyped-def] + """ + Synchronize all PEs with local completion guarantee. + + This function creates a lightweight synchronization barrier across all PEs. + It ensures that all local store operations issued before the sync are + visible to other PEs, but does not guarantee completion of remote memory + operations initiated by the calling PE. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a collective operation - all PEs must participate. + - Lighter weight than barrier_all() - only ensures local store visibility. + - Does not guarantee completion of remote memory updates initiated locally. + - Must be called from kernels launched with cooperative launch. + - Suitable when only synchronization (not completion) is needed. + - More efficient than barrier_all() for synchronization-only patterns. + + Memory Consistency Guarantees: + - Local store operations are visible to other PEs + - Does NOT ensure completion of outgoing remote operations + - Provides synchronization point without full completion overhead + + Example: + ``` + # Lightweight synchronization between PEs + # All PEs execute this - it's a collective operation + nvshmem.sync_all() + # Local stores are visible, but remote ops may still be in flight + ``` + """ + return core.extern_elementwise( + "", + "", + [], + {(): ("nvshmem_sync_all", core.dtype("int32"))}, + is_pure=False, + _semantic=_semantic, + ) + + # Collective Operations (mem-based APIs - sizes in bytes) + @triton.jit # type: ignore[misc] + def alltoall(team, dest, source, nelems_per_pe): # type: ignore[no-untyped-def] + """ + All-to-all tensor exchange between PEs in a team. + + This high-level function provides a tensor-aware interface for NVSHMEM alltoall + operations. Each PE sends nelems_per_pe elements to every other PE and receives + the same amount from every other PE. + + Args: + team: Team handle for the collective operation. Use 0 for NVSHMEM_TEAM_WORLD. + dest: Destination tensor. Must be large enough for nelems_per_pe * n_pes elements. + source: Source tensor containing data for all PEs. Must contain nelems_per_pe * n_pes elements. + nelems_per_pe: Number of elements to exchange with each PE. + + Notes: + - Performs compile-time type checking between dest and source tensors. + - Automatically calculates byte size from tensor type and element count. + - This is a collective operation - all PEs in the team must participate. + - Data layout: source=[data_for_pe0, data_for_pe1, ...], dest=[data_from_pe0, data_from_pe1, ...] + + Example: + ``` + # Each PE exchanges 10 elements with every other PE + nvshmem.alltoall(0, dest_tensor, src_tensor, 10) + ``` + """ + tl.static_assert(dest.type == source.type) + size_bytes_per_pe = nelems_per_pe * dest.type.element_ty.itemsize + return alltoallmem_block_extern_wrapper( + team, dest.to(tl.int64), source.to(tl.int64), size_bytes_per_pe.to(tl.int64) + ) + + @core.extern # type: ignore[misc] + def alltoallmem_block_extern_wrapper( + team: Any, dest: Any, source: Any, size_bytes: Any, _semantic: Any = None + ) -> None: + """Low-level extern wrapper for NVSHMEM alltoall""" + return core.extern_elementwise( + "", + "", + [team, dest, source, size_bytes], + { + ( + core.dtype("int32"), # team handle + core.dtype("int64"), # dest ptr + core.dtype("int64"), # source ptr + core.dtype("int64"), # size in bytes + ): ("nvshmemx_alltoallmem_block", core.dtype("int32")) + }, + is_pure=False, + _semantic=_semantic, + ) + + @triton.jit # type: ignore[misc] + def broadcast(team, dest, source, nelems, pe_root): # type: ignore[no-untyped-def] + """ + Broadcast tensor data from a root PE to all other PEs in a team. + + This high-level function provides a tensor-aware interface for NVSHMEM broadcast + operations. It automatically handles type checking and size calculations, making + the API more ergonomic and type-safe. + + Args: + team: Team handle for the collective operation. Use 0 for NVSHMEM_TEAM_WORLD. + dest: Destination tensor with type information. All PEs receive data here. + source: Source tensor on the root PE. Type must match dest. + nelems: Number of elements to broadcast. + pe_root: PE number of the root PE that provides the source data. + + Notes: + - Performs compile-time type checking between dest and source tensors. + - Automatically calculates byte size from tensor type and element count. + - This is a collective operation - all PEs in the team must participate. + - Must be called from kernels launched with cooperative launch. + + Example: + ``` + # Broadcast 100 elements from PE 0 to all PEs + nvshmem.broadcast(0, dest_tensor, src_tensor, 100, 0) + ``` + """ + tl.static_assert(dest.type == source.type) + nbytes = nelems * dest.type.element_ty.itemsize + return broadcastmem_block_extern_wrapper( + team, dest.to(tl.int64), source.to(tl.int64), nbytes.to(tl.int64), pe_root + ) + + @core.extern # type: ignore[misc] + def broadcastmem_block_extern_wrapper( + team: Any, + dest: Any, + source: Any, + size_bytes: Any, + pe_root: Any, + _semantic: Any = None, + ) -> None: + """Low-level extern wrapper for NVSHMEM broadcast""" + return core.extern_elementwise( + "", + "", + [team, dest, source, size_bytes, pe_root], + { + ( + core.dtype("int32"), # team handle + core.dtype("int64"), # dest ptr + core.dtype("int64"), # source ptr + core.dtype("int64"), # size in bytes + core.dtype("int32"), # pe_root + ): ("nvshmemx_broadcastmem_block", core.dtype("int32")) + }, + is_pure=False, + _semantic=_semantic, + ) + + # Reduction Operation + @triton.jit # type: ignore[misc] + def reduce(team, dest, source, nreduce, operation: tl.constexpr): # type: ignore[no-untyped-def] + """ + Performs a collective reduction on tensors across a team of PEs. + + This high-level function provides a tensor-aware interface for NVSHMEM + reduction operations. It automatically infers the data type from the + input tensors and calls the appropriate underlying NVSHMEM function. + + Args: + team: The team handle for the collective (0 for NVSHMEM_TEAM_WORLD). + dest: Destination tensor for the reduction results. + source: Source tensor containing data to be reduced. Must be the same type as dest. + nreduce: The number of elements in the source tensor to reduce. + operation: The reduction operation to perform ("sum", "max", "min", "prod"). + + Notes: + - Performs compile-time type checking between dest and source tensors. + - This is a collective operation that must be called by all PEs in the team. + - Requires a cooperative grid launch. + + Example: + ``` + # Perform a sum reduction on two tensors + nvshmem.reduce(0, dest_tensor, src_tensor, 100, "sum") + ``` + """ + tl.static_assert(dest.type == source.type) + dtype = dest.type.element_ty + return reduce_extern_wrapper( + team, + dest.to(tl.int64), + source.to(tl.int64), + nreduce.to(tl.int64), + operation, + dtype, + ) + + @core.extern # type: ignore[misc] + def reduce_extern_wrapper( + team: Any, + dest: Any, + source: Any, + nreduce: Any, + operation: str, + dtype: Any, + _semantic: Any = None, + ) -> None: + """ + Low-level extern wrapper for NVSHMEM reduction operations. + + This function provides a generic interface to NVSHMEM reduction operations, + automatically selecting the appropriate NVSHMEM function based on the data type + and operation specified. + Args: + team (int64): The team handle (0 for NVSHMEM_TEAM_WORLD). + dest (pointer): Destination pointer where reduction results are stored. + source (pointer): Source pointer containing data to be reduced. + nreduce (int64): Number of elements to reduce. + operation (str): Reduction operation ("sum", "max", "min", "prod"). + dtype: Data type specification - accepts torch.dtype, tl.dtype, str, or constexpr. + _semantic: Optional semantic information for Triton compilation. + + Raises: + ValueError: If the operation is not supported. + TypeError: If the data type is not supported. + + Example: + nvshmem.reduce(0, dest_ptr, src_ptr, 100, "sum", torch.float32) + """ + # Mapping from Triton dtype names to NVSHMEM typenames + DTYPE_TO_NVSHMEM_MAP = { + "int8": "int8", + "int16": "int16", + "int32": "int32", + "int64": "int64", + "uint8": "uint8", + "uint16": "uint16", + "uint32": "uint32", + "uint64": "uint64", + "fp16": "half", + "bf16": "bfloat16", + "fp32": "float", + "fp64": "double", + } + + # Triton dtype names are standardized as fp16, bf16, fp32, etc. + dtype_name = str(dtype).replace("tl.", "") + + if dtype_name not in DTYPE_TO_NVSHMEM_MAP: + raise TypeError( + f"Unsupported reduction dtype: {dtype_name}. Supported dtypes: {list(DTYPE_TO_NVSHMEM_MAP.keys())}" + ) + + # Extract operation name from constexpr if needed + op_name = operation.value if hasattr(operation, "value") else operation + + # Validate operation is supported + supported_ops = {"sum", "max", "min", "prod"} + if op_name not in supported_ops: + raise ValueError( + f"Unsupported reduction operation: '{op_name}'. Supported ops are {supported_ops}" + ) + + # Map to NVSHMEM typename and validate dtype is supported + nvshmem_typename = DTYPE_TO_NVSHMEM_MAP.get(dtype_name) + if nvshmem_typename is None: + raise TypeError( + f"Unsupported reduction dtype: {dtype_name}. Supported dtypes are {list(DTYPE_TO_NVSHMEM_MAP.keys())}" + ) + + # Generate NVSHMEM function name + nvshmem_func = f"nvshmem_{nvshmem_typename}_{op_name}_reduce" + + # Define function signature - all parameters are int64 in Triton (they are just ptrs) + signature = ( + core.dtype("int32"), # team handle + core.dtype("int64"), # destination pointer + core.dtype("int64"), # source pointer + core.dtype("int64"), # number of elements + ) + + return core.extern_elementwise( + "", + "", + [team, dest, source, nreduce], + {signature: (nvshmem_func, core.dtype("int32"))}, + is_pure=False, + _semantic=_semantic, + ) + + # Utility for inspecting Triton kernels + + triton_kernels: dict = {} + + def _log_triton_kernel(kernel) -> None: # type: ignore[no-untyped-def] + import atexit + import tempfile + + if dist.is_initialized() and dist.get_rank() != 0: + return + + def on_exit() -> None: + logger.info("PTX files:") + for kernel in triton_kernels: + with tempfile.NamedTemporaryFile(dir="/tmp", delete=False) as f: + f.write(kernel.asm["ptx"].encode("utf-8")) + logger.info(f"+- {kernel.name}: {f.name}") # noqa: G004 + + if len(triton_kernels) == 0: + atexit.register(on_exit) + + if kernel not in triton_kernels: + triton_kernels[kernel] = None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tensor/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c5559cc10fabdc1172c9a3ac95ee48ca72b2d65f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tensor/__init__.py @@ -0,0 +1,45 @@ +""" +NOTICE: DTensor has moved to torch.distributed.tensor + +This file is a shim to redirect to the new location, and +we keep the old import path starts with `_tensor` for +backward compatibility. We will remove this folder once +we resolve all the BC issues. +""" + +import sys +from importlib import import_module + + +submodules = [ + # TODO: _shards_wrapper/_utils here mainly for checkpoint BC, remove them + "_shards_wrapper", + "_utils", + "experimental", + "device_mesh", +] + +# Redirect imports +for submodule in submodules: + full_module_name = f"torch.distributed.tensor.{submodule}" + sys.modules[f"torch.distributed._tensor.{submodule}"] = import_module( + full_module_name + ) + +from torch.distributed.tensor import ( # noqa: F401 + DeviceMesh, + distribute_module, + distribute_tensor, + DTensor, + empty, + full, + init_device_mesh, + ones, + Partial, + Placement, + rand, + randn, + Replicate, + Shard, + zeros, +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tensor/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tensor/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2f7f445cae23783ba9fa8ac0aeb05fac7cfe1e4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tensor/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tensor/__pycache__/api.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tensor/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bcfbe9dbb823f4de37b0ceda0cbac2bb2ba1da4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tensor/__pycache__/api.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tensor/__pycache__/placement_types.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tensor/__pycache__/placement_types.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79208cefe3155326534426604ef71d94a3d1570e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tensor/__pycache__/placement_types.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tensor/api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tensor/api.py new file mode 100644 index 0000000000000000000000000000000000000000..9e5742156a86ca511619360038a9028b0efeeaef --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tensor/api.py @@ -0,0 +1,9 @@ +""" +NOTE: torch.distributed._tensor has been moved to torch.distributed.tensor. +The imports here are purely for backward compatibility. We will remove these +imports in a few releases + +TODO: throw warnings when this module imported +""" + +from torch.distributed.tensor._api import * # noqa: F401, F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tensor/placement_types.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tensor/placement_types.py new file mode 100644 index 0000000000000000000000000000000000000000..6a4e70dbba455471feef2326cae8ba28b32d0304 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tensor/placement_types.py @@ -0,0 +1,10 @@ +""" +NOTE: torch.distributed._tensor has been moved to torch.distributed.tensor. +The imports here are purely for backward compatibility. We will remove these +imports in a few releases + +TODO: throw warnings when this module imported +""" + +from torch.distributed.tensor._dtensor_spec import * # noqa: F401, F403 +from torch.distributed.tensor.placement_types import * # noqa: F401, F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22e974cdd64f1082e7a89e441eb8c90163f56d3b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__init__.py @@ -0,0 +1,12 @@ +from .fsdp2_mem_tracker import FSDPMemTracker +from .mem_tracker import MemTracker +from .memory_tracker import MemoryTracker +from .mod_tracker import ModTracker +from .runtime_estimator import RuntimeEstimator +from .sac_estimator import ( + MSPS, + SACEstimator, + SACGreedyOrderMeta, + SACStats, + SACTradeOffStats, +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b6d5766b389dd46faf20e70cf615c349fbc39ef Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/common_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/common_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcf5e2d0f8adb4a2623bb115ede17806a369ad04 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/common_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/fake_collectives.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/fake_collectives.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8a30824d4ced19677a42660721981d6950e65b7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/fake_collectives.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/fsdp2_mem_tracker.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/fsdp2_mem_tracker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3069c0b9ee02b14b8731f04909f352cd3658a566 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/fsdp2_mem_tracker.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/ilp_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/ilp_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb6ca8073adc3ec1ff365f504ffc6aa173c01181 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/ilp_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/mem_tracker.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/mem_tracker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c103de56ac033efa1f6eaa7f6a0053d7fc03de47 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/mem_tracker.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/memory_tracker.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/memory_tracker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5684daf25e61ab07c1e6b7d6528458ed82c34485 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/memory_tracker.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/mod_tracker.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/mod_tracker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef6b4376ddd2a8e547d7aab7e89fb75ef2a8b479 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/mod_tracker.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/runtime_estimator.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/runtime_estimator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..787bde242c0af1e4c89a881fa1a9a387e8da11f5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/runtime_estimator.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/sac_estimator.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/sac_estimator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a71d9846f7ca982649cf4d3c09fb5af54a8cfa9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/sac_estimator.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/sac_ilp.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/sac_ilp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..925f8fb851b8a87a57cb380efd598fd438207290 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/sac_ilp.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/common_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/common_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0188a4aa08440e05bcdbbff8c9d14c05540a7909 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/common_utils.py @@ -0,0 +1,33 @@ +import warnings + +import torch +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + +def get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]: + """ + Recursively extracts untyped storages from a tensor or its subclasses. + + Args: + t (torch.Tensor): The tensor to extract storages from. + + Returns: + Set[torch.UntypedStorage]: A set of untyped storages. + """ + unflattened_tensors = [t] + flattened_tensor_storages = set() + while len(unflattened_tensors) > 0: + obj = unflattened_tensors.pop() + if is_traceable_wrapper_subclass(obj): + attrs, _ = obj.__tensor_flatten__() # type: ignore[attr-defined] + unflattened_tensors.extend([getattr(obj, attr) for attr in attrs]) + else: + if not hasattr(obj, "untyped_storage"): + warnings.warn( + f"Expected a tensor or a traceable wrapper-subclass of tensor, but got {type(obj)}", + category=UserWarning, + stacklevel=2, + ) + else: + flattened_tensor_storages.add(obj.untyped_storage()) + return flattened_tensor_storages diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/fake_collectives.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/fake_collectives.py new file mode 100644 index 0000000000000000000000000000000000000000..0ac0f8a764d3eca836de98bd82d5495817eadf5b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/fake_collectives.py @@ -0,0 +1,307 @@ +import random +from typing import Any + +import torch +from torch._C._distributed_c10d import ( + _resolve_process_group, + FakeWork, + ProcessGroup, + Work, +) +from torch.utils._pytree import tree_map_only + + +torch.distributed.batch_isend_irecv + +c10d = torch.ops.c10d +_c10d_functional = torch.ops._c10d_functional +_c10d_functional_autograd = torch.ops._c10d_functional_autograd +_dtensor = torch.ops._dtensor +used_ids: set[int] = set() + + +def generate_unique_id() -> int: + while True: + new_id = random.randint(1, 10**9) + if new_id not in used_ids: + used_ids.add(new_id) + return new_id + + +# Function to create and return FakeWork object +def create_fakework(args, return_first_arg=True): # type: ignore[no-untyped-def] + work = FakeWork() + work.seq_id = generate_unique_id() + fakework_script_obj = work.boxed() + return (args[0], fakework_script_obj) if return_first_arg else fakework_script_obj + + +# Dictionary mapping collective operations to their meta functions +# All 20 ops from torch.csrc.distributed.c10d.Ops.cpp are included +# _DEPRECATED_META_FUNCTIONS = { +# "allreduce_coalesced_": lambda *args: create_fakework(args, return_first_arg=False), +# "allgather_coalesced_": lambda *args: create_fakework(args, return_first_arg=False), +# "allgather_into_tensor_coalesced_": lambda *args: create_fakework(args, return_first_arg=False), +# "reduce_scatter_tensor_coalesced_": lambda *args: create_fakework(args, return_first_arg=False), +# } +_META_FUNCTIONS = { + "broadcast_": lambda *args: create_fakework(args), + "allreduce_": lambda *args: create_fakework(args), + "allgather_": lambda *args: create_fakework(args), + "_allgather_base_": lambda *args: create_fakework(args), + "reduce_scatter_": lambda *args: create_fakework(args), + "_reduce_scatter_base_": lambda *args: create_fakework(args), + "reduce_": lambda *args: create_fakework(args, return_first_arg=False), + "gather_": lambda *args: create_fakework(args, return_first_arg=False), + "scatter_": lambda *args: create_fakework(args), + "alltoall_": lambda *args: create_fakework(args), + "alltoall_base_": lambda *args: create_fakework(args, return_first_arg=False), + "barrier": lambda *args: create_fakework(args, return_first_arg=False), + "monitored_barrier_": lambda *args: None, + "send": lambda *args: create_fakework(args, return_first_arg=False), + "recv_": lambda *args: create_fakework(args, return_first_arg=False), + "recv_any_source_": lambda *args: create_fakework(args, return_first_arg=False), +} + +lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901 +for op, meta_func in _META_FUNCTIONS.items(): + lib_impl.impl(op, meta_func, "Meta") + +# List of collective operation functions including functional collectives +# Note: The following collectives might be deprecated soon hence not adding them +# depcreated_non_functional_collectives = [ +# c10d.allreduce_coalesced_.default, +# c10d.reduce_scatter_tensor_coalesced_.default, +# c10d.allgather_into_tensor_coalesced_.default, +# c10d.allgather_coalesced_.default, +# ] +non_functional_collectives: set[torch._ops.OpOverload] = { + c10d.broadcast_.default, + c10d.allreduce_.default, + c10d.reduce_.default, + c10d.send.default, + c10d.recv_.default, + c10d.recv_any_source_.default, + c10d.allgather_.default, + c10d.reduce_scatter_.default, + c10d._reduce_scatter_base_.default, + c10d._allgather_base_.default, + c10d.gather_.default, + c10d.scatter_.default, + c10d.alltoall_.default, + c10d.alltoall_base_.default, + c10d.barrier.default, + c10d.monitored_barrier_.default, +} +functional_collectives: set[torch._ops.OpOverload] = { + _c10d_functional.broadcast.default, + _c10d_functional.all_reduce.default, + _c10d_functional.all_gather_into_tensor.default, + _c10d_functional.reduce_scatter_tensor.default, + _c10d_functional.reduce_scatter_tensor_out.default, + _c10d_functional.all_to_all_single.default, + _c10d_functional_autograd.all_to_all_single.default, + _c10d_functional.wait_tensor.default, + _c10d_functional.all_reduce_.default, + _c10d_functional.all_reduce_coalesced.default, + _c10d_functional.all_reduce_coalesced_.default, + _c10d_functional.all_gather_into_tensor_out.default, + _c10d_functional.all_gather_into_tensor_coalesced.default, + _c10d_functional_autograd.all_gather_into_tensor.default, + _c10d_functional.reduce_scatter_tensor_coalesced.default, + _c10d_functional_autograd.reduce_scatter_tensor.default, + _c10d_functional.broadcast_.default, + _dtensor.shard_dim_alltoall.default, +} + +sync_ops: set[torch._ops.OpOverload] = { + c10d.barrier.default, + c10d.monitored_barrier_.default, + _c10d_functional.wait_tensor.default, +} + +collective_ops = set.union(functional_collectives, non_functional_collectives) + + +class CollectiveOp: + # Static sets for performance optimization + PG_ARG_1 = { + c10d.broadcast_.default, + c10d.allreduce_.default, + c10d.reduce_.default, + c10d.send.default, + c10d.recv_.default, + c10d.recv_any_source_.default, + c10d.barrier.default, + # c10d.allreduce_coalesced_.default + } + + PG_ARG_2 = { + c10d.allgather_.default, + c10d._allgather_base_.default, + c10d.reduce_scatter_.default, + c10d._reduce_scatter_base_.default, + c10d.gather_.default, + c10d.scatter_.default, + c10d.alltoall_.default, + c10d.alltoall_base_.default, + # c10d.allgather_coalesced_.default, + # c10d.allgather_into_tensor_coalesced_.default + # c10d.reduce_scatter_tensor_coalesced_.default + } + + PG_ARG_3 = { + _c10d_functional.broadcast.default, + _c10d_functional.broadcast_.default, + _c10d_functional.all_reduce.default, + _c10d_functional.all_reduce_.default, + _c10d_functional.all_reduce_coalesced.default, + _c10d_functional.all_reduce_coalesced_.default, + _c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor_out.default, + _c10d_functional_autograd.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor_coalesced.default, + } + + PG_ARG_4 = { + _c10d_functional.reduce_scatter_tensor.default, + _c10d_functional.reduce_scatter_tensor_coalesced.default, + _c10d_functional_autograd.reduce_scatter_tensor.default, + _c10d_functional.all_to_all_single.default, + _c10d_functional_autograd.all_to_all_single.default, + _dtensor.shard_dim_alltoall.default, + } + + WK_ARG_1 = { + c10d.broadcast_.default, + c10d.allreduce_.default, + c10d.allgather_.default, + c10d.reduce_scatter_.default, + c10d._reduce_scatter_base_.default, + c10d._allgather_base_.default, + c10d.scatter_.default, + c10d.alltoall_.default, + } + + WK = { + c10d.send.default, + c10d.recv_.default, + c10d.recv_any_source_.default, + c10d.reduce_.default, + c10d.gather_.default, + c10d.alltoall_base_.default, + c10d.barrier.default, + } + + COMM_TENSOR_ARG_0 = { + c10d.allreduce_.default, + c10d.send.default, + c10d.recv_.default, + c10d.recv_any_source_.default, + c10d.allgather_.default, + c10d.gather_.default, + c10d.reduce_.default, + c10d.broadcast_.default, + _c10d_functional.all_reduce_coalesced.default, + _c10d_functional.all_reduce_coalesced_.default, + # c10d.allreduce_coalesced_.default + # c10d.allgather_coalesced_.default + # c10d.allgather_into_tensor_coalesced_.default, + } + + COMM_TENSOR_ARG_1 = { + c10d.reduce_scatter_.default, + c10d.scatter_.default, + # c10d.reduce_scatter_tensor_coalesced_.default, + } + + COMM_TENSOR_ARG_RES = { + _c10d_functional.all_gather_into_tensor.default, + _c10d_functional_autograd.all_gather_into_tensor.default, + } + + COMM_TENSOR_SINGLE_UNTYPED_STORAGE = { + c10d._allgather_base_.default, + _c10d_functional.broadcast.default, + _c10d_functional.broadcast_.default, + _c10d_functional.all_reduce.default, + _c10d_functional.all_reduce_.default, + _c10d_functional.reduce_scatter_tensor.default, + _c10d_functional_autograd.reduce_scatter_tensor.default, + } + + COMM_TENSOR_ARG_0_AND_RES = { + _c10d_functional.all_to_all_single.default, + _c10d_functional_autograd.all_to_all_single.default, + _dtensor.shard_dim_alltoall.default, + } + + COMM_TENSOR_RES_SUM = { + _c10d_functional.all_gather_into_tensor_coalesced.default, + _c10d_functional.reduce_scatter_tensor_coalesced.default, + } + + @staticmethod + def sum_tensors(arg: Any) -> int: + """Calculate total memory consumed by the tensors in the argument.""" + total_memory = 0 + + def sum_bytes(t: torch.Tensor) -> None: + nonlocal total_memory + total_memory += t.untyped_storage().nbytes() + + tree_map_only(torch.Tensor, sum_bytes, arg) + return total_memory + + @staticmethod + def get_process_group(func, args) -> ProcessGroup: # type: ignore[no-untyped-def] + """Retrieve the process group for collective operations, except `wait_tensor`.""" + if func in CollectiveOp.PG_ARG_1: + return ProcessGroup.unbox(args[1]) + if func in CollectiveOp.PG_ARG_2: + return ProcessGroup.unbox(args[2]) + if func in CollectiveOp.PG_ARG_3: + return _resolve_process_group(args[2]) + if func in CollectiveOp.PG_ARG_4: + return _resolve_process_group(args[3]) + raise TypeError(f"Func {func} not found in {collective_ops}") + + @staticmethod + def get_comm_tensor_size(func, res, args, kwargs) -> int: # type: ignore[no-untyped-def] + """Compute the communication tensor size, except for `wait_tensor`, `barrier`, and `monitored_barrier`.""" + if func in CollectiveOp.COMM_TENSOR_ARG_0: + return CollectiveOp.sum_tensors(args[0]) + if func in CollectiveOp.COMM_TENSOR_ARG_1: + return CollectiveOp.sum_tensors(args[1]) + if func in CollectiveOp.COMM_TENSOR_ARG_RES: + return res.untyped_storage().nbytes() + if func in CollectiveOp.COMM_TENSOR_SINGLE_UNTYPED_STORAGE: + return args[0].untyped_storage().nbytes() + if func is c10d._reduce_scatter_base_.default: + return args[1].untyped_storage().nbytes() + if func is c10d.alltoall_.default: + # TODO(@sanketpurandare) - Confirm size computation + return max( + CollectiveOp.sum_tensors(args[0]), CollectiveOp.sum_tensors(args[1]) + ) + if func is c10d.alltoall_base_.default: + # TODO(@sanketpurandare) - Confirm size computation + return max( + args[0].untyped_storage().nbytes(), args[1].untyped_storage().nbytes() + ) + if func == _c10d_functional.all_gather_into_tensor_out.default: + return args[-1].untyped_storage().nbytes() + if func in CollectiveOp.COMM_TENSOR_RES_SUM: + return CollectiveOp.sum_tensors(res) + if func in CollectiveOp.COMM_TENSOR_ARG_0_AND_RES: + # TODO(@sanketpurandare) - Confirm size computation + return args[0].untyped_storage().nbytes() + res.untyped_storage().nbytes() + raise TypeError(f"Unknown function: {func} in {collective_ops}") + + @staticmethod + def get_work(func, res) -> Work: # type: ignore[no-untyped-def] + if func in CollectiveOp.WK: + return FakeWork.unbox(res) + elif func in CollectiveOp.WK_ARG_1: + return FakeWork.unbox(res[1]) + raise TypeError(f"Func {func} not found in {collective_ops}") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/fsdp2_mem_tracker.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/fsdp2_mem_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..7db24cad45b1a69a525efab736437fc48899a6d1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/fsdp2_mem_tracker.py @@ -0,0 +1,578 @@ +from collections.abc import Callable +from copy import deepcopy +from enum import auto, Enum +from functools import partial, wraps +from typing import Any, NamedTuple, TYPE_CHECKING, TypeVar +from typing_extensions import ParamSpec, TypeVarTuple, Unpack + +import torch +import torch.distributed._tools.fake_collectives +from torch import nn, optim +from torch._guards import active_fake_mode +from torch.distributed._tools.mem_tracker import _RefType, _State, MemTracker +from torch.distributed.fsdp import FSDPModule +from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup +from torch.distributed.tensor import DTensor +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_map_only +from torch.utils.weak import WeakIdKeyDictionary, weakref + + +if TYPE_CHECKING: + from torch.utils.hooks import RemovableHandle + +_TOTAL_KEY = "Total" + +__all__ = ["FSDPMemTracker"] + +_P = ParamSpec("_P") +_R = TypeVar("_R") +_Ts = TypeVarTuple("_Ts") + +c10d = torch.ops.c10d + + +class _FSDPRefType(_RefType): + """ + Enumerates categories of memory usage in FSDP modules, including parameters, gradients, activations, + and optimizer states. + + Attributes: + SHARDED_PARAM (str): Memory usage of sharded parameters. + UNSHARDED_PARAM (str): Memory usage of unsharded parameters. + SHARDED_GRAD (str): Memory usage of sharded gradients corresponding to the sharded parameters. + UNSHARDED_GRAD (str): Memory usage of unsharded gradients corresponding to the unsharded parameters. + ACT (str): Memory usage of activations and tensors from forward and AC recomputation. + TEMP (str): Memory usage of temporary tensors during the backward pass including gradients of activations. + ALL_GATHER (str): Memory usage of all_gather output tensor. + REDUCE_SCATTER (str): Memory usage of reduce_scatter input tensor. + OPT (str): Memory usage of tensors storing optimizer states. + INP (str): Memory usage of input tensors. + """ + + SHARDED_PARAM = "Sharded Param" + UNSHARDED_PARAM = "Unsharded Param" + BUFFER = "Buffer" + SHARDED_GRAD = "Sharded Grad" + UNSHARDED_GRAD = "Unsharded Grad" + ACT = "Activation" + TEMP = "Temp" + ALL_GATHER = "All Gather" + REDUCE_SCATTER = "Reduce Scatter" + OPT = "OptState" + INP = "Inputs" + + +class _SavedFSDPMethods(NamedTuple): + pre_backward: Callable + post_backward: Callable + + +class _FSDPModState(_State): + """ + Enumerates the states of FSDP modules during the forward and backward passes. + """ + + BEF_PRE_FW = "Before Pre-Forward" + AFT_PRE_FW = "After Pre-Forward" + BEF_POST_FW = "Before Post-Forward" + AFT_POST_FW = "After Post-Forward" + BEF_PRE_BW = "Before Pre-Backward" + AFT_PRE_BW = "After Pre-Backward" + BEF_POST_BW = "Before Post-Backward" + AFT_POST_BW = "After Post-Backward" + PRE_FW_AC = "Pre-Forward AC" + POST_FW_AC = "Post-Forward AC" + PEAK_FW = "Peak Forward" + PEAK_BW = "Peak Backward" + + +class _FSDPModMemStats: + """ + A class to store the memory statistics of an FSDP module. + + Args: + mod_fqn (str): The fully qualified name of the FSDP module. + + Attributes: + snapshots (Dict[_FSDPModState, Dict[torch.device, Dict[str, int]]]): A dictionary of memory snapshots + of the module at different states as defined by ``_FSDPModState``. Each key is a device, and + each value is another dictionary with keys as memory reference types defined by ``_FSDPRefType`` and + values as the memory consumed in bytes. + + """ + + def __init__(self, mod_fqn: str) -> None: + self.mod_fqn = mod_fqn + self.local_peak: dict[torch.device, int] = {} + self.snapshots: dict[ + _FSDPModState, list[dict[torch.device, dict[str, int]]] + ] = {} + + +class _FSDPState(Enum): + PRE_FW = auto() + FW = auto() + POST_FW = auto() + PRE_BW = auto() + BW = auto() + POST_BW = auto() + + +class FSDPMemTracker(MemTracker): + """ + A ``TorchDispatchMode`` based context manager that extends ``torch.distributed._tools.mem_tracker.MemTracker`` to track + and categorize the peak memory and module-wise memory usage of FSDP modules. + + It tracks the peak memory usage across all the devices of all the FSDP modules in the module tree and categorizes + the tensor memory usage as defined by ``_FSDPRefType``. Further, it captures memory `snapshots` at different stages of + the module execution defined by ``_FSDPModState``. + + Attributes: + memory_tracking: A weakref key dictionary to store the memory statistics of each module. Each key is a reference + to a module, and each value is a ``_FSDPModMemStats`` object that stores the memory statistics of the module. + + Args: + mod (torch.nn.Module): The root FSDP module to be tracked. + optm (torch.optim.Optimizer, optional): The optimizer to be tracked. + + Note: Please refer to ``torch.distributed._tools.mem_tracker.MemTracker`` to learn about the limitations. + + Example usage + + .. code-block:: python + + module = ... + optimizer = ... + inp = ... + fmt = FSDPMemTracker(module, optimizer) + fmt.track_inputs((inp,)) + with fmt: + optimizer.zero_grad() + loss = module(inp) + print("After Forward:") + fmt.display_snapshot("current") + loss.backward() + optimizer.step() + fmt.display_snapshot("peak") + fmt.display_modulewise_snapshots(depth=3, units="MB") + + """ + + def __init__( + self, + mod: torch.nn.Module, + optm: torch.optim.Optimizer | None = None, + ) -> None: + super().__init__() + assert isinstance(mod, FSDPModule), "FSDPMemTracker only supports FSDP modules" + self._root_mod = mod + self._optm = optm + self._fsdp_mod_to_saved_methods: WeakIdKeyDictionary = WeakIdKeyDictionary() + self._fsdp_state: _FSDPState = _FSDPState.PRE_FW + self._ref_class: type[_RefType] = _FSDPRefType + + def _instrument_fsdp_sharded_params_grads( + self, fsdp_param_group: FSDPParamGroup + ) -> None: + # Track sharded params and grads after initialization + for fsdp_param in fsdp_param_group.fsdp_params: + self._update_and_maybe_create_winfos( + fsdp_param.sharded_param, + _FSDPRefType.SHARDED_PARAM, + ) + sharded_grad = fsdp_param.sharded_param.grad + if sharded_grad is not None: + self._update_and_maybe_create_winfos( + sharded_grad, + _FSDPRefType.SHARDED_GRAD, + ) + + def _fsdp_state_pre_forward( + self, + fsdp_mod: FSDPModule, + orig_fsdp_state_pre_fw: Callable[_P, tuple[tuple[Unpack[_Ts]], dict[str, Any]]], + ) -> Callable[_P, tuple[tuple[Unpack[_Ts]], dict[str, Any]]]: + # We capture memory snapshots before and after ``FSDPState._pre_forward`` to attribute the `unsharded` params + # and `all_gather` buffers. There are three cases: + # Case 1: If the module is not in the ``memory_tracking`` dictionary, create a new ``_FSDPModMemStats`` + # instance for the module and add it to the ``memory_tracking`` dictionary. + # Case 2: If the module is already in the ``memory_tracking`` dictionary and we are in backward, this means + # we are in the AC region. We check if this is the top most module in the AC region. If it is, + # we store a weak reference and set the flag ``_in_ac`` to True. + # Case 3: If the module is already in the ``memory_tracking`` dictionary and we are in forward, this means + # this module is called for the second time. If it is a root module, that means we are in the next + # iteration and we error out. If it is not a root module, that means it's a submodule that is being + # used multiple times in the same iteration, which we allow and track. + # For Case 1 and 3, we also initialize the ``local_peak`` and ``PEAK_FW`` snapshot for the module. + # For Case 2 we only capture 1 snapshot after ``FSDPState._pre_forward`` runs because it is a no-op. + @wraps(orig_fsdp_state_pre_fw) + def inner( + *args: _P.args, **kwargs: _P.kwargs + ) -> tuple[tuple[Unpack[_Ts]], dict[str, Any]]: + self._fsdp_state = _FSDPState.PRE_FW + mod_fqn = self._mod_tracker.get_known_fqn(fsdp_mod) + assert mod_fqn is not None + if fsdp_mod not in self.memory_tracking: + mod_stat = _FSDPModMemStats(mod_fqn) + self.memory_tracking[fsdp_mod] = mod_stat + snapshot = self.get_tracker_snapshot() + mod_stat.local_peak = { + dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in snapshot.items() + } + mod_stat.snapshots.setdefault(_FSDPModState.PEAK_FW, []).append( + snapshot + ) + mod_stat.snapshots.setdefault(_FSDPModState.BEF_PRE_FW, []).append( + deepcopy(snapshot) + ) + elif not self._mod_tracker.is_bw: + parents = self._mod_tracker.parents - {mod_fqn} + if len(parents) == 1 and "Global" in parents: + raise NotImplementedError( + "FSDPMemTracker does not support memory tracking for multiple iterative calls." + " Either use ``reset_mod_stats`` to clear module memory stats for the previous iteration" + " or file a github issue if you need this feature." + ) + + # pyrefly: ignore [bad-assignment] + args, kwargs = orig_fsdp_state_pre_fw(*args, **kwargs) + + fsdp_state = fsdp_mod._get_fsdp_state() + if fsdp_param_group := fsdp_state._fsdp_param_group: + for fsdp_param in fsdp_param_group.fsdp_params: + self._update_and_maybe_create_winfos( + fsdp_param.unsharded_param, + _FSDPRefType.UNSHARDED_PARAM, + ) + mod_stat = self.memory_tracking[fsdp_mod] + if self._mod_tracker.is_bw: + state = _FSDPModState.PRE_FW_AC + if self._ac_mod is None: + self._ac_mod = weakref.ref(fsdp_mod) + self._in_ac = True + else: + state = _FSDPModState.AFT_PRE_FW + mod_stat.snapshots.setdefault(state, []).append(self.get_tracker_snapshot()) + self._fsdp_state = _FSDPState.FW + return args, kwargs + + return inner + + def _fsdp_state_post_forward( + self, + fsdp_mod: FSDPModule, + orig_fsdp_state_post_fw: Callable[_P, _R], + ) -> Callable[_P, _R]: + # We capture memory snapshots before and after ``FSDPState._post_forward`` to capture the resharded state + # if ``reshard_after_forward`` is not ``False``. There are two cases: + # Case 1: This is called in backward, which means we are in the AC region. If this is the top most module + # in the AC region, we set the flag ``_in_ac`` to False. + # Case 2: This is called in forward. + @wraps(orig_fsdp_state_post_fw) + def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R: + mod_stat = self.memory_tracking[fsdp_mod] + if self._mod_tracker.is_bw: + state = _FSDPModState.POST_FW_AC + if self._ac_mod is not None and self._ac_mod() is fsdp_mod: + self._ac_mod = None + self._in_ac = False + else: + state = _FSDPModState.BEF_POST_FW + mod_stat.snapshots.setdefault(state, []).append(self.get_tracker_snapshot()) + self._fsdp_state = _FSDPState.POST_FW + + output = orig_fsdp_state_post_fw(*args, **kwargs) + + if not self._mod_tracker.is_bw: + mod_stat.snapshots.setdefault(_FSDPModState.AFT_POST_FW, []).append( + self.get_tracker_snapshot() + ) + return output + + return inner + + def _fsdp_param_group_pre_backward( + self, + fsdp_mod: FSDPModule, + orig_fsdp_param_group_pre_backward: Callable[_P, Any], + ) -> Callable[_P, None]: + # We capture memory snapshots before and after ``FSDPParamGroup.pre_backward`` to capture the pre-fetching + # and unsharding of params. We also initialize ``local_peak`` and ``PEAK_BW`` snapshot for the module. + @wraps(orig_fsdp_param_group_pre_backward) + def inner(*args: _P.args, **kwargs: _P.kwargs) -> None: + self._fsdp_state = _FSDPState.PRE_BW + mod_stat = self.memory_tracking[fsdp_mod] + snapshot = self.get_tracker_snapshot() + mod_stat.local_peak = { + dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in snapshot.items() + } + mod_stat.snapshots.setdefault(_FSDPModState.PEAK_BW, []).append(snapshot) + mod_stat.snapshots.setdefault(_FSDPModState.BEF_PRE_BW, []).append( + deepcopy(snapshot) + ) + orig_fsdp_param_group_pre_backward(*args, **kwargs) + + mod_stat.snapshots.setdefault(_FSDPModState.AFT_PRE_BW, []).append( + self.get_tracker_snapshot() + ) + self._fsdp_state = _FSDPState.BW + + return inner + + def _fsdp_param_group_post_backward( + self, + fsdp_mod: FSDPModule, + orig_fsdp_param_group_post_backward: Callable[_P, Any], + ) -> Callable[_P, None]: + # We capture the memory snapshots before and after ``FSDPParamGroup.post_backward`` to track and attribute + # the `unsharded` grads before the post backward and then `sharded` grads and `reduce_scatter` buffers + # after the post backward. + @wraps(orig_fsdp_param_group_post_backward) + def inner(*args: _P.args, **kwargs: _P.kwargs) -> None: + fsdp_state = fsdp_mod._get_fsdp_state() + if fsdp_param_group := fsdp_state._fsdp_param_group: + for fsdp_param in fsdp_param_group.fsdp_params: + unsharded_grad = fsdp_param._unsharded_param.grad + if unsharded_grad is not None: + self._update_and_maybe_create_winfos( + unsharded_grad, + _FSDPRefType.UNSHARDED_GRAD, + update_existing=True, + ) + + mod_stat = self.memory_tracking[fsdp_mod] + mod_stat.snapshots.setdefault(_FSDPModState.BEF_POST_BW, []).append( + self.get_tracker_snapshot() + ) + self._fsdp_state = _FSDPState.POST_BW + orig_fsdp_param_group_post_backward(*args, **kwargs) + + if fsdp_param_group := fsdp_state._fsdp_param_group: + for fsdp_param in fsdp_param_group.fsdp_params: + sharded_grad = fsdp_param.sharded_param.grad + if sharded_grad is not None: + self._update_and_maybe_create_winfos( + sharded_grad, + _FSDPRefType.SHARDED_GRAD, + ) + + mod_stat.snapshots.setdefault(_FSDPModState.AFT_POST_BW, []).append( + self.get_tracker_snapshot() + ) + + return inner + + def _instrument_fsdp_module(self) -> None: + # We uninstall the existing `FSDPState._pre_forward` and `FSDPState._post_forward` hooks and install + # our own hooks that wrap them. We choose this over monkey-patching `FSDPParamGroup.pre_forward` and + # `FSDPParamGroup.post_forward` because during AC these won't be called. + # TODO(@sanketpurandare): This will need to be modified after this PR (https://github.com/pytorch/pytorch/pull/127786) + # lands. For backward we monkey-patch the `FSDPParamGroup.pre_backward` and `FSDPParamGroup.post_backward`. + + # get the unique _MultiHandlers/RemoveHandlers and store in dictionary + # the _MultiHandlers object will only need to be grabbed once. + unique_handlers: dict[RemovableHandle, bool] = {} + # pyrefly: ignore # missing-attribute + for module in self._root_mod.modules(): + if isinstance(module, FSDPModule): + fsdp_state = module._get_fsdp_state() + if fsdp_param_group := fsdp_state._fsdp_param_group: + if not unique_handlers.get(fsdp_state._pre_forward_hook_handle): + unique_handlers[fsdp_state._pre_forward_hook_handle] = True + if not unique_handlers.get(fsdp_state._post_forward_hook_handle): + unique_handlers[fsdp_state._post_forward_hook_handle] = True + # call remove on the handles once + for f_hook_handle in unique_handlers: + f_hook_handle.remove() + # pyrefly: ignore # missing-attribute + for module in self._root_mod.modules(): + if isinstance(module, FSDPModule): + fsdp_state = module._get_fsdp_state() + if fsdp_param_group := fsdp_state._fsdp_param_group: + self._instrument_fsdp_sharded_params_grads(fsdp_param_group) + fsdp_state._pre_forward_hook_handle = ( + # pyrefly: ignore [missing-attribute] + module.register_forward_pre_hook( + self._fsdp_state_pre_forward( + module, fsdp_state._pre_forward + ), + prepend=True, + with_kwargs=True, + ) + ) + # pyrefly: ignore [missing-attribute] + fsdp_state._post_forward_hook_handle = module.register_forward_hook( + self._fsdp_state_post_forward(module, fsdp_state._post_forward), + prepend=False, + always_call=True, + ) + self._fsdp_mod_to_saved_methods[module] = _SavedFSDPMethods( + fsdp_param_group.pre_backward, + fsdp_param_group.post_backward, + ) + fsdp_param_group.pre_backward = self._fsdp_param_group_pre_backward( # type: ignore[assignment] + module, fsdp_param_group.pre_backward + ) + fsdp_param_group.post_backward = ( # type: ignore[assignment] + self._fsdp_param_group_post_backward( + module, fsdp_param_group.post_backward + ) + ) + + # pyrefly: ignore [missing-attribute] + for buffer in self._root_mod.buffers(): + self._update_and_maybe_create_winfos( + buffer, + _FSDPRefType.BUFFER, + ) + + def _instrument_optimizer(self) -> None: + # Register a hook on the optimizer step to track the optimizer states. + # The pre-hook is to set the flag ``_in_opt`` to True. The post-hook unsets the flag, + # and also tracks any optimizer states that are created during the optimizer step. + if self._optm is not None: + self._track_optimizer_states(_FSDPRefType.OPT, self._optm) + + def _opt_step_pre_hook( + optimizer: optim.Optimizer, args: Any, kwargs: Any + ) -> None: + self._in_opt = True + + def _opt_step_post_hook( + optimizer: optim.Optimizer, args: Any, kwargs: Any + ) -> None: + self._track_optimizer_states(_FSDPRefType.OPT, optimizer) + self._in_opt = False + + self._optimizer_hook_handles = ( + self._optm.register_step_pre_hook(_opt_step_pre_hook), + self._optm.register_step_post_hook(_opt_step_post_hook), + ) + + def _register_module_and_optimizer_hooks(self) -> None: + self._instrument_fsdp_module() + self._instrument_optimizer() + + def _deregister_module_and_optimizer_hooks(self) -> None: + for ( + fsdp_mod, + saved_methods, + ) in self._fsdp_mod_to_saved_methods.items(): + fsdp_state = fsdp_mod._get_fsdp_state() + fsdp_state._pre_forward_hook_handle.remove() + fsdp_state._post_forward_hook_handle.remove() + fsdp_state._pre_forward_hook_handle = fsdp_mod.register_forward_pre_hook( + fsdp_state._pre_forward, prepend=True, with_kwargs=True + ) + fsdp_state._post_forward_hook_handle = fsdp_mod.register_forward_hook( + fsdp_state._post_forward, prepend=False + ) + if fsdp_param_group := fsdp_state._fsdp_param_group: + fsdp_param_group.pre_backward = saved_methods.pre_backward + fsdp_param_group.post_backward = saved_methods.post_backward + self._fsdp_mod_to_saved_methods.clear() + + if self._optimizer_hook_handles is not None: + for handle in self._optimizer_hook_handles: + handle.remove() + self._optimizer_hook_handles = None + + def track_inputs(self, inputs: tuple[Any, ...]) -> None: + """ + This is used to track the input tensors to the model and annotate them as ``Inputs``. + Args: + inputs (Tuple[Any]): A tuple containing the input data. This can include tensors + as well as other data types. Only tensors will be tracked. + """ + + def _track_inputs(t: torch.Tensor) -> None: + self._update_and_maybe_create_winfos( + t, + _FSDPRefType.INP, + ) + + tree_map_only(torch.Tensor, _track_inputs, inputs) + + def track_external( + self, *external: nn.Module | optim.Optimizer | torch.Tensor + ) -> None: + """This is no-op for ``FSDPMemTracker``""" + + def __enter__(self) -> "FSDPMemTracker": + if self._depth == 0: + self._register_module_and_optimizer_hooks() + self._track_resize() + self._peak_mem_snap = self.get_tracker_snapshot() + self._peak_mem = { + dev: dev_snap[_TOTAL_KEY] + for dev, dev_snap in self._peak_mem_snap.items() + } + self._mod_tracker.__enter__() + TorchDispatchMode.__enter__(self) + self._depth += 1 + return self + + def __exit__(self, *args: Any) -> None: + self._depth -= 1 + if self._depth == 0: + self._deregister_module_and_optimizer_hooks() + self._restore_resize() + self._mod_tracker.__exit__(*args) + TorchDispatchMode.__exit__(self, *args) + + def __torch_dispatch__(self, func, types, args=..., kwargs=None): # type: ignore[no-untyped-def] + # When running this mode with DTensor, ordinarily all modes will + # run **before** subclasses get a chance to run. + # Returning NotImplemented here gives us a chance to let DTensor + # run and desugar into local tensor ops, before `MemTracker` sees them. + if any(t == DTensor for t in types): + return NotImplemented + if ( + func is torch.ops._c10d_functional.wait_tensor.default + and active_fake_mode() + ): + # N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns + # a new tensor which does not happen in eager mode, when a wait_tensor is called. + # pyrefly: ignore [unsupported-operation] + res = args[0] + else: + res = func(*args, **kwargs or {}) + # If we are tracking an optimizer state, we use the optimizer reference type. + # If we are in backward region and not in AC region, we use the backward reference type. + # Else we use the forward reference type. + if self._in_opt: + reftype = _FSDPRefType.OPT + elif self._mod_tracker.is_bw and not self._in_ac: + reftype = _FSDPRefType.TEMP + else: + reftype = _FSDPRefType.ACT + if func is c10d._allgather_base_.default and self._fsdp_state in [ + _FSDPState.PRE_FW, + _FSDPState.PRE_BW, + ]: + # pyrefly: ignore [unsupported-operation] + output_tensor = args[0] + self._update_and_maybe_create_winfos( + output_tensor, + _FSDPRefType.ALL_GATHER, + update_existing=True, + ) + if ( + func is c10d._reduce_scatter_base_.default + and self._fsdp_state == _FSDPState.POST_BW + ): + # pyrefly: ignore [unsupported-operation] + input_tensor = args[1] + self._update_and_maybe_create_winfos( + input_tensor, + _FSDPRefType.REDUCE_SCATTER, + update_existing=True, + ) + + tree_map_only(torch.Tensor, partial(self._track, reftype), res) + peak_state = ( + _FSDPModState.PEAK_BW if self._mod_tracker.is_bw else _FSDPModState.PEAK_FW + ) + self._update_peak_stats(peak_state) + return res diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/ilp_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/ilp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0e8ba4195ffd20323d419642159fe199549e3de1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/ilp_utils.py @@ -0,0 +1,292 @@ +import copy +from collections import OrderedDict +from typing import cast, TypedDict + +import numpy as np + +import torch +from torch.distributed._tools.mem_tracker import ( + _MemRefType, + _ModMemStats, + _ModState, + MemTracker, +) +from torch.distributed._tools.runtime_estimator import RuntimeEstimator +from torch.distributed._tools.sac_estimator import SACEstimator, SACTradeOffStats + + +class ModOrder(TypedDict): + fw_pre_order: list[str] + bw_pre_order: list[str] + fw_post_order: list[str] + bw_post_order: list[str] + + +class ModRuntime(TypedDict): + fw: float + bw: float + + +class ModStats(TypedDict): + fqn: str + # per-module params + param_per_module: int + # per-module grads + grad_per_module: int + # total accumulated gradients up to and including this module + grad_total: int + # per module fw activation size (excluding input and output) + act_fw_per_module: int + # per module bw activation size during peak_bw + act_bw_per_module: int + # per module activation grad size during peak_bw + act_grad_per_module: int + # total activation size up to but excluding the current module + # includes input of the current module (i.e., output of previous module) + act_total: int + # Inputs to the module + input_per_module: int + # Outputs of the module + output_per_module: int + # Total fw run-time of the module + fw_runtime_per_module: float + # Total bw run-time of the module + bw_runtime_per_module: float + # Is this module a leaf module + is_leaf: bool + # Total ac run-time of the module + sac_runtime: float + # Total ac_memory for the module + sac_memory: int + # Number of piecewise-linear functions used for approximating ac tradeoff curve + n_segments: int + # Slopes of the of piecewise-linear functions + slopes: list[float] + # Intercepts of the of piecewise-linear functions + intercepts: list[float] + # X breakpoints of the of piecewise-linear functions + breakpoints: list[float] + # Original trade-off curves + tradeoff_curve: OrderedDict[float, float] + + +class ModuleInfo(TypedDict): + mod_order: ModOrder + mod_stats: list[ModStats] + + +def aggregate_stats( + model: torch.nn.Module, + mem_tracker: MemTracker, + runtime_estimator: RuntimeEstimator, + sac_estimator: SACEstimator, + dev: torch.device, +) -> ModuleInfo: + """ + Collect modulewise stats for a given model, including memory, runtime, and AC tradeoff stats. + + Args: + model: nn.Module object + runtime_estimator: RuntimeEstimator object with runtime stats + mem_tracker: MemTracker object with memory stats + sac_estimator: SACEstimator object with AC tradeoff stats + dev: device the model was run on (used to extract memory stats from MemTracker) + + Returns: + ModuleInfo: A dictionary with module order and module stats. + """ + + # Memory stats + mod_mem_stats: dict[torch.nn.Module, _ModMemStats] = dict( + copy.deepcopy(mem_tracker.memory_tracking) + ) + + # Runtime stats + mod_runtime_stats: dict[str, ModRuntime] = { + fqn: {"fw": v["fw"], "bw": v["bw"]} + for fqn, v in runtime_estimator.mod_runtimes.items() + } + + # Module order + mod_order: ModOrder = { + "fw_pre_order": list(runtime_estimator.mod_fw_pre_order), + "bw_pre_order": list(runtime_estimator.mod_bw_pre_order), + "fw_post_order": list(runtime_estimator.mod_fw_post_order), + "bw_post_order": list(runtime_estimator.mod_bw_post_order), + } + + # Selective Activation Checkpointing stats + sac_estimator.pwlf_sac_tradeoff_curve() + mod_sac_tradeoff_stats: dict[str, SACTradeOffStats] = copy.deepcopy( + sac_estimator.sac_mod_tradeoff_stats + ) + + module_info: ModuleInfo = { + "mod_order": mod_order, + "mod_stats": [], + } + + for mod in model.modules(): + if mod_mem_stat := mod_mem_stats.get(mod): + if tradeoff_stats := mod_sac_tradeoff_stats.get(mod_mem_stat.mod_fqn, None): + sac_runtime = tradeoff_stats.sac_runtime + sac_memory = tradeoff_stats.sac_memory + n_segments = tradeoff_stats.n_segments + slopes = tradeoff_stats.slopes + intercepts = tradeoff_stats.intercepts + breakpoints = tradeoff_stats.fit_breaks + tradeoff_curve = tradeoff_stats.tradeoff_curve + is_leaf = False + else: + sac_runtime = sac_memory = n_segments = 0 + slopes = intercepts = breakpoints = [] + tradeoff_curve: OrderedDict[float, float] = OrderedDict() # type: ignore[no-redef] + is_leaf = True + mod_stat: ModStats = { + "fqn": mod_mem_stat.mod_fqn, + "param_per_module": mod_mem_stat.parameter_mem, + "grad_per_module": mod_mem_stat.parameter_mem, + "grad_total": mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][ + _MemRefType.GRAD + ], + "act_fw_per_module": max( + 0, + mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][_MemRefType.ACT] + - mod_mem_stat.snapshots[_ModState.PRE_FW][-1][dev][_MemRefType.ACT] + - mod_mem_stat.output_mem, + ), + "act_bw_per_module": max( + 0, + mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.ACT], + ), + "act_grad_per_module": ( + mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.TEMP] + - mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][ + _MemRefType.TEMP + ] + ), + "act_total": mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][ + _MemRefType.ACT + ], + "input_per_module": mod_mem_stat.input_mem, + "output_per_module": mod_mem_stat.output_mem, + "fw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["fw"], + "bw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["bw"], + "is_leaf": is_leaf, + "sac_runtime": sac_runtime, + "sac_memory": sac_memory, + "n_segments": n_segments, + "slopes": slopes, + "intercepts": intercepts, + "breakpoints": breakpoints, + "tradeoff_curve": tradeoff_curve, + } + module_info["mod_stats"].append(mod_stat) + + return module_info + + +class Node(ModStats): + index: int # index according to forward pre-order + pos_fw_post_order: int # index according to forward post-order + + +class Graph: + def __init__(self, n: int) -> None: + self.nodes: list[Node] = [] + self.name2node: dict[str, Node] = {} + self.ad_matrix = np.zeros((n, n)) + self.fw_post_order: list[str] = [] + + def add_node(self, node: Node) -> None: + self.nodes.append(node) + self.name2node[node["fqn"]] = node + + +def parse_module_info(module_info: ModuleInfo) -> Graph: + """ + Parse module info and create a graph (tree) of modules. The graph will be + used by MILP solver to find optimal SAC and/or FSDP configurations. + """ + mod_stats = module_info["mod_stats"] + fw_pre_order = module_info["mod_order"]["fw_pre_order"] + # assertion and number of nodes + assert len(mod_stats) == len(fw_pre_order) + n_nodes = len(mod_stats) + + # create graph + g = Graph(n_nodes) + g.fw_post_order = module_info["mod_order"]["fw_post_order"] + + # sort the modules by pre-order and add them to the graph + module_info["mod_stats"] = sorted( + mod_stats, key=lambda x: fw_pre_order.index(x["fqn"]) + ) + for i, one_mod_stats in enumerate(mod_stats): + node: Node = cast(Node, one_mod_stats) + node["index"] = i + node["pos_fw_post_order"] = g.fw_post_order.index(node["fqn"]) + g.add_node(node) + + # set up ancestor-descendant matrix + for i in range(n_nodes): + for j in range(i, n_nodes): + if is_self_or_submodule(g.nodes[j]["fqn"], g.nodes[i]["fqn"]): + g.ad_matrix[i][j] = 1 + else: + break + + return g + + +def is_self_or_submodule(name_descendant: str, name_ancestor: str) -> bool: + """ + check if name_descendant is a submodule of name_ancestor, or if they are the same + """ + return name_descendant == name_ancestor or name_ancestor + "." in name_descendant + + +def is_submodule(name_descendant: str, name_ancestor: str) -> bool: + """ + if name_descendant is a submodule of name_ancestor, but not the same + """ + return name_ancestor + "." in name_descendant + + +def display_bytes(b: int, unit: str = "MiB") -> str: + """ + return a string that represent the number of bytes in a desired unit + """ + if unit == "KiB": + return f"{b / 2**10:.2f} KiB" + if unit == "MiB": + return f"{b / 2**20:.2f} MiB" + if unit == "GiB": + return f"{b / 2**30:.2f} GiB" + return f"{b:.2f} bytes" + + +def get_peak_memory_runtime_baseline(graph: Graph) -> tuple[int, float]: + """ + Get the baseline peak memory and runtime. + Baseline here means there is no FSDP or AC. + Memory includes the parameters, gradients, activations, and activation gradients. + Memory does not include e.g., optimizer states, embedding tables, etc. + + Returns: + int: peak memory in bytes + float: compute time in ms + """ + P_1 = graph.nodes[0]["param_per_module"] + num_nodes = len(graph.nodes) + peak_mem = 0 + for i in range(num_nodes): + TG_i = graph.nodes[i]["grad_total"] + AG_i = graph.nodes[i]["act_grad_per_module"] + TA_i = graph.nodes[i]["act_total"] + peak_mem = max(peak_mem, P_1 + TG_i + AG_i + TA_i) + compute_time = ( + graph.nodes[0]["fw_runtime_per_module"] + + graph.nodes[0]["bw_runtime_per_module"] + ) + return (peak_mem, compute_time) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/mem_tracker.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/mem_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..bcf03f132b1a7fbd7c03ed0b1a0b03e40ebebdb2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/mem_tracker.py @@ -0,0 +1,938 @@ +import math +import os +import re +import warnings +from collections.abc import Callable +from copy import deepcopy +from enum import auto, Enum +from functools import partial, wraps +from typing import Any, TYPE_CHECKING +from typing_extensions import Self + +import torch +import torch.distributed._tools.fake_collectives +from torch import nn, optim +from torch._guards import active_fake_mode +from torch.distributed._tools.common_utils import get_untyped_storages +from torch.distributed._tools.mod_tracker import ModTracker +from torch.distributed.tensor import DTensor +from torch.optim.optimizer import ( + register_optimizer_step_post_hook, + register_optimizer_step_pre_hook, +) +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_flatten, tree_map_only +from torch.utils.weak import WeakIdKeyDictionary, weakref + + +if TYPE_CHECKING: + from torch.utils.hooks import RemovableHandle + +# This value is hard-coded here: +# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 +_PYTORCH_MIN_ALLOCATE = ( + 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 +) +_TOTAL_KEY = "Total" + +__all__ = ["MemTracker"] + + +class _RefType(str, Enum): + """Base Class for defining memory reference types, categorizing tensors based on their usage within a model.""" + + +class _State(str, Enum): + """Base Class for defining module state to capture snapshots .""" + + +class _MemRefType(_RefType): + """ + An enum to define memory reference types, categorizing tensors based on their usage within a model. + + - PARAM: Tensors registered as nn.Parameter within modules. + - BUFFER: Tensors registered as nn.Buffer within modules. + - GRAD: Gradients associated with parameters. + - ACT: Tensors produced during the forward pass and recomputation in activation checkpointing. + - TMP: Temporary memory used during the backward pass, including gradients of activations. + - OPT: Tensors holding optimizer states. + - OTH: Tensors registered via `track_external` that do not fit the above categories. + """ + + PARAM = "Parameter" + BUFFER = "Buffer" + GRAD = "Gradient" + ACT = "Activation" + TEMP = "Temp" + OPT = "Optstate" + OTH = "Other" + + +class _ModState(_State): + """ + An enum to define the state of a module. + + - PRE_FW: The module is about to run the forward pass. + - POST_FW: The module has finished running the forward pass. + - PEAK_FW: The module has reached the peak memory usage during the forward pass. + - PRE_BW: The module is about to run the backward pass. + - PRE_FW_AC: The module is about to run the forward pass with activation checkpointing. + - POST_FW_AC: The module has finished running the forward pass with activation checkpointing. + - POST_BW: The module has finished running the backward pass. + - PEAK_BW: The module has reached the peak memory usage during the backward pass. + """ + + PRE_FW = "Pre-Forward" + POST_FW = "Post-Forward" + PEAK_FW = "Peak-Forward" + PRE_BW = "Pre-Backward" + PRE_FW_AC = "Pre-Forward-AC" + POST_FW_AC = "Post-Forward-AC" + POST_BW = "Post-Backward" + PEAK_BW = "Peak-Backward" + + +class _ModMemStats: + """ + A class to store the memory statistics of a module. + + Args: + mod_fqn (str): The fully qualified name of the module. + Attributes: + mod_fqn (str): The fully qualified name of the module. + parameter_mem (int): The memory usage of the parameters of the module. + buffer_mem (int): The memory usage of the buffers of the module. + input_mem (int): The memory usage of the inputs to the module. + output_mem (int): The memory usage of the outputs from the module. + snapshots (Dict[_ModState, Dict[torch.device, Dict[str, int]]]): A dictionary of memory snapshots + of the module at different states defined by ``_ModState``. + Note: + The memory snapshot is stored as a dictionary - Dict[torch.device, Dict[str, int]], where each key is a device, + and each value is another dictionary with keys as memory reference types defined by `_MemRefType` and + values as the memory consumed in bytes. + """ + + def __init__(self, mod_fqn: str): + self.mod_fqn = mod_fqn + self.parameter_mem: int + self.buffer_mem: int + self.input_mem: int + self.output_mem: int + self.local_peak: dict[torch.device, int] = {} + self.snapshots: dict[_ModState, list[dict[torch.device, dict[str, int]]]] = {} + + +class _WeakRefInfo: + """ + Manages memory statistics and device attributes for tensor storages. + """ + + def __init__( + self, size: int, element_size: int, device: torch.device, reftype: _RefType + ) -> None: + """ + Initializes the ``_WeakRefInfo`` object with tensor storage properties. + + Args: + size (int): The number of elements in the tensor storage. + element_size (int): The size of each element in the tensor storage. + device (torch.device): The device on which the tensor is allocated. + reftype (_RefType): The reference type of the tensor. + """ + self.size = size + self.element_size = element_size + self.reftype = reftype + # pyrefly: ignore [read-only] + self.device = device + self.mem_consumed = self._calculate_mem_consumed() + + def _calculate_mem_consumed(self) -> int: + """ + Calculates the memory consumed by the tensor storage, considering device-specific allocation rules. + + Returns: + int: The memory consumed in bytes. + """ + mem = self.size * self.element_size + if self.device.type == "cuda": + return math.ceil((mem) / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE + return mem + + def update_mem_consumed(self, st: torch.UntypedStorage) -> int: + """ + Updates and returns the memory consumed if the storage size has changed. + + Args: + st (torch.UntypedStorage): The tensor storage to check for size updates. + + Returns: + int: The updated memory consumed in bytes. + """ + if st.size() != self.size: + self.size = st.size() + self.mem_consumed = self._calculate_mem_consumed() + return self.mem_consumed + + @classmethod + def create_winfo( + cls, + st: torch.UntypedStorage, + device: torch.device, + reftype: _RefType, + callback: Callable[[Self, weakref.ref], Any] | None = None, + ) -> tuple[Self, weakref.ref]: + """ + Creates a new ``_WeakRefInfo`` instance and a weak reference to a ``torch.UntypedStorage`` object, + optionally attaching a callback to the weak reference. + + Args: + st (torch.UntypedStorage): The storage object for which to create the weak reference info. + device (torch.device): The device associated with the storage object. + reftype (_RefType): The type of reference, used to categorize the storage. + callback (Optional[Callable[[Self, weakref.ref]]]): A callback function that is called when + the storage object is about to be finalized (garbage collected). The callback function + should accept two arguments: the ``_WeakRefInfo`` instance and the weak reference to the storage. + Returns: + Tuple[Self, weakref.ref]: A tuple containing the newly created ``_WeakRefInfo`` instance and the + weak reference to the storage object. The weak reference may have an attached callback if provided. + """ + + winfo = cls(st.size(), st.element_size(), device, reftype) + w_st = weakref.ref(st, partial(callback, winfo) if callback else None) + return winfo, w_st + + +def _get_mem_divisor(units: str) -> int: + unit_dict = {"B": 1, "KiB": 2**10, "MiB": 2**20, "GiB": 2**30} + if units in unit_dict: + return unit_dict[units] + else: + raise ValueError( + f"Unsupported unit: {units}. Supported units are: {', '.join(unit_dict.keys())}" + ) + + +def _rounding_fn(value: int, divisor: int, precision: int) -> float | int: + return value if divisor == 1 else round(value / divisor, precision) + + +def _print_snapshot(snapshot: dict[torch.device, dict[str, int]], units: str) -> None: + if len(snapshot) == 0: + print("No memory tracked.") + return + divisor = _get_mem_divisor(units) + for dev, dev_snap in snapshot.items(): + if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0: + continue + print( + f"Device: {dev}", + *( + f"\t{k.value}: {_rounding_fn(v, divisor, 2)} {units}" + if isinstance(k, _RefType) + else f"\t{k}: {_rounding_fn(v, divisor, 2)} {units}" + for k, v in dev_snap.items() + ), + sep="\n", + ) + + +def _print_snapshot_tabular( + snapshot: dict[torch.device, dict[str, int]], units: str +) -> None: + if len(snapshot) == 0: + print("No memory tracked.") + return + try: + from tabulate import tabulate + except ImportError as err: + raise ImportError( + "Please install tabulate to use the tabulate option." + ) from err + divisor = _get_mem_divisor(units) + table_data = [] + key_list = list(next(iter(snapshot.values())).keys()) + headers = ["Device"] + [ + f"{key.value}" if isinstance(key, _RefType) else f"{key}" for key in key_list + ] + + for dev, dev_snap in snapshot.items(): + if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0: + continue + row = [str(dev)] + row.extend(f"{_rounding_fn(v, divisor, 2)} {units}" for v in dev_snap.values()) + table_data.append(row) + print(tabulate(table_data, headers=headers, tablefmt="rst")) + + +def _print_state_snapshots( + snapshots: dict[_State, list[dict[torch.device, dict[str, int]]]], units: str +) -> None: + for state, snapshot_list in snapshots.items(): + print(f"{state.value}") + for i, snapshot in enumerate(snapshot_list): + print(f"# {i + 1}:") + _print_snapshot(snapshot, units) + print() + + +def _print_state_snapshots_tabular( + snapshots: dict[_State, list[dict[torch.device, dict[str, int]]]], units: str +) -> None: + try: + from tabulate import tabulate + except ImportError as err: + raise ImportError( + "Please install tabulate to use the tabulate option." + ) from err + + table_data = [] + last_state_call = None + divisor = _get_mem_divisor(units) + for state, snapshot_list in snapshots.items(): + for i, snapshot in enumerate(snapshot_list): + state_call = f"{state.value} # {i + 1}" + for dev, dev_snap in snapshot.items(): + if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0: + continue + row = { + "State & Call": ( + state_call if state_call != last_state_call else "" + ), + "Device": str(dev), + } + last_state_call = state_call + for k, v in dev_snap.items(): + row[f"{k.value}" if isinstance(k, _RefType) else f"{k}"] = ( + f"{_rounding_fn(v, divisor, 2)} {units}" + ) + table_data.append(row) + print(tabulate(table_data, headers="keys", tablefmt="rst")) + + +class _UpdateType(Enum): + # These are used for tracking updates to the continuouly maintained memory snapshot. + # ADD - When a new tensor storage is tracked + # DEL - When a tensor storage is about to be finalized (garbage collected). + # REF - When a tensor reference is updated, for instance, the gradients are marked as + # generic backward reference types until the grad_hook categorizes them as gradients. + # SIZE - When a tensor's storage is resized. + ADD = auto() + DEL = auto() + REF = auto() + SIZE = auto() + + +class MemTracker(TorchDispatchMode): + """ + A TorchDispatchMode to track, categorize and attribute the tensor memory created or accessed within its context. + + It categorizes the tracked tensors as parameters, buffers, activations, gradients, temporary memory and optimizer states + as defined by ``_MemRefType`` within its context. It captures memory `snapshots` for the modules, called within its context, + at various states defined by ``_ModState``. + + Attributes: + memory_tracking: A weakref key dictionary to store the memory statistics of each module. Each key + is a reference to a module, and each value is a ``_ModMemStats`` object that stores the memory + statistics of the module. + + Note: + The MemTracker should be used as a context manager. The modules, optimizers, and any other tensors created within + the context of MemTracker will be tracked by default. Any tensors or stateful objects such as modules, optimizers etc. + that need to be tracked but are created outside the MemTracker should be registered using the `track_external` method. + The `track_external` method should be called before the MemTracker is used. Any tensors created outside the ``MemTracker`` + and not supplied to the `track_external` method will not be tracked by the ``MemTracker``. + + Example usage: + + .. code-block:: python + + module = ... + optimizer = ... + inp = ... + mem_tracker = MemTracker() + mem_tracker.track_external(module, optimizer, inp) + with mem_tracker as mt: + loss = module(inp) + print("After Forward:") + mt.display_snapshot("current") + loss.backward() + optimizer.step() + optimizer.zero_grad() + mt.display_snapshot("peak") + mt.display_modulewise_snapshots(depth=3, units="MiB") + + Known Limitations: + - The ``MemTracker`` does not track memory for tensors that bypass the ``TorchDispatchMode`` ex. under ``no_dispatch``. + - Resizing tensor storages directly by using non-Tensor methods other than using ``torch.Untyped_Storage.resize_`` + is not tracked. File a Github issue if you have use-cases for this. + - If the tensors are not traceable or wrappable subclasses of ``torch.Tensor``, then the tracker does not know how to + track their storages. File a Github issue if you have use-cases for this. + - During AC in the backward pass there might be misattribution between activation and temp memory, but the peak memory + will be tracked accurately. This will be fixed in the next update by hooking intricately with ``torch.uitls.checkpoint``. + """ + + def __init__(self) -> None: + self.memory_tracking = WeakIdKeyDictionary() + self._curr_mem_snap: dict[torch.device, dict[str, int]] = {} + self._peak_mem: dict[torch.device, int] = {} + self._peak_mem_snap: dict[torch.device, dict[str, int]] = {} + self._param_to_grad_hook_handles = WeakIdKeyDictionary() + self._optimizer_hook_handles: tuple[RemovableHandle, RemovableHandle] | None = ( + None + ) + # Dictionary to store the ``_WeakRefInfo`` instances corresponding to each tensor's storage. + self._WINFO = WeakIdKeyDictionary() + self._mod_tracker = ModTracker() + # This is a general memory tracker which can be used with any ``_RefType`` subclass + self._ref_class: type[_RefType] = _MemRefType + # Flags to track if we are in the AC region or optimizer step region + self._in_opt: bool = False + self._in_ac: bool = False + # Weak references to the topmost AC module currently active + self._ac_mod: weakref.ref | None = None + self._orig_resize = torch.UntypedStorage.resize_ + self._depth = 0 + + def _update_snap( + self, + u_type: _UpdateType, + winfo: _WeakRefInfo, + old_mem_consumed: int | None = None, + old_reftype: _RefType | None = None, + ) -> None: + # Initialize a flag to track if the total memory might drop to zero after updates. + maybe_zero = False + # Ensure the device entry exists in the current memory snapshot, initializing if necessary. + # pyrefly: ignore [no-matching-overload] + dev_snap = self._curr_mem_snap.setdefault( + winfo.device, dict.fromkeys(self._ref_class, 0) + ) + dev_snap.setdefault(_TOTAL_KEY, 0) + # Handle different types of updates based on the update type (`u_type`). + if u_type == _UpdateType.ADD: + # Increase the memory consumed for the specific reference type and update the total. + dev_snap[winfo.reftype] += winfo.mem_consumed + dev_snap[_TOTAL_KEY] += winfo.mem_consumed + elif u_type == _UpdateType.DEL: + # Decrease the memory consumed for the specific reference type and reduce the total. + dev_snap[winfo.reftype] -= winfo.mem_consumed + dev_snap[_TOTAL_KEY] -= winfo.mem_consumed + maybe_zero = True + elif u_type == _UpdateType.REF: + assert old_reftype is not None + # Adjust memory consumption between two reference types within the same device. + dev_snap[old_reftype] -= winfo.mem_consumed + dev_snap[winfo.reftype] += winfo.mem_consumed + elif u_type == _UpdateType.SIZE: + assert old_mem_consumed is not None + # Adjust the memory consumed for a reference type due to a change in size. + change = winfo.mem_consumed - old_mem_consumed + dev_snap[winfo.reftype] += change + dev_snap[_TOTAL_KEY] += change + maybe_zero = True + else: + raise ValueError(f"Invalid update type: {u_type}") + # Check if the total memory for the device has dropped to zero. + if maybe_zero: + if self._curr_mem_snap[winfo.device][_TOTAL_KEY] == 0: + # Remove the device entry from the memory snapshot if the total memory is zero. + del self._curr_mem_snap[winfo.device] + + def _update_and_maybe_create_winfos( + self, + t: torch.Tensor, + reftype: _RefType, + update_existing: bool = False, + ) -> set[_WeakRefInfo]: + sts = get_untyped_storages(t) + winfos = set() + for st in sts: + # Attempt to retrieve existing ``_WeakRefInfo`` and its weak reference from the tracking dictionary. + winfo, _ = self._WINFO.get(st, (None, None)) + if winfo is not None: + # If ``_WeakRefInfo`` exists, check if the reference type needs to be updated. + old_reftype = winfo.reftype + if old_reftype != reftype: + # Update the reference type and apply changes via ``_update_snap``. + winfo.reftype = reftype + self._update_snap(_UpdateType.REF, winfo, old_reftype=old_reftype) + winfos.add(winfo) + elif update_existing: + # If no existing ``_WeakRefInfo`` is found and update_existing is True, raise an error. + raise KeyError("No existing winfo found") + else: + # If no existing _WeakRefInfo is found and update_existing is False, create a new ``_WeakRefInfo``. + winfo, w_st = _WeakRefInfo.create_winfo( + st, t.device, reftype, self._delete_callback + ) + # Store the new ``_WeakRefInfo`` and its weak reference in the tracking dictionary. + self._WINFO[st] = (winfo, w_st) + # Update the snapshot for the newly added ``_WeakRefInfo``. + if winfo.mem_consumed > 0: + self._update_snap(_UpdateType.ADD, winfo) + winfos.add(winfo) + return winfos + + def _delete_callback(self, winfo: _WeakRefInfo, w_st: weakref.ref) -> None: + # Callback to be called when the storage object corresponding to the ``_WeakRefInfo`` + # instance is about to be finalized. + if winfo.mem_consumed > 0: + self._update_snap(_UpdateType.DEL, winfo) + + def _track_resize(self) -> None: + # Need to monkey-patch this because ``torch.UntypedStorage.resize_`` is not captured + # by ``TorchDispatchMode``. + @wraps(self._orig_resize) + def resize_(st: torch.UntypedStorage, size: int) -> None: + self._orig_resize(st, size) + winfo, _ = self._WINFO.get(st, (None, None)) + if winfo is not None and winfo.size != st.size(): + old_mem_consumed = winfo.mem_consumed + winfo.update_mem_consumed(st) + self._update_snap( + _UpdateType.SIZE, winfo, old_mem_consumed=old_mem_consumed + ) + + torch.UntypedStorage.resize_ = resize_ # type: ignore[method-assign, assignment] + + def _restore_resize(self) -> None: + torch.UntypedStorage.resize_ = self._orig_resize # type: ignore[method-assign] + + def _update_peak_stats(self, peak_state: _State) -> None: + # We first capture the current memory snapshot of the current tracker state then, + # We step through each of the modules we have tracked so far in ``memory_tracking`` + # and check if it is currently active by querying ``_mod_tracker.parents`` + # If it is active, we update the per device peak memory usage for the module + # corresponding to the ``_State`` which can be ``PEAK_FW`` or ``PEAK_BW``. + curr_snap = self._curr_mem_snap + + for mod_stats in self.memory_tracking.values(): + if mod_stats.mod_fqn in self._mod_tracker.parents: + if peak_state in mod_stats.snapshots: + for dev, dev_snap in curr_snap.items(): + if mod_stats.local_peak.get(dev, 0) < dev_snap[_TOTAL_KEY]: + mod_stats.local_peak[dev] = dev_snap[_TOTAL_KEY] + mod_stats.snapshots[peak_state][-1][dev] = deepcopy( + dev_snap + ) + + for dev, dev_snap in curr_snap.items(): + if self._peak_mem.get(dev, 0) < dev_snap[_TOTAL_KEY]: + self._peak_mem[dev] = dev_snap[_TOTAL_KEY] + self._peak_mem_snap[dev] = deepcopy(dev_snap) + + def _track(self, reftype: _RefType, t: torch.Tensor) -> None: + # Get the storages of the tensor and check if we have already tracked them. + # If yes, then check if the storage size has changed and update the current snapshot. + # Else create a new ``_WeakRefInfo`` instance and add it to the dictionary. + sts = get_untyped_storages(t) + for st in sts: + winfo, _ = self._WINFO.get(st, (None, None)) + if winfo is not None: + if winfo.size != st.size(): + old_mem_consumed = winfo.mem_consumed + winfo.update_mem_consumed(st) + self._update_snap( + _UpdateType.SIZE, winfo, old_mem_consumed=old_mem_consumed + ) + return + else: + winfo, w_st = _WeakRefInfo.create_winfo( + st, t.device, reftype, self._delete_callback + ) + self._WINFO[st] = (winfo, w_st) + # Update the current snapshot for the newly added ``_WeakRefInfo``. + if winfo.mem_consumed > 0: + self._update_snap(_UpdateType.ADD, winfo) + + def get_tracker_snapshot( + self, type: str = "current" + ) -> dict[torch.device, dict[str, int]]: + """ + Capture a snapshot of the memory usage breakdown per device, based on the specified type. + + Args: + type (str): The type of snapshot to capture. Can be "current" for the current memory usage or "peak" for the + peak memory usage. Defaults to "current". + Returns: + Dict[torch.device, Dict[str, int]]: A dictionary where each key is a torch.device, and each value is another + dictionary. This inner dictionary has keys representing memory reference + types as defined in ``_MemRefType`` and values representing the amount of + memory consumed in bytes. + Raises: + ValueError: If an invalid type is specified. + """ + if type == "current": + return deepcopy(self._curr_mem_snap) + elif type == "peak": + return deepcopy(self._peak_mem_snap) + else: + raise ValueError(f"Invalid type {type}") + + def _track_module_params_and_buffers( + self, module: nn.Module, install_grad_hooks: bool = True + ) -> tuple[int, int]: + # Track the parameters and buffers of the module if not already tracked. + # If the parameters have gradients, track the gradients as well. + # If install_grad_hooks is True, install a gradient hook on the parameters + # to track the gradients, if it has not already been installed. + # Return the total memory consumed by the parameters and buffers. + def _grad_hook(grad: torch.Tensor) -> None: + self._update_and_maybe_create_winfos( + grad, + _MemRefType.GRAD, + ) + + param_memory = 0 + for param in module.parameters(): + winfos = self._update_and_maybe_create_winfos( + param, + _MemRefType.PARAM, + ) + param_memory += sum(winfo.mem_consumed for winfo in winfos) + if param.grad is not None: + self._update_and_maybe_create_winfos( + param.grad, + _MemRefType.GRAD, + ) + if ( + self._param_to_grad_hook_handles.get(param, None) is None + and install_grad_hooks + ): + grad_hook_handle = param.register_hook(_grad_hook) + post_acc_grad_hook_handle = param.register_post_accumulate_grad_hook( + lambda p: (_grad_hook(p.grad)) + ) + self._param_to_grad_hook_handles[param] = ( + grad_hook_handle, + post_acc_grad_hook_handle, + ) + buffer_memory = 0 + for buffer in module.buffers(): + winfos = self._update_and_maybe_create_winfos( + buffer, + _MemRefType.BUFFER, + ) + buffer_memory += sum(winfo.mem_consumed for winfo in winfos) + return (param_memory, buffer_memory) + + def _track_inputs_or_outputs(self, args: Any) -> int: + # Calculate the memory consumed by the inputs or outputs of the module. + input_or_output_memory = 0 + + def add_inps_or_outs(t: torch.Tensor) -> None: + nonlocal input_or_output_memory + sts = get_untyped_storages(t) + for st in sts: + winfo, _ = self._WINFO.get(st, (None, None)) + if winfo is not None: + input_or_output_memory += winfo.mem_consumed + + tree_map_only(torch.Tensor, add_inps_or_outs, args) + return input_or_output_memory + + def _pre_fw_hook(self, module: nn.Module, inputs: Any) -> None: + # This is installed as a pre-fwd user hook with ``ModTracker.`` Based on the following cases we + # set the state and capture the memory snapshot for the module. + # Case 1: If the module is not in the ``memory_tracking`` dictionary, we track the parameters, buffers, + # input and output memory of the module. Create a new ``_ModMemStats`` instance for the module + # and add it to the ``memory_tracking`` dictionary. + # Case 2: If the module is already in the ``memory_tracking`` dictionary and we are in backward, this means + # we are in the AC region. We check if this is the top most module in the AC region. If it is, + # we store a weak reference and set the flag ``_in_ac`` to True. + # Case 3: If the module is already in the ``memory_tracking`` dictionary and we are in forward, this means + # this module is called for the second time. If it is a root module, that means we are in the next + # iteration and we error out. If it is not a root module, that means it's a submodule that is being + # used multiple times in the same iteration, which we allow and track. + # For Case 1 and 3, we also initialize the ``local_peak`` and ``PEAK_FW`` snapshot for the module. + mod_name = self._mod_tracker.get_known_fqn(module) + assert mod_name is not None + if module not in self.memory_tracking: + mod_stats = _ModMemStats(mod_name) + param_mem, buffer_mem = self._track_module_params_and_buffers( + module, install_grad_hooks=True + ) + input_mem = self._track_inputs_or_outputs(inputs) + mod_stats.parameter_mem = param_mem + mod_stats.buffer_mem = buffer_mem + mod_stats.input_mem = input_mem + self.memory_tracking[module] = mod_stats + state = _ModState.PRE_FW + + elif self._mod_tracker.is_bw: + mod_stats = self.memory_tracking[module] + state = _ModState.PRE_FW_AC + if self._ac_mod is None: + self._ac_mod = weakref.ref(module) + self._in_ac = True + else: + parents = set(self._mod_tracker.parents) - {mod_name} + if len(parents) == 1 and "Global" in parents: + raise NotImplementedError( + "MemTracker does not support memory tracking for multiple iterative calls." + " Either use ``reset_mod_stats`` to clear module memory stats for the previous iteration" + " or file a github issue if you need this feature." + ) + mod_stats = self.memory_tracking[module] + state = _ModState.PRE_FW + input_mem = self._track_inputs_or_outputs(inputs) + mod_stats.mod_fqn = mod_name + mod_stats.input_mem = input_mem + + mem_snapshot = self.get_tracker_snapshot() + if state == _ModState.PRE_FW: + mod_stats.local_peak = { + dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in mem_snapshot.items() + } + mod_stats.snapshots.setdefault(_ModState.PEAK_FW, []).append(mem_snapshot) + mod_stats.snapshots.setdefault(state, []).append(deepcopy(mem_snapshot)) + + def _post_fw_hook(self, module: nn.Module, inputs: Any, outputs: Any) -> None: + # This is installed as a post-fwd user hook with ``ModTracker``. Based on the following cases we + # set the state and capture the memory snapshot for the module. + # Case 1: This is called in backward, which means we are in the AC region. If this is the top most module + # in the AC region, we set the flag ``_in_ac`` to False. + # Case 2: This is called in forward so we calculate the output memory + # of the module and update its mod_stats. + mod_stats = self.memory_tracking[module] + if self._mod_tracker.is_bw: + state = _ModState.POST_FW_AC + if self._ac_mod is not None and self._ac_mod() is module: + self._ac_mod = None + self._in_ac = False + else: + state = _ModState.POST_FW + output_mem = self._track_inputs_or_outputs(outputs) + mod_stats.output_mem = output_mem + mod_stats.snapshots.setdefault(state, []).append(self.get_tracker_snapshot()) + + def _pre_bw_hook(self, module: nn.Module, args: Any) -> None: + # This is installed as a pre-bwd user hook with ``ModTracker``. We set the state and capture the + # snapshot for the module. We also initialize the ``local_peak`` and ``PEAK_BW`` snapshot for it. + # If the module is None, we skip the hook. + # This can happen since this installed inside a multi-grad hook on the module's output tensors + # and the module itself may not be alive during backward. + if module is None: + warnings.warn("Module is None. Skipping PRE_BW hook.", stacklevel=2) + return + mod_stats = self.memory_tracking[module] + mem_snapshot = self.get_tracker_snapshot() + mod_stats.local_peak = { + dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in mem_snapshot.items() + } + mod_stats.snapshots.setdefault(_ModState.PEAK_BW, []).append(mem_snapshot) + mod_stats.snapshots.setdefault(_ModState.PRE_BW, []).append( + deepcopy(mem_snapshot) + ) + + def _post_bw_hook(self, module: nn.Module, args: Any) -> None: + # This is installed as a post-bwd user hook with ``ModTracker``. We set the state and capture the + # snapshot for the module if it is not None. + # This can happen since this installed inside a multi-grad hook on the module's input tensors + # and the module itself may not be alive during backward. + if module is None: + warnings.warn("Module is None. Skipping POST_BW hook.", stacklevel=2) + return + mod_stats = self.memory_tracking[module] + mod_stats.snapshots.setdefault(_ModState.POST_BW, []).append( + self.get_tracker_snapshot() + ) + + def _track_optimizer_states( + self, reftype: _RefType, optimizer: optim.Optimizer + ) -> None: + for states in optimizer.state.values(): + for val in states.values(): + if isinstance(val, torch.Tensor): + self._update_and_maybe_create_winfos( + val, + reftype, + ) + + def _register_global_optimizer_hook(self) -> None: + # Register a hook on the optimizer step to track the optimizer states. + # The pre-hook is to set the flag ``_in_opt`` to True. The post-hook unsets the flag, + # and also tracks any optimizer states that are created during the optimizer step. + def _opt_step_pre_hook( + optimizer: optim.Optimizer, args: Any, kwargs: Any + ) -> None: + self._in_opt = True + + def _opt_step_post_hook( + optimizer: optim.Optimizer, args: Any, kwargs: Any + ) -> None: + self._track_optimizer_states(_MemRefType.OPT, optimizer) + self._in_opt = False + + self._optimizer_hook_handles = ( + register_optimizer_step_pre_hook(_opt_step_pre_hook), + register_optimizer_step_post_hook(_opt_step_post_hook), + ) + + def _deregister_param_and_optimizer_hooks(self) -> None: + for ( + grad_hook_handle, + post_acc_grad_hook_handle, + ) in self._param_to_grad_hook_handles.values(): + grad_hook_handle.remove() + post_acc_grad_hook_handle.remove() + self._param_to_grad_hook_handles.clear() + + if self._optimizer_hook_handles is not None: + for handle in self._optimizer_hook_handles: + handle.remove() + self._optimizer_hook_handles = None + + def track_external( + self, *external: nn.Module | optim.Optimizer | torch.Tensor + ) -> None: + """ + Track tensors and stateful objects like modules, optimizers etc. that are created outside the MemTracker. + + This method should be called before the ``MemTracker`` is used. Any tensors that are not module parameters, buffers, + gradients activations, or optimizer states will be categorized as ``Other``. If you want them categorized with a + custom name, please file a GitHub issue. Any tensors created outside the MemTracker and not supplied to this + method will not be be tracked by ``MemTracker``. + + Args: + *external (Union[nn.Module, optim.Optimizer, torch.Tensor]): The external modules, optimizers, and + tensors to be tracked. + """ + flat_external, _ = tree_flatten(external) + for obj in flat_external: + if isinstance(obj, torch.Tensor): + self._update_and_maybe_create_winfos( + obj, + _MemRefType.OTH, + ) + elif isinstance(obj, torch.nn.Module): + self._track_module_params_and_buffers(obj, install_grad_hooks=False) + elif isinstance(obj, optim.Optimizer): + self._track_optimizer_states(_MemRefType.OPT, obj) + elif obj is None: + continue + else: + raise TypeError( + f"Object of type {type(obj)} is not supported for tracking. " + f"Only stateful objects like modules, optimizers, and tensors are supported." + ) + + def display_snapshot( + self, type: str = "current", units: str = "B", tabulate: bool = False + ) -> None: + """ + Display the memory usage breakdown snapshot of the tracker based on the specified type and units. + + Keyword args: + type (str): The type of snapshot to display. Can be "current" for the current memory usage or "peak" for the + peak memory usage. Defaults to "current". + units (str): The units to use for displaying memory usage. Defaults to "B". Supports ["B", "KiB", "MiB", "GiB"]. + tabulate (bool): Whether to display the snapshot in a tabular format. Defaults to False. + """ + snapshot = self.get_tracker_snapshot(type) + if tabulate: + _print_snapshot_tabular(snapshot, units) + else: + _print_snapshot(snapshot, units) + + def display_modulewise_snapshots( + self, depth: int = 2, units: str = "B", tabulate: bool = False + ) -> None: + """ + Print per device memory breakdown snapshot for each module called within MemTracker. + + Snapshots are displayed for the states defined by ``_ModState``. + The module hierarchy is displayed up to the specified depth. + + Keyword Args: + depth (int, optional): The depth of the module hierarchy to display. Defaults to 2. + units (str, optional): The units to use for memory tracking. Defaults to "B". Supports ["B", "KiB", "MiB", "GiB"]. + tabulate (bool, optional): Whether to display the snapshot in a tabular format. Defaults to False. + """ + + def natural_sort_key(s: str) -> list[int | str]: + return [ + int(text) if text.isdigit() else text.lower() + for text in re.split("([0-9]+)", s) + ] + + for mod_stats in sorted( + self.memory_tracking.values(), + key=lambda m_stats: natural_sort_key(m_stats.mod_fqn), + ): + mod_fqn = mod_stats.mod_fqn + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(f"Module: {mod_fqn}") + if tabulate: + _print_state_snapshots_tabular(mod_stats.snapshots, units) + else: + _print_state_snapshots(mod_stats.snapshots, units) + + def reset_mod_stats(self) -> None: + """ + Reset all the module memory stats. Clears ``memory_tracking`` dictionary. + """ + self.memory_tracking.clear() + + def __enter__(self) -> "MemTracker": + if self._depth == 0: + self._register_global_optimizer_hook() + self._mod_tracker.register_user_hooks( + self._pre_fw_hook, + self._post_fw_hook, + self._pre_bw_hook, + self._post_bw_hook, + ) + self._track_resize() + self._peak_mem_snap = self.get_tracker_snapshot() + self._peak_mem = { + dev: dev_snap[_TOTAL_KEY] + for dev, dev_snap in self._peak_mem_snap.items() + } + self._mod_tracker.__enter__() + super().__enter__() + self._depth += 1 + return self + + # pyrefly: ignore [bad-override] + def __exit__(self, *args: Any) -> None: + self._depth -= 1 + if self._depth == 0: + self._deregister_param_and_optimizer_hooks() + self._mod_tracker.clear_user_hooks() + self._restore_resize() + self._mod_tracker.__exit__(*args) + super().__exit__(*args) + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): # type: ignore[no-untyped-def] + # When running this mode with DTensor, ordinarily all modes will + # run **before** subclasses get a chance to run. + # Returning NotImplemented here gives us a chance to let DTensor + # run and desugar into local tensor ops, before `MemTracker` sees them. + if any(t == DTensor for t in types): + return NotImplemented + if ( + func is torch.ops._c10d_functional.wait_tensor.default + and active_fake_mode() + ): + # N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns + # a new tensor which does not happen in eager mode, when a wait_tensor is called. + # pyrefly: ignore [index-error] + res = args[0] + else: + res = func(*args, **kwargs or {}) + # If we are tracking an optimizer state, we use the optimizer reference type. + # If we are in backward region and not in AC region, we use the backward reference type. + # Else we use the forward reference type. + if self._in_opt: + reftype = _MemRefType.OPT + elif self._mod_tracker.is_bw and not self._in_ac: + reftype = _MemRefType.TEMP + else: + reftype = _MemRefType.ACT + tree_map_only(torch.Tensor, partial(self._track, reftype), res) + peak_state = _ModState.PEAK_BW if self._mod_tracker.is_bw else _ModState.PEAK_FW + self._update_peak_stats(peak_state) + return res diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/memory_tracker.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/memory_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..890d2be2794a4e570085a91da1440842473c9f49 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/memory_tracker.py @@ -0,0 +1,304 @@ +# mypy: allow-untyped-defs +import operator +import pickle +from collections import defaultdict +from collections.abc import Callable, Sequence +from itertools import chain +from typing import Any, no_type_check, TYPE_CHECKING + +import torch +import torch.nn as nn +from torch.utils._python_dispatch import TorchDispatchMode + + +if TYPE_CHECKING: + from torch.utils.hooks import RemovableHandle + + +BYTES_PER_MB = 1024 * 1024.0 + + +class MemoryProfileDispatchMode(TorchDispatchMode): + """Run in ``TorchDispatchMode`` to get memory stats at operator level.""" + + def __init__(self, memory_tracker) -> None: + self.memory_tracker = memory_tracker + + def __torch_dispatch__(self, func, types, args=..., kwargs=None): + rs = func(*args, **kwargs) + if func is torch.ops.aten.detach.default: + return rs + func_name: str = ( + self.memory_tracker._cur_module_name + + "." + + func.__name__ + + "_" + + str(self.memory_tracker._operator_names[func.__name__]) + ) + self.memory_tracker._operator_names[func.__name__] = ( + self.memory_tracker._operator_names[func.__name__] + 1 + ) + self.memory_tracker._record_memory_stats(func_name) + + return rs + + +class MemoryTracker: + """ + Collect and plot the memory stats at operator level. + + Includes ``memories_allocated``, ``memories_active`` and ``memories_reserved``. + It also prints a summary for the top 20 operators that generate the most memories. + + Example usage: + + >>> # xdoctest: +SKIP(failing) + >>> net.cuda() + >>> input = input.cuda() + + >>> mem_tracker = MemoryTracker() + >>> mem_tracker.start_monitor(net) + + >>> net.zero_grad(True) + >>> loss = net(input) + >>> if isinstance(loss, dict): + >>> loss = loss['out'] + >>> loss.sum().backward() + >>> net.zero_grad(set_to_none=True) + + >>> mem_tracker.stop() + >>> mem_tracker.summary() + >>> mem_tracker.show_traces() + """ + + def __init__(self) -> None: + torch._C._log_api_usage_once("torch.distributed.memory_tracker") + self._hooks: list[RemovableHandle] = [] + self._operator_names: dict[str, int] = defaultdict(int) + self.memories_allocated: dict[int, dict[str, float]] = defaultdict() + self.memories_active: dict[int, dict[str, float]] = defaultdict() + self.memories_reserved: dict[int, dict[str, float]] = defaultdict() + self._markers: dict[str, int] = defaultdict(int) + self._cur_module_name: str = "" + self._op_index: int = 0 + self._num_alloc_retries: int = 0 + self._device_module = torch.get_device_module() + + @no_type_check + def start_monitor(self, root_module: nn.Module) -> None: + """ + Register module hooks and entering ``MemoryProfileDispatchMode``. + + This enables operator level memory stats can be tracked during module runtime. + """ + self._clear_state() + root_module.__setattr__("_memory_tracker_is_root", True) + for name, m in root_module.named_modules(): + if m is not root_module: + m.__setattr__("_memory_tracker_is_root", False) + # fused_proxy_group does not support hooks + if ".fused_proxy_grouped_embedding_bag" in name: + continue + # hook ordering with other hooks added by users is not managed, so + # the memory stats tracked here may not completely accurate. + h1 = m.register_forward_pre_hook(self._create_pre_forward_hook(name)) + h2 = m.register_forward_hook(self._create_post_forward_hook(name)) + # it does not work well with jagged tensor somehow, the root cause is not + # clear and remove it for now as it does not really capture important info. + # h3 = m.register_backward_hook(self._create_backward_hook(name)) + self._hooks.extend([h1, h2]) + self._device_module.empty_cache() + assert getattr(self, "profile_mode", None) is None + self.profile_mode = MemoryProfileDispatchMode(self) + self.profile_mode.__enter__() + + @no_type_check + def stop(self) -> None: + """ + Remove module hooks and exit ``MemoryProfileDispatchMode`` to stop tracking memory stats at operator level. + + Get some aggregated stats when the memory_tracker() is enabled, like ``num_alloc_retries``. + """ + self._num_alloc_retries = self._device_module.memory_stats().get( + "num_alloc_retries", 0 + ) + + for h in self._hooks: + h.remove() + self._hooks.clear() + assert getattr(self, "profile_mode", None) is not None + self.profile_mode.__exit__(None, None, None) + self.profile_mode = None + + @no_type_check + def summary(self, top: int = 20) -> None: + """ + Print out the top operators that generate the most memories. + + The number of the top operators can be configured. + """ + op_diff: dict[str, float] = defaultdict(float) + op_name, previous_allocated_memory = self.memories_allocated[0] + for i in range(1, self._op_index): + op_name, current_allocated_memory = self.memories_allocated[i] + op_diff[op_name] = current_allocated_memory - previous_allocated_memory + previous_allocated_memory = current_allocated_memory + + print("------------------------------------------------") + print(f"The number of alloc retries are: {self._num_alloc_retries}") + print(f"Top {top} ops that generates memory are:") + for k, v in sorted(op_diff.items(), key=operator.itemgetter(1), reverse=True)[ + :top + ]: + print(f"{k}: {v}MB") + print("------------------------------------------------") + + @no_type_check + def show_traces(self, path: str = "") -> None: + import matplotlib.pyplot as plt + + def _plot_figure(x, y_values, labels): + min_val = min(chain.from_iterable(y_values)) * 0.999 + max_val = max(chain.from_iterable(y_values)) * 1.001 + plt.figure() + for y, label in zip(y_values, labels): + plt.plot(x, y, label=label) + plt.xlabel("# Operator Calls") + plt.ylabel("Memory (MB)") + plt.legend() + for marker_name, marker in self._markers.items(): + if marker_name == "fw_bw_boundary": + plt.plot( + [marker, marker], + [min_val, max_val], + "r", + lw=2, + label=marker_name, + ) + else: + plt.plot( + [marker, marker], + [min_val, max_val], + "k-", + lw=2, + label=marker_name, + ) + + if path != "": + self.load(path) + + y_1 = [gb for (name, gb) in self.memories_allocated.values()] + y_2 = [gb for (name, gb) in self.memories_active.values()] + y_3 = [gb for (name, gb) in self.memories_reserved.values()] + x = list(range(len(y_1))) + # Split figures when there is big difference between + # "reserved_memory" and "allocated_memory" or "active_memory". + _plot_figure( + x, + [list(y_1), list(y_2), list(y_3)], + ["allocated_memory", "active_memory", "reserved_memory"], + ) + _plot_figure(x, [list(y_1)], ["allocated_memory"]) + _plot_figure(x, [list(y_2)], ["active_memory"]) + _plot_figure(x, [list(y_3)], ["reserved_memory"]) + + def save_stats(self, path: str) -> None: + """Save the stats using pickle during runtime if users want to plot the traces in other places like notebook.""" + stats = { + "memories_allocated": self.memories_allocated, + "memories_active": self.memories_active, + "memories_reserved": self.memories_reserved, + "markers": self._markers, + "num_alloc_retries": self._num_alloc_retries, + } + + with open(path, "wb") as f: + pickle.dump(stats, f, pickle.HIGHEST_PROTOCOL) + + def load(self, path: str) -> None: + """Load the pickled memory stats to plot the traces or print the summary.""" + with open(path, "rb") as f: + stats = pickle.load(f) + + self.memories_allocated = stats["memories_allocated"] + self.memories_active = stats["memories_active"] + self.memories_reserved = stats["memories_reserved"] + self._markers = stats["markers"] + self._num_alloc_retries = stats["num_alloc_retries"] + + def _create_pre_forward_hook(self, name: str) -> Callable: + """Prefix operator name with current module and 'forward', and insert 'fw_start' marker at forward pass start.""" + + def _pre_forward_hook(module: nn.Module, inputs: Any) -> None: + self._cur_module_name = f"{name}.forward" + if ( + # pyrefly: ignore [invalid-argument] + hasattr(module, "_memory_tracker_is_root") + # pyrefly: ignore [not-callable] + and module._memory_tracker_is_root + ): + self._add_marker("fw_start") + + return _pre_forward_hook + + def _create_post_forward_hook(self, name: str) -> Callable: + """Insert the marker 'fw_bw_boundary' at the boundary of forward and backward pass.""" + + def _post_forward_hook( + module: nn.Module, + inputs: Sequence[torch.Tensor], + outputs: Sequence[torch.Tensor], + ) -> None: + if ( + # pyrefly: ignore [invalid-argument] + hasattr(module, "_memory_tracker_is_root") + # pyrefly: ignore [not-callable] + and module._memory_tracker_is_root + ): + self._add_marker("fw_bw_boundary") + + return _post_forward_hook + + def _create_backward_hook(self, name: str) -> Callable: + """Insert the current module name with backward prefix for the operator name.""" + + def _backward_hook( + module: nn.Module, grad_input: torch.Tensor, grad_output: torch.Tensor + ) -> None: + self._cur_module_name = f"{name}.backward" + + return _backward_hook + + @no_type_check + def _record_memory_stats(self, fn_name: str) -> None: + """ + Record current memory allocated, current memory active and current memory reserved. + + The memory stats dict is indexed with ``self._op_index``. + """ + memory_allocated: float = self._device_module.memory_allocated() / BYTES_PER_MB + memory_reserved: float = self._device_module.memory_reserved() / BYTES_PER_MB + memory_active: float = ( + self._device_module.memory_stats().get("active_bytes.all.current", 0) + / BYTES_PER_MB + ) + self.memories_allocated[self._op_index] = (fn_name, memory_allocated) + self.memories_reserved[self._op_index] = (fn_name, memory_reserved) + self.memories_active[self._op_index] = (fn_name, memory_active) + self._op_index += 1 + + def _add_marker(self, marker_name: str) -> None: + """Set the marker's x-axis value.""" + marker_val = len(self.memories_allocated.values()) + self._markers[marker_name] = marker_val + + def _clear_state(self) -> None: + """Clear states when start_monitor() is called.""" + self._operator_names.clear() + self.memories_allocated.clear() + self.memories_active.clear() + self.memories_reserved.clear() + self._markers.clear() + self._cur_module_name = "" + self._op_index = 0 + self._num_alloc_retries = 0 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/mod_tracker.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/mod_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..bae745bcc58040dd14a1dfbd0c2f116554870689 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/mod_tracker.py @@ -0,0 +1,259 @@ +# mypy: allow-untyped-defs +import warnings +import weakref +from collections.abc import Callable + +import torch +from torch.autograd.graph import register_multi_grad_hook +from torch.nn.modules.module import ( + register_module_forward_hook, + register_module_forward_pre_hook, +) +from torch.utils._pytree import tree_flatten + + +__all__ = ["ModTracker"] + + +class ModTracker: + """ + ``ModTracker`` is a context manager that tracks the nn.Module hierarchy during execution + so that other system can query which Module is currently being executed (or its backward is being + executed). + + You can access the ``parents`` attribute on this context manager to get the set of all the + Modules currently being executed via their fqn (fully qualified name, also used as the key within + the state_dict). + You can access the ``is_bw`` attribute to know if you are currently running in backward or not. + + Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag + will remain ``True`` after the forward until another Module is executed. If you need it to be + more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance + is possible but not done yet, please submit an issue requesting this if you need it. + + Example usage + + .. code-block:: python + + mod = torch.nn.Linear(2, 2) + + with ModTracker() as tracker: + # Access anything during the forward pass + def my_linear(m1, m2, bias): + print(f"Current modules: {tracker.parents}") + return torch.mm(m1, m2.t()) + bias + + torch.nn.functional.linear = my_linear + + mod(torch.rand(2, 2)) + + """ + + parents: set[str] + """ + A Set containing the fqn for each module currently running their forward + """ + + def __init__(self): + self.parents = {"Global"} + self._active_module_cnt = {} + self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + self._seen_modules: weakref.WeakSet = weakref.WeakSet() + self._has_callback = False + self._post_bw_callbacks_to_enqueue: list[Callable] = [] + self._user_pre_fw_hook = None + self._user_post_fw_hook = None + self._user_pre_bw_hook = None + self._user_post_bw_hook = None + + def _maybe_set_engine_callback(self): + # This assumes no concurrent calls to backward + if self._has_callback: + return + + for post_bw_callback in reversed(self._post_bw_callbacks_to_enqueue): + torch.autograd.Variable._execution_engine.queue_callback(post_bw_callback) + self._post_bw_callbacks_to_enqueue.clear() + + def callback(): + self.parents = {"Global"} + self._has_callback = False + + torch.autograd.Variable._execution_engine.queue_callback(callback) + self._has_callback = True + + @property + def is_bw(self): + """ + A boolean marking if this is currently running during the backward pass or not + """ + return torch._C._current_graph_task_id() != -1 + + def get_known_fqn(self, mod): + """ + Return the fqn for the given module if it is known to the ``ModTracker``, otherwise ``None``. + """ + return self._known_modules.get(mod, None) + + def register_user_hooks( + self, + pre_fw_hook: Callable | None = None, + post_fw_hook: Callable | None = None, + pre_bw_hook: Callable | None = None, + post_bw_hook: Callable | None = None, + ): + """ + Registers user-specified hooks to be called before/after the forward/backward pass for each + module tracked by the ``ModTracker``. One or more can be ``None``. + Args: + pre_fw_hook (Callable, optional): A hook to be called before the forward pass for the + module. It should have the following signature: + pre_fw_hook (module, input) -> None + post_fw_hook (Callable, optional): A hook to be called after the forward pass for the + module. It should have the following signature: + post_fw_hook (module, input, output) -> None + pre_bw_hook (Callable, optional): A multi-grad hook to be called on all the outputs of + the module that require gradients. It should have the following signature: + pre_bw_hook (module, grad_output) -> None + post_bw_hook (Callable, optional): A multi-grad hook to be called on all the inputs of + the module that require gradients. It should have the following signature: + post_bw_hook (module, grad_input) -> None + Raises: + AssertionError: If a new hook is provided when one is already registered. + Note: + If the module is not alive during the backward pass, the pre_bw_hook and post_bw_hook will + will receive None as the module argument. + The module fqn will be present in the ``parents`` attribute when each of the hooks is called. + Hooks are intended to be used as markers only not to modify the inputs/outputs. + """ + + def set_hook(hook, user_hook, hook_name): + if hook is not None and user_hook is not None: + raise AssertionError( + f"Only one {hook_name} can be registered at a time" + f" Clear the existing hook by calling ``clear_user_hooks`` before registering a new one" + ) + return hook + + self._user_pre_fw_hook = set_hook( + pre_fw_hook, self._user_pre_fw_hook, "pre_fw_hook" + ) + self._user_post_fw_hook = set_hook( + post_fw_hook, self._user_post_fw_hook, "post_fw_hook" + ) + self._user_pre_bw_hook = set_hook( + pre_bw_hook, self._user_pre_bw_hook, "pre_bw_hook" + ) + self._user_post_bw_hook = set_hook( + post_bw_hook, self._user_post_bw_hook, "post_bw_hook" + ) + + def clear_user_hooks(self): + """ + Clears the user specified hooks registered with ``register_user_hooks`` + """ + self._user_pre_fw_hook = None + self._user_post_fw_hook = None + self._user_pre_bw_hook = None + self._user_post_bw_hook = None + + def _get_mod_name(self, mod): + if mod not in self._known_modules: + self._known_modules[mod] = type(mod).__name__ + mod_name = self._known_modules[mod] + if mod not in self._seen_modules: + for name, submod in mod.named_children(): + self._known_modules[submod] = f"{mod_name}.{name}" + self._get_mod_name(submod) + self._seen_modules.add(mod) + return mod_name + + def _get_append_fn(self, w_mod, name, is_bw): + def fn(*args): + if is_bw: + self._maybe_set_engine_callback() + if name in self.parents and not self.is_bw: + + def custom_formatwarning(msg, category, filename, lineno, line=None): + return f"{filename}:{lineno}: {category.__name__}: {msg} \n" + + # pyrefly: ignore [bad-assignment] + warnings.formatwarning = custom_formatwarning + warnings.warn( + "The module hierarchy tracking maybe be messed up." + " Please file a bug to PyTorch, if it is the case.", + stacklevel=2, + ) + if name not in self.parents: + self._active_module_cnt[name] = 1 + self.parents.add(name) + else: + self._active_module_cnt[name] += 1 + + if self._user_pre_bw_hook is not None and is_bw: + self._user_pre_bw_hook(w_mod(), args) + + return fn + + def _get_pop_fn(self, w_mod, name, is_bw): + def fn(*args): + if self._user_post_bw_hook is not None and is_bw: + self._user_post_bw_hook(w_mod(), args) + if name in self.parents: + self._active_module_cnt[name] -= 1 + if self._active_module_cnt[name] == 0: + self.parents.remove(name) + elif not self.is_bw: + # Due to some input/output not requiring gradients, we cannot enforce + # proper nesting in backward + raise RuntimeError( + "The Module hierarchy tracking is wrong. Report a bug to PyTorch" + ) + + return fn + + def _fw_pre_hook(self, mod, input): + if torch._dynamo.eval_frame._is_in_optimized_module(): + return + + name = self._get_mod_name(mod) + w_mod = weakref.ref(mod) + self._get_append_fn(w_mod, name, False)() + if self._user_pre_fw_hook is not None: + self._user_pre_fw_hook(mod, input) + args, _ = tree_flatten(input) + tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] + if not self.is_bw: + if tensors: + register_multi_grad_hook(tensors, self._get_pop_fn(w_mod, name, True)) + else: + self._post_bw_callbacks_to_enqueue.append( + self._get_pop_fn(w_mod, name, True) + ) + + def _fw_post_hook(self, mod, input, output): + if torch._dynamo.eval_frame._is_in_optimized_module(): + return + + name = self._get_mod_name(mod) + w_mod = weakref.ref(mod) + if self._user_post_fw_hook is not None: + self._user_post_fw_hook(mod, input, output) + self._get_pop_fn(w_mod, name, False)() + args, _ = tree_flatten(output) + tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] + if not self.is_bw and tensors: + register_multi_grad_hook( + tensors, self._get_append_fn(w_mod, name, True), mode="any" + ) + + def __enter__(self): + self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook) + self._fw_post_handle = register_module_forward_hook( + self._fw_post_hook, always_call=True + ) + return self + + def __exit__(self, *args): + self._fw_pre_handle.remove() + self._fw_post_handle.remove() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/runtime_estimator.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/runtime_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..caf399cf6a802677f754084d2d867a743036520f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/runtime_estimator.py @@ -0,0 +1,398 @@ +# Owner(s): ["module: unknown"] +from collections import defaultdict +from typing import Any, TYPE_CHECKING +from typing_extensions import Self + +import torch +import torch.utils._pytree as pytree +from torch._guards import active_fake_mode +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed._tools.mod_tracker import ModTracker +from torch.utils._mode_utils import no_dispatch +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._runtime_estimation import ( + _FLOAT_TYPES, + _IGNORE_OPS, + _VIEW_OPS, + get_compute_time, + get_transfer_time, +) + + +if TYPE_CHECKING: + from collections.abc import Callable + +__all__ = ["RuntimeEstimator"] + + +class RuntimeEstimator(TorchDispatchMode): + """ + Estimates the GPU runtime in milliseconds using various estimation methods under the ``FakeTensorMode``. + + This class provides a ``TorchDispatchMode`` based context manager that can be used to estimate the eager + runtime of PyTorch functions. It supports two estimation modes, benchmarking (`operator-level-benchmark`) and + roofline cost modeling (`operator-level-cost-model`). + For modules executed under this context manager, it aggregates the forward and backward operation runtimes + and also records their execution orders. + + Attributes: + mod_runtimes (Dict[str, Dict[str, float]]): A dictionary of module runtimes. The key to the outer dictionary + is the fully qualified name (FQN) of the module. For each module the forward and backward runtimes of the + operations are aggregated in the inner dictionary keyed by 'fw' and 'bw'. + mod_fw_pre_order (List[str]): List of module FQNs in pre-forward execution order. + mod_bw_pre_order (List[str]): List of module FQNs in pre-backward execution order. + mod_fw_post_order (List[str]): List of module FQNs in post-forward execution order. + mod_bw_post_order (List[str]): List of module FQNs in post-backward execution order. + total_runtime (float): The total estimated runtime in milliseconds. + + Note: + 1) The benchmarking estimate mode will execute kernels on GPU and assumes that every operation can run in + isolation without causing an OOM error. It is also designed to be used only under ``FakeTensorMode``. + 2) Currently wrapper tensor sub-classes such as ``DTensor`` won't produce correct estimates. We plan to support + them in future PRs. + 3) We only estimate the compute time, if your code has communication, it will not be considered. Again, we will + support this in future PRs. + + Example usage: + + .. code-block:: python + + runtime_estimator = RuntimeEstimator() + with FakeTensorMode(): + module = ... + optimizer = ... + inp = ... + with runtime_estimator(estimate_mode_type="operator-level-cost-model"): + loss = module(inp) + loss.backward() + optimizer.step() + optimizer.zero_grad() + runtime_estimator.display_modulewise_stats() + """ + + _no_fallback_kernel: set[torch._ops._OpNamespace] = set() + fake_mode: FakeTensorMode + + def __init__(self) -> None: + super().__init__() + self._estimate: Callable + self._estimate_mode_type: str + self._mod_tracker = ModTracker() + self.mod_runtimes: dict[str, dict[str, float]] = defaultdict( + lambda: defaultdict(lambda: 0.0) + ) + self.mod_fw_pre_order: list[str] = [] + self.mod_bw_pre_order: list[str] = [] + self.mod_fw_post_order: list[str] = [] + self.mod_bw_post_order: list[str] = [] + self.total_runtime: float = 0.0 + + # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_subclasses/fake_tensor.py#L1969 # noqa: PGH004,B950 + # NB: returns fake tensors + @classmethod + def _maybe_run_and_benchmark_fallback_kernel( # type: ignore[no-untyped-def] + cls, + func, + args, + kwargs, + orig_not_implemented_exception, + ): + """ + Runs and benchmarks a fallback kernel for a given function. + + Args: + func (Callable): The function to benchmark. + args (Tuple): The arguments to pass to the function. + kwargs (Dict[str, Any]): The keyword arguments to pass to the function. + orig_not_implemented_exception (Exception): The original exception to raise if the fallback kernel + is not implemented. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + # these should all be supported, just to be safe + # avoid fallback for operators which inplace modify metadata + # because the input fake tensors would be umodified + if torch.Tag.inplace_view in func.tags: # type: ignore[attr-defined] + raise orig_not_implemented_exception + + inp_impls = {} + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + # Don't use in_kernel_invocation_manager(fake_mode) as we want to do + # REAL compute (not with meta device) + with no_dispatch(): + + def to_real_tensor(e): # type: ignore[no-untyped-def] + if cls.fake_mode.is_our_fake(e): + if e.dtype in _FLOAT_TYPES: + out = torch.rand_like(e, device=e.fake_device) + else: + out = torch.ones_like(e, device=e.fake_device) + if e.is_sparse: + out._coalesced_(e.is_coalesced()) + inp_impls[id(out)] = e + return out + return e + + flat_args = [to_real_tensor(a) for a in flat_args] + args, kwargs = pytree.tree_unflatten(flat_args, args_spec) + r = func(*args, **kwargs) + warmup_iters, actual_iters = 2, 3 + for _ in range(warmup_iters): + func(*args, **kwargs) + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record(torch.cuda.current_stream()) + for _ in range(actual_iters): + func(*args, **kwargs) + end_event.record(torch.cuda.current_stream()) + torch.cuda.synchronize() + cuda_time = start_event.elapsed_time(end_event) + mean_op_time = cuda_time / actual_iters + + storages = set() + + for e in flat_args: + if isinstance(e, torch.Tensor): + if not e.is_sparse: + storages.add(e._typed_storage()._cdata) + + # TODO: also check metadata change on inputs + # proper aliasing/metadata relationship between outputs and inputs will + # not be set up, bc of conversion to device, unless we can reuse an + # input impl + + def map_out(e): # type: ignore[no-untyped-def] + if id(e) not in inp_impls and ( + isinstance(e, torch.Tensor) + and not e.is_sparse + and e._typed_storage()._cdata in storages + ): + raise orig_not_implemented_exception + + if isinstance(e, torch.Tensor): + if id(e) in inp_impls: + return inp_impls[id(e)] + else: + return cls.fake_mode.fake_tensor_converter.from_real_tensor( + cls.fake_mode, e + ) + else: + return e + + return (pytree.tree_map(map_out, r), mean_op_time) + + @classmethod + def _benchmark_estimate(cls, func, args, kwargs) -> tuple[Any, float]: # type: ignore[no-untyped-def] + """ + Estimates the runtime of a function using benchmarking. + + Args: + func: The function to estimate. + args: The arguments to pass to the function. + kwargs: The keyword arguments to pass to the function. + res: The result of the function. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + assert isinstance(cls.fake_mode, FakeTensorMode), ( + "Initialize/Assign FakeTensorMode before using this function" + ) + mean_op_time = 0.0 + if func._overloadpacket not in _VIEW_OPS: + try: + res, mean_op_time = cls._maybe_run_and_benchmark_fallback_kernel( + func, + args, + kwargs, + NotImplementedError, + ) + return (res, mean_op_time) + except NotImplementedError: + cls._no_fallback_kernel.add(func._overloadpacket) + res = func(*args, **kwargs or {}) + return (res, mean_op_time) + + # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_inductor/scheduler.py#L589 # noqa: PGH004,B950 + @classmethod + def _roofline_estimate(cls, func, args, kwargs) -> tuple[Any, float]: # type: ignore[no-untyped-def] + """ + Estimates the runtime of a function using a roofline cost model. + + Args: + func: The function to estimate. + args: The arguments to pass to the function. + kwargs: The keyword arguments to pass to the function. + out: The output of the function. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + assert torch.cuda.is_available(), ( + "Roofline estimation needs to access CUDA capabilities to make estimations" + ) + + # Roofline Cost Model Explanation + + # The roofline cost model estimates the execution time of an operator based on + # the device's empirical maximum FLOPs/sec (pi) and device DRAM bandwidth (beta). + + # Variables: + # - pi: Maximum empirical FLOPs/sec of the device + # - beta: Maximum empirical device DRAM bandwidth (bytes/sec) of the device + # - I: Arithmetic intensity of the operator (FLOPs/bytes) + # - op_flops: FLOPs required by the operator + # - op_bytes: Bytes transferred to and from DRAM for the operator + + # Calculation Steps: + # 1. Calculate arithmetic intensity: I = op_flops / op_bytes + # 2. Calculate estimated FLOPs/sec: est_flops_sec = min(pi, beta * I) + # 3. Calculate estimated operator time: estimated_op_time = op_flops / est_flops_sec + # This simplifies to: estimated_op_time = max(op_flops / pi, op_flops / (beta * I)) + # Further simplifying: estimated_op_time = max(op_flops / pi, op_bytes / beta) + + # Simplified Formulas: + # - compute_time = op_flops / pi + # - transfer_time = op_bytes / beta + # - estimated_op_time = max(compute_time, transfer_time) + + kwargs = kwargs if kwargs else {} + out = func(*args, **kwargs) + op_time = 0.0 + func_packet = func._overloadpacket + if func_packet not in _IGNORE_OPS: + flat_args_kwargs, args_spec = pytree.tree_flatten((args, kwargs)) + flat_outs, out_spec = pytree.tree_flatten(out) + transfer_time = get_transfer_time(flat_args_kwargs, flat_outs) + + out_dtypes = { + t.dtype + for t in flat_outs + if isinstance(t, torch.Tensor) and t.dtype in _FLOAT_TYPES + } + + args, kwargs = pytree.tree_unflatten(flat_args_kwargs, args_spec) + out = pytree.tree_unflatten(flat_outs, out_spec) + + compute_time = get_compute_time(func_packet, args, kwargs, out, out_dtypes) + # We get the estimated time as the max of the transfer time and + # compute time. We divide by 1e6 to get the time in ms + op_time = max(transfer_time, compute_time) / 1e6 + + return (out, op_time) + + def display_modulewise_stats(self, depth: int = 2) -> None: + """ + Displays module-wise statistics collected by ``RuntimeEstimator``. + + Prints the pre-forward and pre-backward execution orders. + Displays the module-wise forward and backward runtimes in milliseconds. + + Args: + depth (int): The maximum depth of module hierarchy to display (default to 2). + """ + print("Pre-Forward Execution Order: ") + for mod_fqn in self.mod_fw_pre_order: + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(mod_fqn) + print("Pre-Backward Execution Order: ") + for mod_fqn in self.mod_bw_pre_order: + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(mod_fqn) + for mod_fqn, runtimes in self.mod_runtimes.items(): + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print( + f"{mod_fqn} fw: {runtimes.get('fw', 0.0):.3f}ms bw: {runtimes.get('bw', 0.0):.3f}ms" + ) + + def __torch_dispatch__(self, func, types, args=..., kwargs=None): # type: ignore[no-untyped-def] + # TODO: @sanketpurandare: Flatten tensors by desugaring the tensor subclasses + # TODO: @sanketpurandare: Add logic for incorporating communication time + res, op_time = self._estimate(func, args, kwargs) + for par in self._mod_tracker.parents: + if self._mod_tracker.is_bw: + self.mod_runtimes[par]["bw"] += op_time + else: + self.mod_runtimes[par]["fw"] += op_time + self.total_runtime += op_time + return res + + def __call__(self, estimate_mode_type: str) -> Self: + """ + Sets the estimate mode type. + + Currently supported modes: + - "operator-level-benchmark": Estimates runtime using operator benchmarking. + - "operator-level-cost-model": Estimates runtime using roofline cost model. + + Args: + estimate_mode_type (str): The type of estimate mode to use. + + Returns: + RuntimeEstimator: The runtime estimator instance. + + Raises: + NotImplementedError: If the estimate mode type is not supported. + """ + if estimate_mode_type == "operator-level-benchmark": + self._estimate = RuntimeEstimator._benchmark_estimate + elif estimate_mode_type == "operator-level-cost-model": + self._estimate = RuntimeEstimator._roofline_estimate + else: + raise NotImplementedError( + f"estimate_mode_type {estimate_mode_type} not supported" + ) + self._estimate_mode_type = estimate_mode_type + return self + + def __enter__(self) -> Self: + fake_mode = active_fake_mode() + assert isinstance(fake_mode, FakeTensorMode), ( + "No FakeTensorMode found, designed to used under FakeTensorMode" + ) + RuntimeEstimator.fake_mode = fake_mode + self.total_runtime = 0.0 + self.mod_runtimes = defaultdict(lambda: defaultdict(lambda: 0.0)) + self.mod_fw_pre_order.clear() + self.mod_bw_pre_order.clear() + self.mod_fw_post_order.clear() + self.mod_bw_post_order.clear() + self._mod_tracker.register_user_hooks( + pre_fw_hook=lambda mod, inp: self.mod_fw_pre_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + pre_bw_hook=lambda mod, g_out: self.mod_bw_pre_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + post_fw_hook=lambda mod, inp, out: self.mod_fw_post_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + post_bw_hook=lambda mod, g_inp: self.mod_bw_post_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + ) + self._mod_tracker.__enter__() + super().__enter__() + return self + + # pyrefly: ignore [bad-override] + def __exit__(self, *args: Any) -> None: + print( + f"Estimated ({self._estimate_mode_type})" + f"total_time: {self.total_runtime:.3f} ms" + ) + if len(self._no_fallback_kernel) > 0: + print("no_fallback_kernel: ", list(self._no_fallback_kernel)) + super().__exit__(*args) + self._mod_tracker.clear_user_hooks() + self._mod_tracker.__exit__() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/sac_estimator.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/sac_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..c43de8c2b916742cf131b1761e801b41fc689ba6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/sac_estimator.py @@ -0,0 +1,961 @@ +import math +import os +import sys +from collections import OrderedDict +from dataclasses import astuple, dataclass +from typing import Any, NamedTuple +from typing_extensions import Self + +import torch +from torch import nan, nn, UntypedStorage +from torch._guards import active_fake_mode +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed._tools.common_utils import get_untyped_storages +from torch.distributed._tools.mod_tracker import ModTracker +from torch.distributed._tools.runtime_estimator import RuntimeEstimator +from torch.testing._internal.composite_compliance import ( + is_inplace, + is_inplace_view_fn, + is_view_fn, +) +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_flatten +from torch.utils.checkpoint import SAC_IGNORED_OPS + + +__all__ = ["SACEstimator", "SACStats", "MSPS", "SACTradeOffStats", "SACGreedyOrderMeta"] +aten = torch.ops.aten + +_ADDITIONAL_IGNORED_OPS = { + aten.lift_fresh.default, # type: ignore[attr-defined] + torch.ops.profiler._record_function_exit._RecordFunction, # type: ignore[attr-defined] + aten.clone.default, # type: ignore[attr-defined] # seems needed for torch.compile +} +OPS_TO_ALWAYS_SKIP = SAC_IGNORED_OPS | _ADDITIONAL_IGNORED_OPS +# This value is hard-coded here: +# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 +_PYTORCH_MIN_ALLOCATE = ( + 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 +) + + +def _display_stats_tabular(headers: list[str], table_data: list[list[Any]]) -> None: + try: + from tabulate import tabulate + except ImportError as err: + raise ImportError("Please install tabulate.") from err + + # Use tabulate to print the table + print(tabulate(table_data, headers=headers, tablefmt="rst")) + + +# Based on: +# https://github.com/facebookresearch/xformers/blob/main/xformers/checkpoint.py#L71 +@dataclass +class _SACMetadata: + """ + Stores metadata for a single operator for SAC. + + Attributes: + func (Any): The operator function. + time_taken (float): The time taken by the operator. + memory_used (float): The memory used by the operator. + curr_idx (int): The current operator index. + output_ids (Tuple[int, ...]): The storage IDs of the operator's outputs. + inplace_info (Tuple[int, ...]): Tuple of self and parent operator for in-place operator. + is_view_like (bool): Whether the operator is view-like. + is_rand_op (bool): Whether the operator is a random operator. + """ + + func: Any + time_taken: float + memory_used: float + curr_idx: int + output_ids: tuple[int, ...] + inplace_info: tuple[int, ...] + is_view_like: bool + is_rand_op: bool + + +@dataclass +class _SACModMetadata: + """ + Stores metadata for a module for SAC. + + Attributes: + start_idx (int): The starting index of the module's operators. + force_store_random (bool): Whether to force store random operators in the module. + sac_metadata (List[_SACMetadata]): List of metadata for each operator in the module. + """ + + start_idx: int + force_store_random: bool + sac_metadata: list[_SACMetadata] + + +@dataclass +class SACStats: + """ + A class for storing Activation Checkpointing statistics corresponding to a module. + + Attributes: + func_names (List[str]): List of operator names. + runtimes (List[float]): List of operator runtimes in millliseconds. + memory (List[int]): List of operator memory usage in bytes. + view_like_ops (List[int]): Indices of view-like operators. + rand_ops (List[int]): Indices of random operators. + saved_autograd_ops (List[int]): Indices of operator results saved by autograd engine. + inplace_ops (List[Tuple[int, int]]): Tuple of indices of op and its first parent for Inplace operators. + force_store_random (bool): Whether to force store random operator results. + """ + + func_names: list[str] + runtimes: list[float] + memory: list[int] + view_like_ops: list[int] + rand_ops: list[int] + saved_autograd_ops: list[int] + inplace_ops: list[tuple[int, int]] + force_store_random: bool + + +class MSPS(NamedTuple): + """ + Represents Memory and Runtime Statistics for an operator/operator group. + + Attributes: + func_names (set[str]): Set of operator/operator group names. + op_idx (int): Operator index (group head index in case of operator groups). + memory (int): Memory usage in bytes. + runtime (float): Runtime in milliseconds. + msps (float): Memory per second calculated as memory/runtime. + """ + + func_names: set[str] + op_idx: int + memory: int + runtime: float + msps: float + + +@dataclass +class SACTradeOffStats: + """ + Stores statistics for activation-checkpointing trade-off. + + Attributes: + n_segments (int): Number of piecewise linear segments fitted to the trade-off curve. + slopes (List[float]): Slopes of the pieces of linear segments fitted to the trade-off curve. + intercepts (List[float]): Intercepts of the of the pieces of linear segments fitted to the trade-off curve. + fit_breaks (List[float]): Breakpoints of the of the pieces of linear segments fitted to the trade-off curve. + tradeoff_curve (OrderedDict[float, float]): Trade-off curve data of memory discarded vs recomputation time. + sac_memory (int): Total memory of operations available for activation checkpointing in bytes. + sac_runtime (float): Total runtime of operations available for activation checkpointing in milliseconds. + """ + + n_segments: int + slopes: list[float] + intercepts: list[float] + fit_breaks: list[float] + tradeoff_curve: OrderedDict[float, float] + sac_memory: int + sac_runtime: float + + +@dataclass +class SACGreedyOrderMeta: + """ + Stores metadata for Greedy-order SAC. + + Attributes: + recomputed_ops (set[int]): Set of operator indices to be recomputed. + stored_ops (set[int]): Set of operator indices to be stored. + inplace_op_groups (dict[int, set[int]]): Dictionary of inplace operator groups from group-head to operators. + random_ops_group (dict[int, set[int]]): Dictionary of random op group head to random ops. + msps_meta (list[MSPS]): List of Memory and Runtime Statistics for operators. + """ + + recomputed_ops: set[int] + stored_ops: set[int] + inplace_op_groups: dict[int, set[int]] + random_ops_group: dict[int, set[int]] + msps_meta: list[MSPS] + + +class SACEstimator(TorchDispatchMode): + """ + Estimates the memory and recomputation time trade-offs for applying Selective Activation Checkpointing (SAC). + + This class provides a ``TorchDispatchMode`` based context manager that can be used to estimate the memory and + runtime trade-offs of functions or ``torch.nn.Module``s for Selective Activation Checkpointing (SAC). It provides + detailed statistics and metadata information for operators of each module and provides a greedy order for selecting + the operators to be recomputed/checkpointed. It also constructs the per-module trade-off graph of discarded memory + vs recomputation time for the obtained greedy order. Using ``RuntimeEstimator`` under the hood, it supports two + estimation modes, `operator-level-benchmark` and (`operator-level-cost-model` (roofline model). + + Attributes: + sac_mod_stats (Dict[str, SACStats]): Dictionary from module FQN (fully qualified name) to ``SACStats``. + sac_mod_tradeoff_stats (Dict[str, SACTradeOffStats]): Dictionary from module FQN to ``SACTradeOffStats``. + sac_mod_greedy_order_meta (Dict[str, SACGreedyOrderMeta]): Dictionary from module FQN to ``SACGreedyOrderMeta``. + + Note: + 1) This class is designed to be used under ``FakeTensorMode``. + 2) Currently, it only supports estimation of compute time and memory usage, and does not consider communication. + + Example usage: + + .. code-block:: python + + sac_estimator = SACEstimator() + with FakeTensorMode(): + module = ... + inp = ... + with sac_estimator("operator-level-cost-model"): + output = module(inp) + sac_estimator.display_modulewise_sac_stats(depth=4, print_tabular=True) + """ + + def __init__(self) -> None: + self.sac_mod_stats: dict[str, SACStats] = {} + self.sac_mod_tradeoff_stats: dict[str, SACTradeOffStats] = {} + self.sac_mod_greedy_order_meta: dict[str, SACGreedyOrderMeta] = {} + self._mod_tracker = ModTracker() + self._sac_metadata: list[_SACMetadata] = [] + self._sac_mod_metadata: dict[str, _SACModMetadata] = {} + self._leaf_modules: set[str] = set() + self._saved_tensor_hook_ctx = torch.autograd.graph.saved_tensors_hooks( + self._pack_hook, lambda x: x + ) + self._saved_tensor_ids: set[int] = set() + self._estimate_runtime = RuntimeEstimator._roofline_estimate + + def _pack_hook(self, x: torch.Tensor) -> torch.Tensor: + # Hook function to track underlying storage IDs of tensors + # Updates the _saved_tensor_ids set with the IDs of the tensor's storages + # Used in conjunction with torch.autograd.graph.saved_tensors_hooks + untyped_storages = get_untyped_storages(x) + storage_ids = (hash(st) for st in untyped_storages) + self._saved_tensor_ids.update(storage_ids) + return x + + def _pre_fw_hook(self, mod: nn.Module, inputs: Any) -> None: + # Pre-forward hook function to prepare module metadata + # Tracks module FQN, force store random flag, and ``SACModMetadata`` + # Initializes metadata for non-leaf modules, marks leaf modules + mod_fqn = self._mod_tracker.get_known_fqn(mod) + assert mod_fqn is not None + num_children = sum(1 for _ in mod.children()) + if num_children > 0: + force_store_random = self._get_force_store_random(inputs) + self._sac_mod_metadata[mod_fqn] = _SACModMetadata( + start_idx=len(self._sac_metadata), + force_store_random=force_store_random, + sac_metadata=[], + ) + else: + self._leaf_modules.add(mod_fqn) + + def _post_fw_hook(self, mod: nn.Module, inputs: Any, outputs: Any) -> None: + # 1. Retrieves the module's FQN and checks if it's a leaf module + # 2. If not a leaf module, computes: + # - ``SACStats`` using the module's metadata and force store random flag + # - ``SACGreedyOrderMeta`` using the computed SAC statistics + mod_fqn = self._mod_tracker.get_known_fqn(mod) + assert mod_fqn is not None + if mod_fqn in self._leaf_modules: + return + else: + self.sac_mod_stats[mod_fqn] = self._get_sac_stats( + data=self._sac_mod_metadata[mod_fqn].sac_metadata, + force_store_random=self._sac_mod_metadata[mod_fqn].force_store_random, + ) + self.sac_mod_greedy_order_meta[mod_fqn] = self._get_greedy_order_meta( + self.sac_mod_stats[mod_fqn] + ) + + def _get_force_store_random(self, inputs: Any) -> bool: + flat_inputs, _ = tree_flatten(inputs) + return all(not isinstance(x, torch.Tensor) for x in flat_inputs) + + def _get_sac_stats( + self, data: list[_SACMetadata], force_store_random: bool + ) -> SACStats: + # 1. Ignore the operations that should be skipped by SAC such as aten.detach.default because autograd + # inserts those during backward and it breaks the fwd-bwd alignment + filtered_data = [x for x in data if x.func not in OPS_TO_ALWAYS_SKIP] + + ( + ops, + runtimes_, + memory_, + new_ids, + output_ids, + inplace_ops_, + view_like_ops_, + rand_ops_, + ) = zip(*[astuple(x) for x in filtered_data], strict=True) + + # 2. Extract the metadata information + runtimes = list(runtimes_) + memory = list(memory_) + func_names = [op._overloadpacket.__name__ for op in ops] + view_like_ops = [i for i, x in enumerate(view_like_ops_) if x] + rand_ops = [i for i, x in enumerate(rand_ops_) if x] + saved_autograd_ops = [ + i + for i, out_ids in enumerate(output_ids) + if set(out_ids).issubset(self._saved_tensor_ids) + ] + + # 3. Remap the inplace indices as we have removed OPS_TO_ALWAYS_SKIP + # FIXME @sanketpurandare: Fix this by changing the parent of the inplace-op + # to itself if the original parent is in OPS_TO_ALWAYS_SKIP. + try: + inplace_ops = [tuple(map(new_ids.index, x)) for x in inplace_ops_ if x] + except ValueError as err: + raise ValueError( + f"The remapping of inplace ops failed since one of the inplace op parents" + f" must have been present in {OPS_TO_ALWAYS_SKIP}" + ) from err + + # 4. The last operation is always stored as the output of the checkpoint + # block, so we can avoid recomputing it. We set the memory to zero + # instead of adding a new constraint because we want both the 0 and 1 + # endpoints for memory_budget to be valid + # FIXME @sanketpurandare: this heuristic for finding the last non-view non-inplace op + # might not always be correct, which would yield suboptimal policies + last_op = len(ops) - 1 + skip_ops_ = set(view_like_ops) | set({x[0] for x in inplace_ops}) + reversed_skip_ops = sorted(skip_ops_, reverse=True) + for op in reversed_skip_ops: + if op == last_op: + last_op -= 1 + + memory[last_op] = 0 + + # 5. Create a single ``SACStats`` object for the entire block of ``_SACMetadata``. + return SACStats( + func_names=func_names, + runtimes=runtimes, + memory=memory, + view_like_ops=view_like_ops, + rand_ops=rand_ops, + saved_autograd_ops=saved_autograd_ops, + inplace_ops=inplace_ops, # type: ignore[arg-type] + force_store_random=force_store_random, + ) + + def _get_inplace_metadata( + self, func: Any, out_storages: set[UntypedStorage] + ) -> tuple[int, tuple[int, ...], dict[str, tuple[int, ...]]]: + # 1. Get the current index of the metadata obtained so far + curr_idx = len(self._sac_metadata) + # 2. Get the set of active modules that are not leaf + active_mod_fqns: set[str] = { + par for par in self._mod_tracker.parents if par not in self._leaf_modules + } + # 3. Output ids are the identifies of the storage objects corresponding to the tensors + output_ids = tuple(hash(st) for st in out_storages) + # 4. If the function is not inplace, return + if not is_inplace(func): + return curr_idx, output_ids, dict.fromkeys(active_mod_fqns, ()) + + op_idx = curr_idx + # 5. Initialize the parent op ids of the inplace op for each of the active modules + mod_op_parent_idxs: dict[str, int] = dict.fromkeys(active_mod_fqns, -1) + for i, d in enumerate(self._sac_metadata): + # 6. Find the first occurrence of a tensor corresponding to each module that + # shares the same storage as the current tensor + past_output_ids = d.output_ids + if set(output_ids).issubset(set(past_output_ids)): + for mod_fqn, op_parent_idx in mod_op_parent_idxs.items(): + if op_parent_idx == -1: + if acm_stats := self._sac_mod_metadata.get(mod_fqn, None): + if i >= acm_stats.start_idx: + mod_op_parent_idxs[mod_fqn] = i + else: + assert mod_fqn == "Global" + mod_op_parent_idxs[mod_fqn] = i + # 7. If no parent tensor is found, then it's probably an inplace op on the arguments + # so one can just store the current-op idx as parent idx + for mod_fqn, op_parent_idx in mod_op_parent_idxs.items(): + if op_parent_idx < 0: + mod_op_parent_idxs[mod_fqn] = op_idx + mod_inplace_info = { + mod_fqn: (op_idx, mod_op_parent_idxs[mod_fqn]) + for mod_fqn in active_mod_fqns + } + return curr_idx, output_ids, mod_inplace_info # type: ignore[return-value] + + def __torch_dispatch__( # type: ignore[no-untyped-def] + self, func, types, args=..., kwargs=None + ): + # 1. Get the runtime estimate + out, op_time = self._estimate_runtime(func, args, kwargs) + flat_outs, _ = tree_flatten(out) + out_storages_cuda: set[UntypedStorage] = set() + out_storages_cpu: set[UntypedStorage] = set() + cuda_devices: set[torch.device] = set() + for o in flat_outs: + if isinstance(o, torch.Tensor): + if o.device.type == "cuda": + out_storages_cuda.update(get_untyped_storages(o)) + cuda_devices.add(o.device) + else: + out_storages_cpu.update(get_untyped_storages(o)) + + # Check if there's more than 1 CUDA device + assert len(cuda_devices) <= 1, ( + f"{func.__name__}'s output has more than 1 CUDA devices {cuda_devices}" + ) + + # 2. Get the memory consumed by output + nbytes_cuda = sum( + math.ceil(st.nbytes() / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE + for st in out_storages_cuda + ) + nbytes_cpu = sum(st.nbytes() for st in out_storages_cpu) + nbytes = nbytes_cuda + nbytes_cpu + # 3. Get the current operator index, output storage identifiers and inplace metadata + out_storages = out_storages_cuda | out_storages_cpu + curr_idx, output_ids, mod_inplace_info = self._get_inplace_metadata( + func, out_storages + ) + # 4. Determine if the function is in-place, random-op or a view-like + is_view_like = is_view_fn(func) or is_inplace_view_fn(func) + is_rand_op = torch.Tag.nondeterministic_seeded in func.tags + if is_view_like: + nbytes = 0 + # sdpa has non-deterministic seed, but might be deterministic + # if no dropout is applied + if func.overloadpacket.__name__ == "_scaled_dot_product_flash_attention": + # pyrefly: ignore [missing-attribute] + is_rand_op = kwargs.get("dropout_p", 0) != 0 + # 5. Create metadata information per active non-leaf module + for mod_fqn in self._mod_tracker.parents: + if mod_fqn in self._leaf_modules: + continue + acm = _SACMetadata( + func=func, + time_taken=op_time, + memory_used=nbytes, + curr_idx=curr_idx, + output_ids=output_ids, + inplace_info=mod_inplace_info[mod_fqn], + is_view_like=is_view_like, + is_rand_op=is_rand_op, + ) + if acm_stats := self._sac_mod_metadata.get(mod_fqn, None): + acm_stats.sac_metadata.append(acm) + else: + assert mod_fqn == "Global", ( + f"Module {mod_fqn} not found in AC Mod Stats" + ) + self._sac_metadata.append(acm) + + return out + + def _get_greedy_order_meta(self, sac_stats: SACStats) -> SACGreedyOrderMeta: + # An inplace-op group is a set of inplace-ops that operate on the same underlying tensor storage. + # 1. inplace_op_groups: A dictionary from the top-most parent of inplace-ops to the inplace-ops in the group + # The top-most op can itself be an inplace-op or can be a non-inplace op. + # 2. inplace_op_to_group_head: A dictionary that maps all the inplace-ops to their respective group heads. + inplace_op_groups: dict[int, set[int]] = {} + inplace_op_to_group_head: dict[int, int] = dict(sac_stats.inplace_ops) + + # Initialize inplace_op_groups using inplace_op_to_group_head + for op_idx, group_head_idx in inplace_op_to_group_head.items(): + op_group = inplace_op_groups.setdefault(group_head_idx, {group_head_idx}) + op_group.add(op_idx) + + # Like inplace ops, all of the random ops in the function/module should all be either recomputed or saved + # as a group. This is because, they affect the ranom seed generator. If force_store_random is set True, + # all of the random ops will be stored by default. For easy of manageability, we store the top-most random op + # as the leader of the random_ops_group. + random_ops_group: dict[int, set[int]] = {} + random_group_head_idx = min(sac_stats.rand_ops, default=-1) + has_rand_ops = bool(sac_stats.rand_ops) + if has_rand_ops: + random_ops_group[random_group_head_idx] = set(sac_stats.rand_ops) + + # 1. Random ops are stored if force_store_random is set + # 2. View-like ops are recomputed by default + # 3. For inplace_op_groups: + # a) If the head of this group is an inplace op, then we have to store the entire group. + # b) If any op in the group is random and force_store_random is set, then entire group will be stored. + # c) If none of ops in the group are random and the head of the group is not an in-place op, then + # this group can be considered for recomputation in its entirety + stored_ops: set[int] = set() + recomputed_ops: set[int] = set() + # Case 1: + if has_rand_ops and sac_stats.force_store_random: + stored_ops.add(random_group_head_idx) + # Case 2: + recomputed_ops.update(set(sac_stats.view_like_ops)) + + for group_head_idx, op_group in inplace_op_groups.items(): + # Case 3a: + if group_head_idx in inplace_op_to_group_head: + stored_ops.add(group_head_idx) + # Case 3b: + if ( + sac_stats.force_store_random & len(op_group & set(sac_stats.rand_ops)) + > 0 + ): + stored_ops.add(group_head_idx) + + # The potential recompute candidates are populated as: + recompute_candidates: set[int] = set() + # 1) The random group head if it is not stored + if has_rand_ops and random_group_head_idx not in stored_ops: + recompute_candidates.add(random_group_head_idx) + # 2) The in-place op group heads that are not stored + recompute_candidates.update(set(inplace_op_groups.keys()) - stored_ops) + # 3) The non-inplace and non-random ops that are neither stored nor recomputed by default + recompute_candidates.update( + set(range(len(sac_stats.memory))) + - recomputed_ops + - stored_ops + - set(inplace_op_to_group_head.keys()) + - set(sac_stats.rand_ops) + ) + + # We define msps for a recomp candidate as the ratio of memory/runtime aka memory savings per second + msps_meta: list[MSPS] = [] + for cand_idx in recompute_candidates: + op_indices = {cand_idx} + if cand_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[cand_idx]) + if has_rand_ops and cand_idx == random_group_head_idx: + op_indices.update(sac_stats.rand_ops) + + mem = sum(sac_stats.memory[op_idx] for op_idx in op_indices) + runtime = sum(sac_stats.runtimes[op_idx] for op_idx in op_indices) + func_names = {sac_stats.func_names[op_idx] for op_idx in op_indices} + msps = (mem / runtime) if runtime > 0 else sys.float_info.max + msps_meta.append(MSPS(func_names, cand_idx, mem, runtime, msps)) + # We choose candidates to be recomputed based on increasing msps + msps_meta.sort(key=lambda x: x.msps, reverse=True) + return SACGreedyOrderMeta( + recomputed_ops, stored_ops, inplace_op_groups, random_ops_group, msps_meta + ) + + def _get_sac_tradeoff_pwlf_stats( + self, + sac_stats: SACStats, + greedy_order_meta: SACGreedyOrderMeta, + n_segments: int = 2, + save_tradeoff_graph: bool = False, + filename: str = "ac_tradeoff", + ) -> SACTradeOffStats: + try: + import numpy as np # type: ignore[import-not-found] + import pwlf # type: ignore[import-untyped, import-not-found] + except ImportError as err: + raise ImportError("Please install pwlf and numpy package.") from err + + stored_ops, recomputed_ops, inplace_op_groups, random_ops_group, msps_meta = ( + greedy_order_meta.stored_ops, + greedy_order_meta.recomputed_ops, + greedy_order_meta.inplace_op_groups, + greedy_order_meta.random_ops_group, + greedy_order_meta.msps_meta, + ) + # 1. Initialize the discarded memory and recomputation runtime to sum of already chosen recomputed_ops + recomp_indices: set[int] = set() + for r_idx in recomputed_ops: + recomp_indices.add(r_idx) + if r_idx in inplace_op_groups: + recomp_indices.update(inplace_op_groups[r_idx]) + if r_idx in random_ops_group: + recomp_indices.update(random_ops_group[r_idx]) + + discarded_mem = sum(sac_stats.memory[op_idx] for op_idx in recomp_indices) + recomp_runtime = sum(sac_stats.runtimes[op_idx] for op_idx in recomp_indices) + # 2. Initialize the max recomputation time and total recomputation memory + sac_runtime = sum(sac_stats.runtimes) + sac_memory = sum(sac_stats.memory) + # 3. Tradeoff curve stores the KV pair of the discarded memory to total memory and, + # recomputation time to total runtime incurred. + delta = 1e-2 + tradeoff_curve = OrderedDict() + # 4. Initialize the trade-off curve with the stats of of already chosen recomputed_ops + tradeoff_curve[(discarded_mem / sac_memory) + delta] = ( + recomp_runtime / sac_runtime + ) + # 5. Update the trade-off curve with memory and runtime stats of SAC candidates in the + # greedy order of their ``MSPS``. + for cand in msps_meta: + discarded_mem += cand.memory + recomp_runtime += cand.runtime + tradeoff_curve[(discarded_mem / sac_memory) + delta] = ( + recomp_runtime / sac_runtime + ) + # 6. Finally, we add the memory and recomputation time of the always stored ops. + stored_indices: set[int] = set() + for s_idx in stored_ops: + stored_indices.add(s_idx) + if s_idx in inplace_op_groups: + stored_indices.update(inplace_op_groups[s_idx]) + if s_idx in random_ops_group: + stored_indices.update(random_ops_group[s_idx]) + discarded_mem += sum(sac_stats.memory[op_idx] for op_idx in stored_indices) + recomp_runtime += sum(sac_stats.runtimes[op_idx] for op_idx in stored_indices) + tradeoff_curve[(discarded_mem / sac_memory) + delta] = ( + recomp_runtime / sac_runtime + ) + x_ = list(tradeoff_curve.keys()) + y_ = list(tradeoff_curve.values()) + # 7. We shift the y values to left and x values to right to upperbound the trade-off function + # TODO: Write a better explanation why this needs to be done + x = x_[: len(x_) - 1] + y = y_[1:] + tradeoff_pwlf = pwlf.PiecewiseLinFit(x, y) + # 8. Fit a piecewise linear function with the specified number of segments to the trade-off curve. + n_segments = max(min(len(x) - 2, n_segments), 1) + tradeoff_pwlf.fit(n_segments=n_segments) + + # save prediction graph + def save_prediction_graph( + pwlf_: pwlf.PiecewiseLinFit, x: list[float], y: list[float], filename: str + ) -> None: + try: + import matplotlib.pyplot as plt # type: ignore[import-not-found] + import numpy as np # type: ignore[import-not-found] + except ImportError as err: + raise ImportError( + "Install matplotlib and numpy using pip: pip install matplotlib numpy" + ) from err + # predict for the determined points + xHat = np.linspace(min(x), max(x), num=10000) + yHat = pwlf_.predict(xHat) + + # plot the results + plt.figure() + plt.plot(x, y, "o", label="Shifted") + plt.plot(xHat, yHat, "-", label="Predicted") + plt.plot(x_, y_, "x", label="Original") + plt.ylabel("Recomp time / Total recomp time") + plt.xlabel("Memory discarded / Total memory") + plt.legend() + plt.title(f"{filename}") + plt.suptitle( + f"Total Memory = {sac_memory} B Total Runtime = {sac_runtime:.4f} ms", + fontsize=10, + ) + folder_name = "tradeoff_graphs" + if not os.path.exists(folder_name): + os.makedirs(folder_name) + # Save the plots in the folder + plt.savefig(os.path.join(folder_name, f"{filename}.png")) + + if save_tradeoff_graph: + save_prediction_graph(tradeoff_pwlf, x, y, filename) + # 9. Obtain the slopes, intercepts and breakpoints of the fitted piecewise linear functions + slopes = tradeoff_pwlf.calc_slopes().tolist() + assert isinstance(tradeoff_pwlf.intercepts, np.ndarray) and isinstance( + tradeoff_pwlf.fit_breaks, np.ndarray + ) + intercepts = tradeoff_pwlf.intercepts.tolist() + fit_breaks = tradeoff_pwlf.fit_breaks.tolist() + return SACTradeOffStats( + n_segments=n_segments, + slopes=slopes, + intercepts=intercepts, # type: ignore[arg-type] + fit_breaks=fit_breaks, # type: ignore[arg-type] + tradeoff_curve=tradeoff_curve, + sac_memory=sac_memory, + sac_runtime=sac_runtime, + ) + + def display_sac_stats( + self, sac_stats: SACStats, print_tabular: bool = False + ) -> None: + """ + Displays the SAC statistics. + + Args: + sac_stats (SACStats): The SAC statistics to display. + print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False. + + Prints: + 1. Total Memory: The total memory usage in bytes. + 2. Total Runtime: The total runtime in milliseconds. + 3. Store Random: A flag indicating whether to force store random operator results. + + Followed by a table with the following columns: + 1. Op Idx: The operator index. + 2. Op Name: The operator name. + 3. Runtimes (ms): The operator runtime in milliseconds. + 4. Memory (B): The operator memory usage in bytes. + 5. View-like: A flag indicating whether the operator is view-like. + 6. Random: A flag indicating whether the operator is random. + 7. Saved Autograd: A flag indicating whether the operator's result is saved by autograd engine. + 8. In-place: The index of the operator's first parent, or None if not in-place. + + If print_tabular is True, the table is printed in a tabular format. + Otherwise, the table is printed in a plain text format. + """ + print( + f"Total Memory: {sum(sac_stats.memory)} B Total Runtime: {sum(sac_stats.runtimes)} ms" + f" Store Random: {sac_stats.force_store_random}" + ) + table_data = [] + op_parent = dict(sac_stats.inplace_ops) + for i, fn_name in enumerate(sac_stats.func_names): + row = [ + str(i), + fn_name, + f"{sac_stats.runtimes[i]:.4f}", + str(sac_stats.memory[i]), + str(i in sac_stats.view_like_ops), + str(i in sac_stats.rand_ops), + str(i in sac_stats.saved_autograd_ops), + str(op_parent.get(i)), + ] + table_data.append(row) + # Define headers + headers = [ + "Op Idx", + "Op Name", + "Runtimes(ms)", + "Memory (B)", + "View-like", + "Random", + "Saved Autograd", + "In-place", + ] + if print_tabular: + _display_stats_tabular(headers, table_data) + else: + max_widths = [0 for _ in range(len(headers))] + table_data.insert(0, headers) + for row in table_data: + for i, elem in enumerate(row): + max_widths[i] = max(max_widths[i], len(elem)) + for row in table_data: + print( + "\t".join( + [f"{elem:<{max_widths[i]}}" for i, elem in enumerate(row)] + ) + ) + + def display_sac_tradeoff_stats( + self, + greedy_order_meta: SACGreedyOrderMeta, + sac_stats: SACStats, + print_tabular: bool = False, + ) -> None: + """ + Displays the SAC trade-off statistics. + + Args: + greedy_order_meta (SACGreedyOrderMeta): The SAC greedy order metadata. + sac_stats (SACStats): The SAC statistics. + print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False. + + Prints: + A table with the following columns: + 1. Op Id(s): The operator index(es). + 2. Op Name(s): The operator name(s). + 3. Discarded Mem (%): The percentage of discarded memory. + 4. Discarded Mem (B): The discarded memory in bytes. + 5. Recomp time (%): The percentage of recomputed time. + 6. Recomp time (ms): The recomputed time in milliseconds. + 7. MSPS: The memory per second. + 8. Always Stored: A flag indicating whether the operator is always stored. + 9. Always Recomputed: A flag indicating whether the operator is always recomputed. + + If print_tabular is True, the table is printed in a tabular format. + Otherwise, the table is printed in a plain text format. + """ + table_data = [] + total_memory, total_runtime = sum(sac_stats.memory), sum(sac_stats.runtimes) + discarded_mem: int = 0 + recomp_runtime: float = 0.0 + + def append_row( + op_indices: set[int], + func_names: set[str], + msps: float | None = None, + stored: bool | None = False, + recomputed: bool | None = False, + ) -> None: + row = [ + str(op_indices), + str(func_names), + f"{discarded_mem / total_memory:.4f}", + str(discarded_mem), + f"{recomp_runtime / total_runtime:.4f}", + str(recomp_runtime), + f"{msps:.2e}" if msps is not None else str(nan), + str(stored), + str(recomputed), + ] + table_data.append(row) + + stored_ops, recomputed_ops, inplace_op_groups, random_ops_group, msps_meta = ( + greedy_order_meta.stored_ops, + greedy_order_meta.recomputed_ops, + greedy_order_meta.inplace_op_groups, + greedy_order_meta.random_ops_group, + greedy_order_meta.msps_meta, + ) + + for op_idx in recomputed_ops: + op_indices: set[int] = {op_idx} + if op_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[op_idx]) + if op_idx in random_ops_group: + op_indices.update(random_ops_group[op_idx]) + discarded_mem += sum(sac_stats.memory[i] for i in op_indices) + recomp_runtime += sum(sac_stats.runtimes[i] for i in op_indices) + func_names = {sac_stats.func_names[i] for i in op_indices} + append_row(op_indices, func_names, recomputed=True) + + for cand in msps_meta: + discarded_mem += cand.memory + recomp_runtime += cand.runtime + op_indices = {cand.op_idx} + if cand.op_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[cand.op_idx]) + if cand.op_idx in random_ops_group: + op_indices.update(random_ops_group[cand.op_idx]) + append_row(op_indices, cand.func_names, msps=cand.msps) + + for op_idx in stored_ops: + op_indices = {op_idx} + if op_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[op_idx]) + if op_idx in random_ops_group: + op_indices.update(random_ops_group[op_idx]) + discarded_mem += sum(sac_stats.memory[i] for i in op_indices) + recomp_runtime += sum(sac_stats.runtimes[i] for i in op_indices) + func_names = {sac_stats.func_names[i] for i in op_indices} + append_row(op_indices, func_names, stored=True) + + headers = [ + "Op Id(s)", + "Op Name(s)", + "Discarded Mem (%)", + "Discarded Mem (B)", + "Recomp time (%)", + "Recomp time (ms)", + "MSPS", + "Always Stored", + "Always Recomputed", + ] + if print_tabular: + _display_stats_tabular(headers, table_data) + else: + max_widths = [0 for _ in range(len(headers))] + table_data.insert(0, headers) + for row in table_data: + for i, elem in enumerate(row): + max_widths[i] = max(max_widths[i], len(elem)) + for row in table_data: + print( + "\t".join( + [f"{elem:<{max_widths[i]}}" for i, elem in enumerate(row)] + ) + ) + + def pwlf_sac_tradeoff_curve( + self, + n_segments: int = 2, + save_tradeoff_graphs: bool = False, + ) -> None: + """ + Fits a piecewise linear function with the specified sumber of segments to the SAC trade-off curve of + discarded memory vs recomputation time. + + Args: + n_segments (int, optional): The number of segments to be used for fitting the piecewise linear function to + the trade-off curve. Defaults to 2. + save_tradeoff_graphs (bool, optional): Whether to save the trade-off graphs to file. Defaults to False. + + If save_tradeoff_graphs is True, the trade-off graphs are saved to file using the module FQN as the filename. + """ + for mod_fqn, sac_stats in self.sac_mod_stats.items(): + self.sac_mod_tradeoff_stats[mod_fqn] = self._get_sac_tradeoff_pwlf_stats( + sac_stats=sac_stats, + greedy_order_meta=self.sac_mod_greedy_order_meta[mod_fqn], + n_segments=n_segments, + save_tradeoff_graph=save_tradeoff_graphs, + filename=mod_fqn, + ) + + def display_modulewise_sac_stats( + self, depth: int = 2, print_tabular: bool = False + ) -> None: + """ + Displays the SAC and trade-off statistics for each module. + + Args: + depth (int, optional): The maximum depth of modules to display. Defaults to 2. + print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False. + + Prints: + For each module with depth less than or equal to the specified depth: + 1. The SAC statistics for the module (using display_sac_stats). + 2. The SAC trade-off statistics for the module (using display_sac_tradeoff_stats). + + If print_tabular is True, the statistics are printed in a tabular format. + Otherwise, the statistics are printed in a plain text format. + """ + for mod_fqn, sac_stats in self.sac_mod_stats.items(): + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(f"Module: {mod_fqn}") + self.display_sac_stats(sac_stats, print_tabular) + print(f"AC Trade-off for Module: {mod_fqn} MSPS = Memory/Runtime") + self.display_sac_tradeoff_stats( + self.sac_mod_greedy_order_meta[mod_fqn], sac_stats, print_tabular + ) + + def __call__(self, estimate_mode_type: str) -> Self: + """ + Sets the estimate mode type. + + Currently supported modes: + - "operator-level-benchmark": Estimates runtime using operator benchmarking. + - "operator-level-cost-model": Estimates runtime using roofline cost model. + + Args: + estimate_mode_type (str): The type of estimate mode to use. + + Returns: + SACEstimator: The SAC estimator instance. + + Raises: + NotImplementedError: If the estimate mode type is not supported. + """ + if estimate_mode_type == "operator-level-benchmark": + self._estimate_runtime = RuntimeEstimator._benchmark_estimate + elif estimate_mode_type == "operator-level-cost-model": + self._estimate_runtime = RuntimeEstimator._roofline_estimate + else: + raise NotImplementedError( + f"estimate_mode_type {estimate_mode_type} not supported" + ) + return self + + def __enter__(self) -> Self: # type: ignore[no-untyped-def] + fake_mode = active_fake_mode() + assert isinstance(fake_mode, FakeTensorMode), ( + "SAC Estimator should be called in FakeTensorMode" + ) + RuntimeEstimator.fake_mode = fake_mode + self._mod_tracker.register_user_hooks( + pre_fw_hook=self._pre_fw_hook, + post_fw_hook=self._post_fw_hook, + ) + self._mod_tracker.__enter__() + self._saved_tensor_hook_ctx.__enter__() + return super().__enter__() + + def __exit__(self, *args: Any) -> None: # type: ignore[no-untyped-def] + self._saved_tensor_hook_ctx.__exit__() + self._mod_tracker.__exit__(*args) + super().__exit__(*args) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/sac_ilp.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/sac_ilp.py new file mode 100644 index 0000000000000000000000000000000000000000..8799493f260a5967c8086aa3d24e64132cc4102d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/_tools/sac_ilp.py @@ -0,0 +1,294 @@ +import logging +import math +from enum import IntEnum + +from torch.distributed._tools.ilp_utils import Graph, is_submodule +from torch.distributed._tools.sac_estimator import SACStats + + +try: + from pulp import ( # type: ignore[import-untyped,import-not-found] + lpDot, + LpInteger, + LpMaximize, + LpMinimize, + LpProblem, + LpStatus, + lpSum, + LpVariable, + PULP_CBC_CMD, + value, + ) +except ImportError as err: + raise ImportError( + "Please install pulp package. See: https://github.com/coin-or/pulp." + ) from err + +# Create a logger object +logger = logging.getLogger(__name__) + +# Set the logging level to INFO +logger.setLevel(logging.INFO) + + +def sac_milp( + graph: Graph, + memory_budget: float, + world_size: int = 1, + ac_units: list[str] | None = None, + fsdp_units: list[str] | None = None, +) -> tuple[dict[str, float], float, int]: + """ + MILP to decide which modules to AC and how much memory to discard. + The objective is to minimize recomputation time. + The constraint is to ensure peak memory is under budget. + + Args: + graph: graph representation of the model as a module submodule tree + where each node is a submodule with memory & runtime stats + memory_budget: memory budget in GiB + world_size: number of GPUs. In the case of FSDP, world_size will be + used to compute the amount of parameter and gradient memory on each rank + ac_units: a list of user-specified AC units. + fsdp_units: a list of FSDP units. AC units cannot be supermodules of FSDP units. + + Returns: + Dict[str, float]: the optimal SAC solution, mapping from module fqn to + the percentage of activation memory to **discard** + float: the recomputation time of the optimal SAC solution + int: upper bound on the peak memory of the optimal SAC solution. + note that value of -1 means that the ILP solver failed to find a solution. + + """ + num_nodes = len(graph.nodes) + M = 10**2 # note: numerical issue may occur if M is too big + MEM_MULTIPLIER = 2**30 + + # Create a MILP problem + prob = LpProblem("SAC", LpMinimize) + + # Create decision variables + # y_i: indicator for if module i is AC'ed + y = LpVariable.matrix("y", list(range(num_nodes)), 0, 1, LpInteger) + # r_i: percentage of discarded activation memory + r = LpVariable.matrix("r", list(range(num_nodes)), 0, 1) + # d_i: discarded activation memory for module i + d = LpVariable.matrix("d", list(range(num_nodes)), 0) + # a_i: total activation memory at module i + a = LpVariable.matrix("a", list(range(num_nodes)), 0) + # m_i: memory at module i, combining parameters, gradients, and activations + m = LpVariable.matrix("m", list(range(num_nodes)), 0) + # rcp_i: percentage of recomputation time + rcp = LpVariable.matrix("rcp", list(range(num_nodes)), 0) + # rct_i: recomputation time for module i (in ms) + rct = LpVariable.matrix("rct", list(range(num_nodes)), 0) + # max_m: peak memory + max_m = LpVariable("max_m", 0) + + # Add constraints + # [Constraint] User specified AC units + if ac_units: + ac_units_set = set(ac_units) + for i in range(num_nodes): + if graph.nodes[i]["fqn"] not in ac_units_set: + prob += y[i] == 0 + + # [Constraint] AC units cannot be supmodules of user specified FSDP units + if fsdp_units: + for i in range(num_nodes): + if any( + is_submodule(fsdp_unit, graph.nodes[i]["fqn"]) + for fsdp_unit in fsdp_units + ): + prob += y[i] == 0 + + # [Constraint] No nested AC units + for i in range(num_nodes): + for j in range(i + 1, num_nodes): + if graph.ad_matrix[i][j] == 1: + prob += y[i] + y[j] <= 1 + + # [Constraint] Do not AC leaf modules + for i in range(num_nodes): + if graph.nodes[i]["is_leaf"]: + prob += y[i] == 0 + + # [Constraint] Express amount of discarded activation memory + for i in range(num_nodes): + # There are two measures for activation memory: ACM and IA + # 1. IA is the activation memory saved when not using AC + # 2. ACM is the total activation memory, including those + # that are not typically saved when not using AC + # Note: ACM >= IA + if (not graph.nodes[i]["is_leaf"]) and graph.nodes[i][ + "sac_memory" + ] < graph.nodes[i]["act_fw_per_module"]: + logger.warning("For module {%s}: ", graph.nodes[i]["fqn"]) + logger.warning( + "activation memory from memory tracker is {%d},", + graph.nodes[i]["act_fw_per_module"], + ) + logger.warning( + "activation memory from SAC estimator is {%d}.", + graph.nodes[i]["sac_memory"], + ) + logger.warning("Something is wrong. Please check!") + logger.warning("Overriding the latter with the former.") + graph.nodes[i]["sac_memory"] = graph.nodes[i]["act_fw_per_module"] + ACM_i = graph.nodes[i]["sac_memory"] / MEM_MULTIPLIER + IA_i = graph.nodes[i]["act_fw_per_module"] / MEM_MULTIPLIER + prob += d[i] == ACM_i * r[i] - (ACM_i - IA_i) * y[i] + + # [Constraint] Ensure correctness of r_i + # There are two parts to its correctness + # 1. r_i > 0 only if y_i == 1 (discard only if it is an AC unit) + # 2. r_i needs to be large enough to cover the difference between + # ACM and IA. Otherwise, we are not saving any memory + for i in range(num_nodes): + prob += y[i] >= r[i] + if graph.nodes[i]["is_leaf"]: + continue + ACM_i = graph.nodes[i]["sac_memory"] / MEM_MULTIPLIER + IA_i = graph.nodes[i]["act_fw_per_module"] / MEM_MULTIPLIER + prob += r[i] >= (ACM_i - IA_i) / ACM_i * y[i] + + # [Constraint] Express total activation memory in the backward pass + for i in range(num_nodes): + AG_i = graph.nodes[i]["act_grad_per_module"] / MEM_MULTIPLIER + TA_i = graph.nodes[i]["act_total"] / MEM_MULTIPLIER + # related to discarded amount of memory + pos = graph.nodes[i]["pos_fw_post_order"] + coeff = [0] * num_nodes + for p in range(pos): + j = graph.name2node[graph.fw_post_order[p]]["index"] + coeff[j] = 1 + prob += a[i] == TA_i + AG_i - lpDot(coeff, d) + + # [Constraint] Express the total amount of memory at each module + # Note that unsharded parameters and gradients are not included here + P_1 = graph.nodes[0]["param_per_module"] / MEM_MULTIPLIER + for i in range(num_nodes): + TG_i = graph.nodes[i]["grad_total"] / MEM_MULTIPLIER + prob += m[i] == a[i] + (P_1 + TG_i) / world_size + + # [Constraint] Express peak memory + for i in range(num_nodes): + prob += max_m >= m[i] + + # [Constraint] Express percentage of recomputation time + for i in range(num_nodes): + for s in range(graph.nodes[i]["n_segments"]): + slope = graph.nodes[i]["slopes"][s] + intercept = graph.nodes[i]["intercepts"][s] + prob += rcp[i] >= slope * r[i] + intercept + + # [Constraint] Express recomputation time + # rct_i = (rcp_i * ACT_i) if y_i == 1 else 0 + for i in range(num_nodes): + ACT_i = graph.nodes[i]["sac_runtime"] + prob += rct[i] <= M * y[i] + prob += rct[i] <= ACT_i * rcp[i] + prob += rct[i] >= ACT_i * rcp[i] - M * (1 - y[i]) + + # [Constraint] Peak memory should be below budget + prob += max_m <= memory_budget + + # Set Objeictive + prob += lpSum(rct) + + # Solve + solver = PULP_CBC_CMD(gapRel=0.05, timeLimit=180, msg=0) + status = prob.solve(solver) + + # If solver fails, print status and return empty solution + if status != 1: + logger.error("Solver failed to find a solution: %s", LpStatus[status]) + return {}, 0, -1 + + # Gather and return solution if optimal solution is found + ac_decisions = {} + for i in range(num_nodes): + if round(y[i].varValue) == 1: + ac_decisions[graph.nodes[i]["fqn"]] = round(r[i].varValue, 4) + recomputation_time = round(value(prob.objective), 2) + peak_mem = round(max_m.varValue * MEM_MULTIPLIER) + + return ac_decisions, recomputation_time, peak_mem + + +class SACDecision(IntEnum): + RECOMPUTE = 0 + SAVE = 1 + + +def get_optimal_checkpointing_policy_per_module( + sac_stats: SACStats, memory_budget: float +) -> list[int]: + """ + This is adapted from -- + https://github.com/facebookresearch/xformers/blob/c6c0ac31f1b08542a0bc27278c6ed10f825f6963/xformers/checkpoint.py#L375 + + Given the SACStats of a module, including list of operators, their memory, runtimes, and metadata, + decide via MILP an optimal set of operators to checkpoint under a given ``memory_budget``. + + Args: + sac_stats: the SACStats object of the module + memory_budget: a float between zero and one + + Returns: + List[int]: the decision whether each operator should be saved (1) or recomptued (0). + """ + if not (0 <= memory_budget <= 1): + raise ValueError( + f"`memory_budget` must be a float between 0 and 1. Got {memory_budget}." + ) + num_ops = len(sac_stats.func_names) + + # Create a MILP problem + prob = LpProblem("SAC-per-module", LpMaximize) + + # Create decision variables + # x[i] = 1 means the i-th operator should be saved, otherwise it should be recomputed + x = LpVariable.matrix("x", list(range(num_ops)), 0, 1, LpInteger) + + # Add constraints + # [Constraint] random ops should be saved if ``force_store_random`` is True + # otherwise, random ops should either be all recomputed or all saved + if sac_stats.force_store_random: + for i in sac_stats.rand_ops: + prob += x[i] == SACDecision.SAVE.value + else: + for i1, i2 in zip(sac_stats.rand_ops[:-1], sac_stats.rand_ops[1:]): + prob += x[i1] == x[i2] + + # [Constraint] view-like ops should always be recomputed + for i in sac_stats.view_like_ops: + prob += x[i] == SACDecision.RECOMPUTE.value + + # [Constraint] inplace ops should always be done in conjunction with its parent op + for op, op_parent in sac_stats.inplace_ops: + if op != op_parent: + prob += x[op] == x[op_parent] + else: + prob += x[op] == SACDecision.SAVE.value + + # [Constraint] saved memory should be under the ``memory_budget`` + max_memory = math.ceil(memory_budget * sum(sac_stats.memory)) + prob += lpDot(x, sac_stats.memory) <= max_memory + + # [Objective] minimize recomputation time, note the ILP is a maximization problem + # because x[i] == 1 means the op is saved (not recomputed), and thus recomputation + # time is sum(sac_stats.runtimes) - lpDot(x, sac_stats.runtimes) + prob += lpDot(x, sac_stats.runtimes) + + # Solve + solver = PULP_CBC_CMD(gapRel=0.05, timeLimit=10, msg=0) + status = prob.solve(solver) + + # If solver fails, print status and return empty solution + if status != 1: + logger.error("Solver failed to find a solution: %s", LpStatus[status]) + return [] + + # Gather and return solution if optimal solution is found + return [round(x[i].varValue) for i in range(num_ops)] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..06c814295699405de9a8f8cf7f6a861b07b63a05 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/__init__.py @@ -0,0 +1 @@ +from .join import Join, Joinable, JoinHook diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e75bb5f86ae2c6ef764bd3d8601d27a4a5c6bb0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/__pycache__/join.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/__pycache__/join.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ca4541e56d6513b070d4041b6f586d3bef790d3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/__pycache__/join.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6434f2121647c9aeb02db1d00e0220a36398de0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/checkpoint_wrapper.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/checkpoint_wrapper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a2d4386a88c05a0035e0dddd93b411f633cf2b2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/checkpoint_wrapper.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..081d397a9c1f11e332f95649d362e1f3c27abe8a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -0,0 +1,321 @@ +# mypy: allow-untyped-defs +import warnings +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterator +from enum import auto, Enum +from functools import partial +from typing import Any + +import torch +import torch.nn as nn +from torch.autograd.graph import save_on_cpu +from torch.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs +from torch.utils.checkpoint import checkpoint as torch_utils_checkpoint + + +_CHECKPOINT_WRAPPED_MODULE = "_checkpoint_wrapped_module" +_CHECKPOINT_PREFIX = _CHECKPOINT_WRAPPED_MODULE + "." + + +class CheckpointImpl(Enum): + REENTRANT = auto() + NO_REENTRANT = auto() + + +class ActivationWrapper(torch.nn.Module, ABC): + """ + Base class for Activation Checkpoint and Activation Offload. + + Not meant to be instantiated directly. + """ + + def __init__(self, mod): + super().__init__() + self._checkpoint_wrapped_module = mod + # state_dict post hook to remove prefix to allow loading into a + # non-checkpoint wrapped module. + self._register_state_dict_hook(self._post_state_dict_hook) + # load_state_dict pre-hook to allow loading back into + # checkpoint-wrapped module. + self.register_load_state_dict_pre_hook(self._pre_load_state_dict_hook) + + @abstractmethod + def forward(self, *args, **kwargs): + raise ValueError("Subclasses should implement forward().") + + def __getattr__(self, name: str) -> Any: + """Forward missing attributes to wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self._checkpoint_wrapped_module, name) + + def __getitem__(self, key: int) -> Any: + """Forward indexing calls in case the module is a nn.Sequential.""" + return self._checkpoint_wrapped_module.__getitem__(key) # type: ignore[operator] + + def named_parameters( + self, + *args, + **kwargs, + ) -> Iterator[tuple[str, torch.nn.Parameter]]: + """ + Override :meth:`named_parameters()` to intercept parameter names. + + remove all occurrences of ``_CHECKPOINT_PREFIX``. + """ + for param_name, param in super().named_parameters(*args, **kwargs): + yield param_name.replace(_CHECKPOINT_PREFIX, ""), param + + @staticmethod + def _post_state_dict_hook( + module: nn.Module, + state_dict: dict[str, Any], + prefix: str, + *args: Any, + ) -> dict[str, Any]: + """ + _post_state_dict_hook() is called after the state_dict() of this FSDP module is executed. + + For ``checkpoint_wrapper``, it will strip checkpoint-wrapped module prefix, + so that this module can be loaded into non-checkpointed modules. + It would still be able to be loaded into checkpoint-wrapped modules as this class, + adds the prefix back before loading the state_dict. + """ + _replace_by_prefix(state_dict, f"{prefix}{_CHECKPOINT_PREFIX}", prefix) + return state_dict + + @staticmethod + def _pre_load_state_dict_hook( + module: nn.Module, + state_dict: dict[str, Any], + prefix: str, + *args: Any, + ) -> None: + """ + ``_pre_state_dict_hook` is called before ``self._load_from_state_dict()`` is called. + + For ``checkpoint_wrapper``, it will add back the module + prefix so that non-checkpointed modules can be loaded into + checkpoint_wrapper modules properly. + """ + _replace_by_prefix(state_dict, prefix, prefix + f"{_CHECKPOINT_PREFIX}") + + +class OffloadWrapper(ActivationWrapper): + def forward(self, *args, **kwargs): + with save_on_cpu(pin_memory=True): + return self._checkpoint_wrapped_module(*args, **kwargs) + + +class CheckpointWrapper(ActivationWrapper): + """ + An ``nn.Module`` that wraps another ``nn.Module`` with checkpointing. + + Note that this module is not meant to be used directly but instead, + it is to be used through the ``checkpoint_wrapper`` function. + """ + + def __init__( + self, + mod: torch.nn.Module, + checkpoint_impl: CheckpointImpl = CheckpointImpl.NO_REENTRANT, + checkpoint_fn=None, + **checkpoint_fn_kwargs, + ): + super().__init__(mod) + self.checkpoint_impl = checkpoint_impl + if checkpoint_fn is None: + # use torch.utils.checkpoint + self.checkpoint_fn = partial( + torch_utils_checkpoint, + use_reentrant=(self.checkpoint_impl == CheckpointImpl.REENTRANT), + **checkpoint_fn_kwargs, + ) + else: + # Construct user-specified checkpoint function. + self.checkpoint_fn = partial( + checkpoint_fn, + **checkpoint_fn_kwargs, + ) + + def forward(self, *args, **kwargs): + # Support keyword arguments for reentrant checkpoint. Note that this + # only works if user has specified self.checkpoint_impl and is not + # using their own custom checkpoint_fn. + if self.checkpoint_impl == CheckpointImpl.REENTRANT and kwargs != {}: + # Pack the args and kwargs + flat_args, kwarg_keys = _pack_kwargs(*args, **kwargs) + + # Function that only takes (packed) args, but can unpack them + # into the original args and kwargs for the checkpointed + # function, and runs that function. + def my_function(*inputs): + # unpack back into args and kwargs + unpacked_args, unpacked_kwargs = _unpack_kwargs(inputs, kwarg_keys) + # run original module + return self._checkpoint_wrapped_module( + *unpacked_args, **unpacked_kwargs + ) + + # Pass the function that only takes packed args into reentrant + # checkpoint API. + return self.checkpoint_fn( # type: ignore[misc] + my_function, + *flat_args, + ) + else: + return self.checkpoint_fn( # type: ignore[misc] + self._checkpoint_wrapped_module, *args, **kwargs + ) + + +def offload_wrapper(module: torch.nn.Module) -> torch.nn.Module: + """ + Wrap a module for activation offloading to CPU. + + Offloads intermediate activations to the CPU for modules wrapped with this function. + Wrappers with activation offload can be composed with ones that do recomputation-based + checkpoint to trade off increased compute versus increased CPU + memory usage and additional H2D transfers. + + Usage:: + offloaded_module = offload_wrapper(module) + outputs = checkpointed_module(inputs) + Args: + module (nn.Module): + The module to be wrapped + Returns: + (nn.Module): + Wrapped module + """ + return OffloadWrapper(module) + + +def checkpoint_wrapper( + module: torch.nn.Module, + checkpoint_impl: CheckpointImpl = CheckpointImpl.NO_REENTRANT, + checkpoint_fn=None, + **checkpoint_fn_kwargs, +) -> torch.nn.Module: + """ + Wrap a module for activation checkpointing. + + If the module is wrapped with this function, all subsequent calls to the module will, + automatically perform checkpointing without the user having to explicitly call ``checkpoint`` function. + + Usage:: + checkpointed_module = checkpoint_wrapper(module) + outputs = checkpointed_module(inputs) + Args: + module (nn.Module): + The module to be wrapped + checkpoint_impl (Optional[CheckpointImpl]): + The checkpointing implementation to use. Note that this will only + be passed into the ``torch.utils.checkpoint.checkpoint`` + implementation, and is ignored if a custom ``checkpoint_fn`` is + specified. Note that for implementations using reentrant checkpoint + from ``torch.utils.checkpoint``, keyword arguments will only be + supported if ``checkpoint_impl`` is passed as ``CheckpointImpl.REENTRANT`. + checkpoint_fn (Optional[Callable]): + Functional checkpoint implementation to use. If this is specified, + it will be used over the default ``torch.utils.checkpoint.checkpoint`` + implementation and the `checkpoint_impl` argument will be ignored. + **checkpoint_fn_kwargs: (Dict[str, Any]): Keyword arguments to pass into `checkpoint_fn`. + + Returns: + (nn.Module): + Wrapped module + """ + + if checkpoint_impl == CheckpointImpl.REENTRANT: + warnings.warn( + f"Please specify {CheckpointImpl.NO_REENTRANT} as " + f"{CheckpointImpl.REENTRANT} will soon be removed as " + "the default and eventually deprecated.", + FutureWarning, + stacklevel=2, + ) + return CheckpointWrapper( + module, + checkpoint_impl, + checkpoint_fn, + **checkpoint_fn_kwargs, + ) + + +def apply_activation_checkpointing( + model, + checkpoint_wrapper_fn=checkpoint_wrapper, + check_fn=lambda _: True, + auto_wrap_policy: Callable[[nn.Module, bool, int], bool] | None = None, +): + """ + Apply :func:`checkpoint_wrapper` to modules within `model` based on a user-defined configuration. + + For each module within `model`, the `check_fn` is used to decide + whether `module` should be wrapped with :func:`checkpoint_wrapper` or not. + + Note:: + This function modifies `model` in place and replaces appropriate layers with + their checkpoint-wrapped modules. + Note:: + This function will not wrap the overall root module. If this is needed, please directly use + :func:`checkpoint_wrapper` or :func:`offload_wrapper`. + Usage:: + model = nn.Sequential( + nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10) + ) + check_fn = lambda l: isinstance(l, nn.Linear) + # checkpoint activations + apply_activation_checkpointing(model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=check_fn) + # Or offload activations to CPU + apply_activation_checkpointing(model, checkpoint_wrapper_fn=offload_wrapper, check_fn=check_fn) + Args: + model (nn.Module): + The model whose submodules should be wrapped with activation checkpointing. + checkpoint_wrapper_fn (Optional[Callable[nn.Module]]) + A ``Callable`` which will wrap modules + check_fn (Optional[Callable[nn.Module, nn.Module]]) + A lambda function which will be passed each child submodule of ``model`` and returns + ``True`` or ``False`` depending on whether the submodule should be wrapped. + auto_wrap_policy (Optional[Callable[[nn.Module, bool, int], bool]]): A policy to wrap model's + submodules with AC. Note that if this is specified, it takes precedence over ``check_fn``. + Returns: None (`model` is modified inplace) + """ + # TODO: Importing inside function to avoid circular import issue between FSDP and + # checkpoint_wrapper. This can be resolved once wrap() APIs are decoupled from FSDP code. + from torch.distributed.fsdp._wrap_utils import _construct_wrap_fn, _post_order_apply + from torch.distributed.fsdp.wrap import ( + _Policy, + _recursive_wrap, + lambda_auto_wrap_policy, + ) + + policy = ( + auto_wrap_policy + if auto_wrap_policy is not None + else partial(lambda_auto_wrap_policy, lambda_fn=check_fn) + ) + if not callable(policy): + if not isinstance(policy, _Policy): + raise ValueError( + f"Expected {policy} to be callable or be a pre-defined wrap policy" + ) + target_module_to_kwargs = policy._run_policy( + model, ignored_modules=set(), root_kwargs={} + ) + wrap_fn = _construct_wrap_fn( + model, target_module_to_kwargs, checkpoint_wrapper_fn + ) + _post_order_apply(model, wrap_fn) + return + + _recursive_wrap( + module=model, + auto_wrap_policy=policy, # type: ignore[arg-type] + wrapper_cls=checkpoint_wrapper_fn, + ignored_modules=set(), + ignored_params=set(), + only_wrap_children=True, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7b57a075ad729d0ae3004dc15585250b04810f43 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/__init__.py @@ -0,0 +1,7 @@ +from . import default_hooks as default + + +LOW_PRECISION_HOOKS = [ + default.fp16_compress_hook, + default.bf16_compress_hook, +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f968bf6a022a0266e4ea0a3b679be632ed7ab77 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/default_hooks.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/default_hooks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dedffeb2cd8b1acbb7c17f9e6864c08a69f4d6b9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/default_hooks.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/default_hooks.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/default_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..76cd01c2265b1d7e5739d79b406cb94a0b0a9893 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_comm_hooks/default_hooks.py @@ -0,0 +1,191 @@ +# mypy: allow-untyped-defs +import functools + +import torch +import torch.distributed as dist + + +class DefaultState: + r""" + Stores state needed to perform the default communication algorithm within a communication hook. + + Args: + process_group (ProcessGroup): The process group to be used. + """ + + __slots__ = [ + "process_group", + "world_size", + "gradient_predivide_factor", + "gradient_postdivide_factor", + ] + + def __init__(self, process_group: dist.ProcessGroup): + if process_group is None: + raise ValueError(f"Expected to pass in an explicit ProcessGroup to {self}.") + self.process_group = process_group + self.world_size = dist.get_world_size(process_group) + # Setting two factors `self.gradient_predivide_factor` + # and `self.gradient_postdivide_factor` to avoid underflow and overflow + self.gradient_predivide_factor = self._get_gradient_predivide_factor( + self.world_size + ) + self.gradient_postdivide_factor = ( + self.world_size / self.gradient_predivide_factor + ) + + @staticmethod + def _get_gradient_predivide_factor(world_size: int) -> float: + factor: int = 1 + while world_size % factor == 0 and world_size / factor > factor: + factor *= 2 + return float(factor) + + +class LowPrecisionState(DefaultState): + r""" + Stores state needed to perform gradient communication in a lower precision within a communication hook. + + Communication hook will cast gradients back to the original + parameter precision specified by ``parameter_type`` (default: torch.float32). + Builds on top of the :class:`DefaultState`. + + Args: + parameter_type (torch.dtype): The precision of model's parameters. + Required for a hook to cast gradients back to a parameter's precision. + """ + + __slots__ = [ + "parameter_type", + ] + + def __init__( + self, + process_group, + parameter_type=torch.float32, + ): + super().__init__(process_group) + self.parameter_type = parameter_type + + +def _decompress(state: LowPrecisionState, grad: torch.Tensor): + """ + Casts gradients back to full parameter precision so that further computation happens in full precision. + """ + orig_grad_data = grad.data + grad.data = grad.data.to(state.parameter_type) + device_type = "" + try: + if grad.device.type == "privateuse1": + device_type = torch._C._get_privateuse1_backend_name() + else: + device_type = grad.device.type + backend = getattr(torch, device_type) + except AttributeError as e: + raise AttributeError( + f"Device {grad.device} does not have a \ + corresponding backend registered as 'torch.device_type'." + ) from e + + # Don't let this memory get reused until after the transfer. + orig_grad_data.record_stream(backend.current_stream()) # type: ignore[arg-type] + + +def allreduce_hook(state: DefaultState, grad: torch.Tensor): + r""" + Implement the FSDP communication hook for ``all_reduce`` algorithm and a necessary pre- and post-division of gradients. + + Args: + state (DefaultState): State information, configures pre- and post-division factors. + grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks. + """ + # Average grad by pre-division factor. Together pre- and post-division factors + # lead to an overall averaging by world_size, required for consistency with PyTorch DDP. + # This is a two-step process to avoid potential underflow and overflow. + if state.gradient_predivide_factor > 1: + grad.div_(state.gradient_predivide_factor) + dist.all_reduce(grad, group=state.process_group) + # Average grad by post-division factor. + if state.gradient_postdivide_factor > 1: + grad.div_(state.gradient_postdivide_factor) + + +def reduce_scatter_hook(state: DefaultState, grad: torch.Tensor, output: torch.Tensor): + r""" + Implement the FSDP communication hook for ``reduce_scatter`` algorithm. + + For sharded FSDP strategies and a necessary pre- and post-division of gradients. + + Args: + state (DefaultState): State information, configures pre- and post-division factors. + grad (torch.Tensor): An unsharded gradient for the local batch that needs to be + communicated across ranks. + output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. + """ + # Average grad by pre-division factor. + if state.gradient_predivide_factor > 1: + grad.div_(state.gradient_predivide_factor) + dist.reduce_scatter_tensor(output, grad, group=state.process_group) + # Average grad's shard by post-division factor. + if state.gradient_postdivide_factor > 1: + output.div_(state.gradient_postdivide_factor) + + +def _low_precision_hook( + prec: torch.dtype, + state: LowPrecisionState, + grad: torch.Tensor, + output: torch.Tensor | None, +): + if grad.dtype != prec: + grad.data = grad.data.to(prec) + if output is not None: + if output.dtype != prec: + output.data = output.data.to(prec) + reduce_scatter_hook(state, grad, output) + _decompress(state, output) + else: + allreduce_hook(state, grad) + _decompress(state, grad) + + +def fp16_compress_hook( + state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor | None = None +): + r""" + Implement FSDP communication hook for a simple gradient compression approach. + Casts ``grad`` to half-precision floating-point format (``torch.float16``). + + It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a + ``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``) + gradients are averaged by a ``state.gradient_postdivide_factor``. + Once post-division is done, compressed gradients are casted back to parameters' precision. + + Args: + state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision. + grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision. + output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. + """ + fp16_hook = functools.partial(_low_precision_hook, torch.float16) + return fp16_hook(state, grad, output) + + +def bf16_compress_hook( + state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor | None = None +): + r""" + Implement FSDP communication hook for a simple gradient compression approach . + Casts ``grad`` to half-precision floating-point format. + + It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a + ``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``) + gradients are averaged by a ``state.gradient_postdivide_factor``. + Once post-division is done, compressed gradients are casted back to parameters' precision. + + Args: + state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision. + grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision. + output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. + """ + bf16_hook = functools.partial(_low_precision_hook, torch.bfloat16) + return bf16_hook(state, grad, output) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ba62bfb68f42a136dcfa27bcf378d3892cf6751a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/__init__.py @@ -0,0 +1 @@ +from .optimizer_overlap import _as_overlapped_optim diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dc97c38452c475547db69d809715c7465a08fec Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/optimizer_overlap.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/optimizer_overlap.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ea5687a65acf62d51f778a95f0d6c5f3102c738 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/optimizer_overlap.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py new file mode 100644 index 0000000000000000000000000000000000000000..569a42ffe7643bb6b6403dfb323a4dfd28493e1b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py @@ -0,0 +1,96 @@ +# mypy: allow-untyped-defs +import inspect +from abc import ABC, abstractmethod + +from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook +from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import ( + _hook_then_optimizer, + _OptimizerHookState, +) +from torch.distributed.fsdp import FullyShardedDataParallel +from torch.distributed.optim import as_functional_optim +from torch.nn.parallel import DistributedDataParallel +from torch.optim import Optimizer + + +# Contains the mappings between the regular and overlapped optimizer types. +_registered_overlapped_optims: dict[type, type] = {} + + +def register_overlapped(optim_cls): + def decorator(target_overlapped_optim_cls): + if target_overlapped_optim_cls in _registered_overlapped_optims: + raise ValueError( + f"{target_overlapped_optim_cls} already registered with optim_cls " + f"{_registered_overlapped_optims[optim_cls]} {optim_cls}, trying to" + f"re-register it for {optim_cls} is not supported." + ) + _registered_overlapped_optims[optim_cls] = target_overlapped_optim_cls + return target_overlapped_optim_cls + + return decorator + + +class OverlappedOptimizer(ABC): + def __init__(self, optim_cls: type) -> None: + """ + Initialize the OverlappedOptimizer. + + Overlappedoptimizer is a base class that child classes can implement to + specify how different optimizers will register themselves with DDP. + """ + self.optim_cls = optim_cls + + @abstractmethod + def register_ddp(self, ddp: DistributedDataParallel) -> None: + """Registers the overlapped optimizer with DDP.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not support overlapped DDP." + ) + + @abstractmethod + def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None: + """Registers the overlapped optimizer with FSDP.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not support overlapped FSDP." + ) + + +@register_overlapped(Optimizer) +class _OverlappedStandardOptimizer(OverlappedOptimizer): + """Overlaps a regular ``Optimizer``.""" + + def __init__(self, optim_cls: type, params, *optim_args, **optim_kwargs) -> None: + super().__init__(optim_cls) + f_optim = as_functional_optim(self.optim_cls, *optim_args, **optim_kwargs) + self._opt_hook_state = _OptimizerHookState(f_optim, params) + + def register_ddp(self, ddp_inst: DistributedDataParallel): + # NOTE: using a custom communication hook and fused optimizer is not + # yet supported. + ddp_inst.register_comm_hook( # type: ignore[operator] + None, # wrapped hook state + _hook_then_optimizer(allreduce_hook, self._opt_hook_state), + ) + + # TODO: register_fsdp once FSDP supports communication hook. + def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None: + """Register the overlapped optimizer with FSDP.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not support overlapped FSDP." + ) + + +def _as_overlapped_optim(optim_cls: type, params, *args, **kwargs): + """Return a new ``OverlappedOptimizer`` instance that supports ``optim_cls``.""" + for clz in inspect.getmro(optim_cls): + try: + return _registered_overlapped_optims[clz]( + optim_cls, params, *args, **kwargs + ) + except KeyError: + pass + + # Fallback to standard overlapped optimizer, which will raise errors if user + # is attempting to use an unsupported optimizer. + return _OverlappedStandardOptimizer(optim_cls, params, *args, **kwargs) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21dca1e3873b48c7074bea84340e76563bf6e092 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/__pycache__/quantization.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/__pycache__/quantization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd90361a1550186fb95bd47da23b812d7f83e372 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/__pycache__/quantization.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/quantization.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..69d88604561355b344b43129108d276e398e0f9f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/_quantization/quantization.py @@ -0,0 +1,151 @@ +# mypy: allow-untyped-defs +import functools +from enum import Enum + +import torch +import torch.distributed as dist + + +TORCH_HALF_MIN = torch.finfo(torch.float16).min +TORCH_HALF_MAX = torch.finfo(torch.float16).max + + +class DQuantType(Enum): + """ + Different quantization methods for auto_quantize API are identified here. + + auto_quantize API currently supports fp16 and bfp16 methods. + """ + + FP16 = ("fp16",) + BFP16 = "bfp16" + + def __str__(self) -> str: + return self.value + + +def _fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor: + return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half() + + +def _quantize_tensor(tensor, qtype): + if not isinstance(tensor, torch.Tensor): + raise RuntimeError( + f"_quantize_tensor expecting torch.Tensor as input but found {type(tensor)}" + ) + if qtype == DQuantType.FP16: + return _fp32_to_fp16_with_clamp(tensor) + elif qtype == DQuantType.BFP16: + return torch.ops.quantization._FloatToBfloat16Quantized(tensor) + else: + raise RuntimeError(f"Quantization type {qtype} is not supported") + + +def _quantize_tensor_list(tensor_list, qtype): + if not isinstance(tensor_list, list) or not all( + isinstance(p, torch.Tensor) for p in tensor_list + ): + raise RuntimeError( + f"_quantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}" + ) + quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list] + return quantized_tensor_list + + +def _dequantize_tensor(tensor, qtype, quant_loss=None): + if not isinstance(tensor, torch.Tensor): + raise RuntimeError( + f"_dequantize_tensor expecting torch.Tensor as input but found {type(tensor)}" + ) + if qtype == DQuantType.FP16: + if tensor.dtype != torch.float16: + raise RuntimeError( + f"tensor dtype is {tensor.dtype} while expected to be FP16." + ) + elif tensor.dtype == torch.float16 and quant_loss is None: + return tensor.float() + else: + # pyrefly: ignore [unsupported-operation] + return tensor.float() / quant_loss + elif qtype == DQuantType.BFP16: + if tensor.dtype != torch.float16: + raise RuntimeError( + f"tensor dtype is {tensor.dtype} while expected to be FP16." + ) + else: + return torch.ops.quantization._Bfloat16QuantizedToFloat(tensor) + else: + raise RuntimeError(f"Quantization type {qtype} is not supported") + + +def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None): + if not isinstance(tensor_list, list) or not all( + isinstance(p, torch.Tensor) for p in tensor_list + ): + raise RuntimeError( + f"_dequantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}" + ) + dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list] + return dequantized_tensor_list + + +def auto_quantize(func, qtype, quant_loss=None): + """ + Quantize the input tensors, choose the precision types, and pass other necessary arguments and then dequantizes the output. + + Currently it only supports: + . FP16 and BFP16 quantization method supported for gloo and nccl backends + . all_gather, all_to_all collective ops + Note: BFP16 only supports 2D tensors. + Args: + func (Callable): A function representing collective operations. + qtype (QuantType): Quantization method + quant_loss (float, optional): This can be used to improve accuracy in the dequantization. + Returns: + (Callable): the same collective as func but enables automatic quantization/dequantization. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + group = kwargs.get("group") + async_op = kwargs.get("async_op", False) + if async_op is True: + raise RuntimeError("The async_op=True mode is not supported yet.") + if func is dist.all_gather: + tensors = args[0] + input_tensors = _quantize_tensor(args[1], qtype) + out_tensors = _quantize_tensor_list(tensors, qtype) + dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op) + for i, t in enumerate( + _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss) + ): + tensors[i] = t + + elif func is dist.all_to_all: + tensors = args[0] + input_tensors = _quantize_tensor_list(args[1], qtype) + out_tensors = _quantize_tensor_list(tensors, qtype) + dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op) + for i, t in enumerate( + _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss) + ): + tensors[i] = t + + elif func is dist.all_to_all_single: + tensors = args[0] + out_splits = kwargs.get("out_splits") + in_splits = kwargs.get("in_splits") + # Quantizing the input/output tensor + input_tensors = _quantize_tensor(args[1], qtype) + out_tensors = _quantize_tensor(tensors, qtype) + dist.all_to_all_single( + out_tensors, input_tensors, out_splits, in_splits, group=group + ) + for i, t in enumerate( + _dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss) + ): + tensors[i] = t + else: + raise RuntimeError(f"The collective op {func} is not supported yet") + + return wrapper diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d9cc6d12785cc760ae77039d5403bd36c94fcdb8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__init__.py @@ -0,0 +1,140 @@ +# mypy: allow-untyped-defs +import sys +from enum import Enum +from functools import partial + + +# To suppress FutureWarning from partial since 3.13 +if sys.version_info >= (3, 13): + from enum import member + + def _enum_member(x): + return member(x) +else: + + def _enum_member(x): + return x + + +import torch.distributed as dist + +from . import ( + debugging_hooks as debugging, + default_hooks as default, + optimizer_overlap_hooks as optimizer_overlap, + powerSGD_hook as powerSGD, + quantization_hooks as quantization, +) + + +__all__ = ["DDPCommHookType", "register_ddp_comm_hook"] + + +def _ddp_comm_hook_wrapper(comm_hook, model, state): + model.register_comm_hook(state, comm_hook) + + +def _powerSGD_comm_hook_wrapper( + comm_hook, + model, + state, + matrix_approximation_rank, + start_powerSGD_iter=1_000, +): + """ + Wrap PowerSGD communication hook. + + To be consistent with the wrappers of other DDP comm hooks, the input state only needs to be a process group, + which will be wrapped up with other state info. + """ + powerSGD_state = powerSGD.PowerSGDState( + process_group=state, + matrix_approximation_rank=matrix_approximation_rank, + start_powerSGD_iter=start_powerSGD_iter, + ) + model.register_comm_hook(powerSGD_state, comm_hook) + + +class DDPCommHookType(Enum): + """ + Enumerate ``ddp_comm_hooks`` and ``ddp_comm_hook_wrapper`` communucation hook types. + + DDPCommHookType enumerates the hooks of ``torch.distributed.algorithms.ddp_comm_hooks`` + as names and ``ddp_comm_hook_wrapper`` partials with hook specified. As an example, + you can register allreduce hook by + ``DDPCommHookType.ALLREDUCE.value(model=model, state=process_group)``. + """ + + ALLREDUCE = _enum_member( + partial(_ddp_comm_hook_wrapper, comm_hook=default.allreduce_hook) + ) + FP16_COMPRESS = _enum_member( + partial(_ddp_comm_hook_wrapper, comm_hook=default.fp16_compress_hook) + ) + BF16_COMPRESS = _enum_member( + partial(_ddp_comm_hook_wrapper, comm_hook=default.bf16_compress_hook) + ) + QUANTIZE_PER_TENSOR = _enum_member( + partial( + _ddp_comm_hook_wrapper, comm_hook=quantization.quantization_pertensor_hook + ) + ) + QUANTIZE_PER_CHANNEL = _enum_member( + partial( + _ddp_comm_hook_wrapper, comm_hook=quantization.quantization_perchannel_hook + ) + ) + POWER_SGD = _enum_member( + partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.powerSGD_hook, + matrix_approximation_rank=1, + ) + ) + # Rank-2 PowerSGD can give a higher accuracy than the default rank-1 version, + # but it runs slower and consumes more memory. + POWER_SGD_RANK2 = _enum_member( + partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.powerSGD_hook, + matrix_approximation_rank=2, + ) + ) + # Batching can lead to a faster training at the cost of accuracy. + BATCHED_POWER_SGD = _enum_member( + partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.batched_powerSGD_hook, + matrix_approximation_rank=1, + ) + ) + BATCHED_POWER_SGD_RANK2 = _enum_member( + partial( + _powerSGD_comm_hook_wrapper, + comm_hook=powerSGD.batched_powerSGD_hook, + matrix_approximation_rank=2, + ) + ) + NOOP = _enum_member( + partial( + _ddp_comm_hook_wrapper, + comm_hook=debugging.noop_hook, + ) + ) + + +def register_ddp_comm_hook(comm_hook_type: DDPCommHookType, model, state=None): + """ + Register ``ddp_comm_hooks`` to DDP model. + + Registers the hooks of ``torch.distributed.algorithms.ddp_comm_hooks`` + to the DDP model. User can specify the type of hook as an enum + ``DDPCommHookType`` type using ``comm_hook_type`` input. State input will + be passed to the model. + Uses Python comm hook implementations. + + Example:: + >>> # xdoctest: +SKIP + >>> register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, model, state) + """ + comm_hook_type.value(model=model, state=state) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c549c494a1083977778eb4e3a2984cb1cb2a7819 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/ddp_zero_hook.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/ddp_zero_hook.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95112017ed5ee86d1a7605d3b434162792e71e56 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/ddp_zero_hook.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/debugging_hooks.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/debugging_hooks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dc37375a7bc01195026daff3876959aa07593e4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/debugging_hooks.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/default_hooks.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/default_hooks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d1ecc71c2e2bde6925356da076914bdc9dd5dfd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/default_hooks.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/mixed_precision_hooks.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/mixed_precision_hooks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..823335ee896434e10ddeaf1cd0dced3f0bdb0fc1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/mixed_precision_hooks.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/optimizer_overlap_hooks.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/optimizer_overlap_hooks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8cef7473fed6ed6aa806fb5f3b9f9deefa2cd20 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/optimizer_overlap_hooks.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/post_localSGD_hook.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/post_localSGD_hook.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a09c24624d2564dd7c2b6a0453d3d361a9b2716 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/post_localSGD_hook.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/powerSGD_hook.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/powerSGD_hook.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21598bb687f79950a7d5a8bf61ae708c82edad3b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/powerSGD_hook.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/quantization_hooks.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/quantization_hooks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..690b73ec09ed465682b160fc42421292544c7d3d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/quantization_hooks.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..fa8c865c89151033b379d6cd4785fd15e002cd66 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py @@ -0,0 +1,457 @@ +# mypy: allow-untyped-defs +import weakref +from collections.abc import Callable +from typing import Any + +import torch +import torch.distributed as dist +from torch.distributed.optim import ZeroRedundancyOptimizer +from torch.distributed.optim.zero_redundancy_optimizer import _OverlapStatus +from torch.nn.parallel.distributed import DistributedDataParallel + + +__all__ = ["hook_with_zero_step", "hook_with_zero_step_interleaved"] + +# Functional optimizers require passing a list of gradients to their `step()` +# method, and ZeRO requires a functional optimizer to overlap with DDP +# Passing a `None` instead of an actual gradient indicates to the optimizer +# to not update the corresponding parameter +_NO_PARAM_UPDATE: None = None + + +def _perform_local_step( + bucket: dist.GradBucket, + zero: ZeroRedundancyOptimizer, + rank: int, +): + r""" + Perform a local optimizer step using the gradients provided by ``bucket``. + + Arguments: + bucket (dist.GradBucket): the bucket providing the gradients. + zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer` + instance to perform the :meth:`_local_step`. + rank (int): the calling process's rank. + + .. warning:: + This function assumes that appropriate synchronization has taken place + so that the bucket's gradients can be used. + """ + overlap_info = zero._overlap_info + bucket_index = bucket.index() + assert len(zero.optim.param_groups) == 1, ( + "Overlapping DDP with ZeRO only supports a single parameter group" + ) + + # Construct the `gradients` input for the local optimizer step, which + # expects `None` in a list position to indicate that the corresponding + # parameter should not be updated + num_local_optim_params = len(zero.optim.param_groups[0]["params"]) + gradients: list[torch.Tensor | None] = [ + _NO_PARAM_UPDATE for _ in range(num_local_optim_params) + ] + assert bucket_index in overlap_info.offsets, ( + f"Bucket index {bucket_index} was not assigned to rank {rank}" + ) + gradients_offset = overlap_info.offsets[bucket_index] + bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index] + bucket_offset = bucket_assignment.offset + length = len(bucket_assignment.parameters) + bucket_gradients = bucket.gradients()[bucket_offset : bucket_offset + length] + for i, grad in enumerate(bucket_gradients): + gradients[gradients_offset + i] = grad + + zero._local_step(gradients) + + +def _broadcast_bucket( + bucket_index: int, + zero: ZeroRedundancyOptimizer, +): + r""" + Broadcasts a bucket's parameters. + + Arguments: + bucket_index (int): the index of the bucket corresponding to the + parameters to broadcast. + zero (ZeroRedundancyOptimizer): the calling process's + :class:`ZeroRedundancyOptimizer` instance. + """ + overlap_info = zero._overlap_info + assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, ( + "`assigned_ranks_per_bucket` is not fully constructed" + ) + # Sort to ensure the same ordering across ranks + assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index]) + assert len(assigned_ranks) > 0, ( + f"Bucket {bucket_index} should be assigned to at least one rank" + ) + for assigned_rank in assigned_ranks: + bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank] + if bucket_index in bucket_assignments: + send_tensor = bucket_assignments[bucket_index].tensor + assert send_tensor is not None + overlap_info.broadcast_handles.append( + dist.broadcast( + send_tensor, + src=dist.get_global_rank(zero.process_group, assigned_rank), + group=zero.process_group, + async_op=True, + ) + ) + + +def _save_ddp_bucket_info( + bucket: dist.GradBucket, + zero: ZeroRedundancyOptimizer, +): + r""" + Save :class:`DistributedDataParallel` gradient bucket information for :class:`ZeroRedundancyOptimizer` instance ``zero``. + + In particular, this function is meant to be called upon seeing each + gradient bucket to use when overlapping, meaning it does not save or compute any global + information. + + Arguments: + bucket (dist.GradBucket): the current gradient bucket. + zero (ZeroRedundancyOptimizer): the calling process's + :class:`ZeroRedundancyOptimizer` instance. + """ + overlap_info = zero._overlap_info + bucket_params = bucket.parameters() + assert len(bucket_params) > 0, "Empty bucket" + + # Save the parameters in the bucket + overlap_info.params_per_bucket.append(bucket_params) + if overlap_info.shard_buckets: + # Additionally save the bucket size for the assignment heuristic to use + bucket_size = 0 + for param in bucket_params: + bucket_size += param.numel() + assert overlap_info.total_size is not None + overlap_info.total_size += bucket_size + + +def _hook_with_zero_step_setup( + ddp_ref: weakref.ReferenceType, + zero: ZeroRedundancyOptimizer, + bucket: dist.GradBucket, +): + r""" + Encapsulate the setup logic for :func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`. + + This means the logic to run in the + hook before the backward pass and optimizer step can actually be + overlapped. This is factored out since it is common to both + :func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`. + + Arguments: + ddp_ref (weakref.ReferenceType): weak reference to the process's + :class:`DistributedDataParallel` instance. + zero (ZeroRedundancyOptimizer): the calling process's + :class:`ZeroRedundancyOptimizer` instance. + bucket (dist.GradBucket): the current gradient bucket. + """ + # Proceed as normal until the DDP buckets have been rebuilt + if not ddp_ref()._has_rebuilt_buckets: # type: ignore[union-attr] + assert zero._overlap_info.status == _OverlapStatus.UNINITIALIZED + return + + bucket_index = bucket.index() + overlap_info = zero._overlap_info + if overlap_info.status == _OverlapStatus.UNINITIALIZED: + overlap_info.status = _OverlapStatus.DDP_HAS_REBUILT_BUCKETS + + if overlap_info.status == _OverlapStatus.DDP_HAS_REBUILT_BUCKETS: + if bucket_index == 0 and len(overlap_info.params_per_bucket) > 0: + # This corresponds to the first bucket of the backward pass + # immediately after all information has been saved, so we + # can perform the delayed ZeRO initialization + zero._init_zero_for_overlap() + else: + # Once DDP buckets have been rebuilt but ZeRO has not been + # properly initialized yet, save the information needed + _save_ddp_bucket_info(bucket, zero) + + +def hook_with_zero_step( + hook: Callable[[Any, dist.GradBucket], torch.futures.Future], + ddp: DistributedDataParallel, + zero: ZeroRedundancyOptimizer, + shard_buckets: bool = False, +) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: + r""" + Modify ``hook`` to overlap :class:`ZeroRedundancyOptimizer` optimizer step with :class:`DistributedDataParallel` backward pass. + + This approach overlaps the optimizer computation and communication with the + backward communication. In particular, the backward computation proceeds + contiguously, and the optimizer computation follows, overlapping with + outstanding backward communication (i.e. all-reduces) and possibly other + optimizer communication (i.e. broadcasts). + The optimizer step computation begins after the last gradient bucket computation has finished. + + This approach may be preferred over :meth:`hook_with_zero_step_interleaved` + if communication is relatively slow compared to computation. + + Arguments: + hook (Callable[[Any, dist.GradBucket], torch.futures.Future]): the hook + to modify. + ddp (DistributedDataParallel): the :class:`DistributedDataParallel` + instance to use. + zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer` + instance to use. + shard_buckets (bool): if ``True``, then the assignment of each + :class:`DistributedDataParallel` bucket is partitioned across + possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e. + across possibly multiple ranks) to approximate uniformity; if + ``False``, then each bucket is wholly assigned to a single + :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank). + + Returns: + The modified hook. + + Raises: + ValueError: if ``zero`` was constructed with ``overlap_with_ddp=False``. + RuntimeError: if using any backend other than NCCL/HCCL since currently + Gloo may hang. + + .. warning:: + Given the way that overlapping :class:`DistributedDataParallel` with + :class:`ZeroRedundancyOptimizer` is currently implemented, the first + two or three training iterations do not perform parameter updates in + the optimizer step, depending on if ``static_graph=False`` or + ``static_graph=True``, respectively. This is because it needs + information about the gradient bucketing strategy used by + :class:`DistributedDataParallel`, which is not finalized until the + second forward pass if ``static_graph=False`` or until the third + forward pass if ``static_graph=True``. + """ + if not zero._overlap_with_ddp: + raise ValueError( + "ZeroRedundancyOptimizer must be constructed with " + "`overlap_with_ddp=True` to use this hook properly" + ) + ddp_ref = weakref.ref(ddp) + + # NOTE: Gloo may hang with this overlapping approach; see https://github.com/pytorch/pytorch/issues/62300 + pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] + if pg == dist.Backend.GLOO: + raise RuntimeError( + "Gloo backend using Overlapping DDP with ZeRO may meet hangs" + ) + + if shard_buckets: + zero._overlap_info.shard_buckets = True + zero._overlap_info.total_size = 0 + + def hook_with_zero_fn( + state: Any, + bucket: dist.GradBucket, + ) -> torch.futures.Future[torch.Tensor]: + r""" + Return :class:`Future` that runs the optimizer step if this corresponds to the last gradient bucket. + + Perform equivalent of :class:`ZeroRedundancyOptimizer` :meth:`step` if ``bucket`` is last gradient bucket. + The function gives a gradient bucket tensor and + performs additional computation on the iteration that + the :class:`DistributedDataParallel` buckets are rebuilt to collect + information used to implement the modified hook. + + Arguments: + state (Any): any state for the hook. + bucket (dist.GradBucket): the :class:`DistributedDataParallel` + gradient bucket. + """ + fut = hook(state, bucket) + _hook_with_zero_step_setup(ddp_ref, zero, bucket) + if zero._overlap_info.status != _OverlapStatus.INITIALIZED: + return fut + + overlap_info = zero._overlap_info + bucket_index = bucket.index() + rank = zero.global_rank + + assert overlap_info.status == _OverlapStatus.INITIALIZED + assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, ( + "`assigned_ranks_per_bucket` is not fully constructed" + ) + assigned_to_bucket = ( + rank in overlap_info.assigned_ranks_per_bucket[bucket_index] + ) + + # Save the bucket reference and all-reduce future for the final bucket + if assigned_to_bucket: + overlap_info.bucket_index_to_bucket[bucket_index] = bucket + overlap_info.bucket_index_to_future[bucket_index] = fut + + # Check that buckets are indexed incrementally starting from 0 in the + # order of their autograd hooks firing + if len(overlap_info.bucket_indices_seen) > 0: + assert overlap_info.bucket_indices_seen[-1] == bucket_index - 1, ( + "Bucket indices are not in incremental order" + ) + else: + assert bucket_index == 0, "Bucket indices do not start from 0" + overlap_info.bucket_indices_seen.append(bucket_index) + + # Directly return the future without any optimizer computation if this + # is not the last bucket + num_buckets = len(overlap_info.params_per_bucket) + is_last_bucket = bucket_index == num_buckets - 1 + if not is_last_bucket: + return fut + + # Perform partial optimizer step on all buckets after the final + # bucket has been computed + # NOTE: This should not be chained as a callback to the last bucket's + # all-reduce future since that would add synchronization that delays + # all optimizer computation to wait for that last all-reduce + for bucket_index in range(num_buckets): + assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index] + if rank in assigned_ranks: + # Wait on the bucket's all-reduce future to ensure correct + # gradients + assert bucket_index in overlap_info.bucket_index_to_future, ( + f"All-reduce future for bucket {bucket_index} not saved " + f"on rank {rank}" + ) + allreduce_future = overlap_info.bucket_index_to_future[bucket_index] + allreduce_future.wait() + + # Perform the partial optimizer step + curr_bucket = overlap_info.bucket_index_to_bucket[bucket_index] + _perform_local_step(curr_bucket, zero, rank) + + _broadcast_bucket(bucket_index, zero) + + # Ensure that all parameter updates are finished before the + # next forward pass + overlap_info.wait_for_broadcasts() + overlap_info.clear_per_iter_info() + + return fut + + return hook_with_zero_fn + + +def hook_with_zero_step_interleaved( + hook: Callable[[Any, dist.GradBucket], torch.futures.Future], + ddp: DistributedDataParallel, + zero: ZeroRedundancyOptimizer, + shard_buckets: bool = False, +) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: + r""" + Modify ``hook`` to overlap :class:`ZeroRedundancyOptimizer` optimizer step with :class:`DistributedDataParallel` backward pass + + This approach overlaps the optimizer computation and communication with the + backward computation and communication. In particular, once a bucket's + gradients have been computed, the optimizer computation using those + gradients is launched (though the actual computation must wait for the + bucket's all-reduce to complete). This yields an interleaving of all- + reduces and broadcasts in the communication stream. + + This approach may be preferred over :meth:`hook_with_zero_step` if + communication is relatively fast compared to computation. + + Arguments: + hook (Any * dist.GradBucket -> torch.futures.Future): the hook to + modify. + ddp (DistributedDataParallel): the :class:`DistributedDataParallel` + instance to use. + zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer` + instance to use. + shard_buckets (bool): if ``True``, then the assignment of each + :class:`DistributedDataParallel` bucket is partitioned across + possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e. + across possibly multiple ranks) to approximate uniformity; if + ``False``, then each bucket is wholly assigned to a single + :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank). + + Returns: + The modified hook. + + Raises: + ValueError: if ``zero`` was constructed with ``overlap_with_ddp=False``. + RuntimeError: if using any backend other than NCCL since currently + Gloo may hang. + + .. warning:: + Given the way that overlapping :class:`DistributedDataParallel` with + :class:`ZeroRedundancyOptimizer` is currently implemented, the first + two or three training iterations do not perform parameter updates in + the optimizer step, depending on if ``static_graph=False`` or + ``static_graph=True``, respectively. This is because it needs + information about the gradient bucketing strategy used by + :class:`DistributedDataParallel`, which is not finalized until the + second forward pass if ``static_graph=False`` or until the third + forward pass if ``static_graph=True``. + """ + if not zero._overlap_with_ddp: + raise ValueError( + "ZeroRedundancyOptimizer must be constructed with " + "`overlap_with_ddp=True` to use this hook properly" + ) + ddp_ref = weakref.ref(ddp) + + # NOTE: Gloo may hang with this overlapping approach; see https://github.com/pytorch/pytorch/issues/62300 + pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] + if pg == dist.Backend.GLOO: + raise RuntimeError( + "Gloo backend using Overlapping DDP with ZeRO may meet hangs" + ) + + if shard_buckets: + zero._overlap_info.shard_buckets = True + zero._overlap_info.total_size = 0 + + def hook_with_zero_interleaved_fn( + state, + bucket: dist.GradBucket, + ) -> torch.futures.Future[torch.Tensor]: + r""" + Return :class:`Future` that gives gradient bucket tensor and performs partial :class:`ZeroRedundancyOptimizer` :meth:`step`. + + This function uses the gradients in gradient in given bucket to perform a partial + :class:`ZeroRedundancyOptimizer` :meth:`step` + + Arguments: + state: any state for the hook. + bucket (dist.GradBucket): the :class:`DistributedDataParallel` + gradient bucket. + """ + fut = hook(state, bucket) + _hook_with_zero_step_setup(ddp_ref, zero, bucket) + if zero._overlap_info.status != _OverlapStatus.INITIALIZED: + return fut + + def zero_step(fut: torch.futures.Future) -> torch.Tensor: + r""" + Perform partial :class:`ZeroRedundancyOptimizer` :meth:`step` using gradients in the :class:`DistributedDataParallel`. + + Returns: + A :class:`torch.Tensor` representing the contents of the + gradient bucket. + """ + overlap_info = zero._overlap_info + bucket_index = bucket.index() + rank = zero.global_rank + + assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index] + overlap_info.bucket_indices_seen.append(bucket_index) + if rank in assigned_ranks: + _perform_local_step(bucket, zero, rank) + + _broadcast_bucket(bucket_index, zero) + + num_buckets = len(overlap_info.params_per_bucket) + if len(overlap_info.bucket_indices_seen) == num_buckets: + # Ensure that all parameter updates are finished before the + # next forward pass + overlap_info.wait_for_broadcasts() + overlap_info.clear_per_iter_info() + + return bucket.buffer() + + return fut.then(zero_step) + + return hook_with_zero_interleaved_fn diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..53a184839a06f4787471f14f48137f4aa344fd91 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py @@ -0,0 +1,29 @@ +from typing import Any + +import torch +from torch.distributed import GradBucket + + +__all__ = ["noop_hook"] + + +def noop_hook(_: Any, bucket: GradBucket) -> torch.futures.Future[torch.Tensor]: + """ + Return a future that wraps the input, so it is a no-op that does not incur any communication overheads. + + This hook should **only** be used for headroom analysis of allreduce optimization, + instead of the normal gradient synchronization. + For example, if only less than 10% speedup of training time can be observed after this hook is registered, + it usually implies that allreduce is not a performance bottleneck for this case. + Such instrumentation can be particularly useful + if GPU traces cannot be easily retrieved or the trace analysis is complicated + some factors such as the overlap between allreduce and computation or the desynchronization across ranks. + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(None, noop_hook) + """ + fut: torch.futures.Future[torch.Tensor] = torch.futures.Future() + fut.set_result(bucket.buffer()) + + return fut diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..20a0de7ef318c10f3b5bbdaf98483d9fd19b2691 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -0,0 +1,211 @@ +# mypy: allow-untyped-defs +from collections.abc import Callable +from typing import Any, cast + +import torch +import torch.distributed as dist + + +__all__ = [ + "allreduce_hook", + "fp16_compress_hook", + "bf16_compress_hook", + "fp16_compress_wrapper", + "bf16_compress_wrapper", +] + + +def _allreduce_fut( + process_group: dist.ProcessGroup, tensor: torch.Tensor +) -> torch.futures.Future[torch.Tensor]: + """Average the input gradient tensor by allreduce and returns a future.""" + group_to_use = process_group if process_group is not None else dist.group.WORLD + + # Apply the division first to avoid overflow, especially for FP16. + # pyrefly: ignore [missing-attribute] + tensor.div_(group_to_use.size()) + + return ( + dist.all_reduce(tensor, group=group_to_use, async_op=True) + .get_future() + .then(lambda fut: fut.value()[0]) + ) + + +def allreduce_hook( + process_group: dist.ProcessGroup, bucket: dist.GradBucket +) -> torch.futures.Future[torch.Tensor]: + """ + Call ``allreduce`` using ``GradBucket`` tensors. + + Once gradient tensors are aggregated across all workers, its ``then`` + callback takes the mean and returns the result. + + If user registers this DDP communication hook, + DDP results is expected to be same as the case where no hook was registered. + Hence, this won't change behavior of DDP and user can use this as a reference + or modify this hook to log useful information or any other purposes while + unaffecting DDP behavior. + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(process_group, allreduce_hook) + """ + return _allreduce_fut(process_group, bucket.buffer()) + + +def _compress_hook( + dtype: torch.dtype, + process_group: dist.ProcessGroup, + bucket: dist.GradBucket, +) -> torch.futures.Future[torch.Tensor]: + group_to_use = process_group if process_group is not None else dist.group.WORLD + # pyrefly: ignore [missing-attribute] + world_size = group_to_use.size() + + buffer = ( + cast(tuple[torch.Tensor, ...], bucket)[0] + if isinstance(bucket, tuple) + else bucket.buffer() + ) + compressed_tensor = buffer.to(dtype).div_(world_size) + + def decompress(fut): + decompressed_tensor = buffer + # Decompress in place to reduce the peak memory. + # See: https://github.com/pytorch/pytorch/issues/45968 + value = fut if isinstance(fut, torch.Tensor) else fut.value()[0] + decompressed_tensor.copy_(value) + return decompressed_tensor + + if torch.compiler.is_compiling(): + grad = dist._functional_collectives.all_reduce( + compressed_tensor, + "sum", + # pyrefly: ignore [bad-argument-type] + group_to_use, + ) + return decompress(grad) + else: + fut = dist.all_reduce( + compressed_tensor, group=group_to_use, async_op=True + ).get_future() + return fut.then(decompress) + + +def fp16_compress_hook( + process_group: dist.ProcessGroup, + bucket: dist.GradBucket, +) -> torch.futures.Future[torch.Tensor]: + """ + Compress by casting ``GradBucket`` to ``torch.float16`` divided by process group size. + + This DDP communication hook implements a simple gradient compression + approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``) + and then divides it by the process group size. + It allreduces those ``float16`` gradient tensors. Once compressed gradient + tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``). + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook) + """ + return _compress_hook(torch.float16, process_group, bucket) + + +def bf16_compress_hook( + process_group: dist.ProcessGroup, + bucket: dist.GradBucket, +) -> torch.futures.Future[torch.Tensor]: + """ + Warning: This API is experimental, and it requires NCCL version later than 2.9.6. + + This DDP communication hook implements a simple gradient compression + approach that casts ``GradBucket`` tensor to half-precision + `Brain floating point format `_ (``torch.bfloat16``) + and then divides it by the process group size. + It allreduces those ``bfloat16`` gradient tensors. Once compressed gradient + tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``). + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(process_group, bf16_compress_hook) + """ + return _compress_hook(torch.bfloat16, process_group, bucket) + + +def fp16_compress_wrapper( + hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]], +) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: + """ + Cast input tensor to ``torch.float16``, cast result of hook back to input dtype. + + This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision + floating point format (``torch.float16``), and casts the resulting tensor of the given hook back to + the input data type, such as ``float32``. + Therefore, ``fp16_compress_hook`` is equivalent to ``fp16_compress_wrapper(allreduce_hook)``. + + Example:: + >>> # xdoctest: +SKIP + >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) + >>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook)) + """ + + def fp16_compress_wrapper_hook( + hook_state, bucket: dist.GradBucket + ) -> torch.futures.Future[torch.Tensor]: + # Cast bucket tensor to FP16. + bucket.set_buffer(bucket.buffer().to(torch.float16)) + + fut = hook(hook_state, bucket) + + def decompress(fut): + decompressed_tensor = bucket.buffer() + # Decompress in place to reduce the peak memory. + # See: https://github.com/pytorch/pytorch/issues/45968 + decompressed_tensor.copy_(fut.value()) + return decompressed_tensor + + # Decompress after hook has run. + return fut.then(decompress) + + return fp16_compress_wrapper_hook + + +def bf16_compress_wrapper( + hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]], +) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: + """ + Warning: This API is experimental, and it requires NCCL version later than 2.9.6. + + This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision + `Brain floating point format `_ (``torch.bfloat16``), + and casts the resulting tensor of the given hook back to the input data type, such as ``float32``. + + Therefore, ``bf16_compress_hook`` is equivalent to ``bf16_compress_wrapper(allreduce_hook)``. + + Example:: + >>> # xdoctest: +SKIP + >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) + >>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook)) + """ + + def bf16_compress_wrapper_hook( + hook_state, bucket: dist.GradBucket + ) -> torch.futures.Future[torch.Tensor]: + # Cast bucket tensor to BF16. + bucket.set_buffer(bucket.buffer().to(torch.bfloat16)) + + fut = hook(hook_state, bucket) + + def decompress(fut): + decompressed_tensor = bucket.buffer() + # Decompress in place to reduce the peak memory. + # See: https://github.com/pytorch/pytorch/issues/45968 + decompressed_tensor.copy_(fut.value()) + return decompressed_tensor + + # Decompress after hook has run. + return fut.then(decompress) + + return bf16_compress_wrapper_hook diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..f1968042e5e21aa1b6714f78356b43896cccdf60 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py @@ -0,0 +1,86 @@ +from dataclasses import dataclass +from typing import Any, no_type_check + +import torch +import torch.distributed as dist +from torch.autograd import Variable +from torch.distributed.utils import _free_storage + + +@dataclass +class _AllreduceUpcastHookState: + """ + State to manage DDP mixed precision in backward / gradient communication. + + This contains a weakref to the DDP module for access to reducer and process + group, and a stream to run parameter and gradient upcasts. + """ + + ddp_weakref: Any + upcast_stream: torch.Stream + wait_for_stream_enqueued: bool = False + + +@no_type_check +def _reducer_allreduce_and_upcast_hook( + hook_state: _AllreduceUpcastHookState, bucket: dist.GradBucket +) -> torch.futures.Future[torch.Tensor]: + """ + Perform allreduce in precision ``reduce_dtype``, upcast to prepare for optimizer. + + Performs allreduce in the reduced precision given by DDP's mixed precision + reduce_dtype, and upcasts parameters and gradients to fp32 in preparation + to run the optimizer. + """ + ddp_weakref = hook_state.ddp_weakref + reducer, process_group = ddp_weakref().reducer, ddp_weakref().process_group + # Cast bucket if different than param_dtype. + if ( + ddp_weakref().mixed_precision.param_dtype + != ddp_weakref().mixed_precision.reduce_dtype + ): + # Cast bucket tensor to reduce_dtype + bucket.set_buffer( + bucket.buffer().to(ddp_weakref().mixed_precision.reduce_dtype) + ) + fut = reducer._run_allreduce_hook(bucket) + ret_fut = torch.futures.Future() + stream = hook_state.upcast_stream + with stream: + fut.wait() + bucket.buffer().div_(process_group.size()) + ret_fut.set_result(bucket.buffer()) + + # Upcast parameters and gradients so optimizer step can run in fp32. + for p in bucket.parameters(): + p.data = p._fp_param + # free storage for mp param as it will be allocated again in next + # forward pass. + _free_storage(p._mp_param) + p.grad.data = p.grad.to(p.data.dtype) + + # enqueue a callback to wait for this stream at end of backward + def wait_for_stream_cb(): + torch.accelerator.current_stream().wait_stream(stream) + # Remove post-backward hooks since they are re-installed in next + # iteration, similar to FSDP. + # Parameters that don't require grad still needed to be casted since + # they may participate in computation. However, they would not be recast + # by hook above as they don't have a grad hook installed, so cast them + # back here. + for _, p in ddp_weakref().module.named_parameters(): + if hasattr(p, "_ddp_mp_hook_state"): + p._ddp_mp_hook_state[1].remove() + delattr(p, "_ddp_mp_hook_state") + if not p.requires_grad and not hasattr(p, "_ddp_ignored"): + p.data = p._fp_param + + # reset for next backward pass + hook_state.wait_for_stream_enqueued = False + + if not hook_state.wait_for_stream_enqueued: + Variable._execution_engine.queue_callback(wait_for_stream_cb) + # mark that the callback is enqueued + hook_state.wait_for_stream_enqueued = True + + return ret_fut diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..162160e394ad0b634365f941f3a9f216bf1aa2d8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py @@ -0,0 +1,163 @@ +# mypy: allow-untyped-defs +from collections.abc import Callable +from dataclasses import dataclass +from functools import partial +from typing import Any, no_type_check + +import torch +import torch.distributed as dist +from torch.autograd import Variable + + +__all__: list[str] = [] + +_FUNCTIONAL_OPTIM_STEP_METHOD_NAME = "step_param" + + +class _OptimizerHookState: + """ + Holds state for running optimizer in-line after DDP communication hook. + + Currently contains only optimizer class which must have a method `step_param`. + """ + + __slots__ = ["functional_optimizer", "params_to_optimize"] + + def __init__(self, functional_optim, params=None): + self.functional_optimizer = functional_optim + self._check_valid_functional_optim() + self._set_params_to_optimize(params) + + def _set_params_to_optimize(self, params): + if params is not None: + self.params_to_optimize = set(params) + + def _check_valid_functional_optim(self): + if not hasattr(self.functional_optimizer, _FUNCTIONAL_OPTIM_STEP_METHOD_NAME): + raise ValueError( + f"Class {type(self.functional_optimizer)} must implement method " + f"{_FUNCTIONAL_OPTIM_STEP_METHOD_NAME}." + ) + + +@dataclass +class _OptimInBackwardHookState: + optim_stream: torch.Stream + wait_for_optim_stream_enqueued: bool + + +@no_type_check +def _apply_optim_in_backward_hook( + gradient_is_bucket_view: bool, +) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: + r""" + Register hook to apply the optimizer in backward. + + If torch.distributed.optim._apply_optimizer_in_backward is used to overlap + optimizer with backward pass, DDP will run the below hook to run optimizer + step for parameters after gradient communication has taken place. + """ + optim_in_bwd_state = _OptimInBackwardHookState( + optim_stream=torch.Stream(), + wait_for_optim_stream_enqueued=False, + ) + + def apply_optim_in_backward_hook( + hook_state: Any, + bucket: dist.GradBucket, + optim_stream_state, + ) -> torch.futures.Future[torch.Tensor]: + # Run original hook + ddp_weakref = hook_state + ddp_inst = ddp_weakref() + reducer, process_group = ddp_inst.reducer, ddp_inst.process_group + fut = reducer._run_allreduce_hook(bucket) + optimizer_stream = optim_stream_state.optim_stream + with optimizer_stream: + fut.wait() + # Apply gradient division since C++ side only allreduces and does + # not average. TODO: (rohan-varma) the div factor may be different + # when running with join hook + bucket.buffer().div_(process_group.size()) + model_params = bucket.parameters() + grads = bucket.gradients() + # TODO (rohan-varma): upcast as needed for DDP mixed precision, + # once optimizer in backward + DDP mixed precision is supported. + for p, g in zip(model_params, grads): + if hasattr(p, "_in_backward_optimizers"): + # Note: need to set grad to the bucket's grad, because + # running allreduce results in the bucket's grad being + # reduced, but not grad field. + if not gradient_is_bucket_view: + p.grad = g + for optim in p._in_backward_optimizers: + optim.step() + + # Need to return a Future[Tensor] to obey comm hook API contract. + ret_fut = torch.futures.Future() + ret_fut.set_result(bucket.buffer()) + + # enqueue a callback to wait for this optimizer stream at the end of + # backward and set all DDP managed grads to None. + def wait_for_optim_stream_callback(): + torch.accelerator.current_stream().wait_stream( + optim_stream_state.optim_stream + ) + # Set DDP managed grads to None + for param in ddp_inst._get_data_parallel_params(ddp_inst.module): + if hasattr(param, "_in_backward_optimizers"): + param.grad = None + + # reset for the next backwards pass + optim_stream_state.wait_for_optim_stream_enqueued = False + + if not optim_stream_state.wait_for_optim_stream_enqueued: + Variable._execution_engine.queue_callback(wait_for_optim_stream_callback) + # mark that the callback is enqueued + optim_stream_state.wait_for_optim_stream_enqueued = True + + return ret_fut + + comm_hook = partial( + apply_optim_in_backward_hook, optim_stream_state=optim_in_bwd_state + ) + # These are needed for DDP's logging of comm hooks + comm_hook.__name__ = apply_optim_in_backward_hook.__name__ + comm_hook.__qualname__ = apply_optim_in_backward_hook.__qualname__ + + return comm_hook + + +def _hook_then_optimizer( + hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]], + optimizer_state: _OptimizerHookState, +) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: + r"""Run optimizer in a functional fashion after DDP communication hook.""" + has_set_params = ( + hasattr(optimizer_state, "params_to_optimize") + and optimizer_state.params_to_optimize is not None + ) + + def hook_then_optimizer_wrapper( + hook_state, bucket: dist.GradBucket + ) -> torch.futures.Future[torch.Tensor]: + # Run original hook + fut = hook(hook_state, bucket) + + def optimizer_step(fut): + gradient_tensors = bucket.gradients() + model_params = bucket.parameters() + for grad_tensor, model_param in zip(gradient_tensors, model_params): + if ( + not has_set_params + or model_param in optimizer_state.params_to_optimize + ): + optimizer_state.functional_optimizer.step_param( + model_param, + grad_tensor, + ) + return bucket.buffer() + + return fut.then(optimizer_step) + + return hook_then_optimizer_wrapper diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..ff513f62183c516b96c62ca89eee51d2b1793e85 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py @@ -0,0 +1,124 @@ +# mypy: allow-untyped-defs +import logging + +import torch +import torch.distributed as dist + +from . import default_hooks as default + + +logger = logging.getLogger(__name__) + + +class PostLocalSGDState: + r""" + Store state for all-reducing gradients globally until given step, then locally after. + + Stores the state for all-reducing gradients globally using ``process_group`` until step ``start_localSGD_iter``, + and all-reducing gradients locally using ``subgroup`` afterwards. + + If ``process_group`` is ``None``, the global process group will be used. + If ``subgroup`` is ``None``, the intra-node process group on each machine will be used. + + Additionally, ``post_local_gradient_allreduce`` may be worth tuning, + because both true and false may give a faster convergence. + """ + + __slots__ = [ + "process_group", + "subgroup", + "start_localSGD_iter", + "post_local_gradient_allreduce", + "iter", + ] + + def __init__( + self, + process_group, + subgroup, + start_localSGD_iter, + post_local_gradient_allreduce=True, + ): + """Initialize state object with given parameters and log when localSGD start.""" + logger.info( + "Local SGD will be started after %s iterations", start_localSGD_iter + ) + + # The group used for all-reducing gradients globally. + self.process_group = process_group + # The group used for all-reducing gradients locally. + self.subgroup = subgroup + self.start_localSGD_iter = start_localSGD_iter + # Allreduce gradients locally since iteration `start_localSGD_iter`. + # This may help with the convergence efficiency at the cost of relatively cheap intra-subgroup communication. + self.post_local_gradient_allreduce = post_local_gradient_allreduce + # Iteration/step in the training loop. + self.iter = 0 + + def maybe_increase_iter(self, bucket): + """Track iterations and trigger log message at start of local SGD.""" + # Since bucket 0 is the last bucket to allreduce in an iteration. + # Only increase `iter` when bucket 0 is processed. + if bucket.is_last(): + self.iter += 1 + + if self.iter == self.start_localSGD_iter: + logger.info("Start to apply local SGD after %s iterations.", self.iter) + + +def post_localSGD_hook( + state: PostLocalSGDState, bucket: dist.GradBucket +) -> torch.futures.Future[torch.Tensor]: + """ + Run post-localSGD algorithm. + + This DDP communication hook is used for running post-localSGD algorithm, + by combining with a model averaging component (e.g., + :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager`) + that runs after the optimizer step. + + Args: + state (PostLocalSGDState): State information to run post-localSGD. + Users mainly need to tune ``start_localSGD_iter`` to determine when to start local SGD. + bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. + Note that since DDP comm hook only supports single process single device mode, + only exactly one tensor is stored in this bucket. + + Returns: + Future handler of the communication, which updates the gradients in place. + + Example:: + >>> # xdoctest: +SKIP + >>> state = PostLocalSGDState(process_group=process_group, subgroup=subgroup, + start_localSGD_iter=10) + >>> ddp_model.register_comm_hook(state, post_localSGD_hook) + >>> # Also need to establish a model averaging module and run model averaging after ``optimizer.step()``. + >>> # Please refer to the examples in ``torch.distributed.algorithms.model_averaging.averagers`` module. + """ + global_group_to_use = ( + state.process_group if state.process_group is not None else dist.group.WORLD + ) + + # The input tensor is a flattened 1D tensor. + input_tensor = bucket.buffer() + + # Run allreduce using `global_group_to_use` in the first `start_localSGD_iter` iterations. + if state.iter < state.start_localSGD_iter: + state.maybe_increase_iter(bucket) + return default._allreduce_fut(global_group_to_use, input_tensor) # type: ignore[arg-type] + + # If `post_local_gradient_allreduce` is not set, + # then no gradient synchronization after the first `start_localSGD_iter` iterations. + if not state.post_local_gradient_allreduce: + fut: torch.futures.Future[torch.Tensor] = torch.futures.Future() + fut.set_result(input_tensor) + return fut + + # Run allreduce using `subgroup` after the first `start_localSGD_iter` iterations. + # Note that by default, a separate subgroup for each node is created which + # causes an intra-node allreduce to be done at each training step. + # From this moment, model averaging should run after the optimizer step, + # to globally allreduce all the parameters. + if state.subgroup is None: + state.subgroup, _ = dist.new_subgroups() + return default._allreduce_fut(state.subgroup, input_tensor) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..f1e95d12514eda18b52ae07a44a68e1678bd27a9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -0,0 +1,862 @@ +# mypy: allow-untyped-defs +import logging +import math +from collections import defaultdict + +import torch +import torch.distributed as dist +from torch.distributed import distributed_c10d +from torch.utils._typing_utils import not_none + +from . import default_hooks as default + + +__all__ = ["PowerSGDState", "powerSGD_hook", "batched_powerSGD_hook"] + +logger = logging.getLogger(__name__) + + +def _orthogonalize(matrices, epsilon=0): + """ + Decide between Gram-Schmidt or QR factorization to orthogonalize a batch of matrices. + + QR factorization doesn't work with half-precision, but it is usually faster with a rank > 2. + """ + assert len(matrices.shape) == 3 and matrices.shape[2] <= matrices.shape[1] + + num_matrices = matrices.shape[0] + rank = matrices.shape[2] + dtype = matrices.dtype + if rank <= 2 or dtype in [torch.float16, torch.bfloat16]: + _orthogonalize_gram_schmidt(matrices, epsilon=epsilon) + else: + torch.linalg.qr( + matrices, + out=( + matrices, + torch.empty( + num_matrices, rank, rank, device=matrices.device, dtype=dtype + ), + ), + ) + + +def _orthogonalize_gram_schmidt(matrices, epsilon=0): + """ + Apply Gram-Schmidt procedure to orthogonalize a batch of matrices. + + If epsilon is 0, this is equivalent to `torch.qr(matrices, out=(matrices, _))`, + """ + num_cols = matrices.shape[2] + for i in range(num_cols): + # Normalize the i'th column. + col = matrices[:, :, i : i + 1] + # If no epsilon is added here, division by zero may be caused by vanishing gradients. + # This epsilon is not needed if the input batch of matrices covers the gradients of at least one entire layer + # in the neural network. + if epsilon == 0: + # Note that col ** 2 can underflow/overflow if we use FP16. + # May need to consider multiplying a scaling factor and dividing it later, or using bfloat16 instead. + try: + col /= torch.norm(col, dim=1, keepdim=True) + except ZeroDivisionError: + logger.error( + "The matrices to be orthogonalized has at least a column of all 0s. Please set a small value such as 1e-8 " + "as `orthogonalization_epsilon` in PowerSGD state." + ) + # Recover the values from NaNs to 0s. + col.fill_(0.0) + else: + col /= torch.norm(col, dim=1, keepdim=True) + epsilon + # Project it on the rest and remove it. + if i + 1 < num_cols: + rest = matrices[:, :, i + 1 :] + rest -= torch.sum(col * rest, dim=1, keepdim=True) * col + + +def _should_compress( + num_rows, num_cols, matrix_approximation_rank, min_compression_rate +): + """ + Recommend if tensor given is worth compressing. + + Returns a recommendation as to whether the 2D tensor described by the arguments is worth compressing, + including statistics describing the expected savings from compression. We consider a tensor worth + compressing when ``min_compression_rate`` < uncompressed size / compressed size, where + uncompressed size = ``num_rows`` * ``num_cols``, + and compressed size = (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``. + + The result of this function is a tuple of the form (compression_recommendation, uncompressed_el_count, compressed_el_count), where: + + compression_recommendation is true if the tensor is worth compressing, and false otherwise (see above); + + uncompressed_el_count is the uncompressed element count, i.e. ``num_rows`` * ``num_cols``; and, + + compress_el_count is the element count after compression, i.e. (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``. + """ # noqa: B950 + uncompressed_size = num_rows * num_cols + compressed_size = (num_rows + num_cols) * matrix_approximation_rank + return ( + compressed_size * min_compression_rate < uncompressed_size, + uncompressed_size, + compressed_size, + ) + + +def _report_compression_stats(bucket, state): + """Report compression stats at frequency of ``compression_stats_logging_frequency`` specified in PowerSGD state.""" + if bucket.is_last() and state.iter >= state.next_stats_report: + stats = state.compression_stats() + logger.info( + "Compression stats: iter %s, total before compression %s, total after compression %s, " + "rate %s", + state.iter, + stats[1], + stats[2], + stats[0], + ) + state.next_stats_report = state.iter + state.compression_stats_logging_frequency + + +class PowerSGDState: + r""" + Store both the algorithm's hyperparameters and internal state for all gradients during training. + + Particularly, ``matrix_approximation_rank`` and ``start_powerSGD_iter`` are the main hyperparameters that should be tuned by the user. + For performance, we suggest to keep binary hyperparameters ``use_error_feedback`` and ``warm_start`` on. + + 1. ``matrix_approximation_rank`` controls the size of compressed low-rank tensors, which determines the compression rate. The lower the rank, the stronger the compression. + + 1.1. If ``matrix_approximation_rank`` is too low, the full model quality will need more training steps to reach or will never reach and yield loss in accuracy. + + 1.2. The increase of ``matrix_approximation_rank`` can substantially increase the computation costs of the compression, and the accuracy may not be further improved beyond a certain ``matrix_approximation_rank`` threshold. + + To tune ``matrix_approximation_rank``, we suggest to start from 1 and increase by factors of 2 (like an exponential grid search, 1, 2, 4, ...), until a satisfactory accuracy is reached. Typically only a small value 1-4 is used. For some NLP tasks (as shown in Appendix D of the original paper), this value has been increased to 32. + + 2. ``start_powerSGD_iter`` defers PowerSGD compression until step ``start_powerSGD_iter``, and vanilla allreduce runs prior to step ``start_powerSGD_iter``. This hybrid scheme of **vanilla allreduce + PowerSGD** can effectively improve the accuracy, even a relatively small ``matrix_approximation_rank`` is used. This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients, and compressing gradients too early may make the training quickly take a suboptimal trajectory, which can result in an irrecoverable impact on the accuracy. + + To tune ``start_powerSGD_iter``, we suggest to start with 10% of total training steps, and increase it until a satisfactory accuracy is reached. If there is a warm-up stage in the training, ``start_powerSGD_iter`` typically should be no less than the number of warm-up steps. + + 3. ``min_compression_rate`` is the minimum compression rate required when a layer is compressed. Due to the computation overheads incurred by the compression, a tensor is worth compressing only if there can be sufficient saving in bandwidth, where ``(num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols``. If the specified compression rate threshold cannot be satisfied, the tensor will be directly allreduced without compression. + + Compression statistics are logged every ``compression_stats_logging_frequency`` iterations once PowerSGD compression starts. + + 4. ``orthogonalization_epsilon`` can be a very small value (e.g., 1e-8) added to every normalized matrix column in orthogonalization step, to prevent div-by-zero error if any column has all 0s. If this can already be prevented (e.g., by batch normalization), an epsilon of 0 is recommended for accuracy. + + 5. ``batch_tensors_with_same_shape`` controls whether to compress and decompress tensors with same shape in a batched operation to achieve higher parallelism. Note that you should also increase the bucket size (i.e., ``bucket_cap_mb`` arg in DDP constructor) to make more same-shaped tensors appear in the same bucket, however this may reduce the overlap between computation and communication, and increase the memory footprint due to stacking the tensors of the same shape. Set to ``True`` if the compression / decompression computation is a bottleneck. + + .. warning :: + If error feedback or warm-up is enabled, the minimum value of ``start_powerSGD_iter`` allowed in DDP is 2. + This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP, + and this can conflict with any tensor memorized before the rebuild process. + """ # noqa: B950 + + __slots__ = [ + "process_group", + # The fields below are the hyperparameters that often need to be tuned by the user. + "matrix_approximation_rank", + "start_powerSGD_iter", + # The fields below are the hyperparameters that seldom need be tuned by the user. + "min_compression_rate", + "orthogonalization_epsilon", + # The fields below are the binary hyperparameters recommended to be turned on for performance and accuracy. + "use_error_feedback", + "warm_start", + "batch_tensors_with_same_shape", + # The fields below are internal state. + "rng", + "error_dict", + "p_memory_dict", + "q_memory_dict", + "iter", + # The fields below are for recording compression stats. + "total_numel_before_compression", + "total_numel_after_compression", + "compression_stats_logging_frequency", + "next_stats_report", + ] + + def __init__( + self, + process_group, + matrix_approximation_rank=1, + start_powerSGD_iter=1_000, + min_compression_rate=2, + use_error_feedback=True, + warm_start=True, + orthogonalization_epsilon=0, + random_seed=0, + compression_stats_logging_frequency=10_000, + batch_tensors_with_same_shape: bool = False, + ): + logger.info( + "PowerSGD config: matrix_approximation_rank = %s; start_powerSGD_iter = %s; " + "min_compression_rate = %s; orthogonalization_epsilon = %s; use_error_feedback = %s; warm_start = %s; " + "random_seed = %s; compression_stats_logging_frequency = %s; batch_tensors_with_same_shape = %s", + matrix_approximation_rank, + start_powerSGD_iter, + min_compression_rate, + orthogonalization_epsilon, + use_error_feedback, + warm_start, + random_seed, + compression_stats_logging_frequency, + batch_tensors_with_same_shape, + ) + + self.process_group = process_group + self.matrix_approximation_rank = matrix_approximation_rank + # Deferring PowerSGD compression util step 'start_powerSGD_iter' can have two advantages: + # 1) It turns out that PowerSGD may lead to a non-trivial accuracy loss, + # even if the matrix approximation rank is increased to a large value. + # To mitigate the accuracy loss, a simple yet effective way is mixing vanilla allreduce + # (or a more conservative compression such as FP16 compression) with PowerSGD. + # 2) There is an internal optimization of rebuilding buckets process in DDP, + # in order to save the memory space. + # This step takes place after the first iteration. + # However, this means that the shape of input bucketized tensors is subject to change, + # which will complicate the implementations of error feedback and warm-up. + # Running vanilla allreduce in the first few iterations can avoid this complexity. + if (use_error_feedback or warm_start) and start_powerSGD_iter <= 1: + raise ValueError( + "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, " + "because PowerSGD can only be applied after the first two iterations in DDP." + ) + self.start_powerSGD_iter = start_powerSGD_iter + self.min_compression_rate = min_compression_rate + # Error feedback is usually crucial for both for convergence and generalization, + # because PowerSGD is a biased compressor, + # i.e., compressing and decompressing a random gradient does not yield the original in expectation. + # This mechanism requires a temporary copy of the input gradients, + # so it increases the peak memory consumption by the size of the gradient tensor. + # However, if the target matrices are known to be exactly low-ranked (instead of just low stable rank), + # sometimes it is possible to converge to the optima without error feedback. + # See: http://proceedings.mlr.press/v54/yurtsever17a/yurtsever17a.pdf + self.use_error_feedback = use_error_feedback + # Warm-start reuses P(s) and Q(s) from the previous iteration. + # This can improve the approximation quality and hence improve the accuracy. + # Additionally, by avoiding the initialization of these low-rank tensors at every step, + # this can also accelerate training. + # However, this is at the cost of extra memory. + self.warm_start = warm_start + # Can use a very small value to prevent div-by-zero error caused by orthogonalization of vanishing gradients. + self.orthogonalization_epsilon = orthogonalization_epsilon + # The purpose of this RNG is to generate different random seeds for initializing Q across iterations, + # but in the same order for all the DDP replicas. + # Different random seeds across iterations indicate different 'projections' of the gradients at different SGD steps. + # If the same random projection is used, + # there will be differences between the gradients that are never synchronized. + import numpy as np + + self.rng = np.random.RandomState(random_seed) + # Since there is only a single state instance for all the input buckets, + # need to maintain a dictionary that maps each bucket index to the local error. + self.error_dict: dict[int, torch.Tensor] = {} + self.p_memory_dict: dict[int, torch.Tensor] = {} + self.q_memory_dict: dict[int, torch.Tensor] = {} + # Iteration/step in the training loop. + self.iter = 0 + # Compression stats accumulators + self.total_numel_before_compression = 0 + self.total_numel_after_compression = 0 + # We'll report compression stats every 'compression_stats_logging_frequency' iterations + # Note that we always report compression stats at least once. + self.compression_stats_logging_frequency = max( + 1, compression_stats_logging_frequency + ) + self.next_stats_report = 0 + # Batching tensors with same shape can increase parallelism in compression / decompression computation. + # This requires a larger bucket size to make more same-shaped tensor to appear in one bucket, however + # this may reduce the overlap between computation and communication, and increase the memory footprint + # due to stacking tensors. + # Turn on if compression / decompression computation is a bottleneck. + self.batch_tensors_with_same_shape = batch_tensors_with_same_shape + + def __getstate__(self): + r""" + Return a ``Dict[str, Any]`` which will be pickled and saved. + + ``process_group`` is not serializable and excluded from + a returned state. + """ + logger.warning( + "NOTE: Process group is not serializable and excluded from a saved state." + ) + return { + slot: getattr(self, slot) + for slot in self.__slots__ + if slot != "process_group" + } + + def __setstate__(self, state): + r""" + Take a provided ``state`` and set to this ``PowerSGDState`` instance. + + ``process_group`` is set to default. + """ + self.process_group = distributed_c10d._get_default_group() + logger.warning( + "NOTE: Process group will be set to a default group (i.e. the world size).\ + If a different group is desired, please set `self.process_group` after PowerSGD state is loaded." + ) + for slot, value in state.items(): + setattr(self, slot, value) + + def maybe_increase_iter(self, bucket): + """Track iterations and trigger log message at start of local SGD.""" + # Since bucket 0 is the last bucket to allreduce in an iteration. + # Only increase `iter` when bucket 0 is processed. + if bucket.is_last(): + self.iter += 1 + + if self.iter == self.start_powerSGD_iter: + logger.info("Start to apply PowerSGD after %s iterations.", self.iter) + + def compression_stats(self): + r""" + Return latest compression statistics as tuple. + + Returns tuple of form (compress_rate, numel_before_compression, numel_after_compression) where: + + compress_rate is the effective compression rate i.e. (number of elements before compression) / (number of elements after compression); + + numel_before_compression is the total number of elements before compression was applied; and, + + numel_after_compression is the total number of elements after compression was applied. + """ # noqa: B950 + compress_rate = ( + self.total_numel_before_compression / self.total_numel_after_compression + if self.total_numel_after_compression > 0 + else 0 + ) + return ( + compress_rate, + self.total_numel_before_compression, + self.total_numel_after_compression, + ) + + +def powerSGD_hook( + state: PowerSGDState, bucket: dist.GradBucket +) -> torch.futures.Future[torch.Tensor]: + r""" + Implement PowerSGD algorithm. + + This DDP communication hook implements PowerSGD gradient compression + algorithm described in the `paper `_. + Once gradient tensors are aggregated across all workers, this hook applies + compression as follows: + + 1. Views the input flattened 1D gradient tensor as a list of per-parameter tensors, and divides all the tensors into two groups: + + 1.1 The tensors that should be compressed before allreduce, because the compression can give enough saving in bandwidth. + + 1.2 Rest of the tensors will be directly allreduced without compression, including all the vector tensors (for biases). + + 2. Handles uncompressed tensors: + + 2.1. Allocate contiguous memory for those uncompressed tensors, and allreduces all the uncompressed tensors as a batch, without compression; + + 2.2. Copies the individual uncompressed tensors from the contiguous memory back to the input tensor. + + 3. Handles the tensors that should be compressed by PowerSGD compression: + + 3.1. For each tensor M, creates two low-rank tensors P and Q for decomposing M, + such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; + + 3.2. Computes each P in Ps, which is equal to MQ; + + 3.3. Allreduces Ps as a batch; + + 3.4. Orthogonalizes each P in Ps; + + 3.5. Computes each Q in Qs, which is approximately equal to M^TP; + + 3.6. Allreduces Qs as a batch; + + 3.7. Computes each M among all the compressed tensors, which is approximately equal to PQ^T. + + Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations. + This not only gives the user more control over the tradeoff between speedup and accuracy, + but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers. + + Args: + state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc. + To tune the compression configs, mainly need to tune ``matrix_approximation_rank``, ``start_powerSGD_iter`` + and ``min_compression_rate``. + bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. + Note that since DDP comm hook only supports single process single device mode, + only exactly one tensor is stored in this bucket. + + Returns: + Future handler of the communication, which updates the gradients in place. + + Example:: + >>> # xdoctest: +SKIP + >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, + start_powerSGD_iter=10, min_compression_rate=0.5) + >>> ddp_model.register_comm_hook(state, powerSGD_hook) + """ # noqa: B950 + process_group = state.process_group + group_to_use = ( + process_group if process_group is not None else not_none(dist.group.WORLD) + ) + world_size = group_to_use.size() + + # The input tensor is a flattened 1D tensor. + input_tensor = bucket.buffer() + + # Run vanilla allreduce in the first `start_powerSGD_iter` iterations. + if state.iter < state.start_powerSGD_iter: + state.maybe_increase_iter(bucket) + return default._allreduce_fut(group_to_use, input_tensor) + + # Apply PowerSGD after `start_powerSGD_iter` iterations. + device = input_tensor.device + dtype = input_tensor.dtype + + # Incorporate the error from the previous state into the gradients. + bucket_index = bucket.index() + input_tensor_cp = None + total_length = input_tensor.shape[0] + if state.use_error_feedback: + if bucket_index in state.error_dict: + input_tensor.add_(state.error_dict[bucket_index]) + else: + logger.info( + "A zero tensor of length %s that represents local error is created.", + total_length, + ) + state.error_dict[bucket_index] = torch.zeros( + total_length, device=device, dtype=dtype + ) + + # Keep a copy of the input tensor, + # so that we can compute the local error caused by compression later, + # by comparing this copy and the input tensor updated after decompression. + input_tensor_cp = input_tensor.detach().clone() + + # Unflatten the input tensor into per-parameter tensors, for layer-wise compression. + tensors = bucket.gradients() + + # Step I: Divide all the tensors into two groups, + # one will be compressed before allreduce and the other will be directly allreduced without compression. + tensors_to_compress, uncompressed_tensors = [], [] + total_Ps_size = 0 + total_Qs_size = 0 + for tensor in tensors: + matrix = tensor.view(tensor.shape[0], -1) + n, m = matrix.shape + matrix_approximation_rank = min(n, m, state.matrix_approximation_rank) + compress_test = _should_compress( + n, m, matrix_approximation_rank, state.min_compression_rate + ) + state.total_numel_before_compression += compress_test[1] + if compress_test[0]: + tensors_to_compress.append(matrix) + total_Ps_size += n * matrix_approximation_rank + total_Qs_size += m * matrix_approximation_rank + state.total_numel_after_compression += compress_test[2] + else: + uncompressed_tensors.append(tensor) + state.total_numel_after_compression += compress_test[1] + + _report_compression_stats(bucket, state) + + # Step II: Handle uncompressed tensors. + # Allocate contiguous memory for these tensors to allreduce efficiently. + uncompressed_tensors_memory = ( + torch.cat([tensor.view(-1) for tensor in uncompressed_tensors]) + if uncompressed_tensors + else torch.tensor([], device=device, dtype=dtype) + ) + + # Step III: Handle the tensors that should be compressed. + # Allocate contiguous memory for Ps and Qs to allreduce efficiently. + # If warm-start is enabled, reuse Ps and Qs from the previous iteration if possible. + # The memory spaces of Ps and Qs need to be allocated in the first iteration when PowerSGD is applied. + need_randomize_qs = False + if not state.warm_start or bucket_index not in state.p_memory_dict: + need_randomize_qs = True + # If warm-start is disabled, low-rank tensors will be initialized at every step. + # Only log this if warm-start to avoid spamming. + if state.warm_start: + logger.info( + "Allocating contiguous memory of length %s for Ps, and of length %s for Qs, respectively.", + total_Ps_size, + total_Qs_size, + ) + state.p_memory_dict[bucket_index] = torch.empty( + total_Ps_size, device=device, dtype=dtype + ) + state.q_memory_dict[bucket_index] = torch.empty( + total_Qs_size, device=device, dtype=dtype + ) + + # Batch tensors to compress by shape. + shape_to_tensors = defaultdict(list) + for tensor in tensors_to_compress: + shape_to_tensors[tensor.shape].append(tensor) + + # This function decides whether to batch tensors with same shape or not according to the argument, + # so the following process could share the same code. + def maybe_batched_tensors_to_compress(): + for tensors in shape_to_tensors.values(): + if state.batch_tensors_with_same_shape: + batch_size = len(tensors) + if batch_size == 1: + # Use the original tensor to avoid copy. + yield tensors[0].unsqueeze(0) + else: + yield torch.stack(tensors) + else: + for tensor in tensors: + yield tensor.unsqueeze(0) + + # Create Ps and Qs that point to the allocated memory. + tensors_to_compress = [] + ps = [] + qs = [] + p_idx = 0 + q_idx = 0 + for tensor in maybe_batched_tensors_to_compress(): + batch_size, n, m = tensor.shape + matrix_approximation_rank = min(n, m, state.matrix_approximation_rank) + tensors_to_compress.append(tensor) + ps.append( + state.p_memory_dict[bucket_index][ + p_idx : p_idx + batch_size * n * matrix_approximation_rank + ].view(batch_size, n, matrix_approximation_rank) + ) + qs.append( + state.q_memory_dict[bucket_index][ + q_idx : q_idx + batch_size * m * matrix_approximation_rank + ].view(batch_size, m, matrix_approximation_rank) + ) + p_idx += batch_size * n * matrix_approximation_rank + q_idx += batch_size * m * matrix_approximation_rank + + # If warm-start is enabled, reuse Qs from the previous iteration if possible and skip filling random values. + # The exception is the first iteration when PowerSGD is applied. + if not need_randomize_qs: + for q in qs: + _orthogonalize(q, state.orthogonalization_epsilon) + else: + with torch.random.fork_rng(devices=[]): + # Fork this RNG to avoid changing the seed globally and affecting the random sampling anywhere else in the training. + # The seed makes sure that the initial random values are the same across all the DDP replicas. + # This seed should differ at every step. + # Since it is very slow to fork RNG state across all the CUDA devices, + # only fork on CPU and then move the generated tensor to the CUDA device (by overwriting q). + torch.manual_seed(state.rng.randint(1_000_000_000)) + for q in qs: + q.copy_( + torch.randn( + *q.shape, + device="cpu", + dtype=dtype, + ) + ) + _orthogonalize(q, state.orthogonalization_epsilon) + + # Compute Ps. + for tensor, q, p in zip(tensors_to_compress, qs, ps): + torch.bmm(tensor, q, out=p) + + # This allreduce is only applied to uncompressed tensors, + # so it should have been kicked off before the above computation on the compressed tensors to hide more communication costs. + # However, this somehow requires a separate future chain at this time. + allreduce_contiguous_uncompressed_tensors_fut = dist.all_reduce( + uncompressed_tensors_memory, group=group_to_use, async_op=True + ).get_future() + + def unpack_uncompressed_tensors_and_allreduce_ps(fut): + uncompressed_tensors_memory = fut.value()[0].div_(world_size) + idx = 0 + for tensor in uncompressed_tensors: + tensor.copy_( + uncompressed_tensors_memory[idx : idx + tensor.numel()].view_as(tensor) + ) + idx += tensor.numel() + + # Since these Ps will be orthogonalized later, no need to divide them by world size. + return ( + dist.all_reduce( + state.p_memory_dict[bucket_index], group=group_to_use, async_op=True + ) + .get_future() + .wait()[0] + ) + + def compute_qs(fut): + state.p_memory_dict[bucket_index] = fut.value() + for p in ps: + _orthogonalize(p, state.orthogonalization_epsilon) + + # Compute Qs. + for tensor, p, q in zip(tensors_to_compress, ps, qs): + torch.bmm(tensor.transpose(1, 2), p, out=q) + + # TODO: The above procedure does two matmul+allreduce steps per iteration -- + # one left multiplication and one right multiplication. + # For warm-start, can take one such step at a time, and alternate between them. + + # Allreduce Qs. + return ( + dist.all_reduce( + state.q_memory_dict[bucket_index], group=group_to_use, async_op=True + ) + .get_future() + .wait()[0] + ) + + def decompress(fut): + state.q_memory_dict[bucket_index] = fut.value().div_(world_size) + + for p, q, tensor in zip(ps, qs, tensors_to_compress): + torch.bmm(p, q.transpose(1, 2), out=tensor) + + # Copy batched tensors back to original buffer. + if state.batch_tensors_with_same_shape: + for tensor in tensors_to_compress: + if tensor.shape[0] == 1: + # Skip tensor with batch_size == 1 since itself is the original tensor. + continue + original_tensors = shape_to_tensors[tensor.shape[1:]] + for i, original_tensor in enumerate(original_tensors): + original_tensor.copy_(tensor[i]) + + if torch.cuda.is_available(): + torch.cuda.synchronize(device) + + if state.use_error_feedback: + # Memorize the local errors. + assert input_tensor_cp is not None + state.error_dict[bucket_index] = input_tensor_cp - input_tensor + if not state.warm_start: + state.p_memory_dict.clear() + state.q_memory_dict.clear() + + state.maybe_increase_iter(bucket) + + return input_tensor + + return ( + allreduce_contiguous_uncompressed_tensors_fut.then( + unpack_uncompressed_tensors_and_allreduce_ps + ) + .then(compute_qs) + .then(decompress) + ) + + +def batched_powerSGD_hook( + state: PowerSGDState, bucket: dist.GradBucket +) -> torch.futures.Future[torch.Tensor]: + r""" + Implement simplified PowerSGD algorithm. + + This DDP communication hook implements a simplified PowerSGD gradient compression + algorithm described in the `paper `_. + This variant does not compress the gradients layer by layer, + but instead compresses the flattened input tensor that batches all the gradients. + Therefore, it is **faster** than :meth:`powerSGD_hook`, + but usually results in a **much lower accuracy**, unless ``matrix_approximation_rank`` is 1. + + .. warning :: + Increasing ``matrix_approximation_rank`` here may not necessarily increase the accuracy, + because batching per-parameter tensors without column/row alignment can destroy low-rank structure. + Therefore, the user should always consider :meth:`powerSGD_hook` first, + and only consider this variant when a satisfactory accuracy can be achieved when ``matrix_approximation_rank`` is 1. + + Once gradient tensors are aggregated across all workers, this hook applies + compression as follows: + + 1. Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings; + + 2. Creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; + + 3. Computes P, which is equal to MQ; + + 4. Allreduces P; + + 5. Orthogonalizes P; + + 6. Computes Q, which is approximately equal to M^TP; + + 7. Allreduces Q; + + 8. Computes M, which is approximately equal to PQ^T. + + 9. Truncates the input tensor to the original length. + + Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations. + This not only gives the user more control over the tradeoff between speedup and accuracy, + but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers. + + Args: + state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc. + To tune the compression configs, mainly need to tune ``matrix_approximation_rank`` and ``start_powerSGD_iter``. + bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. + Note that since DDP comm hook only supports single process single device mode, + only exactly one tensor is stored in this bucket. + + Returns: + Future handler of the communication, which updates the gradients in place. + + Example:: + >>> # xdoctest: +SKIP + >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1) + >>> ddp_model.register_comm_hook(state, batched_powerSGD_hook) + """ # noqa: B950 + process_group = state.process_group + group_to_use = ( + process_group if process_group is not None else not_none(dist.group.WORLD) + ) + world_size = group_to_use.size() + + # The input tensor is a flattened 1D tensor. + input_tensor = bucket.buffer() + + # Run vanilla allreduce in the first `start_powerSGD_iter` iterations. + if state.iter < state.start_powerSGD_iter: + state.maybe_increase_iter(bucket) + return default._allreduce_fut(group_to_use, input_tensor) + + # Apply PowerSGD after `start_powerSGD_iter` iterations. + device = input_tensor.device + total_length = input_tensor.shape[0] + state.total_numel_before_compression += total_length + + # View the input tensor as a 2D square-shape tensor, and pad 0s if necessary. + square_side_length = math.ceil(math.sqrt(total_length)) + state.total_numel_after_compression += ( + square_side_length * state.matrix_approximation_rank * 2 + ) + padded_total_length = square_side_length**2 + input_tensor.resize_(padded_total_length) + input_tensor[total_length:padded_total_length].fill_(0) + + _report_compression_stats(bucket, state) + + # Incorporate the error from the previous state into the gradients. + bucket_index = bucket.index() + input_tensor_cp = None + if state.use_error_feedback: + if bucket_index in state.error_dict: + input_tensor.add_(state.error_dict[bucket_index]) + else: + logger.info( + "A zero tensor of length %s that represents local error is created.", + padded_total_length, + ) + state.error_dict[bucket_index] = torch.zeros( + padded_total_length, device=device, dtype=input_tensor.dtype + ) + + # Keep a copy of the input tensor, + # so that we can compute the local error caused by compression later, + # by comparing this copy and the input tensor updated after decompression. + input_tensor_cp = input_tensor.detach().clone() + matrix = input_tensor.view(square_side_length, square_side_length) + + # Reuse P and Q from the previous iteration if possible. + # The memory spaces of P and Q need to be allocated in the first iteration when PowerSGD is applied. + if not state.warm_start or bucket_index not in state.p_memory_dict: + # If warm-start is disabled, low-rank tensors will be initialized at every step. + # Only log this if warm-start to avoid spamming. + if state.warm_start: + logger.info( + "Initializing low-rank tensors P and Q, each of which has a shape of %s x %s.", + square_side_length, + state.matrix_approximation_rank, + ) + + def create_low_rank_tensor(fill_random_values, rng): + """Return a low-rank 2D tensor of square_side_length * matrix_approximation_rank.""" + if fill_random_values: + with torch.random.fork_rng(devices=[]): + # Fork this RNG to avoid changing the seed globally and affecting the random sampling + # anywhere else in the training. + # The seed makes sure that the initial random values are the same across all the DDP replicas. + # This seed should differ at every step. + # Since it is very slow to fork RNG state across all the CUDA devices, + # only fork on CPU and then move the generated tensor to the CUDA device. + torch.manual_seed(rng.randint(1_000_000_000)) + return torch.randn( + square_side_length, + state.matrix_approximation_rank, + device="cpu", + dtype=input_tensor.dtype, + ).to(device) + else: + return torch.empty( + square_side_length, + state.matrix_approximation_rank, + device=device, + dtype=input_tensor.dtype, + ) + + state.p_memory_dict[bucket_index] = create_low_rank_tensor( + fill_random_values=False, rng=state.rng + ) + state.q_memory_dict[bucket_index] = create_low_rank_tensor( + fill_random_values=True, rng=state.rng + ) + _orthogonalize(state.q_memory_dict[bucket_index]) + + torch.matmul( + matrix, state.q_memory_dict[bucket_index], out=state.p_memory_dict[bucket_index] + ) + allreduce_p_fut = dist.all_reduce( + state.p_memory_dict[bucket_index], group=group_to_use, async_op=True + ).get_future() + + def compute_q(fut): + state.p_memory_dict[bucket_index] = fut.value()[0] + _orthogonalize(state.p_memory_dict[bucket_index]) + + torch.matmul( + matrix.t(), + state.p_memory_dict[bucket_index], + out=state.q_memory_dict[bucket_index], + ) + + # TODO: The above procedure does two matmul+allreduce steps per iteration -- + # one left multiplication and one right multiplication. + # For warm-start, can take one such step at a time, and alternate between them. + + return ( + dist.all_reduce( + state.q_memory_dict[bucket_index], group=group_to_use, async_op=True + ) + .get_future() + .wait()[0] + ) + + def decompress(fut): + state.q_memory_dict[bucket_index] = fut.value().div_(world_size) + torch.matmul( + state.p_memory_dict[bucket_index], + state.q_memory_dict[bucket_index].t(), + out=matrix, + ) + + if state.use_error_feedback: + # Memorize the local errors. + assert input_tensor_cp is not None + state.error_dict[bucket_index] = input_tensor_cp - input_tensor + # Removing this seemingly unnecessary sync somehow may cause failures. + # See: https://github.com/pytorch/pytorch/pull/54838 + if torch.cuda.is_available(): + torch.cuda.synchronize(device) + if not state.warm_start: + state.p_memory_dict.clear() + state.q_memory_dict.clear() + ret = input_tensor.resize_(total_length) + + state.maybe_increase_iter(bucket) + + return ret + + return allreduce_p_fut.then(compute_q).then(decompress) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..886155908e1a702972184a550082c33c677eacdc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py @@ -0,0 +1,220 @@ +# mypy: allow-untyped-defs +import torch +import torch.distributed as dist +from torch import nn + + +def _quantize_per_tensor_backend(x, scale, zero_point): + y = torch.round(x / scale) + zero_point + y = torch.clamp(y, 0, 255).to(torch.uint8) + return y + + +def _dequantize_per_tensor_backend(y, scale, zero_point): + x = scale * (y.to(torch.float32) - zero_point) + return x + + +def _quantize_per_channel_backend(x, scale, zero_point): + y = torch.zeros(x.size(), device=x.device) + for i in range(x.size()[0]): + y[i, :] = torch.round(x[i, :] / scale[i]) + zero_point[i] + y = torch.clamp(y, 0, 255).to(torch.uint8) + return y + + +def _dequantize_per_channel_backend(y, scale, zero_point): + y = y.to(torch.float32).to(y.device) + x = torch.zeros_like(y, device=y.device) + for i in range(x.size()[0]): + x[i, :] = scale[i] * (y[i, :] - zero_point[i]) + return x + + +def _get_allgather_out_list(all_gather_in_list, world_size): + out_list = [ + torch.zeros_like( + all_gather_in_list, + device=all_gather_in_list.device, + dtype=all_gather_in_list.dtype, + ) + for _ in range(world_size) + ] + return out_list + + +def quantization_pertensor_hook( + process_group: dist.ProcessGroup, bucket: dist.GradBucket +) -> torch.futures.Future[torch.Tensor]: + """ + Apply ``torch.quantize_per_tensor`` logic to DDP using ``allgather`` protocol. + + Workers first allgather the scale and zero point of their own + ``GradBucket`` prior to the quantization. After all workers have that information, + the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's + own gradient tensor, and uses ``allgather`` to communicate these across all workers. + The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes and + aggregates each quantized gradient tensor locally and returns the mean. + + .. warning :: + This is experimental, and uses ``allgather`` protocol which is considerably slower than + ``allreduce`` protocol. It works only with flattened grads. + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(process_group, quantization_pertensor_hook) + """ + group_to_use = process_group if process_group is not None else dist.group.WORLD + rank = process_group.rank() if process_group is not None else dist.get_rank() + # pyrefly: ignore [missing-attribute] + world_size = group_to_use.size() + + tensor = bucket.buffer() + + myObserver = torch.ao.quantization.MinMaxObserver().to(tensor.device) + myObserver(tensor) + + s, z = myObserver.calculate_qparams() + s_and_z = torch.FloatTensor([s, z]).to(tensor.device) + + all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size) + + # First, allgather scale and zeros. + fut = dist.all_gather( + all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True + ).get_future() + + def quantize_and_allgather(fut): + # Store scale and zeros across all workers. + all_ranks_s_and_z = fut.wait()[0] + # All workers quantize their own ``GradBucket`` tensors. + quantized_tensor = _quantize_per_tensor_backend( + tensor, all_ranks_s_and_z[rank][0], all_ranks_s_and_z[rank][1] + ) + # Allgather quantized tensors. + fut = dist.all_gather( + _get_allgather_out_list(quantized_tensor, world_size), + quantized_tensor, + group=group_to_use, + async_op=True, + ).get_future() + + return fut.wait() + + def dequantize_and_aggregate(fut): + all_ranks_quantized_tensor = fut.wait()[0] + + aggregated_dequantized_tensor = torch.zeros_like( + all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32 + ) + # Using previously allgathered scales and zeros, dequantize gradient tensors + # locally and then aggregate them. + for r, quantized_tensor in enumerate(all_ranks_quantized_tensor): + aggregated_dequantized_tensor += _dequantize_per_tensor_backend( + quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1] + ) + + return aggregated_dequantized_tensor / world_size + + return fut.then(quantize_and_allgather).then(dequantize_and_aggregate) + + +def quantization_perchannel_hook( + process_group: dist.ProcessGroup, bucket: dist.GradBucket, bucket_size=512 +) -> torch.futures.Future[torch.Tensor]: + """ + Apply``torch.quantize_per_channel`` logic to DDP using ``allgather`` protocol. + + Compared to per-tensor, the main motivation of per-channel is + for considerably large tensors such as a tensor that contains 6 million + elements quantizing per a bucket size of 512 (or 128) elements may significantly + increase the resolution. + + It first splits ``GradBucket`` tensor into multiple chunks (channels) of ``bucket_size`` + elements. Then, workers allgather the scales and zero points of their own + ``GradBucket`` prior to the quantization. After all workers have that information, + the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's + own gradient tensor, and uses ``allgather`` to communicate these across all workers. + The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes, flattens, and + aggregates each quantized gradient tensor locally and returns the mean. + + .. warning :: + This is experimental, and uses ``allgather`` protocol which is considerably slower than + ``allreduce`` protocol. It works only with flattened grads. + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(process_group, quantization_perchannel_hook) + """ + group_to_use = process_group if process_group is not None else dist.group.WORLD + rank = process_group.rank() if process_group is not None else dist.get_rank() + # pyrefly: ignore [missing-attribute] + world_size = group_to_use.size() + + tensor = bucket.buffer() + + tensor_in_channels = ( + nn.functional.pad( + input=tensor, + pad=(0, bucket_size - len(tensor) % bucket_size), + mode="constant", + value=0, + ) + .view(-1, bucket_size) + .to(tensor.device) + ) + + myPerChannelObserver = torch.ao.quantization.PerChannelMinMaxObserver().to( + tensor.device + ) + myPerChannelObserver(tensor_in_channels) + + s_ch, z_ch = myPerChannelObserver.calculate_qparams() + s_and_z = torch.stack((s_ch, z_ch)).to(tensor.device) + + all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size) + # First, allgather scale and zeros. + fut = dist.all_gather( + all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True + ).get_future() + + def quantize_and_allgather(fut): + # Store scale and zeros across all workers. + all_ranks_s_and_z = fut.wait()[0] + # All workers quantize their corresponding ``GradBucket`` tensors. + quantized_tensor = _quantize_per_channel_backend( + tensor_in_channels, + all_ranks_s_and_z[rank, 0, :], + all_ranks_s_and_z[rank, 1, :], + ) + # Allgather quantized tensors. + fut = dist.all_gather( + _get_allgather_out_list(quantized_tensor, world_size), + quantized_tensor, + group=group_to_use, + async_op=True, + ).get_future() + + return fut.wait() + + def dequantize_and_aggregate(fut): + all_ranks_quantized_tensor = fut.wait()[0] + + aggregated_dequantized_tensor = torch.zeros_like( + all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32 + ) + # Using previously allgathered scales and zeros, dequantize gradient tensors + # locally and then aggregate them. + for r, quantized_tensor in enumerate(all_ranks_quantized_tensor): + aggregated_dequantized_tensor += _dequantize_per_channel_backend( + quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1] + ) + + return ( + torch.flatten(aggregated_dequantized_tensor).to(tensor.device)[ + : tensor.size()[0] + ] + / world_size + ) + + return fut.then(quantize_and_allgather).then(dequantize_and_aggregate) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/join.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/join.py new file mode 100644 index 0000000000000000000000000000000000000000..52d0c52fbfb59d3c906bd282db51a76886206c96 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/join.py @@ -0,0 +1,350 @@ +# mypy: allow-untyped-defs +import warnings +from abc import ABC, abstractmethod +from types import TracebackType +from typing import Any, NamedTuple + +import torch +import torch.distributed as dist + + +__all__ = ["JoinHook", "Joinable", "Join"] + + +class JoinHook: + r""" + This defines a join hook, which provides two entry points in the join context manager. + + Entry points : a main hook, which is called repeatedly while there exists a non-joined + process, and a post-hook, which is called once all processes have joined. + + To implement a join hook for the generic join context manager, define a + class that inherits from :class:`JoinHook` and override ``main_hook()`` and + ``post_hook()`` as appropriate. + """ + + def main_hook(self) -> None: + r"""Call this hook while there exists a non-joined process to shadow collective communications in a training iteration. + + Training iteration i.e., in one forward pass, backward pass, and optimizer step. + """ + + def post_hook(self, is_last_joiner: bool) -> None: + r""" + Call hook after all processes have joined. + + It is passed an additional ``bool`` argument ``is_last_joiner``, which indicates if the rank is one of the last to join. + + Arguments: + is_last_joiner (bool): ``True`` if the rank is one of the last to + join; ``False`` otherwise. + """ + + +class Joinable(ABC): + r""" + This defines an abstract base class for joinable classes. + + A joinable class + (inheriting from :class:`Joinable`) should implement :meth:`join_hook`, + which returns a :class:`JoinHook` instance, in addition to + :meth:`join_device` and :meth:`join_process_group` that return device and + process group information, respectively. + """ + + @abstractmethod + def __init__(self) -> None: + super().__init__() + self._join_config = _JoinConfig.construct_disabled_join_config() + + @abstractmethod + def join_hook(self, **kwargs) -> JoinHook: + r""" + Return a :class:`JoinHook` instance for the given :class:`Joinable`. + + Arguments: + kwargs (dict): a :class:`dict` containing any keyword arguments + to modify the behavior of the join hook at run time; all + :class:`Joinable` instances sharing the same join context + manager are forwarded the same value for ``kwargs``. + """ + ... + + @property + @abstractmethod + def join_device(self) -> torch.device: + r"""Return the device from which to perform collective communications needed by the join context manager.""" + ... + + @property + @abstractmethod + def join_process_group(self) -> Any: + r"""Returns the process group for the collective communications needed by the join context manager itself.""" + ... + + +class _JoinConfig(NamedTuple): + r"""This includes all fields needed from a :class:`Joinable` instance for the join context manager side.""" + + enable: bool + throw_on_early_termination: bool + is_first_joinable: bool + + @staticmethod + def construct_disabled_join_config(): + r"""Return a :class:`_JoinConfig` instance indicating that join-related logic should be disabled. + + e.g. if the caller is not in a join context manager. + """ + return _JoinConfig( + enable=False, throw_on_early_termination=False, is_first_joinable=False + ) + + +class Join: + r""" + This class defines the generic join context manager, which allows custom hooks to be called after a process joins. + + These hooks should shadow the + collective communications of non-joined processes to prevent hanging and + erroring and to ensure algorithmic correctness. Refer to :class:`JoinHook` + for details about the hook definition. + + .. warning:: + The context manager requires each participating :class:`Joinable` to + call the method :meth:`notify_join_context()` before its own per- + iteration collective communications to ensure correctness. + + .. warning:: + The context manager requires that all ``process_group`` attributes in + the :class:`JoinHook` objects are the same. If there are multiple + :class:`JoinHook` objects, then the ``device`` of the first is used. + The process group and device information is used for checking for non- + joined processes and for notifying processes to throw an exception if + ``throw_on_early_termination`` is enabled, both of which using an all- + reduce. + + Arguments: + joinables (List[Joinable]): a list of the participating + :class:`Joinable` s; their hooks are iterated over in the given + order. + + enable (bool): a flag enabling uneven input detection; setting to + ``False`` disables the context manager's functionality and should + only be set when the user knows the inputs will not be uneven + (default: ``True``). + + throw_on_early_termination (bool): a flag controlling whether to throw an + exception upon detecting uneven inputs (default: ``False``). + + Example:: + + >>> import os + >>> import torch + >>> import torch.distributed as dist + >>> import torch.multiprocessing as mp + >>> # xdoctest: +SKIP + >>> import torch.nn.parallel.DistributedDataParallel as DDP + >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO + >>> from torch.distributed.algorithms.join import Join + >>> + >>> # On each spawned worker + >>> def worker(rank): + >>> dist.init_process_group("nccl", rank=rank, world_size=2) + >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) + >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01) + >>> # Rank 1 gets one more input than rank 0 + >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)] + >>> with Join([model, optim]): + >>> for input in inputs: + >>> loss = model(input).sum() + >>> loss.backward() + >>> optim.step() + >>> # All ranks reach here without hanging/erroring + """ + + def __init__( + self, + joinables: list[Joinable], + enable: bool = True, + throw_on_early_termination: bool = False, + **kwargs, + ): + if len(joinables) == 0: + raise ValueError("The join context manager requires at least one joinable") + self._joinables = joinables + self._join_hooks = [ + joinable.join_hook(**kwargs) for joinable in self._joinables + ] + self._enable = enable + self._throw_on_early_termination = throw_on_early_termination + self._set_joinable_configs() + self._extract_dist_info() + + def _set_joinable_configs(self) -> None: + r"""Set the :class:`_JoinConfig` of each participating :class:`Joinable`.""" + assert len(self._joinables) > 0 + is_first_joinable = True + for joinable in self._joinables: + joinable._join_config = _JoinConfig( + enable=self._enable, + throw_on_early_termination=self._throw_on_early_termination, + is_first_joinable=is_first_joinable, + ) + is_first_joinable = False + + def _extract_dist_info(self) -> None: + r""" + Extract the process group and device information from the joinables. + + If there are multiple joinables, then the context manager uses the + first specified device. + + Preconditions: + ``self._joinables`` is not ``None`` and is non-empty. + + Raises: + ValueError + If there are multiple conflicting ``process_group`` attributes + among the ``Joinable`` objects. + """ + process_group = None + device = None + # pyrefly: ignore [bad-assignment] + for joinable in self._joinables: + if process_group is None: + process_group = joinable.join_process_group + elif process_group != joinable.join_process_group: + raise ValueError( + "Using join context manager with multiple process groups" + ) + if device is None: + device = joinable.join_device + self._process_group = process_group + self._rank = dist.get_rank(self._process_group) + self._device = device + + def __enter__(self): ... + + def __exit__( + self, + type: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, + ): + r""" + Repeatedly runs the main hooks until all processes join; then, runs the post-hooks. + + Raises: + RuntimeError + If ``throw_on_early_termination=True``. + """ + if not self._enable or type: + return # propagate the exception directly if one was raised + + all_procs_joined = False + is_last_joiner = True + + i = 0 + WARN_THRESHOLD = 1000 + warnings.simplefilter("once") + + while not all_procs_joined: + if i > WARN_THRESHOLD: + warnings.warn( + "Detected uneven input skew of greater than " + f"{WARN_THRESHOLD}. This means that rank " + f"{self._rank} has at least {WARN_THRESHOLD} " + f"fewer inputs than other currently-active ranks. " + "This level of skew could lead to performance " + "degradation during training.", + stacklevel=2, + ) + # Shadow the all-reduce in non-joined processes + num_nonjoined_procs = self._get_num_nonjoined_procs() + if num_nonjoined_procs == 0: + all_procs_joined = True + else: + if self._throw_on_early_termination: + self._notify_procs_to_terminate() + + # Run main hooks + for join_hook in self._join_hooks: + join_hook.main_hook() + + is_last_joiner = False + i += 1 + + # Run post-hooks + for join_hook in self._join_hooks: + join_hook.post_hook(is_last_joiner) + + def _get_num_nonjoined_procs(self): + r"""Return the number of non-joined processes by shadowing an all-reduce in the non-joined processes.""" + num_nonjoined_procs = torch.zeros(1, device=self._device) + dist.all_reduce(num_nonjoined_procs, group=self._process_group) + return num_nonjoined_procs.item() + + def _notify_procs_to_terminate(self): + r"""Schedule an all-reduce to notify non-joined processes to terminate. + + Also raise a ``RuntimeError`` indicating that the current process has exhausted its inputs. + """ + ones = torch.ones(1, device=self._device) + dist.all_reduce(ones, group=self._process_group) + raise RuntimeError(f"Rank {self._rank} exhausted all inputs.") + + @staticmethod + def notify_join_context(joinable: Joinable): + r""" + Notifies the join context manager that the calling process has not yet joined. + + Then, if ``throw_on_early_termination=True``, checks if uneven inputs have been detected + (i.e. if one process has already joined) and throws an exception if so. + + This method should be called from a :class:`Joinable` object before + its per-iteration collective communications. For example, this should + be called at the beginning of the forward pass in + :class:`DistributedDataParallel`. + + Only the first :class:`Joinable` object passed into the context + manager performs the collective communications in this method, and + for the others, this method is vacuous. + + Arguments: + joinable (Joinable): the :class:`Joinable` object calling this + method. + + Returns: + An async work handle for the all-reduce meant to notify the context + manager that the process has not yet joined if ``joinable`` is the + first one passed into the context manager; ``None`` otherwise. + """ + assert hasattr(joinable, "_join_config"), ( + f"Check that the {type(joinable)} constructor calls the " + "``Joinable`` constructor" + ) + + join_config = joinable._join_config + # First joinable is responsible for the collective communications + if not join_config.is_first_joinable or not join_config.enable: + return None + + device = joinable.join_device + process_group = joinable.join_process_group + + # Schedule an all-reduce to indicate that the caller has not yet joined + ones = torch.ones(1, device=device) + work = dist.all_reduce(ones, group=process_group, async_op=True) + + if join_config.throw_on_early_termination: + # Check if uneven inputs have been detected + zeros = torch.zeros(1, device=device) + dist.all_reduce(zeros, group=process_group) + should_throw = zeros.item() + if should_throw: + raise RuntimeError( + "Detected at least one rank that exhausted inputs. " + "Throwing across all ranks." + ) + return work diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb76249df2dbea13b732ce48ec70be8db990b483 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/averagers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/averagers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51a4adda6e2b49261456eb51450c0207b2eacc61 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/averagers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/hierarchical_model_averager.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/hierarchical_model_averager.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4541f95a95f1b1859681a10a629190ef5b869b99 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/hierarchical_model_averager.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02294edf3e66d28941e4f3ad15555efea357c41d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/averagers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/averagers.py new file mode 100644 index 0000000000000000000000000000000000000000..5d669d4ea592250556ed5188b21ae265bb3b2c9c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/averagers.py @@ -0,0 +1,128 @@ +# mypy: allow-untyped-defs +import warnings +from abc import ABC, abstractmethod +from collections.abc import Iterable + +import torch +import torch.distributed as dist +import torch.distributed.algorithms.model_averaging.utils as utils +from torch.utils._typing_utils import not_none as _not_none + + +__all__ = ["ModelAverager", "PeriodicModelAverager"] + + +class ModelAverager(ABC): + r"""Base class for all model averagers. + + Args: + process_group: The process group to be used for all-reduce. + If ``None``, the default process group, which + is created by :func:`torch.distributed.init_process_group`, + will be used. (default: ``None``) + """ + + def __init__(self, process_group: dist.ProcessGroup | None = None): + self.process_group = ( + process_group if process_group is not None else _not_none(dist.group.WORLD) + ) + self.step = 0 + + @abstractmethod + def average_parameters(self, params): + raise NotImplementedError + + +class PeriodicModelAverager(ModelAverager): + r""" + Averages parameters periodically after the warm-up stage. + + This can be used for running `post-local SGD `_, + by running :class:`~torch.nn.DistributedDataParallel` (DDP) + using the subgroups created by :meth:`~torch.distributed.new_subgroups`. + + Args: + period (int): The number of steps per model averaging. + Usually the period should be greater than ``1`` to reduce the communication cost. + Otherwise, only DDP needs to be used. + warmup_steps (int): The number of warm-up steps. During this stage, + model averaging is skipped. + process_group: The process group to be used for all-reduce. + If ``None``, the default process group, which + is created by :func:`torch.distributed.init_process_group`, + will be used. (default: ``None``) + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> import torch + >>> import torch.distributed as dist + >>> import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD + >>> import torch.distributed.algorithms.model_averaging.averagers as averagers + >>> import torch.nn as nn + >>> + >>> dist.init_process_group("nccl", rank=rank, world_size=16) + >>> torch.cuda.set_device(rank) + >>> module = nn.Linear(1, 1, bias=False).cuda() + >>> model = nn.parallel.DistributedDataParallel( + >>> module, device_ids=[rank], output_device=rank + >>> ) + >>> # Register a post-localSGD communication hook. + >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100) + >>> model.register_comm_hook(state, post_localSGD_hook) + >>> + >>> # In the first 100 steps, run global gradient averaging like normal DDP at every step. + >>> # After 100 steps, run model averaging every 4 steps. + >>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``. + >>> averager = averagers.PeriodicModelAverager(period=4, warmup_steps=100) + >>> for step in range(0, 200): + >>> optimizer.zero_grad() + >>> loss = loss_fn(output, labels) + >>> loss.backward() + >>> optimizer.step() + >>> # Will average model parameters globally every 4 steps. Thus, + >>> # inter-node communication only occurs every 4 iterations after + >>> # the initial ``warmup_steps`` period. + >>> averager.average_parameters(model.parameters()) + """ + + def __init__( + self, period, warmup_steps=0, process_group: dist.ProcessGroup | None = None + ): + super().__init__(process_group) + if warmup_steps < 0: + raise ValueError("Arg ``warmup_steps`` must be a non-negative number.") + self.warmup_steps = warmup_steps + if period < 1: + raise ValueError("Arg ``period`` must be a positive value.") + elif period == 1: + warnings.warn( + "When period is 1, no need to use model averaging because the communication cost " + "of all-reducing parameters will be no less than the cost of all-reducing gradients " + "by DistributedDataParallel in the backward pass. Therefore, only " + "DistributedDataParallel should be used for this case.", + stacklevel=2, + ) + self.period = period + + def average_parameters( + self, + params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]], + ): + """ + Averages parameters or parameter groups of an optimizer if ``step`` is no less than ``warmup_steps``. + + Can be divided by ``period``, where ``step`` is increased by 1 + at each iteration in the training loop. + Args: + params: The parameters of a model or parameter groups of an optimizer. + + """ + if ( + self.step >= self.warmup_steps + and (self.step - self.warmup_steps) % self.period == 0 + ): + utils.average_parameters_or_parameter_groups( + params, _not_none(self.process_group) + ) + self.step += 1 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7edc447d1089e2c09ba10764bc0fbfce9a1770 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py @@ -0,0 +1,179 @@ +# mypy: allow-untyped-defs +# Copyright 2022 Cruise LLC +import logging +import warnings +from collections import OrderedDict +from collections.abc import Iterable + +import torch +import torch.distributed as dist +import torch.distributed.algorithms.model_averaging.averagers as averagers +import torch.distributed.algorithms.model_averaging.utils as utils + + +logger = logging.getLogger(__name__) + + +class HierarchicalModelAverager(averagers.ModelAverager): + r""" + Runs hierarchical model averaging (`hierarchical SGD `_). + + Process groups of different sizes are organized in a hierarchy, and they average parameters + by using different periods concurrently after the warm-up stage. + This is an extension of :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager` + that supports `post-local SGD `_, which essentially only supports + a two-level hierarchy: the intra-machine level and the global level, where the intra-machine + level is usually embedded in :meth:`~torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook`. + Similarly, the process groups within this class do not have such an intra-machine process + subgroup, which should be embedded by the post-local SGD communication hook instead. + + Args: + period_group_size_dict: An ordered dict mapping keys of model averaging period to + process group size, used for initializing process groups of + different sizes in a hierarchy to average parameters concurrently. + Particularly, at each iteration, there will be at most a single + process group that runs averaging -- the period of such group should + have the largest period which the current step can be divided by. + For example, if the dict has three keys: 2, 4, and 8, + then this means totally three process groups will be created to + average parameters every 2, 4, and 8 iterations, respectively. + At the 4th iteration, only the second process group will run + averaging, because the first process group should be a + subset of the second process group, and no need to execute the first + process group redundantly. + On the other hand, the third process group can only be triggered + every 8 iterations, so it will not be triggered at the 4th iteration. + warmup_steps (int): The number of warm-up steps. During this stage, model averaging is skipped. + process_group (ProcessGroup, optional): The overall process group containing all the processes that runs model averaging. + If ``None``, the default process group, which is created + by :func:`torch.distributed.init_process_group`, will be used. + (default: ``None``) + + Example:: + >>> # xdoctest: +SKIP('undefined rank') + >>> from collections import OrderedDict + >>> import torch + >>> import torch.distributed as dist + >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import ( + >>> PostLocalSGDState, + >>> post_localSGD_hook, + >>> ) + >>> import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD + >>> import torch.nn as nn + >>> + >>> dist.init_process_group("nccl", rank=rank, world_size=16) + >>> torch.cuda.set_device(rank) + >>> module = nn.Linear(1, 1, bias=False).to(rank) + >>> model = nn.parallel.DistributedDataParallel( + >>> module, device_ids=[rank], output_device=rank + >>> ) + >>> # Register a post-localSGD communication hook. + >>> # Assume that each machine has 4 GPUs, then each intra-machine subgroup has a size of 4. + >>> subgroup, _ = dist.new_subgroups() + >>> state = PostLocalSGDState(process_group=None, subgroup=subgroup, start_localSGD_iter=100) + >>> model.register_comm_hook(state, post_localSGD_hook) + >>> + >>> # Average parameters among each group of 8 processes every 4 iterations, and among all + >>> # the 16 processes every 16 iterations. + >>> averager = hierarchicalSGD.HierarchicalModelAverager( + >>> period_group_size_dict=OrderedDict([(4, 8), (16, 16)]), warmup_steps=100) + >>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``. + >>> # In the first 100 steps, run global gradient averaging like normal DDP at every step. + >>> # After 100 steps, run model averaging at two levels. + >>> for step in range(0, 200): + >>> optimizer.zero_grad() + >>> loss = loss_fn(output, labels) + >>> loss.backward() + >>> optimizer.step() + >>> # Average parameters after ``optimizer.step()``. + >>> # Thus, the inter-node communication only occurs periodically after ``warmup_steps``. + >>> averager.average_parameters(model.parameters()) + + .. warning :: + The last group size in the dict must be the size of the provided ``process_group``, + which indicates model averaging at the highest level of the hierarchy. + If ``process_group`` is not provided, then the last group size should be equal to the world size. + + .. warning :: + `HierarchicalModelAverager` is experimental and subject to change. + """ + + def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=None): + super().__init__(process_group) + if not period_group_size_dict: + raise ValueError("Arg ``period_group_size_dict`` must not be empty.") + self._periods = list(period_group_size_dict.keys()) + if self._periods[0] <= 0: + raise ValueError( + "The minimum period in arg ``period_group_size_dict`` must be a positive value." + ) + elif self._periods[-1] == 1: + warnings.warn( + "When the maximum period in arg ``period_group_size_dict`` is 1, " + "no need to use model averaging because the communication cost " + "of all-reducing parameters will be no less than the cost of all-reducing gradients " + "by DistributedDataParallel in the backward pass. Therefore, only " + "DistributedDataParallel should be used for this case.", + stacklevel=2, + ) + overall_group_size = dist.get_world_size(group=self.process_group) + if list(period_group_size_dict.values())[-1] != overall_group_size: + raise ValueError( + f"The last value in arg ``period_process_group_dict`` {list(period_group_size_dict.values())[-1]} " + f"must be equal to the size of arg ``process_group`` {overall_group_size}." + ) + + self.period_process_group_dict = OrderedDict() + logger.info("Model averaging hierarchy:") + for period, group_size in period_group_size_dict.items(): + logger.info( + "\tEach group that has %s processes average parameters every %s iterations, " + "if no higher-level averaging.", + group_size, + period, + ) + if group_size != overall_group_size: + self.period_process_group_dict[period], _ = dist.new_subgroups( + group_size=group_size, group=self.process_group + ) + else: + self.period_process_group_dict[period] = self.process_group + + if warmup_steps < 0: + raise ValueError("Arg ``warmup_steps`` must be a non-negative number.") + self.warmup_steps = warmup_steps + + def _find_process_group(self): + """ + Return a process group as the value of an ``period_process_group_dict`` entry. + + If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``, + then the returned process group is the one corresponding to the largest period, + since this process group will be used for averaging parameters at this ``step``. + Returns ``None`` if not found. + """ + for period in reversed(self._periods): + if self.step % period == 0: + return self.period_process_group_dict[period] + return None + + def average_parameters( + self, + params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]], + ): + """ + Averages parameters or parameter groups of an optimizer. + + Averaging only occurs if ``step`` is no less than ``warmup_steps`` + and it can be divided by a period in the keys of ``period_process_group_dict``, + where ``step`` is increased by 1 at each iteration in the training loop. + If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``, + only the largest period is used, and the corresponding process group is used for averaging parameters. + Args: + params: The parameters of a model or parameter groups of an optimizer. + """ + if self.step >= self.warmup_steps: + group = self._find_process_group() + if group is not None: + utils.average_parameters_or_parameter_groups(params, group) + self.step += 1 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6a61c036913edd6cf7fbcde6b77bc6ee5970065e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/algorithms/model_averaging/utils.py @@ -0,0 +1,86 @@ +# mypy: allow-untyped-defs +import itertools +from collections.abc import Iterable, Iterator + +import torch +import torch.distributed as dist + +# The two imports below are not always available depending on the +# USE_DISTRIBUTED compile flag. Make sure they raise import error +# if we're trying to use them. +from torch.distributed import group, ProcessGroup + + +__all__ = [ + "average_parameters", + "get_params_to_average", + "average_parameters_or_parameter_groups", +] + + +def average_parameters( + params: Iterator[torch.nn.Parameter], process_group: ProcessGroup +): + """ + Averages all the given parameters. + + For allreduce efficiency, all the parameters are flattened into a contiguous buffer. + Thus, it requires extra memory of the same size as the given parameters. + """ + group_to_use = process_group if process_group is not None else group.WORLD + # Do not update any parameter if not in the process group. + if dist._rank_not_in_group(group_to_use): + return + + params_it1, params_it2 = itertools.tee(params) + # If the input parameters have different data types, + # packing these parameters will trigger an implicit type up-casting. + # The original parameter data types will be restored during the subsequent unpacking. + flat_params = torch.cat([p.data.reshape(-1) for p in params_it1]) + flat_params /= dist.get_world_size(group_to_use) + # Make sure the allreduce will not conflict with any other ongoing process group. + if torch.accelerator.is_available(): + torch.accelerator.synchronize() + dist.all_reduce(flat_params, group=group_to_use) + + offset = 0 + for p in params_it2: + p.data = flat_params[offset : offset + p.numel()].view_as(p).type_as(p) + offset += p.numel() + + +def get_params_to_average( + params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]], +): + """ + Return a list of parameters that need to average. + + This filters out the parameters that do not contain any gradients. + Args: + params: The parameters of a model or parameter groups of an optimizer. + """ + filtered_params = [] + for param in params: + if isinstance(param, torch.nn.Parameter): + # model.parameters() input + param_data = param + if param_data.grad is not None: + filtered_params.append(param_data) + elif isinstance(param, dict): + # optimizer.param_groups input + for param_data in param["params"]: + if param_data.grad is not None: + filtered_params.append(param_data) + else: + raise NotImplementedError( + f"Parameter input of type {type(param)} is not supported" + ) + return filtered_params + + +def average_parameters_or_parameter_groups( + params: Iterable[torch.nn.Parameter] | Iterable[dict[str, torch.nn.Parameter]], + process_group: ProcessGroup, +): + """Averages parameters of a model or parameter groups of an optimizer.""" + average_parameters(iter(get_params_to_average(params)), process_group) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/autograd/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/autograd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6a52c36942e48e389a7e344abeb929febdb62c6c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/autograd/__init__.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +import torch + + +if TYPE_CHECKING: + from types import TracebackType + + +def is_available() -> bool: + return hasattr(torch._C, "_dist_autograd_init") + + +if is_available() and not torch._C._dist_autograd_init(): + raise RuntimeError("Failed to initialize torch.distributed.autograd") + +if is_available(): + from torch._C._distributed_autograd import ( + _current_context, + _get_debug_info, + _get_max_id, + _init, + _is_valid_context, + _new_context, + _release_context, + _retrieve_context, + backward, + DistAutogradContext, + get_gradients, + ) + +__all__ = ["context", "is_available"] + + +class context: + """ + Context object to wrap forward and backward passes when using + distributed autograd. The ``context_id`` generated in the ``with`` + statement is required to uniquely identify a distributed backward pass + on all workers. Each worker stores metadata associated with this + ``context_id``, which is required to correctly execute a distributed + autograd pass. + + Example:: + >>> # xdoctest: +SKIP + >>> import torch.distributed.autograd as dist_autograd + >>> with dist_autograd.context() as context_id: + >>> t1 = torch.rand((3, 3), requires_grad=True) + >>> t2 = torch.rand((3, 3), requires_grad=True) + >>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum() + >>> dist_autograd.backward(context_id, [loss]) + """ + + def __enter__(self) -> int: + self.autograd_context = _new_context() + return self.autograd_context._context_id() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + _release_context(self.autograd_context._context_id()) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/autograd/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/autograd/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7953fea62805c3f1756aec6106eb47c0751f0a9b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/autograd/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8104a8df99f0b5c4a4f1db57ac98602a61666626 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__init__.py @@ -0,0 +1,21 @@ +from . import _extension +from .api import CheckpointException +from .default_planner import DefaultLoadPlanner, DefaultSavePlanner +from .filesystem import FileSystemReader, FileSystemWriter +from .hf_storage import HuggingFaceStorageReader, HuggingFaceStorageWriter +from .metadata import ( + BytesStorageMetadata, + ChunkStorageMetadata, + Metadata, + TensorStorageMetadata, +) +from .optimizer import load_sharded_optimizer_state_dict +from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem +from .quantized_hf_storage import QuantizedHuggingFaceStorageReader + +# pyrefly: ignore [deprecated] +from .state_dict_loader import load, load_state_dict + +# pyrefly: ignore [deprecated] +from .state_dict_saver import async_save, save, save_state_dict +from .storage import StorageReader, StorageWriter diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d8d70f86c382c2e773becb898ad4cfdf019d8b0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_async_executor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_async_executor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72619e612903324bb0235d6ca51d7779dec5d014 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_async_executor.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_async_process_executor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_async_process_executor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..793f61e562223f160ae1ede694ba76d03b49f12a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_async_process_executor.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_async_thread_executor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_async_thread_executor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e44683387504f6bbe6532888b03a4e8073cfbe08 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_async_thread_executor.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_checkpointer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_checkpointer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a82465738bb0871ec44210768ca31fb23c766a7f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_checkpointer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_consolidate_hf_safetensors.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_consolidate_hf_safetensors.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2eb0f573991e93dd93ac9cd852e36bc8b263173d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_consolidate_hf_safetensors.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_save_plans.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_save_plans.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cac23cac055b2241bf0df76bc07489407a51195 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_save_plans.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a3770ac3f33d08a029e812e1d196506b0723f86 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_extension.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_extension.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0415f9db2bb317c67031e693dfbc39470567f83a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_extension.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_fsspec_filesystem.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_fsspec_filesystem.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82fd15d2777f9da73fab3a9538271e492316a2e2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_fsspec_filesystem.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_hf_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_hf_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41b7043210cedecf4f73deaaf79cb9431760a595 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_hf_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_nested_dict.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_nested_dict.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79f549484a7bc60a56591dc64da5dfc8aa452e66 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_nested_dict.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_pg_transport.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_pg_transport.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e30a578b82d2c534fc1775ab3a20e600955eb38f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_pg_transport.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_sharded_tensor_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_sharded_tensor_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7aa102c8e6a4ee5f7c46e59f4e366c5f9b39b29e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_sharded_tensor_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_state_dict_stager.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_state_dict_stager.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22687f4a7fe9c487f016b1745d4e080276c19c8e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_state_dict_stager.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_storage_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_storage_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d62934bb615fa953e4f69eea851f7afbd42f8f1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_storage_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_traverse.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_traverse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ea2ffddb8f1b91932fc356a336e37ecec715863 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_traverse.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_version.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_version.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df1b5453717e810fde41f2c9fe4bab4bc72bd1d1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_version.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/api.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e52ed2af8d009637ada67840139503d9c0a6a05d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/api.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/default_planner.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/default_planner.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3e89e542244e8d79f0352895a712b3d0486db5d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/default_planner.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/filesystem.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/filesystem.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e87763cc9fc7c0e819983a54fa7a96d486ee900 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/filesystem.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/format_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/format_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..247b23e9787ddc4d471de29ce39dc8731dcc57b6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/format_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/hf_storage.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/hf_storage.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bf3127797c770a50a03b7870b19f7c72c78c509 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/hf_storage.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/logger.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/logger.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8997028cebe4c2c2006d4fe882fa68f98e60dcd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/logger.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/logging_handlers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/logging_handlers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e9ef11dc01cb3927a3b33d3b799617a54cd9ca0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/logging_handlers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/metadata.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/metadata.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b14208b939c8b2a2bfa93233d33ece90925bbab Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/metadata.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/optimizer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/optimizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..382d771857bdda6dda651c675c6e7f6f81c91d6a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/optimizer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/planner.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/planner.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac92cc30722fa95412af768e971aa81ab74025b0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/planner.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/planner_helpers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/planner_helpers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11b0a0c262edae563923d07c719e1943dfc95ec8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/planner_helpers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/quantized_hf_storage.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/quantized_hf_storage.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a639b1f6f58d39a24615ca19c906f3600e4f45a8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/quantized_hf_storage.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/resharding.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/resharding.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8781f897d1bd2f754d5c5076b6589aaef75eaa30 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/resharding.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/staging.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/staging.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aada4324050c40b84f747b4aa45bcebc53a7523f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/staging.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/state_dict.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/state_dict.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a25ab2c00fcfbe28dd38ab0e313124741eec50d5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/state_dict.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_loader.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_loader.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7c32c567d5759d086b9cf239ee9fca4a56fd1f9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_loader.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_saver.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_saver.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a3b7c8d6a494202cdc65c8570e84e1e3cc9cb29 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_saver.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/stateful.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/stateful.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9358fde8c76584f9dd6a1acb884991d46063f45b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/stateful.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/storage.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/storage.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d4d50555d936f2ed018df63633debc45e64c2d0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/storage.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7a326bf1f68ddaa1773d98dd4f0471c05fac6f3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_async_executor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_async_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..428c697b91e9b567e99d52714a8248d322798073 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_async_executor.py @@ -0,0 +1,34 @@ +# pyre-strict +# mypy: allow-untyped-defs +import abc +import os +from concurrent.futures import Future +from typing import Optional, Union + +import torch.distributed as dist +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from torch.distributed.checkpoint.planner import SavePlanner +from torch.distributed.checkpoint.storage import StorageWriter + + +class _AsyncCheckpointExecutor(abc.ABC): + @abc.abstractmethod + def execute_save( + self, + staging_future_or_state_dict: Union[STATE_DICT_TYPE, Future[STATE_DICT_TYPE]], + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + no_dist: bool = False, + use_collectives: bool = True, + ) -> Future: + """ + Execute the checkpoint save request asynchronously. + + This method is intended to be used as an abstraction for + implementing async checkpointing. The actual checkpoint save + operation is executed in a separate thread or process depending + on the implementation of this interface. + """ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_async_process_executor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_async_process_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..48390253c302a5acc9806ecac587a24022262565 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_async_process_executor.py @@ -0,0 +1,455 @@ +# pyre-strict +# mypy: allow-untyped-defs +import gc +import logging +import os +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from enum import Enum +from typing import Any, Optional, Union +from uuid import uuid4 + +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.distributed import PrefixStore, TCPStore +from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor +from torch.distributed.checkpoint.logger import _dcp_method_logger, _init_logger +from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE +from torch.distributed.checkpoint.planner import SavePlanner +from torch.distributed.checkpoint.storage import StorageWriter +from torch.distributed.checkpoint.utils import _DistWrapper +from torch.distributed.elastic.agent.server.api import _get_fq_hostname +from torch.distributed.elastic.utils.distributed import get_free_port + + +logger = logging.getLogger() + + +class _CheckpointSaveProcessControlOpts(Enum): + INIT_COMPLETE = "init_complete" + TERMINATE = "terminate" + + +@dataclass(init=False, unsafe_hash=True) +class _CheckpointRequestIdentifier: + checkpoint_id: Union[str, os.PathLike, None] + uuid: str + + def __init__(self, checkpoint_id: Union[str, os.PathLike, None]): + self.checkpoint_id = checkpoint_id + self.uuid = str(uuid4()) + + +@dataclass +class _AsyncCheckpointRequest: + staged_state_dict: STATE_DICT_TYPE + checkpoint_request_id: _CheckpointRequestIdentifier + storage_writer: Optional[StorageWriter] = None + planner: Optional[SavePlanner] = None + no_dist: bool = False + use_collectives: bool = True + + +@dataclass(init=False) +class _ProcessGroupInitInfo: + local_rank: int + global_rank: int + world_size: int + tcp_store_master_addr: str + tcp_store_master_port: int + use_prefix_store: bool + disable_automatic_gc: bool + disable_manual_gc: bool + + def __init__(self, process_group: Optional[dist.ProcessGroup] = None): + self.local_rank = dist.get_node_local_rank(fallback_rank=0) + self.global_rank = dist.get_rank(process_group) + self.world_size = dist.get_world_size(process_group) + self.use_prefix_store = os.environ.get("DCP_USE_PREFIX_STORE", "0") == "1" + self.disable_automatic_gc = ( + os.environ.get("DCP_DISABLE_AUTOMATIC_GC", "0") == "1" + ) + self.disable_manual_gc = os.environ.get("DCP_DISABLE_MANUAL_GC", "0") == "1" + + # Let coordinator rank find a port on the localhost. + # Broadcast the (master_addr, port) to all ranks; each rank in the + # checkpoint daemon process will use TCPStore (master_addr, port) + # for collective communication. + dist_wrapper: _DistWrapper = _DistWrapper( + group=process_group, + use_dist=True, + coordinator_rank=0, + ) + + def get_master_addr_and_port() -> tuple[str, int]: + if self.use_prefix_store: + master_addr = os.environ.get("MASTER_ADDR") + master_port = os.environ.get("MASTER_PORT") + assert master_addr is not None, ( + "DCP needs MASTER_ADDR to use prefix store" + ) + assert master_port is not None, ( + "DCP needs MASTER_PORT to use prefix store" + ) + master_port = int(master_port) + else: + master_addr = os.environ.get("MASTER_ADDR") + if master_addr is None: + master_addr = _get_fq_hostname() + master_port = get_free_port() + + return master_addr, master_port + + self.tcp_store_master_addr, self.tcp_store_master_port = dist_wrapper.broadcast( + step="get_master_addr_and_port", + map_fun=get_master_addr_and_port, + ) + + +class _AsyncCheckpointProcess: + def __init__( + self, + pg_init_info: _ProcessGroupInitInfo, + ): + self.ctx = mp.get_context("spawn") + self._process_pipe, child_end = self.ctx.Pipe() + + self._save_process = self.ctx.Process( + target=self._checkpointing_subprocess, + args=( + pg_init_info, + child_end, + ), + daemon=True, + ) + + self._save_process.start() + + # Close the parent's copy of child end after we pass it into the child, + # so the recv()s on it will fail-fast if the child process dies. + child_end.close() + + # Wait for the checkpoint background process to initialize. + # Using default GLOO init timeout. + response = self._wait_for_response(timeout=1800) + if not response == _CheckpointSaveProcessControlOpts.INIT_COMPLETE: + raise AssertionError(f"Expected INIT_COMPLETE response, got {response}") + + def __del__(self) -> None: + if self._save_process.is_alive(): + try: + logger.info("Terminating the checkpoint background process.") + self._send(_CheckpointSaveProcessControlOpts.TERMINATE) + self._save_process.join(timeout=5) + finally: + if self._save_process.is_alive(): + logger.warning( + "Checkpoint background process is still alive after termination request. Sending SIGTERM." + ) + self._save_process.terminate() + + def _send(self, data: Any) -> None: + self._process_pipe.send(data) + + def _wait_for_response(self, timeout: Optional[float] = None) -> Any: + if not self._save_process.is_alive(): + logger.info("Checkpoint background process is dead calling join()...") + self._save_process.join() + raise RuntimeError( + f"Checkpoint background process is dead. Exit code: {self._save_process.exitcode}" + ) + + if timeout is not None and not self._process_pipe.poll(timeout=timeout): + raise RuntimeError( + f"Timed out after {timeout}s while waiting for response from checkpointer process pid: {self._save_process.pid}" + ) + + try: + response = self._process_pipe.recv() + except EOFError: + raise RuntimeError( # noqa: B904 + f"Checkpoint background process is dead. Exit code: {self._save_process.exitcode}" + ) + + if isinstance(response, BaseException): + raise response + + return response + + def save( + self, + staged_state_dict: STATE_DICT_TYPE, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + no_dist: bool = False, + use_collectives: bool = True, + ) -> Metadata: + # Create a unique identifier to locate requests/responses + # from the checkpoint daemon process. + checkpoint_request_id = _CheckpointRequestIdentifier(checkpoint_id) + async_cp_request = _AsyncCheckpointRequest( + staged_state_dict=staged_state_dict, + checkpoint_request_id=checkpoint_request_id, + storage_writer=storage_writer, + planner=planner, + no_dist=no_dist, + use_collectives=use_collectives, + ) + self._send(async_cp_request) + result = self._wait_for_response() + if not isinstance(result, Metadata): + raise AssertionError(f"Expected Metadata response, got {type(result)}") + return result + + @staticmethod + def _execute_save( + state_dict: STATE_DICT_TYPE, + *, + checkpoint_request_id: _CheckpointRequestIdentifier, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + no_dist: bool = False, + use_collectives: bool = True, + ) -> Metadata: + from torch.distributed.checkpoint.state_dict_saver import save + + metadata = save( + state_dict, + checkpoint_id=checkpoint_request_id.checkpoint_id, + storage_writer=storage_writer, + planner=planner, + no_dist=no_dist, + use_collectives=use_collectives, + ) + return metadata + + @staticmethod + def _checkpointing_subprocess( + pg_init_info: _ProcessGroupInitInfo, + parent_conn, + ) -> None: + # Phase 1: Process Group Initialization + # Only needs to execute once during the lifetime of the checkpoint background process. + try: + _init_logger(pg_init_info.global_rank) + + # Setup environment variables for process group initialization. + os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False" + os.environ["MASTER_ADDR"] = pg_init_info.tcp_store_master_addr + os.environ["MASTER_PORT"] = str(pg_init_info.tcp_store_master_port) + os.environ["LOCAL_RANK"] = str(pg_init_info.local_rank) + os.environ["RANK"] = str(pg_init_info.global_rank) + os.environ["WORLD_SIZE"] = str(pg_init_info.world_size) + + logger.info( + "Initializing dist.ProcessGroup in checkpoint background process on port %s", + pg_init_info.tcp_store_master_port, + ) + # NOTE: GLOO backend is enforced here. + if pg_init_info.use_prefix_store: + logger.info( + "Initializing dist.ProcessGroup in checkpoint background process with prefix store" + ) + store = PrefixStore( + "AsyncCheckpointProcess/", + TCPStore( + pg_init_info.tcp_store_master_addr, + pg_init_info.tcp_store_master_port, + ), + ) + dist.init_process_group( + backend=dist.Backend.GLOO, + store=store, + world_size=pg_init_info.world_size, + rank=pg_init_info.global_rank, + ) + else: + dist.init_process_group(backend=dist.Backend.GLOO) + dist.barrier() + + logger.info("Checkpoint background process is running...") + parent_conn.send(_CheckpointSaveProcessControlOpts.INIT_COMPLETE) + + if pg_init_info.disable_automatic_gc: + # Disable automatic garbage collection + # GC can optionally be called manually after each checkpoint + gc.disable() + logger.info("Disabled automatic garbage collection") + except BaseException as e: # noqa: B036 + logger.error( + f"Checkpoint background process failed during initialization: {e}" # noqa: G004 + ) + parent_conn.send(e) + return + + # Phase 2: Serving Loop + try: + first_request = True + while True: + logger.info("Waiting for checkpoint save request...") + obj = parent_conn.recv() + if ( + isinstance(obj, _CheckpointSaveProcessControlOpts) + and obj == _CheckpointSaveProcessControlOpts.TERMINATE + ): + logger.info("Terminating the checkpoint background process.") + return + if not isinstance(obj, _AsyncCheckpointRequest): + raise AssertionError( + f"Expected _AsyncCheckpointRequest, got {type(obj)}" + ) + logger.info( + f"Received async checkpoint request with id={obj.checkpoint_request_id.checkpoint_id}" # noqa: G004 + ) + + try: + response = _AsyncCheckpointProcess._execute_save( + obj.staged_state_dict, + checkpoint_request_id=obj.checkpoint_request_id, + storage_writer=obj.storage_writer, + planner=obj.planner, + no_dist=obj.no_dist, + use_collectives=obj.use_collectives, + ) + parent_conn.send(response) + logger.info( + f"Completed checkpoint save request for checkpoint_id={obj.checkpoint_request_id}" # noqa: G004 + ) + + # in theory this manual gc should not be needed as we shouldn't be leaking anything from checkpointing process + if ( + pg_init_info.disable_automatic_gc + and not pg_init_info.disable_manual_gc + ): + del obj + + collected_objects = gc.collect() + + logger.info( + f"Manual garbage collection completed - collected {collected_objects} objects." # noqa: G004 + ) + if first_request: + # Freeze GC to not check GC for large checkpoint save plans + # After freezing, subsequent gc.collect() calls will only scan + # NEW objects created after this point, not the frozen save plan + logger.info( + "First checkpoint request completed - freezing gc" + ) + gc.freeze() + first_request = False + except BaseException as e: # noqa: B036 + logger.error( + f"Checkpoint save failed for checkpoint_id={obj.checkpoint_request_id.checkpoint_id}: {e}" # noqa: G004 + ) + parent_conn.send(e) + # Continue serving loop - don't exit process + finally: + logger.info("Checkpoint background process is shutting down...") + dist.destroy_process_group() + parent_conn.close() + + +_CHECKPOINT_PROCESS: Optional[_AsyncCheckpointProcess] = None + + +class _ProcessBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor): + def __init__(self) -> None: + self._executor = ThreadPoolExecutor(max_workers=1) + + @staticmethod + def _execute_save_impl( + *, + pg_init_info: Optional[_ProcessGroupInitInfo], + staging_future_or_state_dict: Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE], + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + no_dist: bool = False, + use_collectives: bool = True, + ) -> Metadata: + global _CHECKPOINT_PROCESS + if _CHECKPOINT_PROCESS is None: + if pg_init_info is None: + raise AssertionError( + "pg_init_info must not be None when _CHECKPOINT_PROCESS is None" + ) + ckpt_kwargs = {} + if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None: + ckpt_kwargs["checkpoint_id"] = ckpt_id + ckpt_kwargs["process_group"] = process_group + + @_dcp_method_logger(**ckpt_kwargs) + def create_checkpoint_daemon_process() -> None: + global _CHECKPOINT_PROCESS + # pyrefly: ignore [bad-argument-type] + _CHECKPOINT_PROCESS = _AsyncCheckpointProcess(pg_init_info=pg_init_info) + + create_checkpoint_daemon_process() + + if _CHECKPOINT_PROCESS is None: + raise AssertionError( + "_CHECKPOINT_PROCESS must not be None after initialization" + ) + staged_state_dict = ( + staging_future_or_state_dict.result() + if isinstance(staging_future_or_state_dict, Future) + else staging_future_or_state_dict + ) + return _CHECKPOINT_PROCESS.save( + staged_state_dict=staged_state_dict, + checkpoint_id=checkpoint_id, + storage_writer=storage_writer, + planner=planner, + no_dist=no_dist, + use_collectives=use_collectives, + ) + + def execute_save( + self, + staging_future_or_state_dict: Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE], + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + no_dist: bool = False, + use_collectives: bool = True, + ) -> Future: + """ + NOTE: + + - Checkpoint process is implemented as a daemon process. + The AsyncCheckpointProcess' lifetime is tied to the lifetime of the + main process (e.g. trainer process). + + - The first call to execute_save_in_process() will initialize the checkpoint + daemon process. Subsequent async checkpoint requests will not need process + initialization. Therefore, the first async checkpoint request will take longer to complete. + + - Process initialization can have significant overhead, dominated by latency for all ranks to spawn + a background process + process group initialization in the background process. + """ + + global _CHECKPOINT_PROCESS + pg_init_info: Optional[_ProcessGroupInitInfo] = None + if _CHECKPOINT_PROCESS is None: + # Find a port on coordinator rank and broadcast + # to all ranks. + pg_init_info = _ProcessGroupInitInfo(process_group) + + f: Future = self._executor.submit( + self._execute_save_impl, + pg_init_info=pg_init_info, + staging_future_or_state_dict=staging_future_or_state_dict, + checkpoint_id=checkpoint_id, + storage_writer=storage_writer, + planner=planner, + no_dist=no_dist, + use_collectives=use_collectives, + ) + f.add_done_callback(lambda f: self._executor.shutdown(wait=False)) + + return f diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_async_thread_executor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_async_thread_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..8dfe63413d433c75a012916f65628f2bd4e57f20 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_async_thread_executor.py @@ -0,0 +1,71 @@ +# pyre-strict +# mypy: allow-untyped-defs +import os +from concurrent.futures import Future, ThreadPoolExecutor +from typing import Optional, Union + +import torch.distributed as dist +from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from torch.distributed.checkpoint.planner import SavePlanner +from torch.distributed.checkpoint.storage import StorageWriter + + +def save_wrapper( + staging_future_or_state_dict: Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE], + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + no_dist: bool = False, + use_collectives: bool = True, +) -> Future: + from torch.distributed.checkpoint.state_dict_saver import save + + staged_dict = ( + staging_future_or_state_dict.result() + if isinstance(staging_future_or_state_dict, Future) + else staging_future_or_state_dict + ) + return save( + staged_dict, + checkpoint_id=checkpoint_id, + storage_writer=storage_writer, + planner=planner, + process_group=process_group, + no_dist=no_dist, + use_collectives=use_collectives, + ) + + +class _ThreadBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor): + def __init__(self) -> None: + self._executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="AsyncCheckpointExecutor" + ) + + def execute_save( + self, + staging_future_or_state_dict: Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE], + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + no_dist: bool = False, + use_collectives: bool = True, + ) -> Future: + f: Future = self._executor.submit( + save_wrapper, + staging_future_or_state_dict=staging_future_or_state_dict, + checkpoint_id=checkpoint_id, + storage_writer=storage_writer, + planner=planner, + process_group=process_group, + no_dist=no_dist, + use_collectives=use_collectives, + ) + f.add_done_callback(lambda f: self._executor.shutdown(wait=False)) + + return f diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_checkpointer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..13b0d627a36cc0fedc75695932260ecec718bcde --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_checkpointer.py @@ -0,0 +1,103 @@ +from concurrent.futures import Future +from typing import Any, Optional + +import torch.distributed as dist +import torch.distributed.checkpoint.state_dict_loader as loader +import torch.distributed.checkpoint.state_dict_saver as saver +from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE +from torch.distributed.checkpoint.storage import ( + LoadPlanner, + SavePlanner, + StorageReader, + StorageWriter, +) + + +__all__: list[str] = [] + + +class _Checkpointer: + """This base class specifies a high level API for saving and loading + distributed `state_dict` 's. It provides an abstraction over the low-level APIs + provided by :py:mod:`torch.distributed.checkpoint.storage`, essentially calling + :py:meth: `torch.distributed.state_dict_saver.save` and + :py:meth: `torch.distributed.state_dict_loader.load` with the provided storage + readers and writers. + + .. warning:: + This feature is experimental and subject to removal/change. + + """ + + def __init__( + self, + storage_writer: StorageWriter, + storage_reader: StorageReader, + *, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + load_planner: Optional[LoadPlanner] = None, + save_planner: Optional[SavePlanner] = None, + ): + """Initializes the Checkpointer instance. + + Args: + storage_writer: Instance of StorageWrite use to perform writes. + storage_reader: StorageReader used to load data from. + process_group: ProcessGroup to be used for cross-rank synchronization. + coordinator_rank: Rank to use to coordinate the checkpoint. rank0 is used by default. + no_dist: If ``True``, distributed checkpoint will not load in SPMD style. (Default: ``False``) + loader_planner: Instance of LoadPlanner to use when loading. + save_planner: Instance of SavePlanner to use when saving. + """ + self.storage_writer = storage_writer + self.storage_reader = storage_reader + self.process_group = process_group + self.coordinator_rank = coordinator_rank + self.no_dist = no_dist + self.load_planner = load_planner + self.save_planner = save_planner + + def save( + self, + state_dict: STATE_DICT_TYPE, + ) -> Metadata: + """Calls :py:meth: `torch.distributed.state_dict_saver.save`. Utilizing values passed during initialization.""" + return saver.save( + state_dict, + self.storage_writer, + process_group=self.process_group, + coordinator_rank=self.coordinator_rank, + no_dist=self.no_dist, + planner=self.save_planner, + ) + + def async_save( + self, + state_dict: STATE_DICT_TYPE, + ) -> Future: + """ + Calls :py:meth: `torch.distributed.state_dict_saver._async_save`. Utilizing values passed during initialization. + + Returns: + Future: A future holding the resultant Metadata object from `save`. + """ + response = saver.async_save( + state_dict, + storage_writer=self.storage_writer, + process_group=self.process_group, + planner=self.save_planner, + ) + if not isinstance(response, Future): + raise AssertionError("response should be a Future instance") + return response + + def load(self, state_dict: dict[str, Any]) -> None: + """Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization.""" + loader.load( + state_dict, + storage_reader=self.storage_reader, + process_group=self.process_group, + planner=self.load_planner, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_consolidate_hf_safetensors.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_consolidate_hf_safetensors.py new file mode 100644 index 0000000000000000000000000000000000000000..32d81fb1ea7213e7672a9e7fe23b030962a354f0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_consolidate_hf_safetensors.py @@ -0,0 +1,716 @@ +# pyre-strict + +import concurrent.futures +import glob +import json +import logging +import math +import mmap +import os +import struct +import time +from dataclasses import dataclass, field +from typing import Any, Optional + +import torch +from torch import distributed as dist +from torch.distributed.checkpoint._hf_utils import ( + _gen_file_name, + _get_dcp_custom_metadata, + _get_safetensors_file_metadata, + _metadata_fn, + DATA_OFFSETS_KEY, + DEFAULT_EXTRA_METADATA_KEY, + DTYPE_KEY, + SAVED_OFFSETS_KEY, + SHAPE_KEY, + SUFFIX, +) + + +logger: logging.Logger = logging.getLogger(__name__) + + +@dataclass +class _FqnData: + """ + Dataclass to store information about a tensor (identified by its fully qualified name). + + Attributes: + offset_in_file: Byte offset where this tensor's data begins in the output file + shape_in_file: Shape of the tensor in the output file + dtype_size: Size of the tensor's data type in bytes + dtype_str: String representation of the tensor's data type + """ + + offset_in_file: int = 0 + shape_in_file: list[int] = field(default_factory=list) + dtype_size: int = 0 + dtype_str: str = "" + + +@dataclass +class _OutputFileData: + """ + Dataclass to store information about an output safetensors file. + + Attributes: + metadata_size: Size of the metadata section in bytes + fqn_data: Dictionary mapping tensor names to their metadata + """ + + metadata_size: int = 0 + fqn_data: dict[str, _FqnData] = field(default_factory=dict) + + +@dataclass +class _InputFileData: + """ + Dataclass to store information about an input safetensors file. + + Attributes: + metadata_size: Size of the metadata section in bytes + metadata: Json metadata from the safetensors file + """ + + metadata_size: int = 0 + metadata: Any = None + + +def _parse_input_metadata( + input_files_data: dict[str, _InputFileData], + output_files_data: dict[str, _OutputFileData], +) -> None: + """ + Parse metadata from input safetensors files to determine the full tensor shapes and types. + + This function analyzes the metadata from all input files to determine the complete shape + of each tensor after consolidation. It updates the output_files_data with this information. + + Args: + input_files_data: dict of metadata from input safetensors files + output_files_data: Dictionary mapping output file paths to their metadata + + Raises: + ValueError: If no DCP custom metadata is found in a safetensors file + """ + + from safetensors.torch import _getdtype # type: ignore[import] + + # Dictionary to track the full size of each tensor across all shards + fqn_to_size_mapping: dict[str, tuple[list[int], str]] = {} + + for file_data in input_files_data.values(): + safetensors_metadata = file_data.metadata + dcp_sharding_info = _get_dcp_custom_metadata(safetensors_metadata) + if not dcp_sharding_info: + raise ValueError( + "No DCP custom metadata found in safetensors file. The file must be saved with DCP to be consolidated." + ) + + for key, val in safetensors_metadata.items(): + if key == DEFAULT_EXTRA_METADATA_KEY: + continue + + # Get the shape of this tensor shard and its offset in the full tensor + sizes = val[SHAPE_KEY] + offsets = dcp_sharding_info[key][SAVED_OFFSETS_KEY] + + if key not in fqn_to_size_mapping: + # First time seeing this tensor - calculate its full size by adding offsets to dimensions + cur_size = [size + offset for size, offset in zip(sizes, offsets)] + fqn_to_size_mapping[key] = (cur_size, val[DTYPE_KEY]) + else: + # We've seen this tensor before - update its size if this shard extends beyond current known dimensions + cur_size = fqn_to_size_mapping[key][0] + for i in range(len(sizes)): + cur_size[i] = max(cur_size[i], sizes[i] + offsets[i]) + + # Now that we know the full size of each tensor, populate the output file data + for fqn, tensor_info in fqn_to_size_mapping.items(): + tensor_size = tensor_info[0] + dtype_str = tensor_info[1] + for output_data in output_files_data.values(): + # Add this tensor to the output file if it's already assigned there + if fqn in output_data.fqn_data: + output_data.fqn_data[fqn] = _FqnData( + shape_in_file=tensor_size, + dtype_size=torch.finfo(_getdtype(dtype_str)).bits + // 8, # Convert bits to bytes + dtype_str=dtype_str, + ) + + +def _write_metadata( + output_files_data: dict[str, _OutputFileData], +) -> None: + """ + Write metadata to the beginning of each output safetensors file. + + This function writes the metadata section to each output file, including information + about tensor shapes, data types, and offsets. It also updates the offset_in_file + field for each tensor in the output_files_data. + + Args: + output_files_data: Dictionary mapping output file paths to their metadata + """ + # Process each output file + for file_path, output_data in output_files_data.items(): + with open(file_path, "wb") as f: + metadata = {} + curr_offset = 0 + + # Calculate offsets for each tensor in the file + for fqn, fqn_data in output_data.fqn_data.items(): + # Calculate the end offset by multiplying all dimensions and the data type size + end_offset = ( + curr_offset + + math.prod(fqn_data.shape_in_file) * fqn_data.dtype_size + ) + + # Store metadata for this tensor + metadata[fqn] = { + SHAPE_KEY: fqn_data.shape_in_file, + DTYPE_KEY: fqn_data.dtype_str, + DATA_OFFSETS_KEY: [ + curr_offset, + end_offset, + ], # Start and end byte offsets + } + # Store the offset for later use when writing the actual tensor data + fqn_data.offset_in_file = curr_offset + + # Update current offset for the next tensor + curr_offset = end_offset + + # Convert metadata to JSON and encode as bytes + json_metadata = json.dumps(metadata) + json_bytes = json_metadata.encode("utf-8") + + # Write the metadata size as an 8-byte unsigned integer (little-endian) + size_in_bytes = len(json_bytes) + header_len = struct.pack(" bytes: + """ + Read tensor data from a safetensors file using memory mapping for efficiency. + + Args: + file_path: Path to the safetensors file + start_offset: Start offset of tensor data within the data section + end_offset: End offset of tensor data within the data section + metadata_size: Size of the metadata header + + Returns: + Raw tensor data as bytes + """ + # Use mmap for efficient access + with open(file_path, "rb") as f: + with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm: + absolute_start = metadata_size + start_offset + absolute_end = metadata_size + end_offset + return bytes(mm[absolute_start:absolute_end]) + + +def _process_output_file( + output_file: str, + output_data: _OutputFileData, + input_files_data: dict[str, _InputFileData], +) -> None: + """ + Process a single output file by writing tensor data from input files using memory mapping. + + This function is designed to be run in parallel for different output files. + + Args: + output_file: Path to the output file + output_data: Metadata for the output file + input_files_data: Dictionary mapping input file paths to their metadata + """ + + sorted_tensors = sorted( + output_data.fqn_data.items(), key=lambda x: x[1].offset_in_file + ) + + with open(output_file, "r+b") as output_stream: + output_stream.seek(0, os.SEEK_END) + # Process each tensor in sequential output order + for tensor_fqn, tensor_fqn_data in sorted_tensors: + full_tensor_mv = memoryview( + bytearray( + math.prod(tensor_fqn_data.shape_in_file) + * tensor_fqn_data.dtype_size + ) + ) + + # Process each input safetensors file + for safetensors_file in input_files_data: + file_metadata = input_files_data[safetensors_file].metadata + input_metadata_size = input_files_data[safetensors_file].metadata_size + + if tensor_fqn not in file_metadata: + continue + + metadata = file_metadata[tensor_fqn] + + data_offsets = metadata[DATA_OFFSETS_KEY] + + # Use memory mapping to read tensor data efficiently + data_to_write = _read_tensor_data_mmap( + safetensors_file, + data_offsets[0], + data_offsets[1], + input_metadata_size, + ) + + # Get the offsets of this tensor shard within the full tensor + fqn_custom_metadata = _get_dcp_custom_metadata(file_metadata)[ + tensor_fqn + ] # type: ignore[index] + offsets_of_tensor_being_read = fqn_custom_metadata[SAVED_OFFSETS_KEY] # type: ignore[index] + + # Write this tensor shard to the appropriate position in the output file + _write_sub_tensor_to_file_optimized( + full_tensor_mv, + data_to_write, + tensor_fqn_data.dtype_size, # Size of each element in bytes + tensor_fqn_data.shape_in_file, # Full tensor shape + offsets_of_tensor_being_read, # Where this shard belongs in the full tensor + metadata[SHAPE_KEY], # Shape of this shard + ) + + output_stream.write(full_tensor_mv) + + +def _write_data( + input_files_data: dict[str, _InputFileData], + output_files_data: dict[str, _OutputFileData], + num_threads: int = 1, +) -> None: + """ + Write tensor data from input files to the output files using memory mapping. + + This function reads tensor data from each input file and writes it to the appropriate + position in the output files based on the tensor's offsets. When num_threads > 1, + the work is split across threads with each thread handling a different output file. + + Args: + input_files_data: Dictionary mapping input file paths to their metadata + output_files_data: Dictionary mapping output file paths to their metadata + num_threads: Number of threads to use for parallel processing + """ + if num_threads <= 1 or len(output_files_data) <= 1: + # Sequential processing + for output_file, output_data in output_files_data.items(): + _process_output_file(output_file, output_data, input_files_data) + else: + # Parallel processing with ThreadPoolExecutor + with concurrent.futures.ThreadPoolExecutor( + max_workers=min(num_threads, len(output_files_data)) + ) as executor: + futures = [] + for output_file, output_data in output_files_data.items(): + futures.append( + executor.submit( + _process_output_file, + output_file, + output_data, + input_files_data, + ) + ) + + # Wait for all futures to complete + for future in concurrent.futures.as_completed(futures): + # Handle any exceptions that might have occurred + try: + future.result() + except Exception as e: + print(f"Error processing output file: {e}") + raise + + +def _write_sub_tensor_to_file_optimized( + full_tensor_mv: memoryview, + sub_tensor_bytes: bytes, + element_size: int, + tensor_shape: list[int], + sub_tensor_offsets: list[int], + sub_tensor_shape: list[int], +) -> None: + """ + Optimized version that writes the maximum number of contiguous bytes possible. + + Uses a unified algorithm that calculates the maximum contiguous bytes that can be + written in each iteration and continues until the entire subtensor is written. + Handles all sharding patterns efficiently: + - Full sub-tensor at once for row-wise sharding + - Row-by-row for column-wise sharding + - Optimized chunks for other patterns + + Args: + full_tensor_mv: Buffer to write the full tensor to + sub_tensor_bytes: Raw tensor data as bytes + element_size: Size of each element in bytes + tensor_shape: Shape of the full tensor + sub_tensor_offsets: Starting offsets of the sub-tensor within the full tensor + sub_tensor_shape: Shape of the sub-tensor + """ + # Handle empty tensors + if not tensor_shape or not sub_tensor_shape: + return + + # Calculate tensor strides for efficient indexing + tensor_strides = [1] + for i in range(len(tensor_shape) - 1, 0, -1): + tensor_strides.insert(0, tensor_strides[0] * tensor_shape[i]) + + sub_tensor_strides = [1] + for i in range(len(sub_tensor_shape) - 1, 0, -1): + sub_tensor_strides.insert(0, sub_tensor_strides[0] * sub_tensor_shape[i]) + + total_elements = math.prod(sub_tensor_shape) + + elements_written = 0 + while elements_written < total_elements: + # Convert linear index to multi-dimensional indices + temp_idx = elements_written + indices = [] + for dim_size in reversed(sub_tensor_shape): + indices.append(temp_idx % dim_size) + temp_idx //= dim_size + indices.reverse() + + # Calculate maximum contiguous elements we can write from this position + max_contiguous = _calculate_max_contiguous_elements( + indices, sub_tensor_shape, tensor_shape + ) + + # Calculate source position in bytes + src_pos = sum(idx * stride for idx, stride in zip(indices, sub_tensor_strides)) + src_byte_offset = src_pos * element_size + + # Calculate destination position in bytes + dest_indices = [ + idx + offset for idx, offset in zip(indices, sub_tensor_offsets) + ] + dest_pos = sum( + idx * stride for idx, stride in zip(dest_indices, tensor_strides) + ) + dest_byte_offset = dest_pos * element_size + + # Write the contiguous chunk + bytes_to_write = max_contiguous * element_size + chunk_data = sub_tensor_bytes[ + src_byte_offset : src_byte_offset + bytes_to_write + ] + full_tensor_mv[dest_byte_offset : dest_byte_offset + bytes_to_write] = ( + chunk_data + ) + + elements_written += max_contiguous + + +def _calculate_max_contiguous_elements( + indices: list[int], + sub_tensor_shape: list[int], + tensor_shape: list[int], +) -> int: + """ + Calculate the maximum number of contiguous elements that can be written from current position. + + This determines the largest chunk by checking how elements are laid out in memory + and finding natural boundaries where contiguity breaks. + + Args: + indices: Current position indices in the sub-tensor + sub_tensor_shape: Shape of the sub-tensor being written + tensor_shape: Shape of the full tensor + + Raises: + ValueError: If input lists are empty, have mismatched lengths, or contain invalid values + """ + # Validate input lists are not empty + if not indices or not sub_tensor_shape or not tensor_shape: + raise ValueError("Input lists cannot be empty") + + # Validate all lists have the same length (same number of dimensions) + if not (len(indices) == len(sub_tensor_shape) == len(tensor_shape)): + raise ValueError( + f"All input lists must have the same length. Got indices: {len(indices)}, " + f"sub_tensor_shape: {len(sub_tensor_shape)}, tensor_shape: {len(tensor_shape)}" + ) + + # Validate indices are within bounds of sub_tensor_shape + for i, (idx, sub_dim) in enumerate(zip(indices, sub_tensor_shape)): + if idx >= sub_dim: + raise ValueError( + f"Index {idx} at dimension {i} is out of bounds for sub-tensor shape {sub_tensor_shape}" + ) + + # Validate sub_tensor dimensions don't exceed tensor dimensions + for i, (sub_dim, tensor_dim) in enumerate(zip(sub_tensor_shape, tensor_shape)): + if sub_dim > tensor_dim: + raise ValueError( + f"Sub-tensor dimension {sub_dim} at position {i} exceeds tensor dimension {tensor_dim}" + ) + + # Start with elements remaining in the last dimension + max_contiguous = sub_tensor_shape[-1] - indices[-1] + + # Check if we can extend across multiple dimensions + # We can write across dimension boundaries if we're writing complete "rows" + # and the layout in destination tensor maintains contiguity + + # For 2D case: check if we can write multiple complete rows + if len(sub_tensor_shape) >= 2: + # If we're at the start of a row and can write complete rows + if indices[-1] == 0: # At start of last dimension (column) + rows_remaining = sub_tensor_shape[-2] - indices[-2] # Rows left to write + + # Check if writing complete rows maintains contiguity in destination + # This is true for row-wise sharding or when sub-tensor spans full width + if sub_tensor_shape[-1] == tensor_shape[-1]: # Full width + max_contiguous = rows_remaining * sub_tensor_shape[-1] + + # For higher dimensions, check if we can extend further + if len(sub_tensor_shape) >= 3 and indices[-2] == 0: + # Check if we can write complete 2D slices + remaining_in_dim = sub_tensor_shape[-3] - indices[-3] + if ( + sub_tensor_shape[-1] == tensor_shape[-1] + and sub_tensor_shape[-2] == tensor_shape[-2] + ): + max_contiguous = ( + remaining_in_dim * sub_tensor_shape[-2] * sub_tensor_shape[-1] + ) + + return max_contiguous + + +def _write_overall_metadata_file( + output_dir: str, + output_files_data: dict[str, _OutputFileData], +) -> None: + """ + Write the overall metadata file that maps tensor names to their file locations. + + This creates a model.safetensors.index.json file that HuggingFace models use + to locate tensors across multiple files. + + Args: + output_dir: Directory where the metadata file will be written + output_files_data: Dictionary mapping output file paths to their metadata + """ + total_size = 0 + weight_map = {} + for output_path, value in output_files_data.items(): + for fqn, fqn_data in value.fqn_data.items(): + total_size += math.prod(fqn_data.shape_in_file) * fqn_data.dtype_size + weight_map[fqn] = os.path.basename(output_path) + + metadata_to_write: dict[str, Any] = {} + metadata_to_write["metadata"] = {"total_size": total_size} + metadata_to_write["weight_map"] = weight_map + + metadata_path = os.path.join(output_dir, f"{_metadata_fn}") + with open(metadata_path, "w") as metadata_file: + json.dump(metadata_to_write, metadata_file, indent=2) + + +def _consolidate_safetensors_files( + input_dir: str, + output_dir: str, + fqn_to_file_mapping: dict[str, str], + num_threads: int, +) -> dict[str, _OutputFileData]: + output_files_data: dict[str, _OutputFileData] = {} + # Create multiple output files based on the provided mapping + for fqn, filename in fqn_to_file_mapping.items(): + output_path = os.path.join(output_dir, filename) + + if output_path not in output_files_data: + output_files_data[output_path] = _OutputFileData(fqn_data={fqn: _FqnData()}) + else: + output_files_data[output_path].fqn_data[fqn] = _FqnData() + + # Find all safetensors files in the input directory + safetensors_files = glob.glob(os.path.join(input_dir, f"*{SUFFIX}")) + + # Read metadata from all input files + input_files_data: dict[str, _InputFileData] = {} + for safetensor_file in safetensors_files: + with open(safetensor_file, "rb") as f: + metadata, size = _get_safetensors_file_metadata(f) + input_files_data[safetensor_file] = _InputFileData( + metadata_size=size, metadata=metadata + ) + # Step 1: Parse metadata to determine tensor shapes and types + _parse_input_metadata(input_files_data, output_files_data) + + # Step 2: Write metadata headers to output files + _write_metadata(output_files_data) + # Step 3: Write actual tensor data from input files to output files + _write_data(input_files_data, output_files_data, num_threads) + + return output_files_data + + +def consolidate_safetensors_files( + input_dir: str, + output_dir: str, + fqn_to_index_mapping: dict[str, int], + num_threads: int = 1, +) -> None: + """ + Main function to consolidate sharded safetensors files into one or more output files. + + This function orchestrates the entire consolidation process: + 1. Sets up the output file structure based on the fqn_to_index_mapping + 2. Finds all safetensors files in the input directory + 3. Parses metadata from all input files + 4. Writes metadata to the output files + 5. Writes tensor data from input files to output files + 6. Writes overall model.index.safetensors.json file with weight map + + Args: + input_dir: Directory containing sharded safetensors files + output_dir: Directory where consolidated files will be written + fqn_to_index_mapping: Optional mapping of tensor names to output file indices. + If None, all tensors will be consolidated into a single file. + num_threads: Number of threads to use for parallel processing of saving data to output files. + """ + start_time = time.time() + logger.info( + "Consolidating safetensors files from %s to %s. Beginning at time %f", + input_dir, + output_dir, + start_time, + ) + + max_index = max(fqn_to_index_mapping.values()) + fqn_to_file_mapping = { + fqn: _gen_file_name(idx, max_index) for fqn, idx in fqn_to_index_mapping.items() + } + + output_files_data = _consolidate_safetensors_files( + input_dir, output_dir, fqn_to_file_mapping, num_threads + ) + + # Step 4: Write overall model.index.safetensors.json file with weight map + _write_overall_metadata_file(output_dir, output_files_data) + + logger.info("Done consolidating. Took %.2f secs.", time.time() - start_time) + + +def consolidate_safetensors_files_on_every_rank( + input_dir: str, + output_dir: str, + fqn_to_index_mapping: dict[str, int], + num_threads: int = 1, + process_group: Optional[dist.ProcessGroup] = None, +) -> None: + """ + Consolidate sharded safetensors files across multiple ranks, with each rank handling a subset of output files. + + This function distributes the consolidation work by assigning output files to different ranks. + All tensors with the same index in fqn_to_index_mapping are processed by the same rank, + as they belong to the same output file. + + If process_group is provided, rank and world_size will be derived from it. Otherwise, + they will be automatically detected from the distributed environment if available. + + Args: + input_dir: Directory containing sharded safetensors files + output_dir: Directory where consolidated files will be written + fqn_to_index_mapping: Mapping of tensor names to output file indices + num_threads: Number of threads to use for parallel processing on each rank + process_group: PyTorch distributed process group (default: None, will use default group) + """ + + start_time = time.time() + # Derive rank and world_size from process_group or default distributed environment + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank(group=process_group) + world_size = dist.get_world_size(group=process_group) + else: + # Default to single process mode if distributed is not initialized + rank = 0 + world_size = 1 + logger.warning( + "Distributed environment not initialized. Running in single process mode." + ) + logger.info( + "Rank %d/%d: Consolidating safetensors files from %s to %s", + rank, + world_size, + input_dir, + output_dir, + ) + + # Find all unique indices in the mapping + unique_indices = set(fqn_to_index_mapping.values()) + + # Distribute indices across ranks + indices_for_this_rank = [] + for idx in unique_indices: + # Simple distribution: index % world_size == rank + if idx % world_size == rank: + indices_for_this_rank.append(idx) + + logger.info( + "Rank %d: Assigned %d output files out of %d total files", + rank, + len(indices_for_this_rank), + len(unique_indices), + ) + + # Filter the fqn_to_index_mapping to only include tensors for this rank + filtered_mapping = { + fqn: idx + for fqn, idx in fqn_to_index_mapping.items() + if idx in indices_for_this_rank + } + + if filtered_mapping: + # Convert index mapping to filename mapping + max_index = max(unique_indices) + filtered_filename_mapping = {} + for fqn, idx in filtered_mapping.items(): + filename = _gen_file_name(idx, max_index) + filtered_filename_mapping[fqn] = filename + + # Call the existing consolidation function with the filtered mapping + _consolidate_safetensors_files( + input_dir=input_dir, + output_dir=output_dir, + fqn_to_file_mapping=filtered_filename_mapping, + num_threads=num_threads, + ) + + logger.info( + "Rank %d: Done consolidating. Processed %d unique indices in %.2f secs.", + rank, + len(indices_for_this_rank), + time.time() - start_time, + ) + + # Wait for all ranks to complete + if dist.is_available() and dist.is_initialized(): + logger.info("Rank %d: Waiting for all ranks to complete...", rank) + dist.barrier() + logger.info("Rank %d: All ranks have completed.", rank) + if rank == 0: + logger.info("Total time taken: %.2f secs.", time.time() - start_time) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_dedup_save_plans.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_dedup_save_plans.py new file mode 100644 index 0000000000000000000000000000000000000000..acb81c41862852320cdc1d412ddaffdd48e73841 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_dedup_save_plans.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import dataclasses +from collections import defaultdict +from typing import TYPE_CHECKING + +from torch.distributed.checkpoint.planner import SavePlan, WriteItem + + +if TYPE_CHECKING: + from torch.distributed.checkpoint.metadata import MetadataIndex + +__all__ = ["dedup_save_plans"] + + +def dedup_save_plans( + all_plans: list[SavePlan], + save_to_lowest_rank: bool = False, +) -> list[SavePlan]: + """ + Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across + a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry. + + Please note that this function does not modify the original SavePlans, but rather returns + """ + + # Map to query the plan indices that a write item is duplicated in + write_item_to_plan_indices: dict[MetadataIndex, set[int]] = defaultdict(set) + # Map to query the write item from its index + write_item_idx_to_write_item: dict[MetadataIndex, WriteItem] = {} + # Set of write item indices that are present in each plan + # After deduplication, this will be the set of write item indices that are present in the final plans + plan_to_item_indices: list[set[MetadataIndex]] = [ + {item.index for item in plan.items} for plan in all_plans + ] + + for plan_idx, plan in enumerate(all_plans): + for write_item in plan.items: + # map each write item to its plan + write_item_to_plan_indices[write_item.index].add(plan_idx) + write_item_idx_to_write_item[write_item.index] = write_item + plan_to_size = [0] * len(all_plans) + for write_item_idx, plan_indices in write_item_to_plan_indices.items(): + if save_to_lowest_rank: + select_plan_idx = min(plan_indices) + else: + select_plan_idx = min( + plan_indices, key=lambda plan_idx: plan_to_size[plan_idx] + ) + + write_item = write_item_idx_to_write_item[write_item_idx] + # Ignore the storage size of anything that is not a tensor, since + # we don't know how much storage they represent + plan_to_size[select_plan_idx] += write_item.tensor_storage_size() or 1 + for plan_idx in plan_indices - {select_plan_idx}: + plan_to_item_indices[plan_idx].discard(write_item_idx) + # Sanity check + if len(all_plans) != len(plan_to_item_indices): + raise AssertionError("len(all_plans) != len(plan_to_item_indices)") + # Create new plans with the updated write items post deduplication + return [ + dataclasses.replace( + plan, items=[item for item in plan.items if item.index in item_indexes] + ) + for plan, item_indexes in zip(all_plans, plan_to_item_indices) + ] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_dedup_tensors.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_dedup_tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..c57b2e149106abbac66522aa571d1a462db4157d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_dedup_tensors.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import dataclasses +import logging +from typing import TYPE_CHECKING + +from torch.distributed.checkpoint.planner import SavePlan + + +if TYPE_CHECKING: + from torch.distributed.checkpoint.metadata import MetadataIndex + +__all__ = ["dedup_tensors"] + + +def init_logger() -> logging.Logger: + logger = logging.getLogger(__name__) + level = logging.INFO + logger.setLevel(level) + console = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s" + ) + console.setFormatter(formatter) + console.setLevel(level) + logger.addHandler(console) + logger.propagate = False + return logger + + +logger = init_logger() + + +# TODO add docstring for dedup_tensors +def dedup_tensors(all_plans: list[SavePlan]) -> list[SavePlan]: + all_plans = list(all_plans) + key_to_plan: dict[MetadataIndex, list[int]] = {} + for plan_idx, plan in enumerate(all_plans): + for write_item in plan.items: + key_to_plan.setdefault(write_item.index, []).append(plan_idx) + + replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1} + + # Remove duplicates by always keeping the first entry. + # Compute the per-rank remove set. + plan_to_keys: dict[int, list[MetadataIndex]] = {} + for key, plans in replicated_items.items(): + for plan_idx in plans[1:]: + plan_to_keys.setdefault(plan_idx, []).append(key) + if len(plan_to_keys) > 0: + logger.info("Duplicate keys to remove: %s", plan_to_keys) + + for plan_idx, keys in plan_to_keys.items(): + key_set = set(keys) + # rewrite items and remove elements + new_items = [ + write_item + for write_item in all_plans[plan_idx].items + if write_item.index not in key_set + ] + all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items) + + return all_plans diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8361362eb3a5ed5abae10d39c1db54c3e8739b46 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__init__.py @@ -0,0 +1,53 @@ +""" +Checkpoint functionality for machine learning models. + +This module provides classes for saving and loading model checkpoints in a distributed +training environment. It includes functionality for coordinating checkpoint operations +across multiple processes and customizing the checkpoint process through hooks. + +Key components: +- Checkpointer: Main class for orchestrating checkpoint operations (save, load) +- CheckpointWriter: Handles writing state dictionaries to storage +- CheckpointReader: Handles reading state dictionaries from storage read +- Barrier: Synchronization mechanism for distributed checkpointing +- RankInfo: Information about the current rank in a distributed environment +""" + +from .barriers import ( + Barrier, + BarrierConfig, + create_barrier_from_config, + TCPStoreBarrier, +) +from .builder import make_async_checkpointer, make_sync_checkpointer +from .checkpoint_reader import CheckpointReader +from .checkpoint_writer import CheckpointWriter, CheckpointWriterConfig, WriterHook +from .checkpointer import AsyncCheckpointer, Checkpointer, SyncCheckpointer +from .config import CheckpointerConfig +from .staging import CheckpointStager, CheckpointStagerConfig, DefaultStager +from .types import RankInfo, STATE_DICT +from .utils import wrap_future + + +__all__ = [ + "Barrier", + "TCPStoreBarrier", + "CheckpointReader", + "CheckpointWriter", + "CheckpointWriterConfig", + "WriterHook", + "Checkpointer", + "SyncCheckpointer", + "AsyncCheckpointer", + "CheckpointerConfig", + "BarrierConfig", + "create_barrier_from_config", + "CheckpointStager", + "CheckpointStagerConfig", + "DefaultStager", + "RankInfo", + "STATE_DICT", + "wrap_future", + "make_sync_checkpointer", + "make_async_checkpointer", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..996ad03b72f22ee3506c9673804c1b71b8f8b3a4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/barriers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/barriers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9a593160808fe0a64d5026fb66fd6d5c2835f69 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/barriers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/builder.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/builder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06bc889048b837abb53929b9431c5d426c7fa1b6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/builder.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpoint_process.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpoint_process.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab606078d61b2e82aa3f21e7a876af4005872146 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpoint_process.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpoint_reader.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpoint_reader.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d8391df539c40b4a1cd363477d782a4418b9d13 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpoint_reader.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpoint_writer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpoint_writer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2331cbcfc96829796679761cbdf95cbac2b10433 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpoint_writer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpointer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpointer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbb4183736105e61c880caa0a166d6d34a5277c0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpointer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/config.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef37015d18c8cbc8c8cbf10fe81d4d98ef20b05d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/config.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/staging.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/staging.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50c420107782eb5215839071ba29affc2485fc54 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/staging.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/types.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/types.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68b42e51c144930e2c387fc6b588b51414c95bf1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/types.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0f29f37abfc7a372e504d3cbaa8220ac9908979 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/barriers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/barriers.py new file mode 100644 index 0000000000000000000000000000000000000000..bcea8ad91401e50f4e9f39ace06f0bafe4f0d6a2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/barriers.py @@ -0,0 +1,267 @@ +""" +Barrier implementations for synchronizing distributed checkpoint operations. + +This module provides abstract and concrete barrier implementations that ensure +all ranks in a distributed training environment complete their checkpoint operations +before proceeding, which is essential for data consistency. +""" + +import abc +import logging +from collections import Counter +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Any, Optional + +import torch.distributed as dist +import torch.distributed.elastic.utils.store as store_util + + +logger = logging.getLogger() + + +# Registry of barrier types +BARRIER_REGISTRY: dict[str, type] = {} + + +def register_barrier(barrier_class: type) -> type: + """Register a barrier class in the global registry.""" + if hasattr(barrier_class, "barrier_type"): + BARRIER_REGISTRY[barrier_class.barrier_type] = barrier_class + return barrier_class + + +@dataclass +class BarrierConfig: + """ + Configuration for barrier construction. + + This class provides a flexible way to configure different barrier implementations + with their specific constructor arguments. The barrier type will be looked up + from a registry and instantiated with rank_info and barrier_args. + + Attributes: + barrier_type: A string identifying the barrier type (e.g., "tcp_store"). + If None, no barrier will be used. + barrier_args: Dictionary of arguments to pass to the barrier constructor. + rank_info will be automatically injected as the first argument. + + Examples: + # No barrier + BarrierConfig() + + # TCPStore barrier + BarrierConfig( + barrier_type="tcp_store", + barrier_args={ + 'timeout_barrier_init_secs': 30, + 'barrier_prefix_list': ['checkpoint'], + 'use_checkpoint_barrier_tcpstore_libuv': False, + 'tcpstore_port': 12345, + 'master_address': 'localhost' + } + ) + """ + + barrier_type: Optional[str] = None + barrier_args: dict[str, Any] = field(default_factory=dict) + + +def create_barrier_from_config( + barrier_config: BarrierConfig, +) -> Optional["Barrier"]: + """ + Create a barrier instance from BarrierConfig. + + Args: + barrier_config: Configuration for barrier construction. + + Returns: + Barrier instance or None if no barrier type is configured. + + Raises: + ValueError: If the barrier_type is not found in the registry. + """ + if barrier_config.barrier_type is None: + return None + + if barrier_config.barrier_type not in BARRIER_REGISTRY: + raise ValueError( + f"Unknown barrier type: {barrier_config.barrier_type}. " + f"Available types: {list(BARRIER_REGISTRY.keys())}" + ) + + barrier_class = BARRIER_REGISTRY[barrier_config.barrier_type] + return barrier_class(**barrier_config.barrier_args) + + +class Barrier(abc.ABC): + """ + Abstract base class for synchronization barriers. + + A barrier ensures that all ranks in a distributed environment reach a certain + point in execution before any rank proceeds further, which is essential for + coordinating operations like checkpointing across multiple processes. + """ + + @abc.abstractmethod + def __init__(self, **kwargs: dict[str, Any]): + """ + Initialize a barrier. + + Args: + **kwargs: Keyword arguments for specific barrier implementations. + Common arguments may include rank information, barrier prefixes, + timeout settings, and other barrier-specific configuration. + """ + # No implementation needed in the abstract base class + + @abc.abstractmethod + def execute_barrier(self) -> None: + """ + Execute a synchronization barrier. + + This method uses the barrier_prefix provided during initialization to + coordinate synchronization across processes. + """ + + +@register_barrier +class DistBarrier(Barrier): + """ + A barrier implementation using PyTorch's distributed barrier for synchronization. + + This barrier uses the built-in torch.distributed.barrier() function to coordinate + synchronization across multiple processes. It's simpler than TCPStoreBarrier but + requires an initialized process group. + """ + + barrier_type = "dist_barrier" + + def __init__( + self, + ) -> None: + """ + Initialize a DistBarrier. + + This barrier requires an initialized PyTorch distributed process group. + No additional arguments are needed as it uses the current process group. + + Raises: + AssertionError: If the distributed process group is not initialized. + """ + if not dist.is_initialized(): + raise AssertionError("DistBarrier requires an initialized process group.") + + def execute_barrier(self) -> None: + """ + Execute a synchronization barrier using the prefix provided during initialization. + """ + # Note: dist.barrier() doesn't support explicit timeouts + # The timeout is handled by the underlying implementation + dist.barrier() + + +@register_barrier +class TCPStoreBarrier(Barrier): + """ + A barrier implementation using PyTorch's TCPStore for synchronization. + + This barrier uses a TCP-based distributed key-value store to coordinate + synchronization across multiple processes. It uses a single TCP store + for all barrier operations, with different prefixes to distinguish between + different barrier types. + """ + + barrier_type = "tcp_store" + + def __init__( + self, + global_rank: int, + global_world_size: int, + barrier_prefix: str, + timeout_barrier_init_secs: int, + use_checkpoint_barrier_tcpstore_libuv: bool, + tcpstore_port: int, + master_address: str, + timeout_secs: int, + ): + """ + Initialize a TCPStoreBarrier. + + Args: + global_rank: The rank of the current process in the distributed environment. + global_world_size: The total number of processes in the distributed environment. + barrier_prefix: A string prefix to identify this specific barrier. + timeout_barrier_init_secs: Timeout in seconds for initializing the TCPStore. + use_checkpoint_barrier_tcpstore_libuv: Whether to use libuv for the TCPStore. + tcpstore_port: Port number for the TCPStore. + master_address: Address of the master node for the TCPStore. + timeout_secs: Maximum time in seconds to wait for all ranks to reach the barrier. + """ + logger.info( + "Initializing TCPStore master_address=%s tcpstore_port=%s rank=%s " + "world_size=%s barrier_prefix=%s timeout_barrier_init_secs=%s " + "use_checkpoint_barrier_tcpstore_libuv=%s timeout_secs=%s", + master_address, + tcpstore_port, + global_rank, + global_world_size, + barrier_prefix, + timeout_barrier_init_secs, + use_checkpoint_barrier_tcpstore_libuv, + timeout_secs, + ) + + # Counter collection to track barrier seq on a per barrier prefix basis. + self._tcp_store_barrier_seq: Counter = Counter() + self._barrier_prefix = barrier_prefix + + # Store rank and world size for barrier operations + self._global_rank = global_rank + self._global_world_size = global_world_size + self._timeout_secs = timeout_secs + + # Create a single TCP store for all barrier operations + self._tcp_store = dist.TCPStore( + master_address, + int(tcpstore_port), + world_size=self._global_world_size, + timeout=timedelta(seconds=timeout_barrier_init_secs), + is_master=(self._global_rank == 0), + ) + + def execute_barrier(self) -> None: + """ + Execute a synchronization barrier using the prefix provided during initialization. + + The implementation uses a sequence number that is incremented every time + a barrier is reached. The sequence number is per barrier prefix to allow + different barriers to operate concurrently. + """ + barrier_prefix = self._barrier_prefix + + logger.info( + "Executing barrier barrier_prefix=%s timeout_secs=%s", + barrier_prefix, + self._timeout_secs, + ) + + def _rank_key(rank: int) -> str: + return f"rank{rank}" + + # Track which barrier sequence this rank is joining. + self._tcp_store.set( + _rank_key(self._global_rank), + str(self._tcp_store_barrier_seq[barrier_prefix]), + ) + + # Execute barrier for that sequence number (for the specific prefix). + store_util.barrier( + store=self._tcp_store, + world_size=self._global_world_size, + key_prefix=( + barrier_prefix + str(self._tcp_store_barrier_seq[barrier_prefix]) + ), + ) + self._tcp_store_barrier_seq[barrier_prefix] += 1 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/builder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7f5fa9e71268b665f0863e85c6775b2391b6c3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/builder.py @@ -0,0 +1,174 @@ +""" +Factory functions for creating checkpointer instances with sensible defaults. + +This module provides high-level factory functions that simplify the creation +of checkpointer instances by automatically handling component initialization +and configuration with reasonable defaults. +""" + +from collections.abc import Callable +from typing import Any, Optional + +import torch.distributed as dist + +from .barriers import create_barrier_from_config +from .checkpoint_process import CheckpointProcess +from .checkpoint_reader import CheckpointReader +from .checkpoint_writer import CheckpointWriter, CheckpointWriterConfig, WriterHook +from .checkpointer import AsyncCheckpointer, SyncCheckpointer +from .config import CheckpointerConfig +from .staging import DefaultStager +from .types import RankInfo + + +def _get_default_rank_info() -> RankInfo: + """ + Get default rank information from the current distributed environment. + + Returns: + RankInfo: Rank information from the default process group if initialized, + otherwise single-rank fallback. + """ + if dist.is_initialized(): + return RankInfo( + global_world_size=dist.get_world_size(), + global_rank=dist.get_rank(), + ) + else: + # Single-rank fallback + return RankInfo(global_world_size=1, global_rank=0) + + +def default_subprocess_init_fn(*_: Any) -> None: + """Default subprocess initialization function (no-op).""" + + +def default_writer_init_fn(rank_info: RankInfo) -> CheckpointWriter: + """Default checkpoint writer initialization function.""" + return CheckpointWriter( + config=CheckpointWriterConfig(), + rank_info=rank_info, + ) + + +def make_sync_checkpointer( + config: CheckpointerConfig = CheckpointerConfig(), + rank_info: Optional[RankInfo] = None, + commit_hook: Optional[WriterHook] = None, +) -> SyncCheckpointer: + """ + Factory function to create a SyncCheckpointer instance with sensible defaults. + + This function creates a synchronous checkpointer with default components, automatically + detecting rank information from the default process group if available, and using the + provided component configurations. + + Args: + config: CheckpointerConfig containing component-specific configurations + (writer_config, staging_config, process_config). Defaults to CheckpointerConfig(). + rank_info: RankInfo for distributed training. Defaults to auto-detection from + the default PyTorch distributed process group if initialized, otherwise + falls back to single-rank (world_size=1, rank=0). + commit_hook: Optional hook for custom actions before and after checkpoint commits. + + Returns: + SyncCheckpointer: A configured synchronous checkpointer instance. + + Examples: + # Simplest usage - auto-detect rank, default config + checkpointer = make_sync_checkpointer() + + # Explicit rank configuration + checkpointer = make_sync_checkpointer( + rank_info=RankInfo(global_world_size=4, global_rank=0) + ) + + # Disable barrier + from .barriers import BarrierConfig + config = CheckpointerConfig(barrier_config=BarrierConfig(barrier_type=None)) + checkpointer = make_sync_checkpointer(config=config) + """ + if rank_info is None: + rank_info = _get_default_rank_info() + + reader = CheckpointReader( + rank_info=rank_info, + ) + + barrier = create_barrier_from_config(config.barrier_config) + + writer = CheckpointWriter( + config=config.writer_config, + rank_info=rank_info, + barrier=barrier, + commit_hook=commit_hook, + ) + + return SyncCheckpointer( + writer=writer, + reader=reader, + ) + + +def make_async_checkpointer( + config: CheckpointerConfig = CheckpointerConfig(), + rank_info: Optional[RankInfo] = None, + subprocess_init_fn: Callable[..., None] = default_subprocess_init_fn, + subprocess_init_args: tuple[Any, ...] = (), + checkpoint_writer_init_fn: Callable[..., CheckpointWriter] = default_writer_init_fn, + checkpoint_writer_init_args: Optional[dict[str, Any]] = None, +) -> AsyncCheckpointer: + """ + Factory function to create an AsyncCheckpointer instance with sensible defaults. + + This function creates an asynchronous checkpointer using the provided configuration, + automatically detecting rank information if not provided. + + Args: + config: CheckpointerConfig containing component-specific configurations. + rank_info: RankInfo for distributed training. Defaults to auto-detection. + subprocess_init_fn: Function to initialize the subprocess. Defaults to no-op. + subprocess_init_args: Arguments to pass to subprocess_init_fn. + checkpoint_writer_init_fn: Function to create CheckpointWriter instance. + checkpoint_writer_init_args: Arguments to pass to checkpoint_writer_init_fn. + + Returns: + AsyncCheckpointer: A configured asynchronous checkpointer instance. + + Examples: + # Create with default config + checkpointer = make_async_checkpointer() + + # Create with custom init functions + checkpointer = make_async_checkpointer( + subprocess_init_fn=my_subprocess_init_fn, + checkpoint_writer_init_fn=my_writer_init_fn + ) + """ + if rank_info is None: + rank_info = _get_default_rank_info() + + reader = CheckpointReader( + rank_info=rank_info, + ) + + checkpoint_stager = DefaultStager( + config=config.staging_config, + ) + + checkpoint_writer_init_args = checkpoint_writer_init_args or {} + + checkpoint_process = CheckpointProcess( + rank_info=rank_info, + config=config.process_config, + subprocess_init_fn=subprocess_init_fn, + subprocess_init_args=subprocess_init_args, + checkpoint_writer_init_fn=checkpoint_writer_init_fn, + checkpoint_writer_init_args=checkpoint_writer_init_args, + ) + + return AsyncCheckpointer( + checkpoint_stager=checkpoint_stager, + checkpoint_process=checkpoint_process, + reader=reader, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/checkpoint_process.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/checkpoint_process.py new file mode 100644 index 0000000000000000000000000000000000000000..c71210aaa54690ca894c9ded0fd54c335b7c2f0b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/checkpoint_process.py @@ -0,0 +1,361 @@ +import logging +import os +import traceback +from collections.abc import Callable +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from enum import Enum +from multiprocessing.connection import Connection +from typing import Any, Optional, Union + +import torch.multiprocessing as mp +from torch.multiprocessing.spawn import ProcessExitedException + +from .checkpoint_writer import CheckpointWriter +from .types import RankInfo, STATE_DICT + + +logger = logging.getLogger(__name__) + + +@dataclass +class CheckpointProcessConfig: + """ + Configuration options for the CheckpointProcess. + + This class provides configuration options for the checkpoint process, + including initialization functions, timeouts, and writer configuration. + + Attributes: + subprocess_init_timeout_secs: Maximum time in seconds to wait for subprocess initialization. + subprocess_shutdown_timeout_secs: Maximum time in seconds to wait for subprocess shutdown. + """ + + subprocess_init_timeout_secs: int = 30 + subprocess_shutdown_timeout_secs: int = 60 + + +class RequestType(Enum): + PING = "ping" + WRITE_CHECKPOINT = "write_checkpoint" + TERMINATE_PROCESS = "exit" + + +@dataclass +class WorkerRequest: + """ + A dataclass for storing the command to be sent to the worker process. + Note: This relies on pickling to send the command to the worker process. Handle + backward compatibility accordingly. + """ + + request_type: RequestType + payload: dict[str, Any] + + +@dataclass +class WorkerResponse: + request_type: RequestType + success: bool + error_msg: Optional[str] = None + payload: Optional[dict[str, Any]] = None + + +class CheckpointProcess: + """ + A checkpoint writer that writes checkpoints to a remote process. + """ + + def __init__( + self, + rank_info: RankInfo, + config: CheckpointProcessConfig, + subprocess_init_fn: Callable[[Any], None], + subprocess_init_args: tuple[Any, ...], + checkpoint_writer_init_fn: Callable[..., CheckpointWriter], + checkpoint_writer_init_args: dict[str, Any], + ): + self._executor = ThreadPoolExecutor(max_workers=1) + self._rank_info = rank_info + self._config = config + self._subprocess_init_fn = subprocess_init_fn + self._subprocess_init_args = subprocess_init_args + self._checkpoint_writer_init_fn = checkpoint_writer_init_fn + self._checkpoint_writer_init_args = checkpoint_writer_init_args + self.process = None + self._parent_end: Optional[Connection] = None + self._child_end: Optional[Connection] = None + + self.process_creation_future = self._executor.submit( + self._create_subprocess, + config, + ) + + def _create_subprocess( + self, + config: CheckpointProcessConfig, + ) -> None: + logger.info( + "Creating checkpoint subprocess for rank %d", self._rank_info.global_rank + ) + + spawn_context = mp.get_context("spawn") + self._parent_end, child_end = spawn_context.Pipe() + + # Known workaround for https://github.com/pytorch/pytorch/issues/37377 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU" + + logger.debug("Spawning subprocess for rank_info=%s", self._rank_info) + self.process = mp.spawn( + fn=CheckpointProcess._subprocess, + args=( + self._rank_info, + child_end, + self._subprocess_init_fn, + self._subprocess_init_args, + self._checkpoint_writer_init_fn, + self._checkpoint_writer_init_args, + ), + nprocs=1, + join=False, + daemon=True, + ) + + # close the child end of the pipe so recv on it will fail + # fast when the child process is terminated unexpectedly. + child_end.close() + self._send( + request_type=RequestType.PING, + payload={}, + ) + + logger.debug( + "Waiting for checkpoint subprocess to initialize (timeout: %ds)", + config.subprocess_init_timeout_secs, + ) + + # wait for the timeout or a response from subprocess + if self._parent_end is None: + raise AssertionError("Parent end of pipe should be initialized") + if not self._parent_end.poll(timeout=config.subprocess_init_timeout_secs): + msg = f"Timed out after {config.subprocess_init_timeout_secs}s waiting for checkpoint subprocess to initialize" + logger.error(msg) + raise TimeoutError(msg) + + self._recv() + logger.info("Checkpoint subprocess initialized successfully") + + @staticmethod + def _subprocess( + sub_rank: int, + rank_info: RankInfo, + parent_pipe: Connection, + subprocess_init_fn: Callable[[Any], None], + subprocess_init_args: tuple[Any, ...], + checkpoint_writer_init_fn: Callable[..., CheckpointWriter], + checkpoint_writer_init_args: dict[str, Any], + ) -> None: + logger.debug( + "Checkpoint subprocess started for rank %d/%d (PID: %d)", + rank_info.global_rank, + rank_info.global_world_size, + os.getpid(), + ) + + if sub_rank != 0: + raise AssertionError("We need only one checkpointer per parent training") + request = WorkerRequest(request_type=RequestType.PING, payload={}) + + try: + # Calling initialize callback, so we can perform app-specific initialization of the subprocess. + subprocess_init_fn(*subprocess_init_args) + + # Initialize checkpoint writer - automatically include rank_info in init_args + writer_init_args = dict(checkpoint_writer_init_args) + if "rank_info" not in writer_init_args: + writer_init_args["rank_info"] = rank_info + checkpoint_writer = checkpoint_writer_init_fn(**writer_init_args) + + while True: + request = parent_pipe.recv() + + if request.request_type == RequestType.PING: + parent_pipe.send( + WorkerResponse(request_type=RequestType.PING, success=True) + ) + elif request.request_type == RequestType.WRITE_CHECKPOINT: + path = request.payload["path"] + logger.info("Writing checkpoint to %s", path) + + checkpoint_writer.write( + path=path, + state_dict=request.payload["state_dict"], + **request.payload["kwargs"], + ) + + logger.info("Checkpoint written successfully to %s", path) + parent_pipe.send( + WorkerResponse(RequestType.WRITE_CHECKPOINT, success=True) + ) + elif request.request_type == RequestType.TERMINATE_PROCESS: + logger.debug("Received termination request.") + parent_pipe.send( + WorkerResponse(RequestType.TERMINATE_PROCESS, success=True) + ) + logger.info("Subprocess terminated gracefully") + break + else: + error_msg = f"Unknown request type: {request.request_type}" + logger.error(error_msg) + raise ValueError(error_msg) + + except Exception as e: + error_text = traceback.format_exc() + logger.error( + "Exception in subprocess (%s): %s", type(e).__name__, error_text + ) + + # Communicating exception via the queue to the main process + parent_pipe.send( + WorkerResponse( + request_type=request.request_type, + success=False, + error_msg=error_text, + ) + ) + parent_pipe.close() + logger.exception("Subprocess terminated due to exception") + + def _send(self, request_type: RequestType, payload: dict[str, Any]) -> None: + try: + if self._parent_end is None: + raise AssertionError("Parent end of pipe should be initialized") + self._parent_end.send( + WorkerRequest( + request_type=request_type, + payload=payload, + ) + ) + except OSError as e: + error_msg = "Child process terminated unexpectedly" + logger.exception( + "Communication failed during %s request", request_type.value + ) + raise RuntimeError(error_msg) from e + + def _recv(self) -> Optional[dict[str, Any]]: + try: + if self._parent_end is None: + raise AssertionError("Parent end of pipe should be initialized") + response = self._parent_end.recv() + if response.success is False: + error_msg = ( + f"Unexpected response from worker process: {response.error_msg}" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) + return response.payload + except (EOFError, BrokenPipeError, ConnectionResetError) as e: + error_msg = f"Child process terminated unexpectedly: {e}" + logger.error(error_msg) + raise RuntimeError(error_msg) from e + + def write( + self, + state_dict: Union[STATE_DICT, Future[STATE_DICT]], + path: str, + **kwargs: Any, + ) -> Optional[Future[None]]: + logger.debug("Waiting for subprocess initialization to complete") + + # wait until the process is started + self.process_creation_future.result() + + return self._executor.submit( + self._write, + state_dict, + path, + **kwargs, + ) + + def _write( + self, + state_dict: Union[STATE_DICT, Future[STATE_DICT]], + path: str, + **kwargs: Any, + ) -> None: + logger.debug("Starting checkpoint write to %s", path) + + # wait for staging state_dict to be available + if isinstance(state_dict, Future): + logger.debug("Waiting for state_dict Future to resolve") + sd = state_dict.result() + else: + sd = state_dict + + # Log state_dict info only if debug logging is enabled (performance-conscious) + if logger.isEnabledFor(logging.DEBUG): + if hasattr(sd, "keys"): + logger.debug("State_dict contains %d keys", len(sd.keys())) + + self._send( + request_type=RequestType.WRITE_CHECKPOINT, + payload={ + "state_dict": sd, + "path": path, + "kwargs": kwargs, + }, + ) + + logger.debug("Waiting for write completion response") + # wait for response + self._recv() + logger.debug("Checkpoint write to %s completed successfully", path) + + def close(self) -> None: + logger.debug( + "Closing CheckpointProcess for rank %d", self._rank_info.global_rank + ) + self._executor.shutdown(wait=True, cancel_futures=True) + + if self.process and self.process.processes[0].is_alive(): + subprocess_pid = self.process.processes[0].pid + # send graceful termination to sub process + try: + # pyrefly: ignore [missing-attribute] + self._parent_end.send( + WorkerRequest( + request_type=RequestType.TERMINATE_PROCESS, + payload={}, + ) + ) + except BrokenPipeError: + logger.warning( + "BrokenPipeError when sending termination request - subprocess (PID: %d) may have already terminated", + subprocess_pid, + ) + # subprocess terminated unexpectedly and below code will raise a + # ProcessExitedException. + + logger.debug( + "Waiting for subprocess to terminate gracefully (timeout: %ds)", + self._config.subprocess_shutdown_timeout_secs, + ) + + try: + if not self.process.join( + timeout=self._config.subprocess_shutdown_timeout_secs + ): + # graceful shutdown failed, kill the process. + logger.warning( + "Subprocess (PID: %d) did not terminate gracefully within %ds, killing it", + subprocess_pid, + self._config.subprocess_shutdown_timeout_secs, + ) + self.process.processes[0].kill() + logger.info("Subprocess killed forcefully") + except ProcessExitedException: + logger.exception("ProcessExitedException during subprocess termination") + raise + + logger.debug("CheckpointProcess closed successfully") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/checkpoint_reader.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/checkpoint_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..7be55938cfde162028c201476a5834533de41009 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/checkpoint_reader.py @@ -0,0 +1,223 @@ +""" +Checkpoint reader functionality for machine learning models. + +This module provides classes for reading checkpoints from storage, including +determining checkpoint layout and configuring the reader. +""" + +import logging +import os +from itertools import zip_longest +from pathlib import Path +from typing import Any, Optional + +import torch +from torch._subclasses.fake_tensor import FakeTensorMode + +from .types import RankInfo, STATE_DICT + + +logger = logging.getLogger(__name__) + + +class CheckpointReader: + """ + Handles reading state dictionaries from storage. + + This class is responsible for reading model state dictionaries from storage according + to the specified checkpoint layout. It supports synchronization barriers to ensure + all ranks in a distributed setting complete their checkpoint operations. + """ + + def __init__( + self, + rank_info: RankInfo, + ): + """ + Initialize a CheckpointReader. + + Args: + rank_info: Information about the current rank in a distributed setting. + """ + + self._rank_info = rank_info + + def read( + self, + path: str, + state_dict: Optional[STATE_DICT] = None, + *, + map_location: Any = None, + **kwargs: dict[str, Any], + ) -> tuple[STATE_DICT, list[str]]: + """ + Reads a state dictionary from storage. + + Args: + path (str): The path from which to read the checkpoint. + map_location (Any): Device mapping function or device name for relocating tensors. + **kwargs: Additional keyword arguments passed to torch.load. + + Returns: + STATE_DICT: The loaded state dictionary. + list[str]: List of missing keys. + """ + logger.debug( + "Reading checkpoint from %s for rank %s", + path, + self._rank_info.global_rank, + ) + + dir_path = Path(path) + file_path = dir_path / f"checkpoint_{self._rank_info.global_rank}.pt" + + # Check if the file exists + if not os.path.exists(file_path): + logger.error("Checkpoint file not found at %s", file_path) + raise FileNotFoundError(f"Checkpoint file not found at {file_path}") + + if state_dict is None: + result: tuple[STATE_DICT, list[str]] = ( + torch.load(file_path, map_location=map_location), + [], + ) + else: + result = self._partial_read( + file_path, state_dict, map_location=map_location, **kwargs + ) + logger.debug("Successfully read checkpoint file from %s", file_path) + return result + + def _partial_read( + self, + file_path: Path, + state_dict: STATE_DICT, + *, + map_location: Any = None, + **kwargs: dict[str, Any], + ) -> tuple[STATE_DICT, list[str]]: + """ + Reads only the keys present in state_dict from the checkpoint file. + + This method optimizes checkpoint loading by only loading the tensors that + are actually needed, based on the keys present in the input state_dict. + This can significantly reduce memory usage and loading time for large checkpoints + when only a subset of the model needs to be loaded. + + Args: + file_path (str): The path to the checkpoint file. + state_dict (STATE_DICT): The state dictionary containing keys to load. + map_location (Any): Device mapping function or device name for relocating tensors. + **kwargs: Additional keyword arguments passed to torch.load. + + Returns: + tuple[STATE_DICT, list[str]]: The updated state dictionary with loaded values and a list of missing keys. + """ + + with FakeTensorMode(): + metadata_dict = torch.load(file_path, map_location=map_location) + + missing_keys = [] + + with open(file_path, "rb") as file: + # Helper function to load tensor data from file + def load_tensor( + target: Optional[torch.Tensor], source: torch.Tensor, full_key: str + ) -> torch.Tensor: + if target is not None and ( + target.size() != source.size() or target.dtype != source.dtype + ): + raise RuntimeError( + f"Target tensor size={target.size()} dtype={target.dtype} does not match " + f"source tensor size={source.size()} dtype={source.dtype} for key {full_key}" + ) + + tensor_offset = source.untyped_storage()._checkpoint_offset + + if tensor_offset is None: + raise AssertionError( + "checkpoint_offset for tensor in torch serialized file is not set. This could " + "happen if the checkpoint was saved with a older version of Pytorch. " + "Please make sure that the checkpoint was saved with Pytorch 2.7 or later." + ) + + tensor_len = source.nelement() * source.element_size() + file.seek( + tensor_offset + source.element_size() * int(source.storage_offset()) + ) + if target is None: + target = torch.empty( + source.size(), dtype=source.dtype, device=source.device + ) + + buffer = file.read(tensor_len) + cpu_tensor = torch.frombuffer(buffer, dtype=source.dtype) + tensor = cpu_tensor.view(source.size()) + target.copy_(tensor) + return target + + # Helper function to recursively process nested structures + def process_value( + target_value: Any, source_value: Any, key_path: str + ) -> Any: + source_type = type(source_value) + if source_type is torch._subclasses.fake_tensor.FakeTensor: + source_type = torch.Tensor + if target_value is not None and not isinstance( + target_value, source_type + ): + raise RuntimeError( + f"Target value {key_path} is set to {type(target_value)}, but source value is {type(source_value)}" + ) + if isinstance(source_value, torch.Tensor): + return load_tensor(target_value, source_value, key_path) + elif isinstance(source_value, dict): + if target_value is None: + # create a new map with all the keys present in source_value + target_value = dict.fromkeys(source_value.keys()) + + # pyrefly: ignore [missing-attribute] + for key in list(target_value.keys()): + current_path = f"{key_path}.{key}" if key_path else key + if key in source_value: + target_value[key] = process_value( + target_value[key], source_value[key], current_path + ) + else: + missing_keys.append(current_path) + + return target_value + elif isinstance(source_value, list): + if target_value is None: + target_value = [None] * len(source_value) + result = [] + for i, (target_item, source_item) in enumerate( + zip_longest(target_value, source_value, fillvalue=None) + ): + current_path = f"{key_path}[{i}]" if key_path else f"[{i}]" + result.append( + process_value(target_item, source_item, current_path) + ) + return result + else: + return source_value + + # Start recursive processing from the root of the state dictionary + updated_state_dict = process_value(state_dict, metadata_dict, "") + + if missing_keys: + if len(missing_keys) > 10: + logger.warning( + "Missing %s keys from checkpoint: %s... (and %s more)", + len(missing_keys), + missing_keys[:10], + len(missing_keys) - 10, + ) + else: + logger.warning( + "Missing %s keys from checkpoint: %s", + len(missing_keys), + missing_keys, + ) + + return updated_state_dict, missing_keys diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/checkpoint_writer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/checkpoint_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..3b0041fbf292bd8b9c38fc2e395b17251fb67089 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/checkpoint_writer.py @@ -0,0 +1,163 @@ +""" +Checkpoint writer functionality for machine learning models. + +This module provides classes for writing checkpoints to storage, including +determining checkpoint layout, configuring the writer, and defining hooks +for custom actions during the checkpoint writing process. +""" + +import abc +import logging +import os +from concurrent.futures import Future +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional + +import torch + +from .barriers import Barrier +from .types import RankInfo, STATE_DICT + + +logger = logging.getLogger(__name__) + + +class WriterHook(abc.ABC): + """ + Abstract base class for checkpoint commit hooks. + + A commit hook provides callbacks that are executed before and after a checkpoint + is committed to storage. This allows for custom actions to be performed at specific + points in the checkpoint writing process, such as metadata updates, cleanup operations, + or notifications. + """ + + @abc.abstractmethod + def pre_commit(self, path: str, **kwargs: dict[str, Any]) -> None: + """ + Performs actions before committing the checkpoint. + """ + + @abc.abstractmethod + def post_commit(self, path: str, **kwargs: dict[str, Any]) -> None: + """ + Performs actions after committing the checkpoint. + """ + + +@dataclass +class CheckpointWriterConfig: + """ + Configuration options for the CheckpointWriter. + + Attributes: + write_barrier_timeout_secs: Maximum time in seconds to wait for all ranks + to reach the checkpoint barrier before timing out. Default is 600 seconds. + """ + + write_barrier_timeout_secs: int = 600 + + +class CheckpointWriter: + """ + Handles writing state dictionaries to storage. + + This class is responsible for writing model state dictionaries to storage according + to the specified checkpoint layout. It supports synchronization barriers to ensure + all ranks in a distributed setting complete their checkpoint operations. + """ + + def __init__( + self, + config: CheckpointWriterConfig, + rank_info: RankInfo, + barrier: Optional[Barrier] = None, + commit_hook: Optional[WriterHook] = None, + ): + """ + Initialize a CheckpointWriter. + + Args: + config: Configuration options for the checkpoint writer. + rank_info: Information about the current rank in a distributed setting. + barrier: Optional synchronization barrier for distributed checkpointing. + Note: The barrier should be initialized with the appropriate barrier_prefix + and timeout_secs parameters. + commit_hook: Optional hook for custom actions before and after checkpoint commits. + """ + + self._config = config + self._rank_info = rank_info + self._commit_hook = commit_hook + self._barrier = barrier + + def write( + self, + path: str, + state_dict: STATE_DICT, + **kwargs: dict[str, Any], + ) -> Optional[Future[None]]: + """ + Writes the state_dict to storage. + + Args: + path (str): The path to write the checkpoint to. + state_dict (STATE_DICT): The state_dict to write. + **kwargs: Additional keyword arguments passed to hooks. + + Returns: + Optional[Future[None]]: A future for tracking the write operation, if applicable. + """ + logger.debug( + "Writing checkpoint to %s for rank %s", + path, + self._rank_info.global_rank, + ) + dir_path = Path(path) + full_path = dir_path / f"checkpoint_{self._rank_info.global_rank}.pt" + os.makedirs( + os.path.dirname(full_path), + exist_ok=True, + ) + torch.save(state_dict, full_path) + logger.debug("Successfully saved checkpoint file to %s", full_path) + + # Execute pre-commit hook if available + commit_hook = self._commit_hook + if commit_hook is not None: + logger.debug("Executing pre-commit hook for %s", path) + commit_hook.pre_commit(path, **kwargs) + + # Wait for all ranks to finish writing if barrier is available + barrier = self._barrier + if barrier is not None: + logger.info( + "Waiting for all ranks at barrier with timeout %ss", + self._config.write_barrier_timeout_secs, + ) + barrier.execute_barrier() + logger.info("All ranks passed barrier") + else: + logger.info("No barrier configured, skipping synchronization") + + # Execute commit hook if available + if commit_hook is not None: + logger.debug("Executing commit hook for %s", path) + commit_hook.post_commit(path, **kwargs) + + logger.info( + "Successfully wrote checkpoint to %s for rank %s", + path, + self._rank_info.global_rank, + ) + return None + + def close(self) -> None: + """ + Close the writer and release any resources. + + This is a no-op for the base CheckpointWriter but may be overridden + by subclasses that need to perform cleanup. + """ + logger.debug("Closing checkpoint writer") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/checkpointer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..2609bd9c4af428ecb3883435db4b738676b9b540 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/checkpointer.py @@ -0,0 +1,341 @@ +import abc +import logging +from concurrent.futures import Future +from typing import Any, Optional, TypeVar + +from .checkpoint_process import CheckpointProcess +from .checkpoint_reader import CheckpointReader +from .checkpoint_writer import CheckpointWriter +from .staging import CheckpointStager +from .types import STATE_DICT +from .utils import wrap_future + + +logger = logging.getLogger(__name__) + +LOG_INTERVAL = 60 +T = TypeVar("T") + + +class Checkpointer(abc.ABC): + """ + WARNING: This class is experimental, and is created to validate certain ideas, + and is subjected to change or deprecation and we strong discourage any usages at + this time. + + Abstract base class that defines the API for checkpointing. + + This class defines the interface for coordinating the writing and loading of model + state dictionaries to and from storage. It provides abstract methods to save and load model states + with support for both synchronous and asynchronous operations. + + Concrete implementations of this class must implement all the abstract methods. + """ + + @abc.abstractmethod + def save( + self, + path: str, + state_dict: STATE_DICT, + **kwargs: dict[str, Any], + ) -> Optional[tuple[Future, Future]]: + """ + Save a state dictionary to storage. + + Args: + path: The path where the checkpoint should be saved. + state_dict: The state dictionary to save. + **kwargs: Additional keyword arguments to pass to the writer. + + Returns: + For synchronous implementations: None + For asynchronous implementations: tuple of (stage_future, write_future) + representing the staging and writing operations. + """ + + @abc.abstractmethod + def load( + self, + path: str, + state_dict: Optional[STATE_DICT] = None, + *, + default_map_location: Any = None, + strict: bool = False, + **kwargs: dict[str, Any], + ) -> STATE_DICT: + """ + Load a state dictionary from storage. + + Args: + path: The path from which to load the checkpoint. + state_dict: Optional state dictionary to update with loaded values. + If provided, only keys in this dictionary will be loaded. + default_map_location: Device mapping function or device name for relocating tensors. + strict: If True, raises an error when there are missing keys in the checkpoint. + **kwargs: Additional keyword arguments to pass to the reader. + + Returns: + The loaded state dictionary. + """ + + @abc.abstractmethod + def close(self) -> None: + """ + Close the checkpointer and release any resources. + + This method should be called when the checkpointer is no longer needed to ensure + proper cleanup of resources. + """ + + +class SyncCheckpointer(Checkpointer): + """ + Synchronous implementation of Checkpointer. + + This class coordinates the writing and loading of model state dictionaries to and from storage + using only synchronous operations. It provides a simple, efficient interface for checkpoint + operations without async overhead. + + Attributes: + _writer: CheckpointWriter for writing state dictionaries to storage. + _reader: CheckpointReader for reading state dictionaries from storage. + + Example: + checkpointer = SyncCheckpointer(writer=writer, reader=reader) + checkpointer.save(state_dict, path) + loaded_state_dict = checkpointer.load(path) + """ + + def __init__( + self, + writer: CheckpointWriter, + reader: CheckpointReader, + ): + """ + Initialize a synchronous checkpointer. + + Args: + writer: CheckpointWriter for writing checkpoints to storage. + reader: CheckpointReader for reading checkpoints from storage. + """ + self._writer = writer + self._reader = reader + + def save( + self, + path: str, + state_dict: STATE_DICT, + **kwargs: dict[str, Any], + ) -> Optional[tuple[Future, Future]]: + """ + Save a state dictionary to storage synchronously. + + Args: + path: The path where the checkpoint should be saved. + state_dict: The state dictionary to save. + **kwargs: Additional keyword arguments to pass to the writer. + + Returns: + Always returns None as operations are synchronous. + + Example: + checkpointer.save("/path/to/checkpoint", state_dict) + """ + logger.debug("Saving checkpoint synchronously to %s", path) + self._writer.write(path, state_dict, **kwargs) + return None + + def load( + self, + path: str, + state_dict: Optional[STATE_DICT] = None, + *, + default_map_location: Any = None, + strict: bool = False, + **kwargs: dict[str, Any], + ) -> STATE_DICT: + """ + Load a state dictionary from storage. + + Args: + path: The path from which to load the checkpoint. + state_dict: Optional state dictionary to update with loaded values. + If provided, only keys in this dictionary will be loaded. + default_map_location: Device mapping function or device name for relocating tensors. + strict: If True, raises an error when there are missing keys in the checkpoint. + **kwargs: Additional keyword arguments to pass to the reader. + + Returns: + The loaded state dictionary. + + Raises: + RuntimeError: If strict=True and there are missing keys in the checkpoint. + FileNotFoundError: If the checkpoint file is not found. + """ + logger.info("Loading checkpoint from %s", path) + + loaded_state_dict, missing_keys = self._reader.read( + path=path, + state_dict=state_dict, + map_location=default_map_location, + **kwargs, + ) + if strict and missing_keys is not None and missing_keys != []: + raise RuntimeError(f"Checkpoint at {path} is missing keys: {missing_keys}") + return loaded_state_dict + + def close(self) -> None: + """ + Close the checkpointer and release any resources. + + This method should be called when the checkpointer is no longer needed to ensure + proper cleanup of resources. + """ + self._writer.close() + logger.info("SyncCheckpointer closed") + + +class AsyncCheckpointer(Checkpointer): + """ + Asynchronous implementation of Checkpointer. + + This class coordinates the writing and loading of model state dictionaries to and from storage + using asynchronous operations for saving. It provides efficient async checkpoint operations + with staging and background writing capabilities. + + Attributes: + _reader: CheckpointReader for reading state dictionaries from storage. + _checkpoint_stager: Stager for async operations. + _checkpoint_process: Process for async operations. + _write_future: Future representing the ongoing async write operation. + + Example: + checkpointer = AsyncCheckpointer( + reader=reader, + checkpoint_stager=stager, + checkpoint_process=process + ) + stage_future, write_future = checkpointer.save(state_dict, path) + # ... do other work ... + write_future.result() # Wait for completion + """ + + def __init__( + self, + checkpoint_stager: CheckpointStager, + checkpoint_process: CheckpointProcess, + reader: CheckpointReader, + ): + """ + Initialize an asynchronous checkpointer. + + Args: + checkpoint_stager: Stager for async operations. + checkpoint_process: Process for async operations. + reader: CheckpointReader for reading checkpoints from storage. + """ + self._reader = reader + self._checkpoint_stager = checkpoint_stager + self._checkpoint_process = checkpoint_process + self._write_future: Optional[Future[Any]] = None + + def save( + self, + path: str, + state_dict: STATE_DICT, + **kwargs: Any, + ) -> Optional[tuple[Future, Future]]: + """ + Save a state dictionary to storage asynchronously. + + Args: + path: The path where the checkpoint should be saved. + state_dict: The state dictionary to save. + **kwargs: Additional keyword arguments to pass to the stager and writer. + + Returns: + A tuple of (stage_future, write_future) representing the staging and writing operations. + + Example: + stage_future, write_future = checkpointer.save("/path/to/checkpoint", state_dict) + # ... do other work ... + write_future.result() # Wait for completion + """ + logger.info( + "Initiating checkpoint save to %s. Will wait for prev checkpoints to complete.", + path, + ) + # Wait for previous checkpoint ops to finish and verify they are successful + if self._write_future is not None: + self._write_future.result() + + logger.debug("Starting state dictionary staging") + staging_result = self._checkpoint_stager.stage( + state_dict=state_dict, + **kwargs, + ) + + logger.debug("Starting checkpoint write to %s", path) + self._write_future = self._checkpoint_process.write( + staging_result, path, **kwargs + ) + logger.info("Checkpoint save to %s initiated", path) + + # Return futures for the staging and writing operations + if self._write_future is not None: + return wrap_future(staging_result), self._write_future + else: + # This should not happen since we just assigned _write_future above + raise RuntimeError("Write future is unexpectedly None") + + def load( + self, + path: str, + state_dict: Optional[STATE_DICT] = None, + *, + default_map_location: Any = None, + strict: bool = False, + **kwargs: Any, + ) -> STATE_DICT: + """ + Load a state dictionary from storage. + + Loading is always performed synchronously, even in AsyncCheckpointer. + + Args: + path: The path from which to load the checkpoint. + state_dict: Optional state dictionary to update with loaded values. + If provided, only keys in this dictionary will be loaded. + default_map_location: Device mapping function or device name for relocating tensors. + strict: If True, raises an error when there are missing keys in the checkpoint. + **kwargs: Additional keyword arguments to pass to the reader. + + Returns: + The loaded state dictionary. + + Raises: + RuntimeError: If strict=True and there are missing keys in the checkpoint. + FileNotFoundError: If the checkpoint file is not found. + """ + logger.info("Loading checkpoint from %s", path) + + loaded_state_dict, missing_keys = self._reader.read( + path=path, + state_dict=state_dict, + map_location=default_map_location, + **kwargs, + ) + if strict and missing_keys is not None and missing_keys != []: + raise RuntimeError(f"Checkpoint at {path} is missing keys: {missing_keys}") + return loaded_state_dict + + def close(self) -> None: + """ + Close the checkpointer and release any resources. + + This method should be called when the checkpointer is no longer needed to ensure + proper cleanup of async resources. + """ + self._checkpoint_stager.close() + self._checkpoint_process.close() + logger.info("AsyncCheckpointer closed") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/config.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/config.py new file mode 100644 index 0000000000000000000000000000000000000000..a81156e3929cac9edb13135b925c8096dc4e702a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/config.py @@ -0,0 +1,44 @@ +""" +Configuration classes for checkpointer construction. + +This module provides configuration dataclasses that consolidate all +configuration options needed to construct checkpointers. +""" + +from dataclasses import dataclass, field + +from .barriers import BarrierConfig +from .checkpoint_process import CheckpointProcessConfig +from .checkpoint_writer import CheckpointWriterConfig +from .staging import CheckpointStagerConfig + + +@dataclass +class CheckpointerConfig: + """ + Configuration class for checkpointer construction. + + This class consolidates the core component configuration options needed to construct + a checkpointer, providing a clean separation of concerns where each component + manages its own configuration. + + Attributes: + writer_config: Configuration options for the checkpoint writer component. + barrier_config: Configuration for barrier construction and arguments. + staging_config: Configuration options for the async staging component. + process_config: Configuration options for the async checkpoint process component. + + """ + + writer_config: CheckpointWriterConfig = field( + default_factory=CheckpointWriterConfig + ) + barrier_config: BarrierConfig = field(default_factory=BarrierConfig) + + # Below configs are used for async checkpointing + staging_config: CheckpointStagerConfig = field( + default_factory=CheckpointStagerConfig + ) + process_config: CheckpointProcessConfig = field( + default_factory=CheckpointProcessConfig + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/staging.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/staging.py new file mode 100644 index 0000000000000000000000000000000000000000..2d83278e13197b40ff31807a71b442621ebe5f51 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/staging.py @@ -0,0 +1,218 @@ +""" +Experimental staging module for PyTorch Distributed Checkpointing. + +This module provides advanced staging capabilities for checkpoints including: +- Asynchronous staging using ThreadPoolExecutor +- Pinned memory allocation for faster CPU-GPU transfers +- Shared memory support for multi-process scenarios +- Non-blocking CUDA operations with stream synchronization +- Caching of frequently used storages for efficient memory management +- Automatic resource cleanup and memory management + +Classes: + CheckpointStager: Abstract base class defining the staging interface + StagingOptions: Configuration dataclass for staging behavior + DefaultStager: Default implementation with comprehensive staging features +""" + +import abc +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from typing import Any, TypeVar, Union + +import torch +from torch.distributed.checkpoint._state_dict_stager import StateDictStager + +from .types import STATE_DICT + + +T = TypeVar("T") + + +class CheckpointStager(abc.ABC): + """ + Abstract base class for checkpoint staging implementations. + + CheckpointStager defines the interface that all staging implementations + must follow. Staging is the process of offloading state dictionaries + for async checkpointing. + """ + + @abc.abstractmethod + def stage( + self, + state_dict: STATE_DICT, + **kwargs: Any, + ) -> Union[STATE_DICT, Future[STATE_DICT]]: + """ + Stage a state dictionary for checkpointing. + + Args: + state_dict: The state dictionary to stage + **kwargs: Additional staging parameters + + Returns: + Either a staged state dictionary (synchronous) or a Future + that will resolve to the staged state dictionary (asynchronous) + """ + + @abc.abstractmethod + def close(self) -> None: + """ + Clean up all resources used by the stager. + """ + + +@dataclass +class CheckpointStagerConfig: + """ + Configuration options for checkpoint staging behavior. + + Attributes: + use_pinned_memory (bool): Enable pinned memory allocation for faster + CPU-GPU transfers. Requires CUDA to be available. Default: True + use_shared_memory (bool): Enable shared memory for multi-process + scenarios. Useful when multiple processes need access to the + same staged data. Default: True + use_async_staging (bool): Enable asynchronous staging using a + background thread pool. Allows overlapping computation with + staging operations. Requires CUDA. Default: True + use_non_blocking_copy (bool): Use non-blocking device memory + copies with stream synchronization. Improves performance by + allowing CPU work to continue during GPU transfers. Default: True + + Note: + CUDA-dependent features will raise exception if CUDA is not available. + """ + + use_pinned_memory: bool = True + use_shared_memory: bool = True + use_async_staging: bool = True + use_non_blocking_copy: bool = True + + +class DefaultStager(CheckpointStager): + """ + DefaultStager provides a full-featured staging implementation that combines + multiple optimization techniques for efficient checkpoint preparation. + + The staging process works as follows: + 1. State dictionary is submitted for staging (sync or async) + 2. Tensors are copied from GPU to optimized CPU storage + 3. CUDA operations are synchronized if non-blocking copies are used + 4. Staged state dictionary is returned or made available via Future + + NOTE: state_dict should be deep-copyable object as staging will create a + copy of it. + + Usage Patterns: + # Synchronous staging + stager = DefaultStager(CheckpointStagerConfig(use_async_staging=False)) + staged_dict = stager.stage(state_dict) + stager.close() + + # Asynchronous staging + stager = DefaultStager(CheckpointStagerConfig(use_async_staging=True)) + future = stager.stage(state_dict) + # ... do other work ... + staged_dict = future.result() + stager.close() + + # Context manager pattern (recommended) + with DefaultStager(config) as stager: + result = stager.stage(state_dict) + # Automatic cleanup on exit + + Performance Considerations: + - Async staging provides best performance when model computation + can overlap with staging operations + - Pinned memory improves CPU-GPU transfer speeds but uses more memory + - Shared memory allows efficient IPC to checkpoint process + - Non-blocking copies reduce GPU idle time during memory transfers + + Thread Safety: + DefaultStager is not thread-safe. Each thread should use its own + instance, or external synchronization should be provided. + """ + + def __init__( + self, + config: CheckpointStagerConfig = CheckpointStagerConfig(), + ): + self._config = config + self._state_dict_stager = StateDictStager( + pin_memory=config.use_pinned_memory, share_memory=config.use_shared_memory + ) + self._staging_executor = None + self._staging_stream = None + + if self._config.use_async_staging: + # pyrefly: ignore [bad-assignment] + self._staging_executor = ThreadPoolExecutor(max_workers=1) + if torch.accelerator.is_available(): + # Note: stream needs to be initialized on the main thread after default cuda + # stream is setup/used to avoid the risk of accidentally reusing the main + # compute stream or in other cases kernels actually launching from the + # main thread. + # pyrefly: ignore [bad-assignment] + self._staging_stream = torch.Stream() + + if self._config.use_non_blocking_copy: + if not torch.accelerator.is_available(): + raise AssertionError( + "Non-blocking copy requires that the current accelerator is available." + ) + + def stage( + self, + state_dict: STATE_DICT, + **kwargs: Any, + ) -> Union[STATE_DICT, Future[STATE_DICT]]: + if self._config.use_async_staging: + if self._staging_executor is None: + raise AssertionError( + "Staging executor should be initialized for async staging" + ) + return self._staging_executor.submit( + self._stage, + state_dict, + **kwargs, + ) + else: + return self._stage(state_dict, **kwargs) + + def _stage(self, state_dict: STATE_DICT, **kwargs: Any) -> STATE_DICT: + state_dict = self._state_dict_stager.stage( + state_dict, non_blocking=self._config.use_non_blocking_copy, **kwargs + ) + + if self._config.use_non_blocking_copy: + if not (self._staging_stream or not self._config.use_async_staging): + raise AssertionError( + "Non-blocking copy in a background thread for async staging needs staging_stream to be initialized." + ) + + # waits for the enqued copy operations to finish. + self._staging_stream.synchronize() if self._staging_stream else torch.accelerator.synchronize() + + return state_dict + + def close(self) -> None: + """ + Clean up all resources used by the DefaultStager. Shuts down the ThreadPoolExecutor + used for async staging operations and cleans up the underlying StateDictStager's + cached storages. Should be called when the stager is no longer needed to prevent + resource leaks, especially in long-running applications. After calling close(), + the stager should not be used for further staging operations. + + state_dict should be deep-copyable object. + + Example: + stager = DefaultStager(CheckpointStagerConfig(use_async_staging=True)) + # ... do staging operations ... + stager.close() # Clean up all resources + """ + if self._staging_executor: + self._staging_executor.shutdown(wait=True) + + self._state_dict_stager.close() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/types.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/types.py new file mode 100644 index 0000000000000000000000000000000000000000..61268fd5b14a89e6ee8dc21fdac300a50a2dd21f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/types.py @@ -0,0 +1,28 @@ +""" +Type definitions for distributed training and checkpointing. + +This module provides type definitions and classes for managing rank information +in distributed training environments, which is essential for proper checkpoint +saving and loading. +""" + +from dataclasses import dataclass +from typing import Any, TypeAlias + + +# Type alias for state dictionaries used in checkpointing +STATE_DICT: TypeAlias = dict[str, Any] + + +@dataclass +class RankInfo: + """ + Information about the current rank in a distributed training environment. + + Attributes: + global_rank: The global rank ID of the current process. + global_world_size: The total number of processes in the distributed environment. + """ + + global_rank: int + global_world_size: int diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..271e9aa112f682c8b56393c13db9eeefdeb37aa7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/utils.py @@ -0,0 +1,42 @@ +""" +Utility functions for the experimental checkpoint module. + +This module contains helper functions and utilities used across the experimental +checkpoint functionality. +""" + +from concurrent.futures import Future +from typing import Any + + +def wrap_future(original_result: Any) -> Future[None]: + """ + Wraps a result (Future or not) to return a Future with None result. + + If the input is a Future, returns a new Future that completes with None when + the original Future completes successfully, or propagates any exception. + If the input is not a Future, returns a completed Future with None result. + + Args: + original_result: The result to wrap (Future or any other value). + + Returns: + A Future that completes with None on success or propagates exceptions. + """ + masked_future: Future[None] = Future() + + if isinstance(original_result, Future): + + def on_complete(_: Future[Any]) -> None: + try: + original_result.result() + masked_future.set_result(None) + except Exception as e: + masked_future.set_exception(e) + + original_result.add_done_callback(on_complete) + else: + # Return a completed future with None result + masked_future.set_result(None) + + return masked_future diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_extension.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_extension.py new file mode 100644 index 0000000000000000000000000000000000000000..663caa8a857263e3fc2924e3e1ec80d13a9ae6b0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_extension.py @@ -0,0 +1,223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import abc +import io +from collections.abc import Sequence +from typing import cast, IO, Optional + +# introduced as collections.abc.Buffer in Python 3.12 +from typing_extensions import Buffer + +from torch._utils import try_import + + +# NOTE: everything in this file is experimental, and subject to +# change. Feedback and bug fixes are always welcome. + +pyzstd_module_name = "pyzstd" +pyzstd = try_import(pyzstd_module_name) +zstandard_module_name = "zstandard" +zstandard = try_import(zstandard_module_name) + + +__all__ = [ + "Extension", + "StreamTransformExtension", + "ZStandard", + "ExtensionRegistry", +] + + +class Extension(abc.ABC): + """ + Extensions provide modular additions to functionality within distributed checkpointing, + which affect the layout or format of the written artifacts. Extensions may be + built into pytorch, or provided externally. + + When writing, the caller provides a list of extension instances of the appropriate + type. Each extension can output a descriptor which is used to reconstitute the + extension at read-time. + """ + + @staticmethod + @abc.abstractmethod + def registry_name() -> str: + """ + See ExtensionRegistry.from_descriptor_list + """ + + @staticmethod + @abc.abstractmethod + def from_descriptor(version: str) -> "Extension": + """ + See ExtensionRegistry.from_descriptor_list + """ + + @abc.abstractmethod + def get_descriptor(self) -> str: + """ + Return descriptor name to be included in metadata. The form should be + "extension_name[@local-domain][/version]". + """ + + +class StreamTransformExtension(Extension): + """ + An extension which performs transformation on a byte stream, such as compression + or encryption. + + Implementations should try to be memory friendly and performant. For example, don't + read the whole input, then transform it, and write it back. If at all possible, do it in + chunks. But, don't read/transform/write one byte at a time, either. + """ + + @abc.abstractmethod + def transform_to(self, output: IO[bytes]) -> IO[bytes]: + """ + Takes a writeable output stream, and generates a new stream which implements the + output transform. Input data written to the returned stream will be transformed + and written to the `output` argument stream. + """ + + @abc.abstractmethod + def transform_from(self, input: IO[bytes]) -> IO[bytes]: + """ + Takes a readable input stream, and generates a new stream which implements the + input transform. When the returned stream is read, data will be read from the + 'input' stream, transformed, and returned. + """ + + +class ZStandard(StreamTransformExtension): + @staticmethod + def is_available() -> bool: + return zstandard is not None or pyzstd is not None + + @staticmethod + # pyrefly: ignore [bad-override] + def from_descriptor(version: str) -> "ZStandard": + if version.partition(".")[0] != "1": + raise ValueError(f"Unknown extension {version=}") + if not ZStandard.is_available(): + raise ValueError( + f"Stream with ZStandard compression cannot be processed because " + f"no module named '{zstandard_module_name}' or '{pyzstd_module_name}'" + ) + return ZStandard() + + @staticmethod + def registry_name() -> str: + return "stream.zstd" + + def __init__(self) -> None: + super().__init__() + if not ZStandard.is_available(): + raise ValueError( + f"ZStandard extension is unavailable because no module named '{zstandard_module_name}' or '{pyzstd_module_name}'" + ) + + def get_descriptor(self) -> str: + return f"{self.registry_name()}/1" + + def transform_to(self, output: IO[bytes]) -> IO[bytes]: + if zstandard is not None: + compressor = zstandard.ZstdCompressor() # type: ignore[union-attr] + return compressor.stream_writer(output) + + class Writer(io.RawIOBase): + def __init__(self, output: IO[bytes]) -> None: + self.output = output + self.compressor = pyzstd.ZstdCompressor() # type: ignore[union-attr] + + def writeable(self) -> bool: + return True + + def write(self, b: Buffer) -> Optional[int]: + outdata = self.compressor.compress(b) + if outdata: + self.output.write(outdata) + return len(memoryview(b)) + + def flush(self) -> None: + outdata = self.compressor.flush() + if outdata: + self.output.write(outdata) + self.output.flush() + + return cast(IO[bytes], Writer(output)) + + def transform_from(self, input: IO[bytes]) -> IO[bytes]: + if zstandard is not None: + decompressor = zstandard.ZstdDecompressor() # type: ignore[union-attr] + return decompressor.stream_reader(input) + + class Reader(io.RawIOBase): + def __init__(self, input: IO[bytes]) -> None: + self.input = input + self.decompressor = pyzstd.EndlessZstdDecompressor() # type: ignore[union-attr] + + def readable(self) -> bool: + return True + + def readinto(self, b: Buffer) -> Optional[int]: + # This needs to read enough so it can decompress + # something so the output doesn't look like EOF. This + # means reading at least one block. The max block + # size is 128KB, so we read that plus some + # overhead to be sure. + + if self.decompressor.needs_input: + indata = self.input.read((128 + 6) * 1024) + else: + indata = b"" + + bview = memoryview(b) + blen = len(bview) + outdata = self.decompressor.decompress(indata, blen) + if outdata is None: + return None + + count = len(outdata) + bview[:count] = outdata + return count + + def seekable(self) -> bool: + return False + + return cast(IO[bytes], Reader(input)) + + +class ExtensionRegistry: + def __init__(self) -> None: + # Populate default registry contents + self.extensions: dict[str, type[Extension]] = { + cls.registry_name(): cls for cls in (ZStandard,) + } + + def register(self, cls: type[Extension]) -> None: + self.extensions[cls.registry_name()] = cls + + def from_descriptor_list(self, descriptors: Sequence[str]) -> Sequence[Extension]: + """ + Given a seuquence of descriptor strings as returned by + Extension.get_descriptor at save time, creates a sequence of + Extension instances. The name[@local-domain] preceding the + version number is used to look up an implementation class in + the registry, and the version is passed to the class's + from_descriptor static method. If the registry contains no + match, this will throw ValueError. If the from_descriptor + method raises an exception, that will pass through to the + caller. + """ + + def from_descriptor(desc: str) -> Extension: + name, _, version = desc.partition("/") + if version is None: + version = 0 + ext = self.extensions.get(name) + if not ext: + raise ValueError(f"Unknown extension {name=}") + # pyrefly: ignore [bad-argument-type] + return ext.from_descriptor(version) + + return [from_descriptor(desc) for desc in descriptors] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_fsspec_filesystem.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_fsspec_filesystem.py new file mode 100644 index 0000000000000000000000000000000000000000..e239bbe891fb95b374479fdafaab9a0d16604147 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_fsspec_filesystem.py @@ -0,0 +1,168 @@ +# Mypy will not try inferring the types of any 3rd party libraries installed. +# mypy: ignore-errors + +import io +import os +from collections.abc import Generator, Sequence +from contextlib import contextmanager +from pathlib import Path +from typing import Optional, TYPE_CHECKING, Union + +from fsspec.core import url_to_fs + +from torch.distributed.checkpoint._extension import StreamTransformExtension +from torch.distributed.checkpoint.filesystem import ( + FileSystemBase, + FileSystemReader, + FileSystemWriter, + SerializationFormat, +) + + +if TYPE_CHECKING: + from fsspec import AbstractFileSystem + + +__all__ = [ + "FsspecWriter", + "FsspecReader", +] + + +class FileSystem(FileSystemBase): + def __init__(self) -> None: + self.fs: Optional[AbstractFileSystem] = None + + @contextmanager + def create_stream( + self, path: Union[str, os.PathLike], mode: str + ) -> Generator[io.IOBase, None, None]: + if self.fs is None: + raise AssertionError("fs should not be None") + path = os.fspath(path) + + # fsspec does not support concurrent transactions, and not all + # AbstractFileSystem have working rollback implementations, so + # just manually delete the file if necessary on errors. + with self.fs.open(path, mode) as stream: + try: + yield stream + except: # noqa: B001,E722 + if any(ch in mode for ch in "w+a"): # cleanup file if not read-only + try: + self.rm_file(path) + except: # noqa: B001,E722 + pass + raise + + def concat_path( + self, path: Union[str, os.PathLike], suffix: str + ) -> Union[str, os.PathLike]: + return os.path.join(path, suffix) + + def init_path( + self, path: Union[str, os.PathLike], **kwargs + ) -> Union[str, os.PathLike]: + self.fs, _ = url_to_fs(path, **kwargs) + return path + + def rename( + self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] + ) -> None: + self.fs.rename(path, new_path) + + def mkdir(self, path: Union[str, os.PathLike]) -> None: + self.fs.makedirs(path, exist_ok=True) + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + if isinstance(checkpoint_id, Path): + return False + + try: + url_to_fs(checkpoint_id) + except ValueError: + return False + + return True + + def exists(self, path: Union[str, os.PathLike]) -> bool: + return self.fs.exists(path) + + def rm_file(self, path: Union[str, os.PathLike]) -> None: + self.fs.rm(path) + + def ls(self, path: Union[str, os.PathLike]) -> list[str]: + # setting detail to False explicitly to keep the list[str] return type, + # instead of the list[Dict] return type when detail=True + return self.fs.ls(path, detail=False) + + +# TODO: add the dcp.async_save mixin +class FsspecWriter(FileSystemWriter): + """ + Basic implementation of StorageWriter using FFspec. + + This implementation makes the following assumptions and simplifications: + + * The checkpoint path is an empty or non-existing directory. + * File creation is atomic + + The checkpoint consist of one file per write request plus + a `.metadata` file with the serialized metadata. + + """ + + def __init__( + self, + path: Union[str, os.PathLike], + single_file_per_rank: bool = True, + sync_files: bool = True, + thread_count: int = 1, + per_thread_copy_ahead: int = 10_000_000, + overwrite: bool = True, + _extensions: Optional[Sequence[StreamTransformExtension]] = None, + serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE, + **kwargs, + ) -> None: + """ + Initialize the writer pointing to `path`. + + Args: + path: directory where the checkpoint will be written to. + single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. + sync_files : force files to be synced to permanent storage. Default to True. + thread_count: Number of IO threads to use to write. Default to 1. + per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. + overwrite: Whether to allow overwriting existing checkpoints. Defaults to True. + _extensions: Extensions to apply to output streams (EXPERIMENTAL) + + N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. + """ + super().__init__( + path, + single_file_per_rank, + sync_files, + thread_count, + per_thread_copy_ahead, + overwrite=overwrite, + _extensions=_extensions, + serialization_format=serialization_format, + ) + self.fs = FileSystem() + self.path = self.fs.init_path(path, **kwargs) + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + return FileSystem.validate_checkpoint_id(checkpoint_id) + + +class FsspecReader(FileSystemReader): + def __init__(self, path: Union[str, os.PathLike], **kwargs) -> None: + super().__init__(path) + self.fs = FileSystem() + self.path = self.fs.init_path(path, **kwargs) + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + return FileSystem.validate_checkpoint_id(checkpoint_id) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_hf_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_hf_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0d14229b7f8ccfe5a51211d0fb6a4c332af6066b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_hf_utils.py @@ -0,0 +1,106 @@ +import io +import json +import struct +from dataclasses import dataclass +from typing import Any, Optional + +import torch + + +_metadata_fn: str = "model.safetensors.index.json" + +FILE_NAME = "model-{cpt_idx}-of-{num_files}" +SHARDED_FILE_NAME = "shard-{shard_idx}-model-{cpt_idx}-of-{num_files}" +SUFFIX = ".safetensors" + +# metadata keys +CUSTOM_METADATA_KEY = "DCP_SHARDING_INFO" +DEFAULT_EXTRA_METADATA_KEY = "__metadata__" +SAVED_OFFSETS_KEY = "saved_offsets" +SHAPE_KEY = "shape" +DATA_KEY = "data" +DTYPE_KEY = "dtype" +DATA_OFFSETS_KEY = "data_offsets" + +DTYPE_MAP = { + "F16": torch.float16, + "F32": torch.float32, + "F64": torch.float64, + "I8": torch.int8, + "U8": torch.uint8, + "I16": torch.int16, + "I32": torch.int32, + "I64": torch.int64, + "BF16": torch.bfloat16, +} + +HF_DCP_VERSION: float = 1.0 +DCP_VERSION_KEY = "DCP_VERSION" +DCP_SHARDING_INFO_KEY = "DCP_SHARDING_INFO" + +FORMAT_KEY = "format" +FORMAT_VALUE = "pt" + +NUM_BYTES_FOR_HEADER_LEN = 8 + +SHARDED_DIR_NAME = "sharded" + + +@dataclass +class _HFStorageInfo: + """This is the per entry storage info.""" + + relative_path: str + shape: torch.Size + dtype: torch.dtype + + +def _gen_file_name( + index: int, largest_index: int, shard_index: Optional[int] = None +) -> str: + if shard_index is not None: + return ( + SHARDED_FILE_NAME.format( + shard_idx=f"{shard_index}".zfill(5), + cpt_idx=f"{index}".zfill(5), + num_files=f"{largest_index}".zfill(5), + ) + + SUFFIX + ) + else: + return ( + FILE_NAME.format( + cpt_idx=f"{index}".zfill(5), num_files=f"{largest_index}".zfill(5) + ) + + SUFFIX + ) + + +def _get_safetensors_file_metadata(file_bytes: io.IOBase) -> tuple[Any, int]: + # this uses the same logic that's done in HF code base + # https://github.com/2404589803/huggingface_hub/blob/main/src/huggingface_hub/hf_api.py#L5308 + # and follows their documentation on how their files are serialized + # https://huggingface.co/docs/safetensors/index#format + + header_len_bytes = file_bytes.read(NUM_BYTES_FOR_HEADER_LEN) + header_len = struct.unpack(" torch.dtype: + try: + dtype = DTYPE_MAP[dtype_str] + except KeyError: + dtype = torch.get_default_dtype() + + return dtype + + +def _get_dcp_custom_metadata(metadata: Any) -> Optional[Any]: + if DEFAULT_EXTRA_METADATA_KEY in metadata: + custom_metadata = metadata[DEFAULT_EXTRA_METADATA_KEY] + if CUSTOM_METADATA_KEY in custom_metadata: + return json.loads(custom_metadata[CUSTOM_METADATA_KEY]) + return None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_nested_dict.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_nested_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..eb26058370f766fbb96e4a5f1530577234eed62a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_nested_dict.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE + +from . import _version +from ._traverse import ( + OBJ_PATH, + set_element, + STATE_DICT_ITEM, + traverse_state_dict, + traverse_state_dict_v_2_3, +) + + +""" +TODO: +Need to add ability to handle tuple, OrderedDict, NamedTuple. +Update mappings from dict to a class. +Change set_element to recreate the right type for tuple, OrderedDict, and NamedTuple. +""" + + +FLATTEN_MAPPING = dict[str, OBJ_PATH] + + +# TODO: Update Docstring for nested_dict.py +def flatten_state_dict( + state_dict: STATE_DICT_TYPE, +) -> tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]: + """ + Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary. + + Use ``unflatten_state_dict`` to revert this process. + Returns: + A tuple with the flatten state_dict and a mapping from original to new state_dict. + N.B. The new keys are derived from the object paths, joined by dot. + For example: ``{ 'a': {'b':...}}`` results in the key `a.b`. + """ + flattened: STATE_DICT_TYPE = {} + mappings: FLATTEN_MAPPING = {} + + def flat_copy(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: + new_fqn = ".".join(map(str, path)) + if new_fqn in flattened: + raise ValueError(f"duplicated flatten key {new_fqn}") + flattened[new_fqn] = value + mappings[new_fqn] = path + + # We started to flatten dictionary since v2.4. But in order to not break + # the checkpoints that were saved before v2.4, we need to keep the old + # traversal so that we can reconstruct those checkpoints. + use_v_2_3 = ( + _version._derived_version is not None and _version._derived_version == "2_3" + ) + if use_v_2_3: + traverse_state_dict_v_2_3(state_dict, flat_copy) + else: + traverse_state_dict(state_dict, flat_copy) + return flattened, mappings + + +def unflatten_state_dict( + state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING +) -> STATE_DICT_TYPE: + """Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``.""" + nested: STATE_DICT_TYPE = {} + for key, value in state_dict.items(): + set_element(nested, mapping[key], value) + return nested diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_pg_transport.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_pg_transport.py new file mode 100644 index 0000000000000000000000000000000000000000..b258517bdcebaa553c4acf7f5511b29432304ed9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_pg_transport.py @@ -0,0 +1,387 @@ +import logging +import pickle +import time +from collections.abc import Callable, Generator +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import timedelta +from typing import cast, Optional, TypeVar, Union + +import torch +from torch.distributed import ProcessGroup, Work +from torch.distributed._shard.sharded_tensor import ( + Shard as ShardedTensorShard, + ShardedTensor, + ShardMetadata, +) +from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata +from torch.distributed.tensor import _DTensorSpec, DTensor +from torch.utils._pytree import ( + KeyPath, + tree_flatten_with_path, + tree_unflatten, + TreeSpec, +) + + +logger: logging.Logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +@dataclass +class _TensorMeta: + """ + This is the metadata for a tensor that is used to transfer checkpoints. + It contains the shape, the dtype, the storage offset and the stride of the + tensor. + + This must be pickleable so that it can be sent over the wire. + """ + + shape: torch.Size + dtype: torch.dtype + storage_offset: int + stride: tuple[int, ...] + nbytes: int + + +@dataclass +class _DTensorMeta: + """ + This is the metadata for a DTensor that is used to transfer checkpoints. + It contains the metadata for the local tensor and the spec of the DTensor. + + This must be pickleable so that it can be sent over the wire. + """ + + local: _TensorMeta + spec: _DTensorSpec + + +@dataclass +class _ShardedTensorMeta: + """ + This is the metadata for a ShardedTensor that is used to transfer checkpoints. + It contains the metadata for all local shards and the global tensor metadata. + + This must be pickleable so that it can be sent over the wire. + """ + + local_shards_meta: list[_TensorMeta] + local_shards_shard_metadata: list[ + ShardMetadata + ] # Original shard metadata for each local shard + sharded_tensor_metadata: ShardedTensorMetadata + + +@dataclass +class _StateDictMeta: + """ + This is the metadata for a state dict that is used to transfer checkpoints. + It contains the step, the pytree spec of the state dict and the metadata for + each tensor in the state dict. + + This must be pickleable so that it can be sent over the wire. + + Args: + step: the step of the checkpoint to verify consistency + treespec: the pytree spec of the state dict + paths: the path of each leaf in the state dict + non_tensor_leaves: the metadata for each tensor in the state dict and any + non-tensor leaves in the state dict + """ + + treespec: TreeSpec + paths: list[KeyPath] + non_tensor_leaves: list[ + Union[object, _TensorMeta, _DTensorMeta, _ShardedTensorMeta] + ] + + +@contextmanager +def _timeit(name: str) -> Generator[None, None, None]: + start = time.perf_counter() + yield + dur = time.perf_counter() - start + logger.info("%s took %ss", name, dur) + + +def _prepare_tensor(tensor: torch.Tensor) -> tuple[torch.Tensor, _TensorMeta]: + return ( + _cast_tensor(tensor, torch.uint8), + _TensorMeta( + shape=tensor.shape, + dtype=tensor.dtype, + storage_offset=cast(int, tensor.storage_offset()), + stride=tensor.stride(), + nbytes=tensor.untyped_storage().nbytes(), + ), + ) + + +def _prepare_state_dict( + state_dict: object, + device: torch.device, +) -> tuple[_StateDictMeta, list[torch.Tensor]]: + leaves: list[tuple[KeyPath, object]] + leaves, treespec = tree_flatten_with_path(state_dict) + + paths: list[KeyPath] = [] + non_tensor_leaves: list[ + Union[object, _TensorMeta, _DTensorMeta, _ShardedTensorMeta] + ] = [] + tensors: list[torch.Tensor] = [] + for key_path, v in leaves: + paths.append(key_path) + + if isinstance(v, DTensor): + tensor, tensor_meta = _prepare_tensor(v._local_tensor) + + tensors.append(tensor) + + non_tensor_leaves.append( + _DTensorMeta( + local=tensor_meta, + spec=v._spec, + ) + ) + elif isinstance(v, ShardedTensor): + # Handle ShardedTensor by extracting all local shards + local_shards = v.local_shards() + + # Prepare metadata for all local shards + local_shards_meta = [] + local_shards_shard_metadata = [] + for shard in local_shards: + tensor, tensor_meta = _prepare_tensor(shard.tensor) + tensors.append(tensor) + local_shards_meta.append(tensor_meta) + local_shards_shard_metadata.append(shard.metadata) + + non_tensor_leaves.append( + _ShardedTensorMeta( + local_shards_meta=local_shards_meta, + local_shards_shard_metadata=local_shards_shard_metadata, + sharded_tensor_metadata=v.metadata(), # Complete metadata + ) + ) + elif isinstance(v, torch.Tensor): + tensor, tensor_meta = _prepare_tensor(v) + tensors.append(tensor) + non_tensor_leaves.append(tensor_meta) + else: + non_tensor_leaves.append(v) + + return ( + _StateDictMeta( + treespec=treespec, + paths=paths, + non_tensor_leaves=non_tensor_leaves, + ), + tensors, + ) + + +def _cast_tensor(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + """ + Casts the underlying storage to a tensor of the given dtype. + + The returned tensor will be of size ``storage.nbytes``. + + This works for all datatypes and supports strided/offset tensors with the + caveat that the cast tensor may be larger than the original tensor due to + the differences in striding. + """ + if type(tensor) is not torch.Tensor: + raise AssertionError(f"can only cast standard tensors not {type(tensor)}") + storage = tensor.untyped_storage() + ret = torch.tensor(storage, dtype=dtype, device=tensor.device) + if ret.untyped_storage() is not storage: + raise AssertionError("storage should be the same") + return ret + + +class PGTransport: + """ + This is a checkpoint transport that uses the process group to transfer checkpoints. + This allows for fast recovery of workers by fetching the current weights + from an existing worker. + + Args: + pg: the process group to use for communication + timeout: the timeout for communication + device: the device to use for tensors + state_dict: if specified this function will be called to do an inplace + receive into the returned state_dict. This is much faster than + having to allocate new tensors and transferring them to the CPU. + """ + + def __init__( + self, + pg: ProcessGroup, + timeout: timedelta, + device: torch.device, + state_dict: Optional[Callable[[], object]] = None, + ) -> None: + self._work: list[Work] = [] + self._pg = pg + self._timeout = timeout + # pyrefly: ignore [read-only] + self._device = device + self._state_dict = state_dict + + def send_checkpoint(self, dst_ranks: list[int], state_dict: object) -> None: + """ + Send a checkpoint to multiple destination ranks. + + The process: + 1. Prepares the state dict by converting tensors to a serializable format + 2. Sends metadata as pickled data + 3. Sends each tensor sequentially to all destination ranks + + Args: + dst_ranks: List of destination ranks to send the checkpoint to + state_dict: The state dictionary containing model parameters + """ + with _timeit("preparing state_dict"): + meta, tensors = _prepare_state_dict(state_dict, device=self._device) + + work = [] + + with _timeit("send meta"): + buf = pickle.dumps(meta) + len_t = torch.tensor([len(buf)], dtype=torch.int64, device=self._device) + buf_t = torch.frombuffer(buf, dtype=torch.uint8).to(self._device) + for dst_rank in dst_ranks: + work.append(self._pg.send([len_t], dst_rank, tag=1)) + work.append(self._pg.send([buf_t], dst_rank, tag=2)) + + with _timeit("send tensors"): + for i, t in enumerate(tensors): + original_device = t.device + t = t.to(self._device) + for dst_rank in dst_ranks: + work.append(self._pg.send([t], dst_rank, tag=3 + i)) + + # if we did a copy we should wait for the work to complete so we + # can free the memory to avoid OOMs + if original_device == torch.device("cpu"): + for w in work: + w.wait() + work = [] + + for w in work: + w.wait() + + def recv_checkpoint(self, src_rank: int) -> object: + """ + Receive a checkpoint from a source rank. + + The process: + 1. Receives metadata about the checkpoint structure + 2. Receives each tensor, potentially reusing existing tensors for in-place updates + 3. Reconstructs the original state dict structure + + Args: + src_rank: The source rank to receive the checkpoint from + + Returns: + The reconstructed state dictionary with model parameters + """ + state_dict = self._state_dict() if self._state_dict else {} + state_dict_leaves, _ = tree_flatten_with_path(state_dict) + + dst_tensors: dict[KeyPath, object] = dict(state_dict_leaves) + + len_t = torch.zeros(1, dtype=torch.int64, device=self._device) + self._pg.recv([len_t], src_rank, tag=1).wait() + length = cast(int, len_t.item()) + + buf = torch.empty(length, dtype=torch.uint8, device=self._device) + self._pg.recv([buf], src_rank, tag=2).wait() + + meta: _StateDictMeta = pickle.loads(buf.cpu().numpy().tobytes()) + + i: int = 0 + works: list[Work] = [] + + def recv(path: KeyPath, v: _TensorMeta) -> torch.Tensor: + nonlocal i + + inplace = dst_tensors.get(path) + if ( + isinstance(inplace, torch.Tensor) + and inplace.device.type == self._device.type + ): + if isinstance(inplace, DTensor): + inplace = inplace._local_tensor + t = _cast_tensor(inplace, torch.uint8) + if t.nbytes != v.nbytes: + raise AssertionError("inplace tensor storage must be the same size") + else: + t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device) + + work = self._pg.recv([t], src_rank, tag=3 + i) + i += 1 + + if inplace is None: + # if not inplace we need to copy it to CPU to avoid OOMing + work.wait() + t = t.cpu() + else: + works.append(work) + + return torch.as_strided( + t.view(v.dtype), + size=v.shape, + stride=v.stride, + storage_offset=v.storage_offset, + ) + + values: list[object] = [] + for path, v in zip(meta.paths, meta.non_tensor_leaves): + if isinstance(v, _TensorMeta): + values.append(recv(path, v)) + elif isinstance(v, _DTensorMeta): + tensor = recv(path, v.local) + # pyrefly: ignore [bad-argument-type, bad-argument-count, unexpected-keyword] + values.append(DTensor(tensor, v.spec, requires_grad=False)) + elif isinstance(v, _ShardedTensorMeta): + # Receive all local shards that were sent to us + local_shards = [] + current_rank = self._pg.rank() + + # Receive tensors for each local shard that was sent + for j, shard_meta in enumerate(v.local_shards_meta): + tensor = recv(path, shard_meta) + + # Use the original shard metadata that was stored during preparation + # but update the placement to reflect the current rank/device + original_shard_metadata = v.local_shards_shard_metadata[j] + updated_shard_metadata = ShardMetadata( + shard_offsets=original_shard_metadata.shard_offsets, + shard_sizes=original_shard_metadata.shard_sizes, + placement=f"rank:{current_rank}/{tensor.device.type}", + ) + + local_shard = ShardedTensorShard( + tensor=tensor, metadata=updated_shard_metadata + ) + local_shards.append(local_shard) + + # Use complete metadata to reconstruct ShardedTensor + sharded_tensor = ( + ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=local_shards, + sharded_tensor_metadata=v.sharded_tensor_metadata, + ) + ) + values.append(sharded_tensor) + else: + values.append(v) + + for work in works: + work.wait() + + return tree_unflatten(values, meta.treespec) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_sharded_tensor_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_sharded_tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a68bcddeb7f9d9ffe6f89056dfe1ccc30cc12eb5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_sharded_tensor_utils.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import copy +from typing import TYPE_CHECKING + +import torch.distributed as dist +from torch.distributed._shard.sharded_tensor import Shard, ShardedTensor, ShardMetadata +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from torch.distributed.remote_device import _remote_device + +from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict +from .utils import _element_wise_add, _normalize_device_info + + +if TYPE_CHECKING: + from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata + + +# TODO: We need to refactor this code. +def _flatten_sharded_tensors(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + r""" + Transform ``state_dict`` by flattening all nested ShardedTensor instances found. + + The resulting ShardedTensor instances are only correct regarding the local shard and + MUST not be used for any other purpose but checkpointing, as no operator will work with them. + + This function should be used in conjunction with a state_dict produced by FSDP's + StateDictType.SHARDED_STATE_DICT methods. + """ + new_state_dict: STATE_DICT_TYPE = {} + + def rewrite_dict(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: + if not isinstance(value, ShardedTensor): + set_element(new_state_dict, path, value) + return + shards = value.local_shards() + + if len(shards) == 0: + return + if len(shards) != 1: + set_element(new_state_dict, path, value) + return + + outer_shard = shards[0] + + inner_st = outer_shard.tensor + if not isinstance(inner_st, ShardedTensor): + set_element(new_state_dict, path, value) + return + + if len(inner_st.local_shards()) != 1: + raise ValueError("Cannot handle inner tensor with more than 1 shard") + inner_shard = inner_st.local_shards()[0] + + local_shards = [ + Shard( + tensor=inner_shard.tensor, + metadata=ShardMetadata( + shard_offsets=_element_wise_add( + outer_shard.metadata.shard_offsets, + inner_shard.metadata.shard_offsets, + ), + shard_sizes=inner_shard.metadata.shard_sizes, + placement=f"rank:{dist.get_rank()}/{inner_shard.tensor.device}", + ), + ) + ] + + st_meta: ShardedTensorMetadata = copy.deepcopy(value.metadata()) + other_rank = 0 if dist.get_rank() > 0 else 1 + device_info = _normalize_device_info(inner_shard.tensor.device.type, 0) + + # Remove the outer ST shard the inner ST covers + for i, shard_md in enumerate(st_meta.shards_metadata): + if shard_md.shard_offsets == outer_shard.metadata.shard_offsets: + st_meta.shards_metadata.pop(i) + break + + # Attribute other rank for the other shards + for shard_md in st_meta.shards_metadata: + shard_md.placement = _remote_device(f"rank:{other_rank}/{device_info}") + + # Add other inner shards from the inner tensor + for inner_md in inner_st.metadata().shards_metadata: + if inner_md.shard_offsets != inner_shard.metadata.shard_offsets: + st_meta.shards_metadata.append( + ShardMetadata( + shard_offsets=_element_wise_add( + outer_shard.metadata.shard_offsets, + inner_md.shard_offsets, + ), + shard_sizes=inner_md.shard_sizes, + placement=f"rank:{other_rank}/{device_info}", + ) + ) + + # Finally add this shard + st_meta.shards_metadata.append(local_shards[0].metadata) + + st = ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=local_shards, + sharded_tensor_metadata=st_meta, + ) + set_element(new_state_dict, path, st) + + traverse_state_dict(state_dict, rewrite_dict) + return new_state_dict diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_state_dict_stager.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_state_dict_stager.py new file mode 100644 index 0000000000000000000000000000000000000000..155a87b9dec5bcd1f532d17ee2b8ef56454e37ab --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_state_dict_stager.py @@ -0,0 +1,467 @@ +# mypy: allow-untyped-defs +import types +import warnings +import weakref +from copyreg import dispatch_table +from typing import Any + +import torch +import torch.cuda._pin_memory_utils as pin_memory_utils +from torch.storage import UntypedStorage +from torch.utils.weak import WeakIdKeyDictionary + + +class StateDictStager: + """ + A class for optimizing storage objects during staging for async checkpointing. + + StateDictStager stages the state_dict to CPU DRAM while applying optimizations + like memory sharing and pinning to improve performance. It caches storage objects + to avoid redundant copies and can be configured to automatically share memory + (for multi-process usage) and pin memory (for faster CPU-GPU transfers). + + Attributes: + pin_memory (bool): Whether to pin CPU memory for faster CPU-GPU transfers + share_memory (bool): Whether to share memory across processes + pin_memory_min_bytes (int): Minimum tensor size in bytes to pin memory (default: 5) + _cached_storage_mapping (WeakIdKeyDictionary): Maps storage objects to optimized CPU storages using weak references + """ + + def __init__( + self, + pin_memory: bool = False, + share_memory: bool = False, + pin_memory_min_bytes: int = 5, + ): + if pin_memory and not torch.cuda.is_available(): + warnings.warn( + "Ignoring pin_memory flag for checkpoint staging as pinning memory" + "requires CUDA, but CUDA is not available. ", + stacklevel=2, + ) + self.pin_memory = False + else: + self.pin_memory = pin_memory + self.share_memory = share_memory + # Mapping from original storage objects to CPU storages using weak references + self._cached_storage_mapping = WeakIdKeyDictionary() + self.pin_memory_min_bytes = pin_memory_min_bytes + + def _deepcopy_atomic(x, _): + return x + + def _deepcopy_list(x, memo, non_blocking=False): + y: list = [] + memo[id(x)] = y + append = y.append + for a in x: + append( + self.deepcopy_with_tensor_offload( + a, memo, non_blocking=non_blocking + ) + ) + return y + + def _deepcopy_tuple(x, memo, non_blocking=False): + y = [ + self.deepcopy_with_tensor_offload(a, memo, non_blocking=non_blocking) + for a in x + ] + # We're not going to put the tuple in the memo, but it's still important we + # check for it, in case the tuple contains recursive mutable structures. + try: + return memo[id(x)] + except KeyError: + pass + + # Check if any elements changed during deepcopy + for k, j in zip(x, y): + if k is not j: + # At least one element changed, create new tuple + return tuple(y) + + # No elements changed, return original tuple + return x + + def _deepcopy_dict(x, memo, non_blocking=False): + y: dict = {} + memo[id(x)] = y + for key, value in x.items(): + y[ + self.deepcopy_with_tensor_offload( + key, memo, non_blocking=non_blocking + ) + ] = self.deepcopy_with_tensor_offload( + value, memo, non_blocking=non_blocking + ) + return y + + def _deepcopy_method(x, memo, non_blocking=False): # Copy instance methods + return type(x)( + x.__func__, + self.deepcopy_with_tensor_offload( + x.__self__, memo, non_blocking=non_blocking + ), + ) + + d: dict[Any, Any] = {} + self._deepcopy_dispatch = d + d[type(None)] = _deepcopy_atomic + d[int] = _deepcopy_atomic + d[float] = _deepcopy_atomic + d[bool] = _deepcopy_atomic + d[complex] = _deepcopy_atomic + d[bytes] = _deepcopy_atomic + d[str] = _deepcopy_atomic + d[types.CodeType] = _deepcopy_atomic + d[type] = _deepcopy_atomic + d[range] = _deepcopy_atomic + d[types.BuiltinFunctionType] = _deepcopy_atomic + d[types.FunctionType] = _deepcopy_atomic + d[weakref.ref] = _deepcopy_atomic + d[property] = _deepcopy_atomic + d[types.MethodType] = _deepcopy_method + d[dict] = _deepcopy_dict + d[tuple] = _deepcopy_tuple + d[list] = _deepcopy_list + + def _stage_untyped_storage( + self, + storage: UntypedStorage, + non_blocking: bool = False, + ): + """ + Called from the hooked storage_deepcopy function in torch.Tensor.__deepcopy__. + + This method handles the storage optimization logic for the StagingStateDict class. + It checks if the storage has already been cached, and if so, reuses it. + Otherwise, it creates a new CPU storage and applies memory optimizations. + + Args: + storage: The storage to optimize + + Returns: + The optimized storage + """ + # Check if we've already cached this storage + if storage in self._cached_storage_mapping: + cached_storage = self._cached_storage_mapping[storage] + assert cached_storage.size() == storage.size(), ( + "For async checkpointing, We cache storages in DRAM and reuse them." + "Cached storage size does not match original storage size." + "This should never happen as we track the original storage weakref " + "and clean up the cache storage. Please report this to PyTorch Distributed Checkpointing." + ) + # Reuse cached storage but update with new data + cached_storage.copy_(storage, non_blocking=non_blocking) + return cached_storage + + # Create new CPU storage + if self.share_memory: + new_storage = type(storage)._new_shared(storage.size(), device="cpu") + else: + new_storage = type(storage)(storage.size(), device="cpu") + + # Skip pinning for tensors below the minimum size threshold + # Small tensors (e.g., optimizer step counters, scalars) have negligible + # transfer time improvement from pinning, but pinning overhead is significant + if self.pin_memory and new_storage.nbytes() >= self.pin_memory_min_bytes: + pin_memory_utils.pin_memory(new_storage.data_ptr(), new_storage.nbytes()) + # Set up a weak reference to unpin when cpu storage is garbage collected + f = weakref.finalize( + new_storage, pin_memory_utils.unpin_memory, new_storage.data_ptr() + ) + # This makes sure that the finalizer is not called after + # cuda context is destroyed. + f.atexit = False + + new_storage.copy_(storage, non_blocking=non_blocking) + + # Cache the storage - WeakIdKeyDictionary will automatically clean up when storage is garbage collected + self._cached_storage_mapping[storage] = new_storage + return new_storage + + @torch.no_grad() + def stage( + self, + state_dict: Any, + non_blocking: bool = False, + ) -> Any: + return self.deepcopy_with_tensor_offload(state_dict, None, [], non_blocking) + + def _offload_tensor(self, x, memo, non_blocking=False): + """ + Deep copy a PyTorch tensor with optimized storage handling. + + This method creates a CPU copy of a tensor while applying memory optimizations + like sharing and pinning based on the StateDictStager configuration. + + Args: + x: The tensor to copy + memo: Memo dictionary for tracking already copied objects + non_blocking: Whether to perform non-blocking copies where possible + + Returns: + A CPU copy of the tensor with optimized storage + """ + # if data_ptr is not 0, we allocate a new storage below. so we can skip + # memory allocation by using [] for size. + y = x.new_empty([] if x.data_ptr() != 0 else x.size(), device="cpu") + + # Store in memo dict early to handle recursive references + d = id(x) + memo[d] = y + + if type(x) is torch.Tensor or x.data_ptr() != 0: + # Get the untyped storage + untyped_storage = x.untyped_storage() + storage_id = id(untyped_storage) + + # Check if this storage has already been staged in this deepcopy operation + # This handles the case where different tensors share the same storage + # (e.g., FSDP state_dict where norm.weight and norm_weight reference same storage) + # PyTorch caches untyped_storage() calls, so same storage -> same id + if storage_id in memo: + copied_storage = memo[storage_id] + else: + # Storage not seen before in this operation, stage it + copied_storage = self._stage_untyped_storage( + untyped_storage, non_blocking=non_blocking + ) + # Add to memo to avoid re-staging if we see this storage again + memo[storage_id] = copied_storage + + # Set the tensor data using the staged storage + y.set_(copied_storage, x.storage_offset(), x.size(), x.stride()) + + # Copy any attributes the tensor might have + if hasattr(x, "__dict__"): + for attr_name, attr_value in x.__dict__.items(): + setattr( + y, + attr_name, + self.deepcopy_with_tensor_offload( + attr_value, memo, non_blocking=non_blocking + ), + ) + + if hasattr(x, "__slots__"): + for slot in x.__slots__: + if hasattr(x, slot): + setattr( + y, + slot, + self.deepcopy_with_tensor_offload( + getattr(x, slot), memo, non_blocking=non_blocking + ), + ) + + return y + + def close(self): + """ + Clean up all cached storages and release associated resources. + + This method clears the internal storage cache, allowing garbage collection + of cached CPU storages. Any pinned memory associated with cached storages + will be automatically unpinned through weak reference finalizers. + """ + self._cached_storage_mapping.clear() + + @torch.no_grad() + def deepcopy_with_tensor_offload(self, x, memo=None, _nil=[], non_blocking=False): # noqa: B006 + """Deep copy operation on arbitrary Python objects with special handling for PyTorch tensors. + + This implementation extends the standard deepcopy functionality to handle PyTorch tensors + and their storages in a way that optimizes memory usage and performance, similar to the + stage method. It applies memory sharing and pinning optimizations based on the StateDictStager + configuration. + + Args: + x: The object to deep copy + memo: Memo dictionary for tracking already copied objects + _nil: Sentinel value for memo dictionary + non_blocking: Whether to perform non-blocking copies where possible + + Returns: + A deep copy of the input object with optimized tensor storage handling + """ + if memo is None: + memo = {} + + d = id(x) + y = memo.get(d, _nil) + if y is not _nil: + return y + + cls = type(x) + + # tensors and subclasses of tensors are handled separately + if isinstance(x, torch.Tensor): + y = self._offload_tensor(x, memo, non_blocking=non_blocking) + else: + # Use the dispatch table for standard types + copier = self._deepcopy_dispatch.get(cls) + if copier is not None: + # Check if this is an atomic copier (only accepts x and memo) + if copier.__name__ == "_deepcopy_atomic": + y = copier(x, memo) + else: + y = copier(x, memo, non_blocking=non_blocking) + else: + if issubclass(cls, type): + # type copier is also atomic + y = self._deepcopy_dispatch[type](x, memo) + else: + copier = getattr(x, "__deepcopy__", None) + if copier is not None: + y = copier(memo) + else: + reductor = dispatch_table.get(cls) + if reductor: + rv = reductor(x) + else: + reductor = getattr(x, "__reduce_ex__", None) + if reductor is not None: + rv = reductor(4) + else: + reductor = getattr(x, "__reduce__", None) + if reductor: + rv = reductor() + else: + raise RuntimeError( + f"un(deep)copyable object of type {cls}" + ) + if isinstance(rv, str): + y = x + else: + # Unpack rv tuple elements (up to 5 from pickle protocol) + # and explicitly pass non_blocking as keyword arg + if len(rv) == 2: + func, args = rv + y = self._reconstruct( + x, memo, func, args, non_blocking=non_blocking + ) + elif len(rv) == 3: + func, args, state = rv + y = self._reconstruct( + x, + memo, + func, + args, + state, + non_blocking=non_blocking, + ) + elif len(rv) == 4: + func, args, state, listiter = rv + y = self._reconstruct( + x, + memo, + func, + args, + state, + listiter, + non_blocking=non_blocking, + ) + elif len(rv) == 5: + func, args, state, listiter, dictiter = rv + y = self._reconstruct( + x, + memo, + func, + args, + state, + listiter, + dictiter, + non_blocking=non_blocking, + ) + else: + raise RuntimeError( + f"Unexpected pickle protocol return value length: {len(rv)}" + ) + + # If is its own copy, don't memoize. + if y is not x: + memo[d] = y + self._keep_alive(x, memo) # Make sure x lives at least as long as d + return y + + def _keep_alive(self, x, memo): + """Keeps a reference to the object x in the memo. + + Because we remember objects by their id, we have + to assure that possibly temporary objects are kept + alive by referencing them. + We store a reference at the id of the memo, which should + normally not be used unless someone tries to deepcopy + the memo itself... + """ + try: + memo[id(memo)].append(x) + except KeyError: + # aha, this is the first one :-) + memo[id(memo)] = [x] + + def _reconstruct( + self, + x, + memo, + func, + args, + state=None, + listiter=None, + dictiter=None, + non_blocking=False, + ): + deep = memo is not None + if deep and args: + args = tuple( + self.deepcopy_with_tensor_offload(arg, memo, non_blocking=non_blocking) + for arg in args + ) + y = func(*args) + if deep: + memo[id(x)] = y + + if state is not None: + if deep: + state = self.deepcopy_with_tensor_offload( + state, memo, non_blocking=non_blocking + ) + if hasattr(y, "__setstate__"): + y.__setstate__(state) + else: + if isinstance(state, tuple) and len(state) == 2: + state, slotstate = state + else: + slotstate = None + if state is not None: + y.__dict__.update(state) + if slotstate is not None: + for key, value in slotstate.items(): + setattr(y, key, value) + + if listiter is not None: + if deep: + for item in listiter: + item = self.deepcopy_with_tensor_offload( + item, memo, non_blocking=non_blocking + ) + y.append(item) + else: + for item in listiter: + y.append(item) + if dictiter is not None: + if deep: + for key, value in dictiter: + key = self.deepcopy_with_tensor_offload( + key, memo, non_blocking=non_blocking + ) + value = self.deepcopy_with_tensor_offload( + value, memo, non_blocking=non_blocking + ) + y[key] = value + else: + for key, value in dictiter: + y[key] = value + return y diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_storage_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_storage_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..73acc628342a058f659042b2d41c8245c86c2c42 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_storage_utils.py @@ -0,0 +1,49 @@ +import os +from typing import Union + +from .filesystem import FileSystemReader, FileSystemWriter +from .storage import StorageReader, StorageWriter + + +def _storage_setup( + storage: Union[StorageReader, StorageWriter, None], + checkpoint_id: Union[str, os.PathLike, None], + reader: bool = False, +) -> Union[None, StorageReader, StorageWriter]: + if storage: + if checkpoint_id is not None: + storage.reset(checkpoint_id) + return storage + + if not checkpoint_id: + raise RuntimeError( + "`checkpoint_id` must be specified if " + "storage_reader/storage_writer is None." + ) + + targets: list[type[Union[StorageReader, StorageWriter]]] = [] + if reader: + targets = [ + FileSystemReader, + ] + else: + targets = [ + FileSystemWriter, + ] + try: + from ._fsspec_filesystem import FsspecReader, FsspecWriter + + targets.append(FsspecReader if reader else FsspecWriter) + except Exception: + pass + + for target in targets: + if target.validate_checkpoint_id(checkpoint_id): + storage = target(checkpoint_id) # type: ignore[call-arg] + storage.reset(checkpoint_id) + return storage + + raise RuntimeError( + "Cannot detect which StorageReader or StorageWriter to use. " + "Please specify the storage_reader/storage_writer." + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_traverse.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_traverse.py new file mode 100644 index 0000000000000000000000000000000000000000..48eb67b4f7621b1aa3a4d6b2d7c56c5503337eb7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_traverse.py @@ -0,0 +1,200 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from collections.abc import Callable, Collection, Mapping, MutableMapping +from typing import cast, Optional, TypeVar, Union + +import torch +from torch.distributed._shard.sharded_tensor.api import ShardedTensor +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from torch.distributed.tensor import DTensor + + +PATH_ITEM = Union[str, int] +OBJ_PATH = tuple[PATH_ITEM, ...] +T = TypeVar("T") + +STATE_DICT_ITEM = object +CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM] + +__all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"] + + +def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool: + return isinstance(value, torch.Tensor) + + +# TODO: update docstring for traverse.py +def traverse_state_dict( + state_dict: STATE_DICT_TYPE, + visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None], + keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors, +) -> None: + """ + Invoke ``visitor`` for each value recursively in ``state_dict``. + Mapping will be traversed and ``visitor`` will be applied to the leaf elements. + ``visitor`` will only be applied to elements in a list or a tuple, if the + container contains tensors or mappings. + """ + + def _is_terminal(value: STATE_DICT_ITEM) -> bool: + values: Collection[STATE_DICT_ITEM] + if isinstance(value, Mapping): + return False + elif isinstance(value, list): + values = value + else: + return True + + for entry in values: + if isinstance(entry, (Mapping, list)) and not _is_terminal(entry): + return False + if keep_traversing is not None and keep_traversing(entry): + return False + return True + + def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: + if isinstance(value, Mapping): + for k, v in value.items(): + _traverse_obj(path + (str(k),), v) + elif _is_terminal(value): + visitor(path, value) + elif isinstance(value, (list, tuple)): + for i, v in enumerate(value): + _traverse_obj(path + (i,), v) + + for key, value in state_dict.items(): + _traverse_obj((str(key),), value) + + +def traverse_state_dict_v_2_3( + state_dict: STATE_DICT_TYPE, + visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None], + keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors, +) -> None: + """ + Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates + to false for all elements. + By default, all collections with at least one ``torch.Tensor`` element are traversed. + Visitor takes a path argument that is a tuple of the keys used to reach it. + """ + + # a value is terminal if it has no other containers values inside it + def _is_terminal(value: STATE_DICT_ITEM) -> bool: + values: Collection[STATE_DICT_ITEM] + if isinstance(value, Mapping): + values = value.values() + elif isinstance(value, list): + values = value + else: + return True + + for entry in values: + if isinstance(entry, (Mapping, list)) and not _is_terminal(entry): + return False + if keep_traversing is not None and keep_traversing(entry): + return False + return True + + def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: + if _is_terminal(value): + visitor(path, value) + elif isinstance(value, Mapping): + for k, v in value.items(): + _traverse_obj(path + (str(k),), v) + elif isinstance(value, list): + for i, v in enumerate(value): + _traverse_obj(path + (i,), v) + + for key, value in state_dict.items(): + _traverse_obj((str(key),), value) + + +def set_element( + root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM +) -> None: + """Set ``value`` in ``root_dict`` along the ``path`` object path.""" + cur_container = cast(CONTAINER_TYPE, root_dict) + + def extend_list(lst: list[STATE_DICT_ITEM], idx: int) -> None: + while len(lst) <= idx: + lst.append(None) + + for i in range(1, len(path)): + prev_key = path[i - 1] + key = path[i] + def_val = cast(STATE_DICT_ITEM, {} if type(key) is str else []) + + if isinstance(cur_container, Mapping): + cur_container = cast( + CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val) + ) + else: + # pyrefly: ignore [bad-argument-type] + extend_list(cur_container, prev_key) + if cur_container[prev_key] is None: + cur_container[prev_key] = def_val + cur_container = cur_container[prev_key] + + key = path[-1] + if type(key) is int: + extend_list(cast(list[STATE_DICT_ITEM], cur_container), key) + + cur_container[key] = value + + +def get_element( + root_dict: STATE_DICT_TYPE, + path: OBJ_PATH, + default_value: Optional[T] = None, +) -> Optional[T]: + """Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found.""" + cur_value = cast(CONTAINER_TYPE, root_dict) + for part in path: + if type(part) is int: + if not isinstance(cur_value, list) or len(cur_value) < part: + return default_value + elif not isinstance(cur_value, Mapping) or part not in cur_value: + return default_value + + # pyrefly: ignore [index-error] + cur_value = cast(CONTAINER_TYPE, cur_value[part]) + return cast(Optional[T], cur_value) + + +def _print_nested( + value: STATE_DICT_ITEM, + prefix: str = "", + print_fun: Callable[[str], None] = print, +) -> None: + if type(value) is ShardedTensor: + print_fun(f"{prefix} ShardedTensor size: {value.size()}") + for shard in value.local_shards(): + _print_nested( + shard.tensor, + f"{shard.metadata.shard_offsets} ", + print_fun=print_fun, + ) + elif type(value) is (DTensor): + print_fun(f"{prefix} DistributedTensor size: {value.size()}") + # TODO: add local offset for _local_tensor in print_nested. + _print_nested( + value._local_tensor, + print_fun=print_fun, + ) + elif isinstance(value, torch.Tensor): + print_fun(f"{prefix} Tensor size: {value.size()}") + else: + print_fun(f"{prefix} Type: {type(value)}") + + +def print_tensor( + path: OBJ_PATH, + value: STATE_DICT_ITEM, + print_fun: Callable[[str], None] = print, +) -> None: + """ + Use this callback with traverse_state_dict to print its content. + + By default the content is printed using the builtin ``print`` but this can + be change by passing a different ``print_fun` callable. + """ + _print_nested(value, prefix=str(path), print_fun=print_fun) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_version.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..b3065bdfd6a2c141a959ef0ffe30aeafdc2dc54f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/_version.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +from typing import Optional + + +_derived_version: Optional[str] = None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/api.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa4854db2358ae4361403d37d59563ab8963fbd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/api.py @@ -0,0 +1,42 @@ +import traceback as tb +from typing import Any + + +WRAPPED_EXCEPTION = tuple[BaseException, tb.StackSummary] + +__all__ = ["CheckpointException"] + + +def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION: + return (exc, tb.extract_tb(exc.__traceback__)) + + +def _is_wrapped_exception(obj: Any) -> bool: + if not isinstance(obj, tuple): + return False + if len(obj) != 2: + return False + return isinstance(obj[0], BaseException) and isinstance(obj[1], tb.StackSummary) + + +class CheckpointException(BaseException): + """Exception raised if failure was detected as part of a checkpoint load or save.""" + + def __init__(self, msg: str, failures: dict[int, WRAPPED_EXCEPTION]): + super().__init__(msg, failures) + self._failures = failures + + @property + def failures(self) -> dict[int, WRAPPED_EXCEPTION]: + """Return a dictionary mapping node ranks to their associated exceptions in case of failure.""" + return self._failures + + def __str__(self) -> str: + str = f"CheckpointException ranks:{self._failures.keys()}\n" + for rank, exc_pair in self._failures.items(): + exc, trace = exc_pair + str += f"Traceback (most recent call last): (RANK {rank})\n" + if trace is not None: + str += "".join(tb.format_list(trace)) + str += "".join(tb.format_exception_only(type(exc), value=exc)) + return str diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/default_planner.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/default_planner.py new file mode 100644 index 0000000000000000000000000000000000000000..716cb90a996534e4388a42545935ebee894eeb1a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/default_planner.py @@ -0,0 +1,702 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +import dataclasses +import io +import logging +import math +import sys +from bisect import bisect_right, insort +from collections import ChainMap +from typing import Any, cast, Optional, Union + +import torch +from torch.distributed._shard._utils import narrow_tensor_by_index +from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans +from torch.distributed.checkpoint._nested_dict import ( + FLATTEN_MAPPING, + flatten_state_dict, +) +from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors +from torch.distributed.checkpoint._traverse import set_element +from torch.distributed.checkpoint.metadata import ( + BytesStorageMetadata, + ChunkStorageMetadata, + Metadata, + MetadataIndex, + STATE_DICT_TYPE, + STORAGE_TYPES, + StorageMeta, + TensorStorageMetadata, +) +from torch.distributed.checkpoint.planner import ( + LoadPlan, + LoadPlanner, + ReadItem, + SavePlan, + SavePlanner, + WriteItem, + WriteItemType, +) +from torch.distributed.checkpoint.planner_helpers import ( + _compare_save_plans, + _contains_usable_plan, + _create_default_metadata_only_plan, + _create_read_items, + _create_write_items, + _init_state_dict, + _merge_delta_local_plans, +) +from torch.distributed.checkpoint.utils import find_state_dict_object +from torch.distributed.tensor import DTensor + +from . import _version + + +logger: logging.Logger = logging.getLogger(__name__) + + +__all__ = [ + "DefaultSavePlanner", + "DefaultLoadPlanner", + "create_default_local_load_plan", + "create_default_global_load_plan", + "create_default_local_save_plan", + "create_default_global_save_plan", +] + + +# TODO: Update docstrings for default_planner.py +class DefaultSavePlanner(SavePlanner): + mappings: FLATTEN_MAPPING + + def __init__( + self, + flatten_state_dict: bool = True, + flatten_sharded_tensors: bool = True, + dedup_replicated_tensors: Optional[bool] = None, + dedup_save_to_lowest_rank: bool = False, + enable_plan_caching: bool = False, + ) -> None: + self.flatten_state_dict = flatten_state_dict + self.flatten_sharded_tensors = flatten_sharded_tensors + self.mappings = {} + self.dedup_save_to_lowest_rank = dedup_save_to_lowest_rank + if dedup_replicated_tensors is not None: + logger.warning( + "DefaultSavePlanner's `dedup_replicated_tensors` argument is being " + "deprecated, and no longer has any effect. Please remove this argument " + "from your call." + ) + self._cached_plans_key: str = self.__class__.__name__ + self._enable_plan_caching = enable_plan_caching + + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + storage_meta: Optional[StorageMeta] = None, + is_coordinator: bool = False, + ) -> None: + if self.flatten_state_dict: + state_dict, self.mappings = flatten_state_dict(state_dict) + if self.flatten_sharded_tensors: + state_dict = _flatten_sharded_tensors(state_dict) + self.state_dict = state_dict + self.is_coordinator = is_coordinator + + def create_local_plan(self) -> SavePlan: + plan = create_default_local_save_plan(self.state_dict, self.is_coordinator) + if self.flatten_state_dict: + plan = dataclasses.replace(plan, planner_data=self.mappings) + self.plan = plan + + if self._enable_plan_caching: + # If plans are equal, we can skip sending the plan to the coordinator. + if ( + self._cached_plans_key in SavePlanner._cached_save_plan + and _compare_save_plans( + plan, SavePlanner._cached_save_plan[self._cached_plans_key] + ) + ): + logger.info( + "No change in the local plan. Skipping sending the plan to the coordinator" + ) + return SavePlan([], usable=False) + else: + SavePlanner._cached_save_plan[self._cached_plans_key] = plan + + return self.plan + + def _dedup_save_plans(self, all_plans: list[SavePlan]) -> list[SavePlan]: + return dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank) + + def _create_global_plan( + self, all_plans: list[SavePlan] + ) -> tuple[list[SavePlan], Metadata]: + deduped_plans = self._dedup_save_plans(all_plans) + + global_plan, metadata = create_default_global_save_plan(deduped_plans) + + if self.flatten_state_dict: + # | does not work for Python 3.8 or older version. + # merged_mappings = reduce( + # lambda x, y: x | y, (p.planner_data for p in global_plan) + # ) + planner_data_dict = [p.planner_data for p in global_plan] + merged_mappings = dict(ChainMap(*planner_data_dict)) + metadata = dataclasses.replace(metadata, planner_data=merged_mappings) + + if not _validate_global_plan(global_plan, metadata): + raise ValueError("Failed to validate global plan") + + return global_plan, metadata + + def _create_global_plan_with_caching( + self, all_plans: list[SavePlan] + ) -> tuple[list[SavePlan], list[SavePlan], Metadata]: + """ + Create global plan with caching. + Returns a tuple of global_plan_delta, global_plan, metadata. + """ + global_plan_delta: list[SavePlan] = [] + + if self._cached_plans_key not in SavePlanner._cached_all_plans: + # Case 1: If the plans are not cached, the cache will be hydrated with the + # all_plans, global_plans (Deduped), and metadata. + + # Cache the original all_plans + SavePlanner._cached_all_plans[self._cached_plans_key] = all_plans + global_plan, metadata = self._create_global_plan(all_plans) + # Cache the deduped and validated global_plan + SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan + # Cache the metadata + SavePlanner._cached_metadata[self._cached_plans_key] = metadata + # If plans are not cached, global_plan delta will be the same as global plan. + return global_plan, global_plan, metadata + + # Case 2: Plans are cached + if not _contains_usable_plan(all_plans): + # Case 2.1: Plans are cached and the local plans have NOT changed (No usable plans). + # Global plan delta will be empty plans to avoid the collective overhead. + # We can reuse the deduped global plan and metadata from the cache directly. + global_plan_delta = [SavePlan([], usable=False)] * len(all_plans) + global_plan = SavePlanner._cached_global_plan[self._cached_plans_key] + metadata = SavePlanner._cached_metadata[self._cached_plans_key] + else: + # Case 2.2: Plans are cached but the local plans have changed. + # We will merge the changed local plans with the cached local plans. + # Updated plans will overwrite the cached plans. New global plan and metadata will be created and cached. + # Global plan delta will be created by comparing the new global plan with the cached global plan. + # Only the global plan delta (updated ones) will be sent to the coordinator to avoid the collective overhead. + merged_plans = _merge_delta_local_plans( + SavePlanner._cached_all_plans[self._cached_plans_key], all_plans + ) + # Cache the updated local plans + SavePlanner._cached_all_plans[self._cached_plans_key] = merged_plans + global_plan, metadata = self._create_global_plan(merged_plans) + + if self._cached_plans_key in self._cached_global_plan: + for cached_plan, new_plan in zip( + SavePlanner._cached_global_plan[self._cached_plans_key], global_plan + ): + if _compare_save_plans(cached_plan, new_plan): + global_plan_delta.append(SavePlan([], usable=False)) + else: + global_plan_delta.append(new_plan) + + # Cache the new global plan and the metadata + SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan + SavePlanner._cached_metadata[self._cached_plans_key] = metadata + + return global_plan_delta, global_plan, metadata + + def create_global_plan( + self, all_plans: list[SavePlan] + ) -> tuple[list[SavePlan], Metadata]: + global_plan_delta: list[SavePlan] = [] + if self._enable_plan_caching: + # If the plans are cached, we only need to send the global plan delta to be scattered + # across ranks. Ranks will use the cached final plans instead. + ( + global_plan_delta, + global_plan, + metadata, + ) = self._create_global_plan_with_caching(all_plans) + else: + global_plan, metadata = self._create_global_plan(all_plans) + # If the caching is not enabled, global delta plan will always be same as the new global plan. + global_plan_delta = global_plan + + self.global_plan = global_plan + self.metadata = metadata + + return global_plan_delta, self.metadata + + def _finish_plan_with_caching(self, new_plan: SavePlan) -> SavePlan: + finished_plan: SavePlan = new_plan + + if not new_plan.usable: + finished_plan = SavePlanner._cached_final_save_plan[self._cached_plans_key] + else: + finished_plan = new_plan + SavePlanner._cached_final_save_plan[self._cached_plans_key] = new_plan + return finished_plan + + def finish_plan(self, new_plan: SavePlan) -> SavePlan: + finished_plan: SavePlan = new_plan + + if self._enable_plan_caching: + finished_plan = self._finish_plan_with_caching(new_plan) + + self.plan = finished_plan + return self.plan + + def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]: + object = self.lookup_object(write_item.index) + return self.transform_object(write_item, object) + + def lookup_object(self, index: MetadataIndex) -> Any: + """Extension from the planner interface to make it easy to extend the default planner.""" + return find_state_dict_object(self.state_dict, index) + + def transform_object(self, write_item: WriteItem, object: Any): + """Extension from the planner interface to make it easy to extend the default planner.""" + if write_item.type == WriteItemType.BYTE_IO: + bytes = io.BytesIO() + torch.save(object, bytes) + object = bytes + return object + + +class DefaultLoadPlanner(LoadPlanner): + """ + DefaultLoadPlanner that adds multiple features on top of LoadPlanner. + + In particular it adds the following: + + flatten_state_dict: Handle state_dict with nested dicts + flatten_sharded_tensors: For FSDP in 2D parallel mode + allow_partial_load: If False, will raise a runtime error if a key is present in state_dict, but not in the checkpoint. + """ + + original_state_dict: STATE_DICT_TYPE + mappings: FLATTEN_MAPPING + + def __init__( + self, + flatten_state_dict: bool = True, + flatten_sharded_tensors: bool = True, + allow_partial_load: bool = False, + ) -> None: + self.flatten_state_dict = flatten_state_dict + self.flatten_sharded_tensors = flatten_sharded_tensors + self.original_state_dict = {} + self.mappings = {} + self.allow_partial_load = allow_partial_load + + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + metadata: Optional[Metadata] = None, + is_coordinator: bool = False, + ) -> None: + _init_state_dict(state_dict) + self.original_state_dict = state_dict + + if self.flatten_sharded_tensors: + state_dict = _flatten_sharded_tensors(state_dict) + + if self.flatten_state_dict: + state_dict, self.mappings = flatten_state_dict(state_dict) + + self.state_dict = state_dict + self.metadata = metadata + self.is_coordinator = is_coordinator + + def create_local_plan(self) -> LoadPlan: + if self.metadata is None: + raise AssertionError("self.metadata is not None") + if self.flatten_state_dict: + # To support checkpoints that are saved before v2.4, we have to + # differentiate if the missing keys are due to old checkpoints. + # The contracts are: + # 1. There are 3 cases when we found a missing key. + # 1.1 Actual missing key, but allow_partial_load is False + # 1.2 Actual missing key, but allow_partial load is True + # 1.3 Old checkpoint, but allow_partial_load is False + # 1.4 Old checkpoint, but allow_partial_load is True + # 2. If we found a missing key, we first convert the keys back to + # the key format of v2.3 + # 3. If the previous missing keys are in the v2.3 keys, we assume + # this is a old checkpoint. + # 4. Pass the state_dict to `create_default_local_load_plan()`, + # which has the logic to check missing for allow_partial_load. + # So for 1.2 and 1.4 cases, we delegate allow_partial_load check to + # `create_default_local_load_plan()`. The logic here is to determine + # whether the checkpoint belong to 2.3 (or before) or 2.4 (or after). + current_keys = set(self.state_dict.keys()) + load_keys = set(self.metadata.state_dict_metadata.keys()) + missing_keys = load_keys - current_keys + if missing_keys: + _version._derived_version = "2_3" + old_state_dict, old_mappings = flatten_state_dict( + self.original_state_dict + ) + old_keys = set(old_state_dict.keys()) + if old_keys & missing_keys: + self.state_dict, self.mappings = old_state_dict, old_mappings + # _derived_version is only used by flatten_state_dict now. + # Set it back to None so that later we can save to a new version. + _version._derived_version = None + + return create_default_local_load_plan( + self.state_dict, self.metadata, not self.allow_partial_load + ) + + def create_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]: + return create_default_global_load_plan(global_plan) + + def finish_plan(self, new_plan: LoadPlan) -> LoadPlan: + return new_plan + + def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None: + if self.flatten_state_dict: + set_element( + self.original_state_dict, + self.mappings[read_item.dest_index.fqn], + torch.load(value, weights_only=False), + ) + else: + self.state_dict[read_item.dest_index.fqn] = torch.load( + value, weights_only=False + ) + + def resolve_tensor(self, read_item: ReadItem): + tensor = self.lookup_tensor(read_item.dest_index) + return self.transform_tensor(read_item, tensor) + + def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None: + pass + + def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor: + """Extension from the planner interface to make it easy to extend the default planner.""" + return find_state_dict_object(self.state_dict, index) + + def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor): + """Extension from the planner interface to make it easy to extend the default planner.""" + return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths) + + +class _EmptyStateDictLoadPlanner(DefaultLoadPlanner): + """ + Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata. + Useful for loading in state_dict without first initializing a model, such as + when converting a DCP checkpoint into a Torch save file. + + . N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner + + .. warning:: + Because the entire state dict is initialized, It's recommended to only utilize + this LoadPlanner on a single rank or process to avoid OOM. + + """ + + def __init__(self, keys=None, *args, **kwargs): + self.keys = keys + super().__init__(*args, **kwargs) + + def _should_include_key(self, key: str, metadata: Metadata) -> bool: + if self.keys is None: + return True + + if key in self.keys: + return True + + unflattened_keys: list[str] = [] + planner_data = metadata.planner_data.get(key) + for unflattened_key in planner_data: + if unflattened_keys: + unflattened_keys.append( + ".".join([unflattened_keys[-1], str(unflattened_key)]) + ) + + else: + unflattened_keys.append(unflattened_key) + + if any(unflattened_key in self.keys for unflattened_key in unflattened_keys): + return True + + return False + + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + metadata: Optional[Metadata] = None, + is_coordinator: bool = False, + ) -> None: + if state_dict: + raise AssertionError("not state_dict") + if metadata is None: + raise AssertionError("metadata is not None") + + # rebuild the state dict from the metadata + for k, v in metadata.state_dict_metadata.items(): + if not self._should_include_key(k, metadata): + continue + + if isinstance(v, TensorStorageMetadata): + v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment] + if metadata.planner_data is not None and k in metadata.planner_data: + set_element(state_dict, metadata.planner_data[k], v) + else: + state_dict[k] = v + + super().set_up_planner(state_dict, metadata, is_coordinator) + + +def create_default_local_load_plan( + state_dict: dict[str, Any], metadata: Metadata, strict: bool = True +) -> LoadPlan: + requests = [] + """ + Create the ``LoadPlan`` used by DefaultLoadPlanner. + + It produces one read item per value in ``state_dict`` using the metadata in ``metadata``. + + The default behavior is to match key exactly between state_dict and metadata. + It handles resharding by issuing multiple read requests against storage in order to match + load requirements. + """ + + for fqn, obj in state_dict.items(): + # ignore state_dict keys which do not exist in `state_dict` if strict=False + if fqn not in metadata.state_dict_metadata: + if strict: + raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.") + else: + continue + + md = metadata.state_dict_metadata[fqn] + if ( + isinstance(md, TensorStorageMetadata) + and getattr(obj, "size", None) is not None + and md.size != obj.size() + ): + raise ValueError( + f"Size mismatch between saved {md.size} and current: {obj.size()} for {fqn}", + ) + # Since DTensor supports submesh, adding extra check to ensure _create_read_items() + # gets called only when the current rank is part of the mesh for the corresponding DTensor. + if isinstance(obj, DTensor): + if obj.device_mesh.get_coordinate() is not None: + requests += _create_read_items(fqn, md, obj) + else: + requests += _create_read_items(fqn, md, obj) + + return LoadPlan(requests) + + +def create_default_global_load_plan( + all_plans: list[LoadPlan], +) -> list[LoadPlan]: + """ + Create global load plan used by DefaultLoadPlanner. + + The default load behavior involved no global coordination and this function + currently doesn't change the local plans. + """ + return all_plans + + +def create_default_local_save_plan( + state_dict: dict[str, Any], is_coordinator: bool +) -> SavePlan: + """ + Create the ``SavePlan`` used by DefaultSavePlanner. + + On non-coordinator ranks, this function ignores tensors and non-tensor objects, + only producing writes for ShardedTensor objects. + + On the coordinator rank, produce writes for all values. + """ + requests = [] + for fqn, obj in state_dict.items(): + # Since DTensor supports submesh, adding extra check to ensure _create_write_items() + # gets called only when the current rank is part of the mesh for the corresponding DTensor. + if isinstance(obj, DTensor): + if obj.device_mesh.get_coordinate() is not None: + requests += _create_write_items(fqn, obj) + else: + # For the plain tensor and non-tensor values, add the request for all + # the ranks. Coordinator will decides whether to deduplicate the + # values based on the keys. + requests += _create_write_items(fqn, obj) + + return SavePlan(requests) + + +def create_default_global_save_plan( + all_plans: list[SavePlan], + rewrite_index_hints: bool = True, +) -> tuple[list[SavePlan], Metadata]: + """ + Create the global plan and metadata used by DefaultSavePlanner. + + Metadata is produced by concatenating the metadata of all ``WriteItem`` from the supplied plans. + + The only global planning change is to update index hints in all ``MetadataIndex`` objects if + ``rewrite_index_hints`` is True. + """ + md: dict[str, STORAGE_TYPES] = {} + new_plans = [] + for plan in all_plans: + new_items = [] + for item in plan.items: + if item.type != WriteItemType.SHARD: + if item.index.fqn in md: + raise AssertionError("item.index.fqn not in md") + + if item.type == WriteItemType.BYTE_IO: + md[item.index.fqn] = BytesStorageMetadata() + new_items.append(item) + else: + if item.tensor_data is None: + raise AssertionError("item.tensor_data is not None") + tensor_md = cast( + TensorStorageMetadata, + md.setdefault( + item.index.fqn, + TensorStorageMetadata( + properties=item.tensor_data.properties, + size=item.tensor_data.size, + chunks=[], + ), + ), + ) + new_item = item + if rewrite_index_hints: + new_index = dataclasses.replace( + item.index, index=len(tensor_md.chunks) + ) + new_item = dataclasses.replace(item, index=new_index) + new_items.append(new_item) + + if item.tensor_data.chunk is None: + raise AssertionError(f""" + Cannot create MD for tensor without bounds. + FQN: {item.index.fqn} + """) + tensor_md.chunks.append(item.tensor_data.chunk) + new_plans.append(dataclasses.replace(plan, items=new_items)) + return (new_plans, Metadata(md)) + + +def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata: + """Return the ``Metadata`` if DefaultSavePlanner was used to checkpoint ``state_dict``.""" + plan = _create_default_metadata_only_plan(state_dict) + _, md = create_default_global_save_plan([plan]) + return md + + +def _check_box_overlap(box0: ChunkStorageMetadata, box1: ChunkStorageMetadata) -> bool: + """Check if two boxes overlap. Tuples are (offset, lengths).""" + # For each dim of each shard, check if one shard resides on the other + # end of second shard with respect to that dim. As an example for a 2D + # shard, we would check if one shard is above or on the left of the + # other shard. + ndims = len(box0.offsets) + for i in range(ndims): + if box0.offsets[i] >= box1.offsets[i] + box1.sizes[i]: + return False + if box1.offsets[i] >= box0.offsets[i] + box0.sizes[i]: + return False + + return True + + +def _check_box_bounds( + outer_box_size: torch.Size, inner_box: ChunkStorageMetadata +) -> bool: + for i in range(len(outer_box_size)): + if inner_box.offsets[i] < 0: + return False + if inner_box.sizes[i] < 0: + return False + if inner_box.offsets[i] + inner_box.sizes[i] > outer_box_size[i]: + return False + + return True + + +def _validate_global_plan(global_plan: list[SavePlan], metadata: Metadata) -> bool: + all_good = True + for key, value in metadata.state_dict_metadata.items(): + if isinstance(value, BytesStorageMetadata): + continue + if len(value.size) == 0: + continue + chunks = value.chunks + chunks_volume = 0 + for chunk in chunks: + # Compute the volume + if not _check_box_bounds(value.size, chunk): + logger.warning( + """ + key:%s has out of bounds chunk: + tensor-size:%s chunk: %s + """, + key, + value.size, + chunk, + ) + all_good = False + chunks_volume += math.prod(chunk.sizes) + + if len(chunks) > 1: + dims = len(value.size) + sweep_dim = max(range(dims), default=0, key=lambda d: value.size[d]) + sorted_indices = sorted( + range(len(chunks)), + key=lambda idx: ( + chunks[idx].offsets[sweep_dim], + *(chunks[idx].offsets[d] for d in range(dims)), + ), + ) + active: list[tuple[int, int]] = [] + for idx in sorted_indices: + current = chunks[idx] + start = current.offsets[sweep_dim] + end = start + current.sizes[sweep_dim] + + cutoff = bisect_right(active, (start, sys.maxsize)) + if cutoff: + del active[:cutoff] + + for _, other_idx in active: + other = chunks[other_idx] + if _check_box_overlap(current, other): + logger.warning( + "key:%s has overlapping chunks: %s %s", + key, + current, + other, + ) + all_good = False + + insort(active, (end, idx)) + + # Check whether combined chunk cover the whole tensor + tensor_volume = math.prod(value.size) + if len(global_plan) > 1 and chunks_volume != tensor_volume: + logger.warning( + """ + key:%s invalid fill tensor-volume: + %s chunks-volume: %s + """, + key, + tensor_volume, + chunks_volume, + ) + all_good = False + + return all_good diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/filesystem.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/filesystem.py new file mode 100644 index 0000000000000000000000000000000000000000..b21cac12ff90522f075b7b32029eae01e7a92169 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/filesystem.py @@ -0,0 +1,1035 @@ +# mypy: allow-untyped-defs +import collections +import dataclasses +import io +import json +import operator +import os +import pickle +import queue +import threading +import uuid +import warnings +from abc import ABC, abstractmethod +from collections.abc import Callable, Generator, Iterable, Iterator, Sequence +from contextlib import contextmanager +from dataclasses import dataclass +from enum import Enum +from io import UnsupportedOperation +from pathlib import Path +from typing import Any, cast, Final, IO, Optional, Union + +# introduced as collections.abc.Buffer in Python 3.12 +from typing_extensions import Buffer + +import torch +from torch import Tensor +from torch._utils import _get_available_device_type, _get_device_module +from torch.distributed._shard._utils import narrow_tensor_by_index +from torch.distributed.checkpoint._extension import ( + ExtensionRegistry, + StreamTransformExtension, +) +from torch.distributed.checkpoint._hf_utils import ( + CUSTOM_METADATA_KEY, + DCP_VERSION_KEY, + FORMAT_KEY, + FORMAT_VALUE, + HF_DCP_VERSION, +) +from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE, StorageMeta +from torch.distributed.checkpoint.planner import ( + LoadItemType, + LoadPlan, + LoadPlanner, + ReadItem, + SavePlan, + SavePlanner, + WriteItem, + WriteItemType, +) +from torch.distributed.checkpoint.staging import BlockingAsyncStager +from torch.distributed.checkpoint.storage import ( + StorageReader, + StorageWriter, + WriteResult, +) +from torch.distributed.checkpoint.utils import _create_file_view +from torch.futures import Future + + +__all__ = [ + "FileSystemWriter", + "FileSystemReader", + "FileSystem", + "FileSystemBase", + "SerializationFormat", +] + +_metadata_fn: str = ".metadata" + +CURRENT_DCP_VERSION: Final[str] = "1.0.0" + + +@dataclass +class _StorageInfo: + """This is the per entry storage info.""" + + relative_path: str + offset: int + length: int + transform_descriptors: Optional[Sequence[str]] = None + + def __getstate__(self): + return {k: v for k, v in self.__dict__.items() if v is not None} + + +@dataclass +class _StoragePrefix: + prefix: str + + +class SerializationFormat(Enum): + TORCH_SAVE = "torch_save" + SAFETENSORS = "safetensors" + + +DEFAULT_SUFFIX = ".distcp" + + +def _generate_uuid() -> str: + return str(uuid.uuid4()) + + +class _TensorLoader(ABC): + @abstractmethod + def add(self, size: int, obj: object) -> None: + pass + + @abstractmethod + def start_loading(self) -> None: + pass + + @abstractmethod + def values(self) -> Iterator[tuple[torch.Tensor, object]]: + pass + + +class _SerialCpuLoader(_TensorLoader): + def __init__(self, resolve_fun: Callable) -> None: + self.resolve_fun = resolve_fun + self.items: list[tuple[int, object]] = [] + + def add(self, size: int, obj: object) -> None: + self.items.append((size, obj)) + + def start_loading(self) -> None: + pass + + def values(self) -> Iterator[tuple[torch.Tensor, object]]: + for _, obj in self.items: + tensor = self.resolve_fun(obj).detach() + tensor = tensor.cpu() + if tensor.storage().size() != tensor.numel(): + tensor = tensor.clone() + yield ( + tensor, + obj, + ) + + +class _OverlappingCpuLoader(_TensorLoader): + def __init__( + self, + resolve_fun: Callable, + stream: Optional[torch.Stream] = None, + inflight_threshhold: int = 1_000_000, + ) -> None: + self.resolve_fun = resolve_fun + self.items: list[tuple[int, object]] = [] + self.inflight_threshhold = inflight_threshhold + self.in_flight_data = 0 + self.current_items: collections.deque = collections.deque() + self.idx = 0 + self.started = False + self.device_type = ( + stream.device_type if stream else _get_available_device_type() + ) + self.device_module = _get_device_module(self.device_type) + self.stream = cast( + torch.cuda.Stream, stream or self.device_module.current_stream() + ) + if self.stream != self.device_module.current_stream(): + self.stream.wait_stream(self.device_module.current_stream()) + + @property + def _done(self) -> bool: + return self.idx >= len(self.items) + + def _drain(self) -> list[tuple[torch.Tensor, object]]: + drained = [] + if self.in_flight_data >= self.inflight_threshhold: + self.stream.synchronize() + while self.in_flight_data >= self.inflight_threshhold: + val = self.current_items.popleft() + self.in_flight_data -= val[0].numel() * val[0].element_size() + drained.append(val) + return drained + + def _refill(self) -> None: + with self.device_module.stream(self.stream): + while not self._done and self.in_flight_data < self.inflight_threshhold: + _, obj = self.items[self.idx] + self.idx += 1 + tensor = self.resolve_fun(obj).detach() + if tensor.device.type == self.device_type: + tensor = tensor.to(device="cpu", non_blocking=True) + elif tensor.device == torch.device("cpu"): + if ( + tensor.untyped_storage().size() + != tensor.numel() * tensor.itemsize + ): + # this forces the tensor to be both contiguous and with minimal storage + tensor = tensor.clone() + + self.current_items.append( + ( + tensor, + obj, + ) + ) + self.in_flight_data += tensor.numel() * tensor.element_size() + + def _finish(self) -> Iterable[tuple[torch.Tensor, object]]: + if not self._done: + raise AssertionError("_finish called before all items were processed") + if len(self.current_items) > 0: + self.stream.synchronize() + return self.current_items + + def add(self, size: int, obj: object) -> None: + if self.started: + raise RuntimeError("cannot add items after loading started") + self.items.append((size, obj)) + + def start_loading(self) -> None: + if self.started: + return + self.started = True + self.items.sort(key=operator.itemgetter(0)) + self._refill() + + def values(self) -> Iterator[tuple[torch.Tensor, object]]: + self.start_loading() + while not self._done: + drained = self._drain() + self._refill() + yield from drained + + yield from self._finish() + + +class _StorageWriterTransforms: + """ + This is experimental, and will likely move elsewhere in the + future. It lives here to minimize changes while we are still + learning and gathering feedback. + """ + + def __init__( + self, extensions: Optional[Sequence[StreamTransformExtension]] = None + ) -> None: + """ + If the extensions arg is None, this means the implementation + should provide whatever defaults it chooses. An empty + sequence indicates no extensions should be used. At this + time, the default extensions sequence is empty. + """ + self.extensions = () if extensions is None else extensions + + def transform_save_stream( + self, write_item: WriteItem, raw_stream: io.IOBase + ) -> tuple[IO[bytes], list[str]]: + # In order to avoid leaking fds, transformers' close must + # cascade to wrapped streams, but since this function can + # append to the raw stream, we can't close the actual stream. + # So, we use this to put a wrapper around the raw stream's + # close() to make it a noop, and it gets closed once all files + # are appended. + + class NoCloseWriter(io.IOBase): + def __init__(self, raw: io.IOBase): + self.raw = raw + + def writeable(self) -> bool: + return True + + def write(self, b: Buffer) -> int: + return self.raw.write(b) + + def close(self): + self.flush() + self.raw.flush() + # but not close. + + transform_to = cast(IO[bytes], NoCloseWriter(raw_stream)) + + for ex in self.extensions: + transform_to = ex.transform_to(transform_to) + + return (transform_to, [ex.get_descriptor() for ex in reversed(self.extensions)]) + + +def _item_size(item: WriteItem) -> int: + size = 1 + if item.tensor_data is None: + raise AssertionError("WriteItem tensor_data must not be None") + # can't use math.prod as PT needs to support older python + for s in item.tensor_data.size: + size *= s + + dtype = item.tensor_data.properties.dtype + return size * torch._utils._element_size(dtype) + + +def _split_by_size_and_type(bins: int, items: list[WriteItem]) -> list[list[WriteItem]]: + if bins == 1: + return [items] + + bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO] + tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO] + + buckets: list[list[WriteItem]] = [[] for _ in range(bins)] + bucket_sizes = [0 for _ in range(bins)] + + tensor_w.sort(key=_item_size, reverse=True) + + for i, wi in enumerate(bytes_w): + buckets[i % bins].append(wi) + + for wi in tensor_w: + # TODO replace with headq + idx = min(enumerate(bucket_sizes), key=operator.itemgetter(1))[0] + buckets[idx].append(wi) + bucket_sizes[idx] += _item_size(wi) + + return buckets + + +def _write_item( + transforms: _StorageWriterTransforms, + stream: io.IOBase, + data: Union[io.BytesIO, torch.Tensor], + write_item: WriteItem, + storage_key: str, + serialization_format: SerializationFormat, +) -> WriteResult: + offset = stream.tell() + + (transform_to, transform_descriptors) = transforms.transform_save_stream( + write_item, stream + ) + + if write_item.type == WriteItemType.BYTE_IO: + if not isinstance(data, io.BytesIO): + raise AssertionError("Data must be io.BytesIO for BYTE_IO write items") + transform_to.write(data.getbuffer()) + else: + if not isinstance(data, torch.Tensor): + raise AssertionError( + "Data must be torch.Tensor for non-BYTE_IO write items" + ) + if data.device != torch.device("cpu"): + raise AssertionError("Tensor must be on CPU device") + if serialization_format == SerializationFormat.TORCH_SAVE: + torch.save(data, transform_to) + + transform_to.close() + + if serialization_format == SerializationFormat.TORCH_SAVE or isinstance( + data, io.BytesIO + ): + length = stream.tell() - offset + else: + length = data.numel() * data.element_size() + + # For consistency with earlier versions, leave this field out of the + # metadata if there are no extensions. + info_transform_descriptors = ( + None if len(transform_descriptors) == 0 else transform_descriptors + ) + + return WriteResult( + index=write_item.index, + size_in_bytes=length, + storage_data=_StorageInfo( + storage_key, + offset, + length, + transform_descriptors=info_transform_descriptors, + ), + ) + + +def _write_files_from_queue( + create_stream: Callable, + file_queue: queue.Queue, + result_queue: queue.Queue, + planner: SavePlanner, + transforms: _StorageWriterTransforms, + inflight_threshhold: int, + use_fsync: bool, + thread_count: int, + serialization_format: SerializationFormat, +) -> None: + try: + while True: + file_name, storage_key, write_items = file_queue.get_nowait() + loader: _TensorLoader + + custom_backend_name = torch._C._get_privateuse1_backend_name() + custom_device_mod = getattr(torch, custom_backend_name, None) + + # TODO: Using the OverlappingCpuLoader with multiple threads creates significant + # performance degradation, observed as being related to cuda stream syncs. We + # should try to fix this and use _OverlappingCpuLoader for all threaded cases + if ( + thread_count == 1 + and ( + torch.cuda.is_available() + or (custom_device_mod and custom_device_mod.is_available()) + ) + and inflight_threshhold > 0 + ): + loader = _OverlappingCpuLoader( + planner.resolve_data, + inflight_threshhold=inflight_threshhold, + ) + else: + loader = _SerialCpuLoader( + planner.resolve_data, + ) + + tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO] + for write_item in tensor_w: + loader.add(_item_size(write_item), write_item) + loader.start_loading() + + bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO] + write_results = [] + + with create_stream(file_name, "wb") as stream: + for write_item in bytes_w: + data = planner.resolve_data(write_item) + write_results.append( + _write_item( + transforms, + stream, + data, + write_item, + storage_key, + serialization_format, + ) + ) + + tensor_dict = {} + metadata_dict = {} + for tensor, write_item in loader.values(): + if not tensor.is_cpu: + raise AssertionError("Tensor must be on CPU") + write_results.append( + _write_item( + transforms, + stream, + tensor, + write_item, # type: ignore[arg-type] + storage_key, + serialization_format, + ) + ) + tensor_dict[write_item.index.fqn] = tensor # type: ignore[attr-defined] + metadata_dict[write_item.index.fqn] = { # type: ignore[attr-defined] + "saved_offsets": write_item.tensor_data.chunk.offsets # type: ignore[attr-defined] + } + + if serialization_format == SerializationFormat.SAFETENSORS: + from safetensors.torch import save # type: ignore[import-not-found] + + stream.write( + save( + tensor_dict, + metadata={ + CUSTOM_METADATA_KEY: json.dumps(metadata_dict), + DCP_VERSION_KEY: str(HF_DCP_VERSION), + FORMAT_KEY: FORMAT_VALUE, + }, + ) + ) + + if use_fsync: + try: + os.fsync(stream.fileno()) + except (AttributeError, UnsupportedOperation): + os.sync() + stream.close() + result_queue.put(write_results) + except queue.Empty: + pass + + +class FileSystemBase(ABC): + @contextmanager + @abstractmethod + def create_stream( + self, path: Union[str, os.PathLike], mode: str + ) -> Generator[io.IOBase, None, None]: ... + + @abstractmethod + def concat_path( + self, path: Union[str, os.PathLike], suffix: str + ) -> Union[str, os.PathLike]: ... + + @abstractmethod + def rename( + self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] + ) -> None: ... + + @abstractmethod + def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: ... + + @abstractmethod + def mkdir(self, path: Union[str, os.PathLike]) -> None: ... + + @classmethod + @abstractmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: ... + + @abstractmethod + def exists(self, path: Union[str, os.PathLike]) -> bool: ... + + @abstractmethod + def rm_file(self, path: Union[str, os.PathLike]) -> None: ... + + +class FileSystem(FileSystemBase): + @contextmanager + def create_stream( + self, path: Union[str, os.PathLike], mode: str + ) -> Generator[io.IOBase, None, None]: + if not isinstance(path, Path): + path = Path(path) + with path.open(mode) as stream: + yield cast(io.IOBase, stream) + + def concat_path( + self, path: Union[str, os.PathLike], suffix: str + ) -> Union[str, os.PathLike]: + if not isinstance(path, Path): + path = Path(path) + return path / suffix + + def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: + if not isinstance(path, Path): + path = Path(path) + return path + + def rename( + self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] + ) -> None: + if not isinstance(path, Path): + path = Path(path) + + path.rename(cast(Path, new_path)) + + def mkdir(self, path: Union[str, os.PathLike]) -> None: + if not isinstance(path, Path): + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + if isinstance(checkpoint_id, Path): + return True + + if "://" in str(checkpoint_id): + return False + + for p in Path(checkpoint_id).parents: + if p.exists() and os.access(str(p), os.W_OK): + return True + + return False + + def exists(self, path: Union[str, os.PathLike]) -> bool: + if not isinstance(path, Path): + path = Path(path) + return path.exists() + + def rm_file(self, path: Union[str, os.PathLike]) -> None: + if not isinstance(path, Path): + path = Path(path) + path.unlink() + + def ls(self, path: Union[str, os.PathLike]) -> list[str]: + if not isinstance(path, Path): + path = Path(path) + return [str(p) for p in path.iterdir()] + + +class _FileSystemWriter(StorageWriter): + """ + Basic implementation of StorageWriter using file IO. + + This implementation makes the following assumptions and simplifications: + + * The checkpoint path is an empty or non-existing directory. + * File creation is atomic + + The checkpoint consist of one file per write request plus + a `.metadata` file with the serialized metadata. + + """ + + def __init__( + self, + path: Union[str, os.PathLike], + single_file_per_rank: bool = True, + sync_files: bool = True, + thread_count: int = 1, + per_thread_copy_ahead: int = 10_000_000, + overwrite: bool = True, + _extensions: Optional[Sequence[StreamTransformExtension]] = None, + serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE, + *args: Any, + **kwargs: Any, + ) -> None: + """ + Initialize the writer pointing to `path`. + + Args: + path: directory where the checkpoint will be written to. + single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. + sync_files : force files to be synced to permanent storage. Default to True. + thread_count: Number of IO threads to use to write. Default to 1. + per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. + overwrite: Whether to allow overwriting existing checkpoints. Defaults to True. + _extensions: Extensions to apply to output streams (EXPERIMENTAL) + + N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. + """ + super().__init__() + self.fs = FileSystem() + self.path = self.fs.init_path(path) + self.single_file_per_rank = single_file_per_rank + self.sync_files = sync_files + self.thread_count = thread_count + self.per_thread_copy_ahead = per_thread_copy_ahead + self.save_id = _generate_uuid() + self.overwrite = overwrite + self.transforms = _StorageWriterTransforms(_extensions) + self.serialization_format = serialization_format + self.rank: Optional[int] = None + self.use_collectives: bool = True + + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + if checkpoint_id: + self.path = self.fs.init_path(checkpoint_id) + self.save_id = _generate_uuid() + + def set_up_storage_writer( + self, is_coordinator: bool, *args: Any, **kwargs: Any + ) -> None: + self.rank = kwargs.get("rank") + self.use_collectives = kwargs.get("use_collectives", True) + + def _metadata_exists(self) -> bool: + if self.use_collectives: + # A global checkpoint metadata file + metadata_path = self._get_metadata_path(rank=None) + else: + # A rank 0 specific metadata file if every rank has written its own metadata + # Just looking for lowest rank metadata file is sufficient + metadata_path = self._get_metadata_path(rank=0) + + return self.fs.exists(metadata_path) + + def prepare_local_plan(self, plan: SavePlan) -> SavePlan: + self.fs.mkdir(self.path) + if self._metadata_exists(): + if self.overwrite: + warnings.warn( + f"Detected an existing checkpoint in {self.path}, overwriting since {self.overwrite=}." + " Past version 2.5 of PyTorch, `overwrite` will default to False. Set this variable to True to" + " maintain this functionality or False to raise when an existing checkpoint is found.", + stacklevel=2, + ) + else: + raise RuntimeError(f"Checkpoint already exists and {self.overwrite=}.") + + if self.rank is not None and not self.use_collectives: + plan = dataclasses.replace( + plan, storage_data=_StoragePrefix(f"__{self.rank}_") + ) + + return plan + + def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]: + new_plans = [ + dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_")) + if plan.storage_data is None + else plan + for i, plan in enumerate(plans) + ] + return new_plans + + def write_data( + self, + plan: SavePlan, + planner: SavePlanner, + ) -> Future[list[WriteResult]]: + storage_plan: _StoragePrefix = plan.storage_data + file_count = 0 + + def gen_file(): + nonlocal file_count + file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}" + file_count += 1 + return file_name + + file_queue: queue.Queue = queue.Queue() + if self.single_file_per_rank: + for bucket in _split_by_size_and_type(self.thread_count, plan.items): + file_name = gen_file() + path = self.fs.concat_path(self.path, file_name) + file_queue.put((path, file_name, bucket)) + else: + for item in plan.items: + file_name = gen_file() + path = self.fs.concat_path(self.path, file_name) + file_queue.put((path, file_name, [item])) + + return self._write_data(planner, file_queue) + + def _write_data( + self, + planner: SavePlanner, + file_queue: queue.Queue, + ) -> Future[list[WriteResult]]: + result_queue: queue.Queue = queue.Queue() + + threads = [] + for _ in range(1, self.thread_count): + t = threading.Thread( + target=_write_files_from_queue, + args=( + self.fs.create_stream, + file_queue, + result_queue, + planner, + self.transforms, + self.per_thread_copy_ahead, + self.sync_files, + self.thread_count, + self.serialization_format, + ), + ) + t.start() + threads.append(t) + + _write_files_from_queue( + create_stream=self.fs.create_stream, + file_queue=file_queue, + result_queue=result_queue, + planner=planner, + transforms=self.transforms, + inflight_threshhold=self.per_thread_copy_ahead, + use_fsync=self.sync_files, + thread_count=self.thread_count, + serialization_format=self.serialization_format, + ) + + for t in threads: + t.join() + + res = [] + try: + while True: + res += result_queue.get_nowait() + except queue.Empty: + fut: Future[list[WriteResult]] = Future() + fut.set_result(res) + return fut + + def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: + metadata = dataclasses.replace(metadata, version=CURRENT_DCP_VERSION) + + storage_md = {} + for wr_list in results: + storage_md.update({wr.index: wr.storage_data for wr in wr_list}) + metadata.storage_data = storage_md + + metadata.storage_meta = self.storage_meta() + tmp_filename = ( + f"__{self.rank}{_metadata_fn}.tmp" + if not self.use_collectives and self.rank is not None + else f"{_metadata_fn}.tmp" + ) + tmp_path = cast(Path, self.fs.concat_path(self.path, tmp_filename)) + with self.fs.create_stream(tmp_path, "wb") as metadata_file: + pickle.dump(metadata, metadata_file) + if self.sync_files: + try: + os.fsync(metadata_file.fileno()) + except (AttributeError, UnsupportedOperation): + os.sync() + + # delete in-case other checkpoints were present. + if not self.use_collectives and self.rank is not None: + metadata_path = self._get_metadata_path(self.rank) + else: + metadata_path = self._get_metadata_path() + + if self.fs.exists(metadata_path): + self.fs.rm_file(metadata_path) + + self.fs.rename(tmp_path, metadata_path) + + def storage_meta(self) -> Optional[StorageMeta]: + return StorageMeta(checkpoint_id=self.checkpoint_id, save_id=self.save_id) + + def _get_metadata_path(self, rank: Optional[int] = None) -> os.PathLike: + filename = f"{_metadata_fn}" if rank is None else f"__{rank}{_metadata_fn}" + return cast(Path, self.fs.concat_path(self.path, filename)) + + @property + def checkpoint_id(self) -> Union[str, os.PathLike]: + """ + return the checkpoint_id that will be used to save the checkpoint. + """ + return self.path + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + return FileSystem.validate_checkpoint_id(checkpoint_id) + + +class _StorageReaderTransforms: + """ + This is experimental, and will likely move elsewhere in the + future. It lives here to minimize changes while we are still + learning and gathering feedback. + """ + + def __init__(self, extension_registry: Optional[ExtensionRegistry] = None) -> None: + self.extension_registry = ( + ExtensionRegistry() if extension_registry is None else extension_registry + ) + + def transform_load_stream( + self, + read_item: ReadItem, + transform_descriptors: Sequence[str], + raw_stream: IO[bytes], + ) -> IO[bytes]: + extensions = self.extension_registry.from_descriptor_list(transform_descriptors) + transform_from = raw_stream + for ex in extensions: + if isinstance(ex, StreamTransformExtension): + transform_from = ex.transform_from(transform_from) + return transform_from + + +class FileSystemReader(StorageReader): + def __init__( + self, + path: Union[str, os.PathLike], + _extension_registry: Optional[ExtensionRegistry] = None, # EXPERIMENTAL + ) -> None: + super().__init__() + self.fs = FileSystem() + self.path = self.fs.init_path(path) + self.storage_data: dict[Any, Any] = {} + self.load_id = _generate_uuid() + self.transforms = _StorageReaderTransforms(_extension_registry) + self.rank = None + self.use_collectives = True + + def _slice_file(self, file, sinfo: _StorageInfo) -> IO[bytes]: + return cast(IO[bytes], _create_file_view(file, sinfo.offset, sinfo.length)) + + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + self.storage_data = {} + if checkpoint_id: + self.path = self.fs.init_path(checkpoint_id) + self.load_id = _generate_uuid() + + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + # group requests by file + per_file: dict[str, list[ReadItem]] = {} + for read_item in plan.items: + item_md: _StorageInfo = self.storage_data[read_item.storage_index] + path = item_md.relative_path + per_file.setdefault(path, []).append(read_item) + + for relative_path, reqs in per_file.items(): + new_path = self.fs.concat_path(self.path, relative_path) + with self.fs.create_stream(new_path, "rb") as stream: + # TODO sort by offset and cache the reading + for req in reqs: + item_md = self.storage_data[req.storage_index] + file_slice = self._slice_file(stream, item_md) + transform_from = self.transforms.transform_load_stream( + req, + # This field wasn't present in older + # implementations so provide a fallback. + item_md.transform_descriptors or (), + file_slice, + ) + + if req.type == LoadItemType.BYTE_IO: + read_bytes = io.BytesIO(transform_from.read(-1)) + read_bytes.seek(0) + planner.load_bytes(req, read_bytes) + else: + if transform_from.seekable(): + seekable = transform_from + else: + # torch.load requires a seekable input, so read the transform + # stream now and store the output if needed + seekable = io.BytesIO(transform_from.read(-1)) + seekable.seek(0) + + tensor = cast( + Tensor, + torch.load( + seekable, + map_location="cpu", + weights_only=True, + ), + ) + tensor = narrow_tensor_by_index( + tensor, req.storage_offsets, req.lengths + ) + target_tensor = planner.resolve_tensor(req).detach() + + if target_tensor.size() != tensor.size(): + raise AssertionError( + f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" + ) + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + fut: Future = Future() + fut.set_result(None) + return fut + + def _get_metadata_path(self, rank: Optional[int] = None) -> os.PathLike: + filename = f"{_metadata_fn}" if rank is None else f"__{rank}{_metadata_fn}" + return cast(Path, self.fs.concat_path(self.path, filename)) + + # Implementing the abstract function in StorageReader + def read_metadata(self, *args: Any, **kwargs: Any) -> Metadata: + rank = kwargs.get("rank") + path = self._get_metadata_path(rank) + with self.fs.create_stream(path, "rb") as metadata_file: + metadata = pickle.load(metadata_file) + + if getattr(metadata, "storage_meta", None) is None: + metadata.storage_meta = StorageMeta() + metadata.storage_meta.load_id = self.load_id + + return metadata + + def set_up_storage_reader( + self, metadata: Metadata, is_coordinator: bool, *args: Any, **kwargs: Any + ) -> None: + self.storage_data = metadata.storage_data + self.rank = kwargs.get("rank") + self.use_collectives = kwargs.get("use_collectives", True) + if self.storage_data is None: + raise AssertionError("storage_data must not be None in metadata") + + def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: + return plan + + def prepare_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]: + return plans + + @property + def checkpoint_id(self) -> Union[str, os.PathLike]: + """ + return the checkpoint_id that will be used to load the checkpoint. + """ + return self.path + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + return FileSystem.validate_checkpoint_id(checkpoint_id) + + +class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager): + """ + Basic implementation of StorageWriter using file IO. + + This implementation makes the following assumptions and simplifications: + + * The checkpoint path is an empty or non-existing directory. + * File creation is atomic + + The checkpoint consist of one file per write request plus + a global `.metadata` file with the serialized metadata if rank coordination is enabled. + a rank local `__{rank}.metadata` file with the serialized metadata if rank coordination is NOT enabled. + + """ + + def __init__( + self, + path: Union[str, os.PathLike], + single_file_per_rank: bool = True, + sync_files: bool = True, + thread_count: int = 1, + per_thread_copy_ahead: int = 10_000_000, + cache_staged_state_dict: bool = False, + overwrite: bool = True, + _extensions: Optional[Sequence[StreamTransformExtension]] = None, + serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE, + ) -> None: + """ + Initialize the writer pointing to `path`. + + Args: + path: directory where the checkpoint will be written to. + single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. + sync_files : force files to be synced to permanent storage. Default to True. + thread_count: Number of IO threads to use to write. Default to 1. + per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. + cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency + at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation + that the stager is maintained and reused for multiple dcp.async_save calls. Default to False. + overwrite: Whether to allow overwriting existing checkpoints. Defaults to True. + _extensions: Extensions to apply to output streams (EXPERIMENTAL) + + N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. + """ + _FileSystemWriter.__init__( + self, + path=path, + single_file_per_rank=single_file_per_rank, + sync_files=sync_files, + thread_count=thread_count, + per_thread_copy_ahead=per_thread_copy_ahead, + overwrite=overwrite, + _extensions=_extensions, + serialization_format=serialization_format, + ) + BlockingAsyncStager.__init__( + self, + cache_staged_state_dict=cache_staged_state_dict, + ) + + def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + """Override of AsyncStager.stage""" + # in the async case, the state dict is already on CPU, so maintaining this + # buffer makes no sense + self.per_thread_copy_ahead = 0 + return super().stage(state_dict) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/format_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/format_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..912f983fe2a7ce9267ce74940d42f9bd2b3969ca --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/format_utils.py @@ -0,0 +1,292 @@ +# mypy: allow-untyped-defs +import argparse +import os +from enum import Enum +from typing import cast, Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed._shard._utils import narrow_tensor_by_index +from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter +from torch.distributed.checkpoint._nested_dict import flatten_state_dict +from torch.distributed.checkpoint.default_planner import ( + _EmptyStateDictLoadPlanner, + DefaultLoadPlanner, +) +from torch.distributed.checkpoint.metadata import ( + Metadata, + STATE_DICT_TYPE, + STORAGE_TYPES, + TensorProperties, + TensorStorageMetadata, +) +from torch.distributed.checkpoint.planner import LoadItemType, LoadPlan, LoadPlanner +from torch.distributed.checkpoint.planner_helpers import _create_chunk_list +from torch.distributed.checkpoint.state_dict_loader import _load_state_dict +from torch.distributed.checkpoint.state_dict_saver import _save_state_dict +from torch.distributed.checkpoint.storage import StorageReader +from torch.futures import Future + + +__all__ = [ + "dcp_to_torch_save", + "torch_save_to_dcp", + "BroadcastingTorchSaveReader", + "DynamicMetaLoadPlanner", +] + + +class BroadcastingTorchSaveReader(StorageReader): + """ + StorageReader for reading a Torch Save file. This reader will read the entire checkpoint + on the coordinator rank, and then broadcast and shard each tensor to all ranks. + + . N.B. Intended to be used with DynamicMetaLoadPlanner + + .. warning:: + Current implementation only supports loading Tensors. + + >>> # xdoctest: +SKIP("undefined vars") + >>> sd = {"mode": model} + >>> dcp.load( + >>> sd, + >>> storage_reader=BroadcastingTorchSaveReader(), + >>> planner=DynamicMetaLoadPlanner(), + >>> checkpoint_id="path_to_model.pt" + >>> ) + """ + + def __init__( + self, + checkpoint_id: Optional[Union[str, os.PathLike]] = None, + coordinator_rank: int = 0, + ) -> None: + self.checkpoint_id = checkpoint_id + self.coordinator_rank = coordinator_rank + + # pyrefly: ignore [bad-override] + def read_metadata(self) -> Metadata: + """Extends the default StorageReader to support building the metadata file""" + # Metadata is built in planner.set_up_planner, since we are not actually reading metadata from + # the disk + return Metadata(state_dict_metadata={}) + + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + """ + Reads torch save data on the coordinator rank, and broadcast afterwards + this incurrs a communication cost, but avoids having to load + the entire checkpoint on each rank, hopefully preventing OOM issues + """ + planner = cast(DefaultLoadPlanner, planner) + + # data is read in on the coordinator rank, and broadcast afterwards + # this incurs a communication cost, but it avoids having to load + # the entire checkpoint on each rank, hopefully preventing OOM issues + # TODO: read on each host, instead of only the coordinator + if self.is_coordinator: + if self.checkpoint_id is None: + raise AssertionError("checkpoint_id must be set before reading data") + torch_state_dict = torch.load( + self.checkpoint_id, map_location="cpu", weights_only=False + ) + if planner.flatten_state_dict: + torch_state_dict, _ = flatten_state_dict(torch_state_dict) + else: + torch_state_dict = None + + for req in plan.items: + if req.type == LoadItemType.BYTE_IO: + raise RuntimeError( + f"Non-tensor value identified at {req.storage_index.fqn}. " + f"At this time {type(self).__name__} only supports loading Tensors." + ) + + # Broadcast the tensor from the coordinator rank + if self.is_coordinator: + pg_device = dist.distributed_c10d._get_pg_default_device() + # pyrefly: ignore [unsupported-operation] + tensor = torch_state_dict[req.storage_index.fqn].to(pg_device) + else: + tensor = torch.empty_like(planner.state_dict[req.storage_index.fqn]) + + dist.broadcast(tensor, src=self.coordinator_rank, async_op=False) + + tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths) + target_tensor = planner.resolve_tensor(req).detach() + if not target_tensor.size() == tensor.size(): + raise AssertionError( + f"req {req.storage_index} mismatch sizes, " + f"{target_tensor.size()} vs {tensor.size()}" + ) + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + fut: Future = Future() + fut.set_result(None) + return fut + + # pyrefly: ignore [bad-override] + def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: + """Implementation of the StorageReader method""" + self.is_coordinator = is_coordinator + if self.is_coordinator: + if not dist.get_rank() == self.coordinator_rank: + raise AssertionError( + f"Coordinator rank mismatch: expected {self.coordinator_rank}, " + f"got {dist.get_rank()}" + ) + + if self.checkpoint_id is None: + raise AssertionError( + "checkpoint_id must be set before setting up storage reader" + ) + + def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: + """Implementation of the StorageReader method""" + return plan + + def prepare_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]: + """Implementation of the StorageReader method""" + return global_plan + + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + """Implementation of the StorageReader method""" + self.checkpoint_id = checkpoint_id + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + """Implementation of the StorageReader method""" + return os.path.isfile(checkpoint_id) + + +class DynamicMetaLoadPlanner(DefaultLoadPlanner): + """ + Extension of DefaultLoadPlanner, which creates a new Metadata object based on the passed in state dict, + avoiding the need to read metadata from disk. This is useful when reading formats which don't have a + metadata file, like Torch Save files. + + . N.B. Intended to be used with BroadcastingTorchSaveReader + + .. warning:: + Current implementation only supports loading Tensors. + + >>> # xdoctest: +SKIP("undefined vars") + >>> sd = {"mode": model} + >>> dcp.load( + >>> sd, + >>> storage_reader=BroadcastingTorchSaveReader(), + >>> planner=DynamicMetaLoadPlanner(), + >>> checkpoint_id="path_to_model.pt" + >>> ) + """ + + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + metadata: Optional[Metadata] = None, + is_coordinator: bool = False, + ) -> None: + """Setups of the planner, extnding default behavior by creating the Metadata object from the state dict""" + super().set_up_planner(state_dict, metadata, is_coordinator) + + state_dict_metadata: dict[str, STORAGE_TYPES] = {} + for key, tensor in self.state_dict.items(): + if not torch.is_tensor(tensor): + raise RuntimeError( + f"Non-tensor value identified at {key}. " + f"At this time {type(self).__name__} only supports loading Tensors." + ) + + state_dict_metadata[key] = TensorStorageMetadata( + TensorProperties(dtype=tensor.dtype), + tensor.size(), + _create_chunk_list(tensor), + ) + self.metadata = Metadata(state_dict_metadata=state_dict_metadata) + + +def dcp_to_torch_save( + dcp_checkpoint_dir: Union[str, os.PathLike], + torch_save_path: Union[str, os.PathLike], +): + """ + Given a directory containing a DCP checkpoint, this function will convert it into a + Torch save file. + + Args: + dcp_checkpoint_dir: Directory containing the DCP checkpoint. + torch_save_path: Filename to store the converted Torch save file. + + .. warning:: + To avoid OOM, it's recommended to only run this function on a single rank. + """ + sd: STATE_DICT_TYPE = {} + _load_state_dict( + sd, + storage_reader=FileSystemReader(dcp_checkpoint_dir), + planner=_EmptyStateDictLoadPlanner(), + no_dist=True, + ) + torch.save(sd, torch_save_path) + + +def torch_save_to_dcp( + torch_save_path: Union[str, os.PathLike], + dcp_checkpoint_dir: Union[str, os.PathLike], +): + """ + Given the location of a torch save file, converts it into a DCP checkpoint. + + Args: + torch_save_path: Filename of the Torch save file. + dcp_checkpoint_dir: Directory to store the DCP checkpoint. + + .. warning:: + To avoid OOM, it's recommended to only run this function on a single rank. + """ + + state_dict = torch.load(torch_save_path, weights_only=False) + # we don't need stateful behavior here because the expectation is anything loaded by + # torch.load would not contain stateful objects. + _save_state_dict( + state_dict, storage_writer=FileSystemWriter(dcp_checkpoint_dir), no_dist=True + ) + + +if __name__ == "__main__": + + class FormatMode(Enum): + TORCH_TO_DCP = "torch_to_dcp" + DCP_TO_TORCH = "dcp_to_torch" + + # Parse command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument( + "mode", + type=str, + help="Conversion mode", + choices=[m.value for m in FormatMode], + default=FormatMode.TORCH_TO_DCP, + ) + parser.add_argument("src", type=str, help="Path to the source model") + parser.add_argument("dst", type=str, help="Path to the destination model") + args = parser.parse_args() + + print( + f"Converting checkpoint from {args.src} to {args.dst} using method: '{args.mode}'" + ) + checkpoint_missing_warning = ( + f"No checkpoint found at {args.src}. Skipping conversion." + ) + if args.mode == FormatMode.TORCH_TO_DCP.value: + if os.path.isfile(args.src): + torch_save_to_dcp(args.src, args.dst) + else: + print(checkpoint_missing_warning) + elif args.mode == FormatMode.DCP_TO_TORCH.value: + if os.path.isdir(args.src): + dcp_to_torch_save(args.src, args.dst) + else: + print(checkpoint_missing_warning) + else: + raise ValueError(f"Unknown conversion mode: {args.mode}") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/hf_storage.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/hf_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..52f9209da0ec58826cfa3c445e2b2070c5dee60f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/hf_storage.py @@ -0,0 +1,391 @@ +# mypy: allow-untyped-defs +import dataclasses +import json +import logging +import queue +import threading +from typing import Any, Optional + +import torch +from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter +from torch.distributed.checkpoint._consolidate_hf_safetensors import ( + consolidate_safetensors_files, +) +from torch.distributed.checkpoint._hf_utils import ( + _gen_file_name, + _HFStorageInfo, + _metadata_fn, + CUSTOM_METADATA_KEY, + SAVED_OFFSETS_KEY, + SHARDED_DIR_NAME, + SUFFIX, +) +from torch.distributed.checkpoint.filesystem import SerializationFormat +from torch.distributed.checkpoint.metadata import ( + ChunkStorageMetadata, + Metadata, + MetadataIndex, + StorageMeta, + TensorProperties, + TensorStorageMetadata, +) +from torch.distributed.checkpoint.planner import ( + LoadPlan, + LoadPlanner, + ReadItem, + SavePlan, + SavePlanner, + WriteItem, +) +from torch.distributed.checkpoint.storage import WriteResult +from torch.futures import Future + + +logger: logging.Logger = logging.getLogger(__name__) + +__all__ = ["HuggingFaceStorageWriter", "HuggingFaceStorageReader"] + + +class HuggingFaceStorageWriter(FileSystemWriter): + """ + A writer that writes to storage in the huggingface safetensors format. + """ + + def __init__( + self, + path: str, + fqn_to_index_mapping: Optional[dict[str, int]] = None, + thread_count: int = 1, + save_distributed: bool = False, + enable_consolidation: bool = False, + thread_count_consolidation: int = 1, + ) -> None: + """ + Initialize the huggingface writer pointing to path. + + Args: + path: directory where the checkpoint will be read from. + fqn_to_index_mapping: A mapping from tensor FQN to the index of the file that the tensor should be written to. + Indices are from 1 to N, where N is the number of files. If not provided, + the tensors will be written to a single file. If none, then all the tensors on the + same rank will be written to the same file. + thread_count: Number of threads to use to write distributed checkpoint. Default to 1. + save_distributed: If True, save the checkpoint using distributed APIs where every rank saves its own shard. + Default is False which assumes rank-0 checkpointing of the full state_dict. + enable_consolidation: If True, consolidate the sharded checkpoint after saving. The sharded tensors will be + saved to path/sharded and the full tensors will be saved to path. Default to False. + thread_count_consolidation: Number of threads to use for parallel processing of saving data + to consolidated output files. Default to 1. + """ + + super().__init__( + path=path, + serialization_format=SerializationFormat.SAFETENSORS, + thread_count=thread_count, + ) + self.fqn_to_index_mapping: Optional[dict[str, int]] = fqn_to_index_mapping + self.save_distributed: bool = save_distributed + self.enable_consolidation: bool = enable_consolidation + self.consolidated_output_path: Optional[str] = None + if self.enable_consolidation: + self.consolidated_output_path = str(self.path) + self.path = self.fs.concat_path(self.path, SHARDED_DIR_NAME) + self.thread_count_consolidation = thread_count_consolidation + + def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]: + new_plans = [] + for i, plan in enumerate(plans, start=1): + storage_data: dict[str, Any] = {} + if self.fqn_to_index_mapping is not None: + storage_data["fqn_to_index_mapping"] = self.fqn_to_index_mapping + if self.save_distributed: + storage_data["shard_index"] = i + + new_plans.append(dataclasses.replace(plan, storage_data=storage_data)) + + return new_plans + + def write_data( + self, + plan: SavePlan, + planner: SavePlanner, + ) -> Future[list[WriteResult]]: + if len(plan.items) == 0: + fut: Future = Future() + fut.set_result([]) + return fut + + # storage_plan is a map from key to file index + storage_data: dict[str, Any] = plan.storage_data + storage_plan: Optional[dict[str, int]] = None + shard_index: Optional[int] = None + if "fqn_to_index_mapping" in storage_data: + storage_plan = storage_data["fqn_to_index_mapping"] + if "shard_index" in storage_data: + shard_index = storage_data["shard_index"] + + buckets = self._split_by_storage_plan(storage_plan, plan.items) + highest_index = max(storage_plan.values()) if storage_plan is not None else 1 + + file_queue: queue.Queue = queue.Queue() + for file_index, write_items in buckets.items(): + file_name = _gen_file_name(file_index, highest_index, shard_index) + file_queue.put( + (self.fs.concat_path(self.path, file_name), file_name, write_items) + ) + + return super()._write_data(planner, file_queue) + + def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: + if self.save_distributed and not self.enable_consolidation: + # if we are saving distributed, without consolidating, + # then we have no metadata to write because a metadata + # file with fqn to file mapping doesn't make sense + # in this case, because fqns will be in multiple files + logger.info("Not consolidating sharded checkpoint in finish step.") + return + if self.save_distributed: + fqn_to_index_mapping: dict[str, int] = ( + self.fqn_to_index_mapping + if self.fqn_to_index_mapping is not None + else dict.fromkeys(metadata.state_dict_metadata.keys(), 1) + ) + + return consolidate_safetensors_files( + input_dir=str(self.path), + output_dir=self.consolidated_output_path, # type: ignore[arg-type] + num_threads=self.thread_count_consolidation, + fqn_to_index_mapping=fqn_to_index_mapping, + ) + + # writing a model.index.safetensors.json file with fqn to file mapping + # for the rank-0 checkpointing case + metadata_to_write = {} + storage_md = {} + total_size = 0 + for wr_list in results: + storage_md.update( + {wr.index.fqn: wr.storage_data.relative_path for wr in wr_list} + ) + total_size += sum([wr.storage_data.length for wr in wr_list]) + metadata_to_write["metadata"] = {"total_size": total_size} + metadata_to_write["weight_map"] = storage_md + + metadata_path = self.fs.concat_path(self.path, f"{_metadata_fn}") + with self.fs.create_stream(metadata_path, "w") as metadata_file: + json.dump(metadata_to_write, metadata_file, indent=2) + + def _split_by_storage_plan( + self, storage_plan: Optional[dict[str, int]], items: list[WriteItem] + ) -> dict[int, list[WriteItem]]: + # storage_plan is a map from key to index + if storage_plan is None: + return {1: items} + + buckets = {} + for item in items: + key = item.index.fqn + + idx = storage_plan[key] + if idx not in buckets: + buckets[idx] = [item] + else: + buckets[idx].append(item) + + return buckets + + @property + def metadata_path(self) -> str: + return _metadata_fn + + +class HuggingFaceStorageReader(FileSystemReader): + """ + A reader that reads a checkpoint in the huggingface safetensors format. + """ + + def __init__(self, path: str, thread_count: int = 1) -> None: + """ + Initialize the huggingface reader pointing to path. + + Args: + path: directory where the checkpoint will be read from. + thread_count: Number of threads to use to read distributed checkpoint. Default to 1. + """ + + super().__init__(path=path) + self.thread_count = thread_count + + def _process_read_request(self, f, req: ReadItem, planner: LoadPlanner) -> None: + """Helper function to process a single read request.""" + # Create slices for each dimension based on offsets and lengths + slices = tuple( + slice(offset, offset + length) + for offset, length in zip(req.storage_offsets, req.lengths) + ) + tensor = f.get_slice(req.storage_index.fqn)[slices] + target_tensor = planner.resolve_tensor(req).detach() + + if target_tensor.size() != tensor.size(): + raise AssertionError( + f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" + ) + + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + def _read_files_from_queue( + self, + file_queue: queue.Queue, + result_queue: queue.Queue, + planner: LoadPlanner, + ) -> None: + from safetensors import safe_open # type: ignore[import] + + try: + while True: + file_name, reqs = file_queue.get_nowait() + with safe_open(filename=file_name, framework="pt") as f: + for req in reqs: + self._process_read_request(f, req, planner) + result_queue.put(True) # Signal that this file has been processed + except queue.Empty: + pass + + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + from safetensors import safe_open # type: ignore[import] + + per_file: dict[str, list[ReadItem]] = {} + + for read_item in plan.items: + item_md: _HFStorageInfo = self.storage_data[read_item.storage_index] + file_name = item_md.relative_path + per_file.setdefault(file_name, []).append(read_item) + + if self.thread_count <= 1 or len(per_file) <= 1: + for file_name, reqs in per_file.items(): + with safe_open(filename=file_name, framework="pt") as f: + for req in reqs: + self._process_read_request(f, req, planner) + else: + # Use parallel implementation with thread pool + file_queue: queue.Queue = queue.Queue() + result_queue: queue.Queue = queue.Queue() + + # Fill the queue with files to process + for file_name, reqs in per_file.items(): + file_queue.put((file_name, reqs)) + + # Create and start worker threads + threads = [] + num_threads = min(self.thread_count, len(per_file)) + for _ in range(num_threads): + t = threading.Thread( + target=self._read_files_from_queue, + args=(file_queue, result_queue, planner), + ) + t.start() + threads.append(t) + + # Wait for all threads to complete + for t in threads: + t.join() + + # Check if all files were processed + processed_count = 0 + try: + while True: + result_queue.get_nowait() + processed_count += 1 + except queue.Empty: + pass + + if processed_count != len(per_file): + raise AssertionError( + f"Not all files were processed: {processed_count} out of {len(per_file)}" + ) + + fut: Future = Future() + fut.set_result(None) + return fut + + # pyrefly: ignore [bad-override] + def read_metadata(self) -> Metadata: + from safetensors import safe_open # type: ignore[import] + from safetensors.torch import _getdtype # type: ignore[import] + + state_dict_metadata: dict[str, TensorStorageMetadata] = {} + storage_data: dict[MetadataIndex, _HFStorageInfo] = {} + + safetensors_files = [] + for file in self.fs.ls(self.path): + if file.endswith(SUFFIX): + safetensors_files.append(file) + + for safetensor_file in safetensors_files: + with safe_open(safetensor_file, framework="pt") as f: + keys = f.keys() + extra_metadata = f.metadata() + + dcp_sharding_info = None + if extra_metadata and extra_metadata.get(CUSTOM_METADATA_KEY): + dcp_sharding_info = json.loads( + extra_metadata.get(CUSTOM_METADATA_KEY) + ) + + for key in keys: + shape = f.get_slice(key).get_shape() + dtype = f.get_slice(key).get_dtype() + # construct state_dict_metadata + if dcp_sharding_info is not None: + offset = dcp_sharding_info[key][SAVED_OFFSETS_KEY] + else: + offset = [0] * len(shape) + + if key not in state_dict_metadata: + state_dict_metadata[key] = TensorStorageMetadata( + properties=TensorProperties(dtype=_getdtype(dtype)), + size=torch.Size( + [saved + offset for saved, offset in zip(shape, offset)] + ), + chunks=[ + ChunkStorageMetadata( + offsets=torch.Size(offset), + sizes=torch.Size(shape), + ) + ], + ) + else: + state_dict_metadata[key].chunks.append( + ChunkStorageMetadata( + torch.Size(offset), sizes=torch.Size(shape) + ) + ) + size = list(state_dict_metadata[key].size) + for i in range(len(size)): + size[i] = max(size[i], shape[i] + offset[i]) + state_dict_metadata[key].size = torch.Size(size) + + # construct storage data + if dcp_sharding_info is not None: + metadata_index = MetadataIndex( + fqn=key, offset=dcp_sharding_info[key][SAVED_OFFSETS_KEY] + ) + else: + metadata_index = MetadataIndex(fqn=key, offset=[0] * len(shape)) + storage_data[metadata_index] = _HFStorageInfo( + relative_path=safetensor_file, + shape=torch.Size(shape), + dtype=_getdtype(dtype), + ) + + metadata = Metadata( + state_dict_metadata=state_dict_metadata, # type: ignore[arg-type] + storage_data=storage_data, + ) + + if getattr(metadata, "storage_meta", None) is None: + metadata.storage_meta = StorageMeta() + metadata.storage_meta.load_id = self.load_id # type: ignore[union-attr] + + return metadata diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/logger.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..677cac0339cb9fab60c77f75da04bc7ef06504f3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/logger.py @@ -0,0 +1,121 @@ +# mypy: allow-untyped-defs +import functools +import logging +import time +from collections.abc import Callable +from typing import Any, TypeVar +from typing_extensions import ParamSpec +from uuid import uuid4 + +import torch.distributed.c10d_logger as c10d_logger +from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME + + +logger = logging.getLogger() + + +__all__: list[str] = [] + +# pyrefly: ignore [unknown-name] +global _dcp_logger +_dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME) + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +def _msg_dict_from_dcp_method_args(*args, **kwargs) -> dict[str, Any]: + """ + Extracts log data from dcp method args + """ + msg_dict = {} + + # checkpoint ID can be passed in through the serializer or through the checkpoint id directly + storage_writer = kwargs.get("storage_writer") + storage_reader = kwargs.get("storage_reader") + planner = kwargs.get("planner") + + checkpoint_id = kwargs.get("checkpoint_id") + if not checkpoint_id and (serializer := storage_writer or storage_reader): + checkpoint_id = getattr(serializer, "checkpoint_id", None) + + msg_dict["checkpoint_id"] = ( + # pyrefly: ignore [unsupported-operation] + str(checkpoint_id) if checkpoint_id is not None else checkpoint_id + ) + + # Uniquely identify a _dcp_method_logger wrapped function call. + msg_dict["uuid"] = str(uuid4().int) + + if storage_writer: + msg_dict["storage_writer"] = storage_writer.__class__.__name__ + + if storage_reader: + msg_dict["storage_reader"] = storage_reader.__class__.__name__ + + if planner: + msg_dict["planner"] = planner.__class__.__name__ + + return msg_dict + + +def _get_msg_dict(func_name, *args, **kwargs) -> dict[str, Any]: + msg_dict = _msg_dict_from_dcp_method_args(*args, **kwargs) + msg_dict.update(c10d_logger._get_msg_dict(func_name, *args, **kwargs)) + + return msg_dict + + +def _dcp_method_logger( + log_exceptions: bool = False, **wrapper_kwargs: Any +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: # pyre-ignore + """This method decorator logs the start, end, and exception of wrapped events.""" + + def decorator(func: Callable[_P, _T]): + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + msg_dict = _get_msg_dict( + func.__name__, *args, **{**wrapper_kwargs, **kwargs} + ) + + # log start event + msg_dict["event"] = "start" + t0 = time.time_ns() + msg_dict["time"] = t0 + msg_dict["log_exceptions"] = log_exceptions + _dcp_logger.debug(msg_dict) + + # exceptions + try: + result = func(*args, **kwargs) + except BaseException as error: + if log_exceptions: + msg_dict["event"] = "exception" + msg_dict["error"] = f"{error}" + msg_dict["time"] = time.time_ns() + _dcp_logger.error(msg_dict) + raise + + # end event + msg_dict["event"] = "end" + t1 = time.time_ns() + msg_dict["time"] = time.time_ns() + msg_dict["times_spent"] = t1 - t0 + _dcp_logger.debug(msg_dict) + + return result + + return wrapper + + return decorator + + +def _init_logger(rank: int): + logger.setLevel(logging.INFO) + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + formatter = logging.Formatter( + f"[{rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + ch.setFormatter(formatter) + logger.addHandler(ch) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/logging_handlers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/logging_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..99c3ee4156ce340e37a2723106df5ea64b19170d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/logging_handlers.py @@ -0,0 +1,14 @@ +import logging + +from torch.distributed.logging_handlers import _log_handlers + + +__all__: list[str] = [] + +DCP_LOGGER_NAME = "dcp_logger" + +_log_handlers.update( + { + DCP_LOGGER_NAME: logging.NullHandler(), + } +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/metadata.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..36864b6bf3ad60778ad008fcbb4c10002933c4c6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/metadata.py @@ -0,0 +1,185 @@ +# mypy: allow-untyped-defs +import os +from collections.abc import Sequence +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Optional, Union + +import torch +from torch.distributed.checkpoint.stateful import StatefulT + + +__all__ = [ + "ChunkStorageMetadata", + "TensorStorageMetadata", + "BytesStorageMetadata", + "Metadata", + "MetadataIndex", + "TensorProperties", + "StorageMeta", +] + + +@dataclass +class ChunkStorageMetadata: + """ + Each chunk is expected to have the same properties of the TensorStorageMetadata + that includes it. + """ + + offsets: torch.Size + sizes: torch.Size + + +class _MEM_FORMAT_ENCODING(Enum): + """Describe the memory format of a tensor.""" + + TORCH_CONTIGUOUS_FORMAT = 0 + TORCH_CHANNELS_LAST = 1 + TORCH_PRESERVE_FORMAT = 2 + + +@dataclass +class TensorProperties: + """Properties used to create :class:`Tensor`""" + + # Regular tensor fields + dtype: torch.dtype = field(default_factory=torch.get_default_dtype) + # This field is deprecated. + layout: torch.layout = field(default=torch.strided) + # This field is deprecated. + requires_grad: bool = False + # This field is deprecated. + memory_format: torch.memory_format = field(default=torch.contiguous_format) + # This field is deprecated. + pin_memory: bool = False + + def __getstate__(self): + # Since torch.memory_format cannot be pickled! + memory_format = self.memory_format + if memory_format == torch.contiguous_format: + mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT + elif memory_format == torch.channels_last: + mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST + elif memory_format == torch.preserve_format: + mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT + else: + raise RuntimeError(f"Invalid torch.memory_format: {memory_format}") + + return ( + self.dtype, + self.layout, + self.requires_grad, + mem_format_encoding, + self.pin_memory, + ) + + def __setstate__( + self, + state, + ): + ( + self.dtype, + self.layout, + self.requires_grad, + mem_format_encoding, + self.pin_memory, + ) = state + + if mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT: + memory_format = torch.contiguous_format + elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST: + memory_format = torch.channels_last + elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT: + memory_format = torch.preserve_format + else: + raise RuntimeError( + f"Invalid torch.memory_format encoding: {mem_format_encoding}" + ) + + self.memory_format = memory_format + + @staticmethod + def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties": + return TensorProperties( + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + memory_format=torch.contiguous_format, + pin_memory=tensor.is_pinned(), + ) + + +@dataclass +class TensorStorageMetadata: + properties: TensorProperties + size: torch.Size + chunks: list[ChunkStorageMetadata] + + +@dataclass +class BytesStorageMetadata: + pass + + +STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata] +STATE_DICT_TYPE = dict[str, Union[StatefulT, Any]] + + +@dataclass +class StorageMeta: + checkpoint_id: Union[str, os.PathLike, None] = None + save_id: Optional[str] = None + load_id: Optional[str] = None + modules: list[str] = field(default_factory=list) + + +@dataclass +class Metadata: + """This class represents the metadata of the checkpoint.""" + + # Keys are the same from the `state_dict` used. + state_dict_metadata: dict[str, STORAGE_TYPES] + # It is the responsibility of the planner and storage plugins to ensure + # backward compatibility of the planner_data and storage_data. DCP will + # also ensure the backward compatibility of the metadata in this file and + # the metadata of the built-in planner and storage plugins. + planner_data: Any = None + storage_data: Any = None + storage_meta: Optional[StorageMeta] = None + version: Optional[str] = None + + +@dataclass(frozen=True) +class MetadataIndex: + """This class represents a lookup key for items in a state dict or Metadata.""" + + fqn: str + """Fully Qualified Name of the object""" + + offset: Optional[torch.Size] = None + """If the object is a tensor, offset into the tensor we're looking for""" + + index: Optional[int] = field(hash=False, compare=False, default=None) + """ + Index hint when searching for tensor chunk to speedup lookups (optional) + + A common representation of a sharded tensor is as a list of chunks so to + find the index in such a list you need to linear search it. + + When constructing an instance of MetadataIndex that points to that list, + one can provide the index as a hint and it will be probed first before + the linear search and thus making it significantly faster. + """ + + def __init__( + self, + fqn: str, + offset: Optional[Sequence[int]] = None, + index: Optional[int] = None, + ): + # We must use object.__setattr__ due to frozen=True + object.__setattr__(self, "fqn", fqn) + object.__setattr__(self, "index", index) + if offset is not None: + object.__setattr__(self, "offset", torch.Size(offset)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/optimizer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..343497da0aa21f35a081a7ca9063d4dcbbf41ccc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/optimizer.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import dataclasses +from collections.abc import Sequence +from typing import cast, Optional, Union + +import torch +import torch.distributed as dist +from torch._utils import _get_device_module +from torch.distributed._shard.sharded_tensor.api import ShardedTensor +from torch.distributed._shard.sharded_tensor.metadata import ( + TensorProperties as ShardTensorProperties, +) +from torch.distributed._shard.sharded_tensor.shard import Shard +from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec +from torch.distributed.checkpoint._nested_dict import unflatten_state_dict +from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner +from torch.distributed.checkpoint.metadata import ( + BytesStorageMetadata, + ChunkStorageMetadata, + Metadata, + MetadataIndex, + STATE_DICT_TYPE, + TensorProperties, + TensorStorageMetadata, +) +from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner +from torch.distributed.checkpoint.planner_helpers import ( + _create_read_items, + create_read_items_for_chunk_list, +) + +# pyrefly: ignore [deprecated] +from torch.distributed.checkpoint.state_dict_loader import load_state_dict +from torch.distributed.checkpoint.storage import StorageReader +from torch.distributed.checkpoint.utils import ( + _element_wise_add, + _element_wise_sub, + _normalize_device_info, +) +from torch.distributed.distributed_c10d import _get_default_group +from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor +from torch.distributed.remote_device import _remote_device +from torch.distributed.tensor import DTensor + + +STATE_DICT_2D_LAYOUT = dict[str, tuple[Optional[Sequence[int]], Sequence[int]]] + + +# TODO: Update docstrings for optimizer.py +__all__ = [ + "load_sharded_optimizer_state_dict", +] + + +def _gen_rank_device(global_rank: int, device_type: str = "cuda") -> str: + if device_type == "cpu": + return "cpu" + device_module = _get_device_module(device_type) + if device_module.is_available(): + return _normalize_device_info( + device_type, global_rank % device_module.device_count() + ) + return "cpu" + + +def _create_colwise_spec( + pg: Optional[dist.ProcessGroup] = None, +) -> ChunkShardingSpec: + pg_device_type = dist.distributed_c10d._get_pg_default_device(pg).type + if pg is None: + placements = [ + f"rank:{idx}/{_gen_rank_device(idx, pg_device_type)}" + for idx in range(dist.get_world_size()) + ] + else: + placements = [ + f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx), pg_device_type)}" + for idx in range(pg.size()) + ] + return ChunkShardingSpec( + dim=0, + placements=cast(list[Union[_remote_device, str]], placements), + ) + + +def _is_nested_tensor(val: torch.Tensor) -> bool: + if type(val) is ShardedTensor: + if len(val.local_shards()) == 0: + return False + if type(val.local_shards()[0].tensor) is ShardedTensor: + return True + if type(val.local_shards()[0].tensor) is DTensor: + raise ValueError("Cannot handle DTensor nested inside ShardedTensor") + elif type(val) is DTensor and ( + type(val._local_tensor) is DTensor or type(val._local_tensor) is ShardedTensor + ): + raise ValueError("Cannot handle nested DTensor") + return False + + +def _alloc_tensor( + props: TensorProperties, size: Sequence[int], device_type: str = "cuda" +) -> torch.Tensor: + if device_type == "cpu": + device = cast(torch.device, _get_device_module(device_type).current_device()) + else: + device = torch.device( + device_type, _get_device_module(device_type).current_device() + ) + + return torch.empty( + size=size, + dtype=props.dtype, + layout=props.layout, + requires_grad=props.requires_grad, + pin_memory=props.pin_memory, + device=device, + ) + + +def _get_state_dict_2d_layout( + state_dict: STATE_DICT_TYPE, +) -> tuple[STATE_DICT_2D_LAYOUT, Optional[dist.ProcessGroup]]: + """ + Load the right TP slice of the optimizer state. + + This is not easy since the per-tensor slicing can't be inferred from checkpoint metadata. + We take advantage of the model state_dict producing a sliced ST to figure out what we need to load. + This is pretty fragile and it might be easier for FSDP to compute this info for us. + Returns a dictionary where keys are the same of the state_dict and the value is a tuple of + (offset, size) for the current rank TP slice. + N.B. The state_dict *MUST* come from FSDP.sharded_state_dict. + """ + specs: STATE_DICT_2D_LAYOUT = {} + dp_pg: Optional[dist.ProcessGroup] = None + for key, value in state_dict.items(): + specs[key] = (None, value.size()) + if _is_nested_tensor(value): + if not len(value.local_shards()) == 1: + raise AssertionError("Cannot handle ST with multiple shards") + if not isinstance(value, ShardedTensor): + raise AssertionError("Can only handle nested ShardedTensor") + shard = value.local_shards()[0] + specs[key] = ( + shard.metadata.shard_offsets, + shard.metadata.shard_sizes, + ) + dp_pg = shard.tensor._process_group # type: ignore[attr-defined] + + return ( + specs, + dp_pg, + ) + + +class _ReaderWithOffset(DefaultLoadPlanner): + translation: dict[MetadataIndex, MetadataIndex] + state_dict: STATE_DICT_TYPE + # pyrefly: ignore [bad-override] + metadata: Metadata + + def __init__(self, fqn_to_offset: dict[str, Sequence[int]]) -> None: + super().__init__() + self.fqn_to_offset = fqn_to_offset + self.metadata = Metadata({}) + self.state_dict = {} + self.translation = {} + + def create_local_plan(self) -> LoadPlan: + requests = [] + self.translation = {} + for fqn, obj in self.state_dict.items(): + md = self.metadata.state_dict_metadata[fqn] + if not isinstance(obj, ShardedTensor): + requests += _create_read_items(fqn, md, obj) + continue + + if fqn not in self.fqn_to_offset: + requests += _create_read_items(fqn, md, obj) + continue + + offset = self.fqn_to_offset[fqn] + + if not len(obj.local_shards()) == 1: + raise AssertionError("Expected exactly one local shard") + original_shard = obj.local_shards()[0] + local_chunks = [ + ChunkStorageMetadata( + offsets=torch.Size( + _element_wise_add(original_shard.metadata.shard_offsets, offset) + ), + sizes=torch.Size(original_shard.metadata.shard_sizes), + ) + ] + + reqs = create_read_items_for_chunk_list( + fqn, cast(TensorStorageMetadata, md), local_chunks + ) + # TODO: The ReadItems will have a displaced MetadataIndex, fix it. + # TODO: we should change _create_sharded_read_items to have more ergonomic API + for ri in reqs: + if ri.dest_index.offset is None: + raise AssertionError("dest_index.offset must not be None") + original_offset = _element_wise_sub(ri.dest_index.offset, offset) + original_index = dataclasses.replace( + ri.dest_index, offset=torch.Size(original_offset) + ) + self.translation[ri.dest_index] = original_index + + requests += reqs + return LoadPlan(requests) + + def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor: + return super().lookup_tensor(self.translation.get(index, index)) + + +def load_sharded_optimizer_state_dict( + model_state_dict: STATE_DICT_TYPE, + optimizer_key: str, + storage_reader: StorageReader, + planner: Optional[LoadPlanner] = None, +) -> STATE_DICT_TYPE: + """ + Load a state_dict in conjunction with FSDP sharded optimizer state. + + This is the current recommended way to checkpoint FSDP. + >>> # xdoctest: +SKIP + >>> import torch.distributed.checkpoint as dist_cp + >>> # Save + >>> model: torch.nn.Model + >>> optim_params = model.parameters() + >>> optim = torch.optim.SGD(optim_params, lr=0.01) + >>> # Save + >>> with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): + >>> state_dict = { + >>> "optimizer": FSDP.optim_state_dict(model, optim), + >>> "model": model.state_dict() + >>> } + >>> dist_cp.save_state_dict( + >>> state_dict=optim_state, + >>> storage_writer=dist_cp.FileSystemWriter("checkpoint"), + >>> planner=dist_cp.DefaultSavePlanner(), + >>> ) + >>> + >>> # Load + >>> with FSDP.state_dict_type(model_tp, StateDictType.SHARDED_STATE_DICT): + >>> model_state_dict = model_tp.state_dict() + >>> checkpoint = { + >>> "model": model_state_dict + >>> } + >>> dist_cp.load_state_dict( + >>> state_dict=checkpoint, + >>> storage_reader=dist_cp.FileSystemReader(checkpoint_file), + >>> planner=dist_cp.DefaultLoadPlanner(), + >>> ) + >>> model.load_state_dict(checkpoint["model_state"]) + >>> + >>> optim_state = dist_cp.load_sharded_optimizer_state_dict( + >>> model_state_dict, + >>> optimizer_key="optimizer", + >>> storage_reader=dist_cp.FileSystemReader("checkpoint"), + >>> ) + >>> + >>> flattened_osd = FSDP.optim_state_dict_to_load( + >>> model, optim, optim_state["optimizer"] + >>> ) + >>> + >>> optim.load_state_dict(flattened_osd) + """ + metadata = storage_reader.read_metadata() + + layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict) + dp_pg_device_type = dist.distributed_c10d._get_pg_default_device(dp_pg).type + device_module = _get_device_module(dp_pg_device_type) + + if dp_pg is None: + placements = [] + for i in range(dist.get_world_size()): + device_info = _normalize_device_info( + dp_pg_device_type, i % device_module.device_count() + ) + placements.append(f"rank:{i}/{device_info}") + sharding_spec = ChunkShardingSpec(dim=0, placements=placements) # type: ignore[arg-type] + else: + sharding_spec = _create_colwise_spec(dp_pg) + + # Create a state_dict for optimizer state + state_dict: STATE_DICT_TYPE = {} + + fqn_to_offset: dict[str, Sequence[int]] = {} + for key, value in metadata.state_dict_metadata.items(): + key_path = metadata.planner_data[key] + if key_path[0] != optimizer_key: + continue + + if isinstance(value, BytesStorageMetadata): + state_dict[key] = "" + continue + + # value: TensorStorageMetadata + if value.size.numel() == 1: + state_dict[key] = _alloc_tensor( + value.properties, value.size, dp_pg_device_type + ) + elif dp_pg is None: + state_dict[key] = _create_chunk_sharded_tensor( + _alloc_tensor(value.properties, value.size, dp_pg_device_type), + rank=dist.get_rank(), + world_size=dist.get_world_size(), + num_devices_per_node=device_module.device_count(), + pg=_get_default_group(), + ) + else: + spec_key = key_path[2] + alloc_size = layout_specs.get(spec_key, (None, value.size))[1] + + properties = ShardTensorProperties( + dtype=value.properties.dtype, + layout=value.properties.layout, + requires_grad=value.properties.requires_grad, + memory_format=value.properties.memory_format, + pin_memory=value.properties.pin_memory, + ) + + st_md = sharding_spec.build_metadata(torch.Size(alloc_size), properties) + local_shards = [] + current_rank = dist.get_rank(dp_pg) + for shard_md in st_md.shards_metadata: + if cast(_remote_device, shard_md.placement).rank() != current_rank: + continue + local_shards.append( + Shard( + tensor=_alloc_tensor( + value.properties, shard_md.shard_sizes, dp_pg_device_type + ), + metadata=shard_md, + ) + ) + + st = ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards, st_md, process_group=dp_pg + ) + + if spec_key in layout_specs and layout_specs[spec_key][0] is not None: + fqn_to_offset[key] = cast(Sequence[int], layout_specs[spec_key][0]) + + state_dict[key] = st + + # Whether we unflatten before or after doesn't matter + load_state_dict( + state_dict=state_dict, + storage_reader=storage_reader, + # FIXME the type of planner is wrong in load_state_dict + planner=_ReaderWithOffset(fqn_to_offset) if dp_pg is not None else planner, + ) + + state_dict = unflatten_state_dict(state_dict, metadata.planner_data) + + return state_dict diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/planner.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/planner.py new file mode 100644 index 0000000000000000000000000000000000000000..8c97dc0379b109dd3a9706176390720a88128851 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/planner.py @@ -0,0 +1,450 @@ +import abc +import io +import operator +from dataclasses import dataclass +from enum import auto, Enum +from functools import reduce +from typing import Any, Optional, Union + +import torch +from torch.distributed.checkpoint.metadata import ( + ChunkStorageMetadata, + Metadata, + MetadataIndex, + STATE_DICT_TYPE, + StorageMeta, + TensorProperties, +) + + +__all__ = [ + "WriteItemType", + "LoadItemType", + "BytesIOWriteData", + "TensorWriteData", + "WriteItem", + "ReadItem", + "SavePlan", + "LoadPlan", + "SavePlanner", + "LoadPlanner", +] + + +class WriteItemType(Enum): + TENSOR = auto() + SHARD = auto() + BYTE_IO = auto() + + +class LoadItemType(Enum): + TENSOR = auto() + BYTE_IO = auto() + + +@dataclass(frozen=True) +class BytesIOWriteData: + nbytes: int + + +@dataclass(frozen=True) +class TensorWriteData: + chunk: ChunkStorageMetadata + properties: TensorProperties + size: torch.Size + + +@dataclass(frozen=True) +class WriteItem: + """Dataclass which holds information about what needs to be written to storage.""" + + index: MetadataIndex + type: WriteItemType + + # Size of bytesIO data to be written. + bytes_io_data: Optional[BytesIOWriteData] = None + + # Value present if it's a tensor write + tensor_data: Optional[TensorWriteData] = None + + def tensor_storage_size(self) -> Optional[int]: + """ + Calculates the storage size of the underlying tensor, or None if this is not a tensor write. + + Returns: + Optional[int] storage size, in bytes of underlying tensor if any. + """ + if self.tensor_data is None: + return None + + numels = reduce(operator.mul, self.tensor_data.size, 1) + dtype_size = torch._utils._element_size(self.tensor_data.properties.dtype) + return numels * dtype_size + + +@dataclass(frozen=True) +class ReadItem: + # Read Item + type: LoadItemType + + # Index into the state_dict + dest_index: MetadataIndex + # Offsets into destination tensor + dest_offsets: torch.Size + + # Index into the checkpoint + storage_index: MetadataIndex + # Offset into the checkpoint data + storage_offsets: torch.Size + + # Size of the hypercube to copy + lengths: torch.Size + + +@dataclass(frozen=True) +class SavePlan: + items: list[WriteItem] + storage_data: Any = None + planner_data: Any = None + # This is used to indicate that the ranks should + # use the cached plans to write data instead. + usable: bool = True + + +@dataclass +class LoadPlan: + items: list[ReadItem] + storage_data: Any = None + planner_data: Any = None + + +class SavePlanner(abc.ABC): + """ + Abstract class defining the protocol used by save_state_dict to plan the save process. + + SavePlanners are stateful objects that can be used to customize the whole save process. + + SavePlanner acts as an access proxy to the state_dict, so any transformation done to it + will be visible to the whole process. + + A planner subclass can expect the following sequence of calls during save_state_dict: + + 1) set_up_planner - called on all ranks. + Signals the start of a checkpoint save. + + 2) create_local_plan - called on all ranks. + Process the state_dict and produces a `SavePlan` that will be sent for global planning. + + 3) create_global_plan - called on the coordinator rank only. + Takes the SavePlan from all ranks and make any global decision. + + 4) finish_plan - called on all ranks. + This gives each rank a chance to adjust to global planning decisions. + + 5) resolve_data - called multiple times on each rank + Lookups a value on the `state_dict` for the storage layer to write. + + Users are recommended to extend DefaultSavePlanner instead of this interface directly as + most changes can be expressed by changes in a single method. + + There are 3 usual patterns of extension: + + Rewriting state_dict. This is the simplest way to extend the save process as it + doesn't requite understanding the intrincacies of how SavePlan works: + + >>> # xdoctest: +SKIP("undefined vars") + >>> class RenamePlanner(DefaultSavePlanner): + >>> def set_up_planner( + >>> self, + >>> state_dict: STATE_DICT_TYPE, + >>> storage_meta: Optional[StorageMeta], + >>> is_coordinator: bool, + >>> ) -> None: + >>> # prefix all keys with `foo_`` + >>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator) + + Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted + + >>> # xdoctest: +SKIP("undefined vars") + >>> class FP16Planner(DefaultSavePlanner): + >>> def create_local_plan(self): + >>> plan = super().create_local_plan() + >>> for p in plan: + >>> if p.tensor_data is not None: + >>> p.tensor_data.properties.dtype = torch.float16 + >>> return plan + >>> + >>> def resolve_data(self, write_item): + >>> item = super().resolve_data(write_item) + >>> return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16) + + Using the global planning step to make central decisions that can't be made individually by each rank + + >>> # xdoctest: +SKIP("undefined vars") + >>> from itertools import zip_longest + >>> from dataclasses import replace + >>> class DDPLoadBalancingPlanner(DefaultSavePlanner): + >>> # This uses the default local plan behavior of having all non-sharded writes in rank 0 + >>> # This sample doesn't handle ShardedTensors + >>> def create_global_plan(self, all_plans): + >>> iters = [iter(all_plans[0].items)] * len(all_plans) + >>> items_per_rank = [ + >>> [item for item in items if item is not None] + >>> for items in zip(*zip_longest(*iters), strict=True) + >>> ] + >>> all_plans = [ + >>> replace(plan, items=items) + >>> for plan, items in zip(all_plans, items_per_rank, strict=True) + >>> ] + >>> return super().create_global_plan(all_plans) + + Finally, some planners need to save additional metadata in the checkpoint, this is + accomplished by having each rank contribute their data items in the local plan and + the global planner aggregate them: + + >>> # xdoctest: +SKIP("undefined vars") + >>> class SaveExtraDataPlanner(DefaultSavePlanner): + >>> def create_local_plan(self) -> SavePlan: + >>> plan = super().create_local_plan() + >>> return replace(plan, planner_data="per-rank-data") + >>> + >>> def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: + >>> global_plan, metadata = super().create_global_plan(all_plans) + >>> merged_data = [p.planner_data for p in global_plan] + >>> metadata = replace(metadata, planner_data=merged_data) + >>> return global_plan, metadata + """ + + # Save plan for the current rank as computed by `create_local_plan` API + # Cached on the local rank. + _cached_save_plan: dict[str, SavePlan] = {} + # Final save plan for the current rank. + # This is created by merging the plan created by `create_local_plan` API + # and the result of `create_global_plan` for the given rank. + # This is the final plan computed by the `finish_plan` API that gets + # sent to the `write_data`. + # Cached on the local rank. + _cached_final_save_plan: dict[str, SavePlan] = {} + # Collection of all the local plans from all the ranks. + # This is the input to the `create_global_plan` API. + # Cached on the coordinator rank. + _cached_all_plans: dict[str, list[SavePlan]] = {} + # Global checkpoint plan as computed by `create_global_plan` API. + # Cached on the coordinator rank. + _cached_global_plan: dict[str, list[SavePlan]] = {} + # Metadata for the global checkpoint plan as computed by `create_global_plan` API. + # Cached on the coordinator rank. + _cached_metadata: dict[str, Metadata] = {} + + @abc.abstractmethod + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + storage_meta: Optional[StorageMeta] = None, + is_coordinator: bool = False, + ) -> None: + """ + Initialize this planner to save ``state_dict``. + + Implementations should save those values as they won't be provided lated in the save process. + + This is called on all ranks. + """ + + @abc.abstractmethod + def create_local_plan(self) -> SavePlan: + """ + Compute the save plan for the current rank. + + This will be aggregated and passed to create_global_plan. + Planner specific data can be passed through SavePlan::planner_data. + + This is called on all ranks. + """ + + @abc.abstractmethod + def create_global_plan( + self, all_plans: list[SavePlan] + ) -> tuple[list[SavePlan], Metadata]: + """ + Compute the global checkpoint plan and return the local plan of each rank. + + This is called on the coordinator rank only. + """ + + @abc.abstractmethod + def finish_plan(self, new_plan: SavePlan) -> SavePlan: + """ + Merge the plan created by `create_local_plan` and the result of `create_global_plan`. + + This is called on all ranks. + """ + + @abc.abstractmethod + def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]: + """ + Transform and prepare ``write_item`` from ``state_dict`` for storage, ensuring idempotency and thread-safety. + + Lookup the object associated with ``write_item`` in ``state_dict`` and apply any + transformation (such as serialization) prior to the storage layer consuming it. + + Called on each rank multiple times, at least once per WriteItem in the final SavePlan. + + This method should be idempotent and thread-save. StorageWriter implementations + are free to call it as frequently as they need. + + Any transformation that allocates memory should be lazily done when his method + is called in order to reduce peak memory required by checkpointing. + + When returning tensors, they can be on any device or format, they can be views too. + It's the storage layer responsibility to figure out how to save them. + """ + + +class LoadPlanner: + """ + Abstract class defining the protocol used by load_state_dict to plan the load process. + + LoadPlanner are stateful objects that can be used to customize the whole load process. + + LoadPlanner acts as an access proxy to the state_dict, so any transformation done to it + will be visible to the whole process. + + A planner subclass can expect the following sequence of calls during load_state_dict: + + 1) set_up_planner - called on all ranks. + Signals the start of loading a checkpoint. + + 2) create_local_plan - called on all ranks. + Process the state_dict and produces a `LoadPlan` that will be sent for global planning. + + 3) create_global_plan - called on the coordinator rank only. + Takes the LoadPlan from all ranks and make any global decision. + + 4) load_bytes - called multiple times on each rank + This is called once per non-tensor value in state_dict. + + 5) resolve_tensor and commit_tensor - called multiple times on each rank + They are called in pair for each Tensor value in state_dict. + + Users are recommended to extend DefaultLoadPlanner instead of this interface directly as + most changes can be expressed by changes in a single method. + + There are two usual patterns of extension: + + Rewriting state_dict. This is the simplest way to extend the load process as it + doesn't requite understanding the intrincacies of how LoadPlan works. We need + to keep a reference to the original state_dict as load happens in place so + we need to be able to perform it in place + + >>> # xdoctest: +SKIP("undefined vars") + >>> class RenamePlanner(DefaultLoadPlanner): + >>> def set_up_planner( + >>> self, + >>> state_dict: STATE_DICT_TYPE, + >>> metadata: Metadata, + >>> is_coordinator: bool, + >>> ) -> None: + >>> self.original_state_dict = state_dict + >>> state_dict = {"foo_" + k: v for k, v in state_dict.items()} + >>> + >>> if self.flatten_sharded_tensors: + >>> state_dict = _flatten_sharded_tensors(state_dict) + >>> + >>> if self.flatten_state_dict: + >>> state_dict, self.mappings = flatten_state_dict(state_dict) + >>> + >>> self.state_dict = state_dict + >>> self.metadata = metadata + >>> self.is_coordinator = is_coordinator + >>> + >>> def load_bytes(self, read_item, value): + >>> # Remove the "foo_" prefix + >>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False) + + + Modifying resolve_tensor and commit_tensor to handle load time transformation. + + >>> # xdoctest: +SKIP("undefined vars") + >>> class MetaModelMaterialize(DefaultSavePlanner): + >>> def resolve_tensor(self, read_item): + >>> tensor = super().resolve_tensor(read_item) + >>> return torch.empty_like(tensor, device="cpu") + >>> + >>> def commit_tensor(self, read_item, tensor): + >>> self.state_dict[read_item.dest_index.fqn] = tensor + """ + + @abc.abstractmethod + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + metadata: Optional[Metadata] = None, + is_coordinator: bool = False, + ) -> None: + """ + Initialize this instance to load data into ``state_dict``. + + . N.B. This is called on every rank. + """ + + @abc.abstractmethod + def create_local_plan(self) -> LoadPlan: + """ + Create a LoadPlan based on state_dict and metadata provided by set_up_planner. + + . N.B. This is called on every rank. + """ + + @abc.abstractmethod + def create_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]: + """ + Compute the global load plan and return plans for each rank. + + . N.B. This is called on the coordinator rank only + """ + + @abc.abstractmethod + def finish_plan(self, central_plan: LoadPlan) -> LoadPlan: + """Accept the plan from coordinator and return final LoadPlan.""" + + @abc.abstractmethod + def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None: + """ + Load the item described by ``read_item``and ``value``. + + This method is expected to modify in-place the underlying state_dict. + + The contents of ``value`` are defined by the SavePlanner used to produce + the checkpoint being loaded. + """ + + def resolve_bytes(self, read_item: ReadItem) -> io.BytesIO: + """ + Return the BytesIO to be used by the StorageReader to load `read_item`. + + The BytesIO should alias with one on the underlying state_dict as StorageReader will replace its contents. + """ + raise NotImplementedError("LoadPlanner.resolve_bytes is not implemented") + + @abc.abstractmethod + def resolve_tensor(self, read_item: ReadItem) -> torch.Tensor: + """ + Return the tensor described by ``read_item`` to be used by the StorageReader to load `read_item`. + + The tensor should alias with one on the underlying state_dict as StorageReader will replace its contents. + If, for any reason, that's not possible, the planner can use the ``commit_tensor`` method to copy the data + back to the one in state_dict. + """ + + @abc.abstractmethod + def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None: + """ + Call once the StorageReader finished loading data into ``tensor``. + + The provided tensor is the same one returned by the call to ``resolve_tensor``. + This method is only needed if this LoadPlanner needs to post process ``tensor`` prior to + copying it back to the one in the state_dict. + + The contents of tensor will follow its device synchronization model. + """ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/planner_helpers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/planner_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..9d7af7d7a821b541cf66044a28d828d863624da2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/planner_helpers.py @@ -0,0 +1,491 @@ +# mypy: allow-untyped-defs +import io +from collections.abc import Callable +from typing import Any, cast + +import torch +import torch.distributed as dist +from torch._utils import _get_device_module +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed.tensor import DTensor +from torch.distributed.tensor._utils import compute_local_shape_and_global_offset + +from .metadata import ( + BytesStorageMetadata, + ChunkStorageMetadata, + MetadataIndex, + STATE_DICT_TYPE, + STORAGE_TYPES, + TensorProperties, + TensorStorageMetadata, +) +from .planner import ( + LoadItemType, + ReadItem, + SavePlan, + TensorWriteData, + WriteItem, + WriteItemType, +) +from .resharding import ( + _check_shard_metadata_pair_overlap, + _shards_get_overlap_region_wrt_saved_tensor, +) + + +__all__: list[str] = ["create_read_items_for_chunk_list"] + + +def _compare_save_plans(plan: SavePlan, other_plan: SavePlan) -> bool: + """ + Compare the two Save plans and return True if they are equal. + + Args: + plan (SavePlan): First SavePlan to compare. + other_plan (SavePlan): Second SavePlan to compare. + + Returns: + True if the two plans are equal, False otherwise. + """ + if plan.usable != other_plan.usable: + return False + + # Both the plans should have the same number of items + if len(plan.items) != len(other_plan.items): + return False + + # Both the plans should have the same write items. + for plan_item, other_plan_item in zip(plan.items, other_plan.items): + # Write item type should be same + if plan_item.type != other_plan_item.type: + return False + + plan_metadata_index = plan_item.index + other_plan_metadata_index = other_plan_item.index + + # Write item metadata_index should be same + if ( + plan_metadata_index.fqn != other_plan_metadata_index.fqn + or plan_metadata_index.offset != other_plan_metadata_index.offset + or plan_metadata_index.index != other_plan_metadata_index.index + ): + return False + + # Write item tensor_data should be present in both the write items plans, if it exists in either of them. + tensor_data = plan_item.tensor_data + other_tensor_data = other_plan_item.tensor_data + if (tensor_data and not other_tensor_data) or ( + not tensor_data and other_tensor_data + ): + return False + + if tensor_data and other_tensor_data: + # Write item tensor_data size should be same + if tensor_data.size != other_tensor_data.size: + return False + + # Write item tensor_data chunk should be present in both the write items, if it exists in either of them. + chunk = tensor_data.chunk + other_chunk = other_tensor_data.chunk + if (chunk and not other_chunk) or (not chunk and other_chunk): + return False + + # Write item tensor_data chunk offsets and sizes should be same + if chunk and other_chunk: + if ( + chunk.offsets != other_chunk.offsets + or chunk.sizes != other_chunk.sizes + ): + return False + + return True + + +def _contains_usable_plan(delta_plans: list[SavePlan]) -> bool: + """ + Check if any delta plan is usable, indicating the plan has changed. + + Args: + delta_plans (List[SavePlan]): A list of delta plans to check. + Returns: + True if any delta plan is usable, False otherwise. + """ + return any(delta_plan and delta_plan.usable for delta_plan in delta_plans) + + +def _merge_delta_local_plans( + cached_plans: list[SavePlan], + delta_plans: list[SavePlan], +) -> list[SavePlan]: + """ + Merge a list of delta plans into a single plan. + + Args: + cached_plans (List[SavePlan]): A list of cached plans. + delta_plans (List[SavePlan]): A list of delta plans to merge. It can contain empty plans + + Returns: + A single merged plan. If a delta plan is not usable, use the cached plan. Otherwise, use the delta plan. + """ + merged_plans = [] + + for cached_plan, delta_plan in zip(cached_plans, delta_plans): + if delta_plan and not delta_plan.usable: + merged_plans.append(cached_plan) + else: + merged_plans.append(delta_plan) + + return merged_plans + + +def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata: + return ChunkStorageMetadata( + offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size() + ) + + +def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata: + return ChunkStorageMetadata( + offsets=torch.Size(shard_md.shard_offsets), + sizes=torch.Size(shard_md.shard_sizes), + ) + + +def _sharded_tensor_metadata( + sharded_tensor: ShardedTensor, shard_md: ShardMetadata +) -> TensorWriteData: + shard_properties = sharded_tensor.metadata().tensor_properties + + properties = TensorProperties( + dtype=shard_properties.dtype, + layout=shard_properties.layout, + requires_grad=shard_properties.requires_grad, + memory_format=shard_properties.memory_format, + pin_memory=shard_properties.pin_memory, + ) + + return TensorWriteData( + chunk=_chunk_for_shard(shard_md), + properties=properties, + size=sharded_tensor.metadata().size, + ) + + +def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem: + sizes, offsets = compute_local_shape_and_global_offset( + tensor.shape, tensor.device_mesh, tensor.placements + ) + sizes, offsets = torch.Size(sizes), torch.Size(offsets) + + return WriteItem( + index=MetadataIndex(fqn, offsets), + type=WriteItemType.SHARD, + tensor_data=TensorWriteData( + chunk=ChunkStorageMetadata( + offsets=offsets, + sizes=sizes, + ), + properties=TensorProperties.create_from_tensor(tensor.to_local()), + size=tensor.size(), + ), + ) + + +def _create_write_item_for_shard( + fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata +) -> WriteItem: + offsets = torch.Size(shard_md.shard_offsets) + return WriteItem( + index=MetadataIndex(fqn, offsets), + type=WriteItemType.SHARD, + tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md), + ) + + +def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem: + offsets = torch.Size([0] * len(tensor.size())) + return WriteItem( + index=MetadataIndex(fqn, offsets), + type=WriteItemType.TENSOR, + tensor_data=TensorWriteData( + chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()), + properties=TensorProperties.create_from_tensor(tensor), + size=tensor.size(), + ), + ) + + +def _create_write_item_for_bytesio(fqn: str, bytes: Any): + return WriteItem( + index=MetadataIndex(fqn), + type=WriteItemType.BYTE_IO, + ) + + +def _create_read_item_for_byteio( + dest_index, dest_offset, storage_index, storage_offset, length +): + return ReadItem( + type=LoadItemType.BYTE_IO, + dest_index=dest_index, + dest_offsets=torch.Size((dest_offset,)), + storage_index=storage_index, + storage_offsets=torch.Size((storage_offset,)), + lengths=torch.Size((length,)), + ) + + +def _create_read_item_for_tensor( + dest_index, dest_offsets, storage_index, storage_offsets, lengths +): + return ReadItem( + type=LoadItemType.TENSOR, + dest_index=dest_index, + dest_offsets=torch.Size(dest_offsets), + storage_index=storage_index, + storage_offsets=torch.Size(storage_offsets), + lengths=torch.Size(lengths), + ) + + +def create_read_items_for_chunk_list( + fqn: str, + checkpoint_md: TensorStorageMetadata, + local_chunks: list[ChunkStorageMetadata], +) -> list[ReadItem]: + """ + Create a list of ``ReadItem`` based on the checkpoint and local chunks. + + This applies the resharding algorithm and computes the reads needed + to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``. + + Args: + fqn (str) : The state_dict FQN to pass to ``ReadItem``. + checkpoint_md (TensorStorageMetadata): metadata for a given tensor + from a checkpoint. + local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be + loaded. + + Returns: + A list of ``ReadItem`` that will satisfy all input chunks. + """ + read_items = [] + # this is a naive quadratic algo that can be optimized later + for idx, shard in enumerate(local_chunks): + for storage_idx, storage_md in enumerate(checkpoint_md.chunks): + if not _check_shard_metadata_pair_overlap(shard, storage_md): + continue + + storage_offsets = [] + dest_offsets = [] + lengths = [] + for ( + _dim, + offset_for_saved_tensor, + offset_for_current_tensor, + length, + ) in _shards_get_overlap_region_wrt_saved_tensor( + saved_shard=storage_md, current_shard=shard + ): + storage_offsets.append(offset_for_saved_tensor) + dest_offsets.append(offset_for_current_tensor) + lengths.append(length) + + read_items.append( + _create_read_item_for_tensor( + dest_index=MetadataIndex(fqn, shard.offsets, idx), + dest_offsets=dest_offsets, + storage_index=MetadataIndex(fqn, storage_md.offsets, storage_idx), + storage_offsets=storage_offsets, + lengths=lengths, + ) + ) + return read_items + + +def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan: + requests = [] + for fqn, obj in state_dict.items(): + if isinstance(obj, DTensor): + requests.append(_create_write_items_for_dtensor(fqn, obj)) + elif isinstance(obj, ShardedTensor): + requests.extend( + _create_write_item_for_shard(fqn, obj, shard_md) + for shard_md in obj.metadata().shards_metadata + ) + elif isinstance(obj, torch.Tensor): + requests.append(_create_write_item_for_tensor(fqn, obj)) + else: + requests.append(_create_write_item_for_bytesio(fqn, obj)) + return SavePlan(requests) + + +def _create_write_items(fqn: str, object: Any) -> list[WriteItem]: + if hasattr(object, "__create_write_items__"): + # DTensor implements _Checkpointable + return object.__create_write_items__(fqn, object) + elif isinstance(object, ShardedTensor): + return [ + _create_write_item_for_shard(fqn, object, shard.metadata) + for shard in object.local_shards() + ] + elif isinstance(object, torch.Tensor): + return [_create_write_item_for_tensor(fqn, object)] + else: + return [_create_write_item_for_bytesio(fqn, object)] + + +def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata: + sizes, offsets = compute_local_shape_and_global_offset( + tensor.shape, tensor.device_mesh, tensor.placements + ) + sizes, offsets = torch.Size(sizes), torch.Size(offsets) + return ChunkStorageMetadata( + offsets=offsets, + sizes=sizes, + ) + + +def _create_chunk_list(tensor: torch.Tensor) -> list[ChunkStorageMetadata]: + if hasattr(tensor, "__create_chunk_list__"): + # DTensor implements _Checkpointable + local_chunks = tensor.__create_chunk_list__() # type: ignore[attr-defined] + elif isinstance(tensor, ShardedTensor): + local_chunks = [ + _chunk_for_shard(shard.metadata) for shard in tensor.local_shards() + ] + elif isinstance(tensor, torch.Tensor): + local_chunks = [_create_chunk_from_tensor(tensor)] + else: + raise ValueError( + "Unsupported Type, expecting one of [Tensor, DTensor, ShardedTensor] " + f",but got {type(tensor)}" + ) + + return local_chunks + + +def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> list[ReadItem]: + if not isinstance(md, BytesStorageMetadata): + try: + local_chunks = _create_chunk_list(obj) + except ValueError as ex: + raise ValueError( + f"Invalid checkpoint metadata for {fqn}, " + + f"expected BytesStorageMetadata but found {type(md)}", + ) from ex + + return create_read_items_for_chunk_list(fqn, md, local_chunks) + else: + return [ + _create_read_item_for_byteio( + dest_index=MetadataIndex(fqn), + dest_offset=0, + storage_index=MetadataIndex(fqn), + storage_offset=0, + length=0, + ) + ] + + +def _init_state_dict(state_dict: dict[str, Any]) -> Any: + """ + Initializes meta tensor if the meta tensor is DTensor or torch.Tensor. + """ + + def dtensor_func(value: DTensor): + device = getattr(value, "device", None) + if device == torch.device("meta"): + device_type = dist.distributed_c10d._get_pg_default_device().type + device = cast( + torch.device, _get_device_module(device_type).current_device() + ) + new_local_tensor = torch.empty_like(value.to_local(), device=device) + # We need to pass shape and stride explicitly, since DTensor might be + # sharded unevenly. + dtensor = DTensor.from_local( + new_local_tensor, + device_mesh=value.device_mesh, + placements=value.placements, + shape=value.size(), + stride=value.stride(), + ) + return dtensor + else: + return value + + def sharded_tensor_func(value: Any): + device = getattr(value, "device", None) + if device == torch.device("meta"): + raise RuntimeError( + f"Found unsupported type {type(value)} for meta device loading." + ) + else: + return value + + def tensor_func(value: torch.Tensor): + device = getattr(value, "device", None) + if device == torch.device("meta"): + device_type = dist.distributed_c10d._get_pg_default_device().type + device = cast( + torch.device, _get_device_module(device_type).current_device() + ) + tensor = torch.empty_like(value, device=device) + return tensor + else: + return value + + _iterate_state_dict( + state_dict, + dtensor_func, + sharded_tensor_func, + tensor_func, + ) + + +def _iterate_state_dict( + iter_object: Any, + dtensor_func: Callable, + sharded_tensor_func: Callable, + tensor_func: Callable, +): + """ + Iterate through the state dict, applying the given functions to each tensor type + and update the state dict in place. + + Args: + iter_object (Any): the target state_dict. + sharded_tensor_func (Callable): the function to apply to ShardedTensor + dtensor_func (Callable): the function to apply to DTensor + tensor_func (Callable): the function to apply to Tensor + + # TODO: let state_dict_util._iterate_state_dict() to support in place option + so we don't need to have two versions of _iterate_state_dict. + """ + + if isinstance(iter_object, DTensor): + return dtensor_func(iter_object) + elif isinstance(iter_object, ShardedTensor): + return sharded_tensor_func(iter_object) + elif isinstance(iter_object, torch.Tensor): + return tensor_func(iter_object) + elif ( + isinstance(iter_object, (int, float, str, bytes, io.BytesIO)) + or iter_object is None + ): + return iter_object + elif isinstance(iter_object, dict): + for key, value in iter_object.items(): + iter_object[key] = _iterate_state_dict( + value, dtensor_func, sharded_tensor_func, tensor_func + ) + return iter_object + elif isinstance(iter_object, (list, tuple)): + ret = [ + _iterate_state_dict(v, dtensor_func, sharded_tensor_func, tensor_func) + for v in iter_object + ] + if isinstance(iter_object, tuple): + ret = tuple(ret) # type: ignore[assignment] + return ret diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/quantized_hf_storage.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/quantized_hf_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..464052d99062a9b7b4e4b156cbe7a25d0fedc017 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/quantized_hf_storage.py @@ -0,0 +1,506 @@ +# mypy: allow-untyped-defs +import json +import logging +import math +from pathlib import Path +from typing import Any + +import torch +from torch.distributed.checkpoint._hf_utils import _metadata_fn +from torch.distributed.checkpoint.metadata import TensorStorageMetadata +from torch.distributed.checkpoint.planner import LoadPlanner, ReadItem + +from .hf_storage import HuggingFaceStorageReader + + +logger: logging.Logger = logging.getLogger(__name__) + +__all__ = ["QuantizedHuggingFaceStorageReader"] + + +class QuantizedHuggingFaceStorageReader(HuggingFaceStorageReader): + """ + Extension of HuggingFaceStorageReader that handles quantized tensors. + Checkpoint should have the full tensor in a SafeTensor file. The quantized + tensor should not be sharded across multiple files. + + This reader handles the dequantization of tensors during the read process, + converting them from quantized blocks to full dequantized tensors before + copying to the target tensor. + """ + + def __init__( + self, + path: str, + thread_count: int = 1, + target_dtype: torch.dtype = torch.float32, + block_size: int = 128, + ): + """ + Initialize the HuggingFace storage reader to load quantized checkpoints + + Args: + path: directory where the checkpoint will be read from. + thread_count: Number of threads to use to read distributed checkpoint. Defaults to 1. + target_dtype: Target dtype for dequantized tensor. Defaults to torch.float32. + block_size: Fixed block size for dequantization. Defaults to 128. + """ + super().__init__(path=path, thread_count=thread_count) + + self.target_dtype: torch.dtype = target_dtype + self.block_size: int = block_size + self._weight_scale_mapping: dict[str, str] = {} + # Track which file contains each tensor + self._weight_map: dict[str, str] = {} + # Cache for full tensor shapes (fqn -> shape) + self._tensor_full_shapes: dict[str, torch.Size] = {} + + def read_metadata(self) -> Any: + metadata = super().read_metadata() + + # Load quantization metadata first. + self._load_quantization_metadata() + + # Build a cache of FQN -> full tensor shape, correcting for quantized tensors. + for fqn, tensor_metadata in metadata.state_dict_metadata.items(): + # Only process TensorStorageMetadata which has size attribute. + if isinstance(tensor_metadata, TensorStorageMetadata): + # Check if this is a MXFP4 quantized tensor that needs shape correction. + if fqn.endswith("_blocks"): + # Save the quantized tensor shapes for lookup when dequantization. + self._tensor_full_shapes[fqn + "_quantized"] = tensor_metadata.size + *prefix_shape, G, B = tensor_metadata.size + dequantized_size = torch.Size([*prefix_shape, G * B * 2]) + + # Update the metadata with the size after dequantization. + # Metadata used by planner to slice state dict. + tensor_metadata.size = dequantized_size + self._tensor_full_shapes[fqn] = dequantized_size + else: + self._tensor_full_shapes[fqn] = tensor_metadata.size + + return metadata + + def _load_quantization_metadata(self): + """Load quantization metadata from the checkpoint.""" + checkpoint_path = Path(self.path) + # Load weight mapping from index file + index_file = checkpoint_path / _metadata_fn + + with open(index_file) as f: + index_data = json.load(f) + weight_map = index_data.get("weight_map", {}) + self._build_weight_scale_mapping(weight_map) + + def _build_weight_scale_mapping(self, weight_map: dict[str, str]): + """Analyze and build weight-scale tensor pairs from weight mapping.""" + # Store the complete weight map for file location lookups. + self._weight_map = weight_map + + for tensor_name in weight_map: + if tensor_name.endswith(".weight_scale_inv"): + weight_name = tensor_name.replace(".weight_scale_inv", ".weight") + if weight_name in weight_map: + self._weight_scale_mapping[weight_name] = tensor_name + # Handle MXFP4 format: _blocks and _scales. + elif tensor_name.endswith("_scales"): + blocks_name = tensor_name.replace("_scales", "_blocks") + if blocks_name in weight_map: + self._weight_scale_mapping[blocks_name] = tensor_name + + def _process_read_request( + self, f: Any, req: ReadItem, planner: LoadPlanner + ) -> None: + """Override the Helper function that processes a single read request.""" + tensor_fqn = req.storage_index.fqn + + # Check if this is a quantized tensor that needs dequantization + if self._is_tensor_quantized(tensor_fqn): + tensor = self._read_quantized_tensor_with_block_alignment(req, f) + else: + # Standard tensor reading + slices = tuple( + slice(offset, offset + length) + for offset, length in zip(req.storage_offsets, req.lengths) + ) + tensor = f.get_slice(tensor_fqn)[slices] + + target_tensor = planner.resolve_tensor(req).detach() + + if target_tensor.size() != tensor.size(): + raise AssertionError( + f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" + ) + + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + def _get_slice_to_block_mapping( + self, req: ReadItem + ) -> tuple[tuple[int, int], tuple[int, int], slice, slice]: + """ + Calculate which blocks correspond to the requested slice. + + Args: + req: Read request containing tensor info and required slices + + Returns: + Tuple of (row_block_range, col_block_range, row_slice, col_slice) + """ + # Get the slice information + row_slice = slice( + req.storage_offsets[0], req.storage_offsets[0] + req.lengths[0] + ) + col_slice = slice( + req.storage_offsets[1], req.storage_offsets[1] + req.lengths[1] + ) + + # Calculate which blocks this slice spans + row_start_block = row_slice.start // self.block_size + row_end_block = (row_slice.stop - 1) // self.block_size + 1 # Inclusive end + + col_start_block = col_slice.start // self.block_size + col_end_block = (col_slice.stop - 1) // self.block_size + 1 # Inclusive end + + return ( + (row_start_block, row_end_block), + (col_start_block, col_end_block), + row_slice, + col_slice, + ) + + def _dequantize_tensor_mxfp4( + self, + blocks: torch.Tensor, + scales: torch.Tensor, + req: ReadItem, + group_start: int, + offset_in_first_group: int, + ) -> torch.Tensor: + """ + Dequantize a 4D tensor using MXFP4 format. + Adapted from openai's implementation: + https://github.com/openai/gpt-oss/blob/8890e95919f975a490fc0ba09ffb10890ec7319d/gpt_oss/torch/weights.py#L68 + + Args: + blocks: Sliced quantized weight tensor of shape [a_slice, b_slice, groups_slice, B] in uint8 + scales: FULL scale tensor of shape [a, b, c] in uint8 (will be converted to exponents) + req: Read request containing slice information + group_start: The starting group index in the checkpoint + offset_in_first_group: Offset in values within the first group + + Returns: + Dequantized tensor matching the requested shape + """ + # FP4 lookup table + FP4_VALUES = [ + +0.0, + +0.5, + +1.0, + +1.5, + +2.0, + +3.0, + +4.0, + +6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ] + + # blocks: [a_slice, b_slice, groups_slice, B] uint8. + # Read slightly more groups than needed, and slice at the end. + + # Slice the scales to match the blocks dimensions. + # [a_full, b_full, c_full] -> [a_slice, b_slice, groups_slice] + dim0_start = req.storage_offsets[0] + dim0_end = dim0_start + req.lengths[0] + dim1_start = req.storage_offsets[1] + dim1_end = dim1_start + req.lengths[1] + num_groups = blocks.shape[2] + scales = scales[ + dim0_start:dim0_end, + dim1_start:dim1_end, + group_start : group_start + num_groups, + ] + + scales = scales.to(torch.int32) - 127 + + assert blocks.shape[:-1] == scales.shape, ( + f"{blocks.shape=} does not match {scales.shape=}" + ) + + lut = torch.tensor(FP4_VALUES, dtype=self.target_dtype, device=blocks.device) + + *prefix_shape, G, B = blocks.shape + rows_total = math.prod(prefix_shape) * G + + blocks = blocks.reshape(rows_total, B) + scales = scales.reshape(rows_total, 1) + + out = torch.empty( + rows_total, B * 2, dtype=self.target_dtype, device=blocks.device + ) + + rows_per_chunk = 16384 * 512 + + for r0 in range(0, rows_total, rows_per_chunk): + r1 = min(r0 + rows_per_chunk, rows_total) + + blk = blocks[r0:r1] + exp = scales[r0:r1] + + # nibble indices -> int64 + idx_lo = (blk & 0x0F).to(torch.long) + idx_hi = (blk >> 4).to(torch.long) + + sub = out[r0:r1] + sub[:, 0::2] = lut[idx_lo] + sub[:, 1::2] = lut[idx_hi] + + torch.ldexp(sub, exp, out=sub) + + del idx_lo, idx_hi, blk, exp + + result = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) + + # Slice the last dimension to match the requested range. + if offset_in_first_group > 0 or result.shape[-1] > req.lengths[2]: + end_offset = offset_in_first_group + req.lengths[2] + result = result[..., offset_in_first_group:end_offset] + + return result + + def _dequantize_tensor( + self, + weight: torch.Tensor, + scale_inv: torch.Tensor, + full_tensor_shape: torch.Size, + slice_info: tuple[tuple[int, int], tuple[int, int], slice, slice], + ) -> torch.Tensor: + """ + Dequantize a sliced tensor using the appropriate portion of the scale tensor. + + Args: + weight: Sliced quantized weight tensor + scale_inv: Full scale inverse tensor for dequantization + full_tensor_shape: Shape of the original full tensor + slice_info: Block mapping information from _get_slice_to_block_mapping + + Returns: + Dequantized tensor + """ + (row_block_range, col_block_range, row_slice, col_slice) = slice_info + + # Convert to float32 for computation + # Certain quantized dtypes like Float8_e4m3fn + # don't support multiplication on CPU yet in PyTorch. + upcasted_weight = weight.to(torch.float32) + + # Create output tensor in target dtype + dequantized = weight.detach().to(dtype=self.target_dtype, copy=True) + + # Get the actual slice boundaries + row_start_global = row_slice.start + row_end_global = row_slice.stop + col_start_global = col_slice.start + col_end_global = col_slice.stop + + # Apply scaling factors to each block that intersects with our slice + for block_i in range(row_block_range[0], row_block_range[1]): + for block_j in range(col_block_range[0], col_block_range[1]): + # Calculate the block boundaries in global coordinates + block_row_start_global = block_i * self.block_size + block_row_end_global = min( + block_row_start_global + self.block_size, full_tensor_shape[0] + ) + block_col_start_global = block_j * self.block_size + block_col_end_global = min( + block_col_start_global + self.block_size, full_tensor_shape[1] + ) + + # Find the intersection of the block with our slice + intersect_row_start = max(block_row_start_global, row_start_global) + intersect_row_end = min(block_row_end_global, row_end_global) + intersect_col_start = max(block_col_start_global, col_start_global) + intersect_col_end = min(block_col_end_global, col_end_global) + + # Skip if no intersection + if ( + intersect_row_start >= intersect_row_end + or intersect_col_start >= intersect_col_end + ): + continue + + # Convert global coordinates to local coordinates in the sliced tensor + local_row_start = intersect_row_start - row_start_global + local_row_end = intersect_row_end - row_start_global + local_col_start = intersect_col_start - col_start_global + local_col_end = intersect_col_end - col_start_global + + # Get the block from the sliced tensor + block = upcasted_weight[ + local_row_start:local_row_end, local_col_start:local_col_end + ] + + # Apply the scale factor + scale = scale_inv[block_i, block_j] + block = block * scale + + # Convert block to target dtype and store + block_converted = block.to(dtype=self.target_dtype) + dequantized[ + local_row_start:local_row_end, local_col_start:local_col_end + ] = block_converted + + return dequantized + + def _is_tensor_quantized(self, tensor_fqn: str) -> bool: + """ + Check if a tensor is a quantized. + + Args: + tensor_fqn: Fully qualified name of the tensor + + Returns: + True if tensor is quantized and has a corresponding scale tensor, + False otherwise + """ + # Skip scale tensors themselves + if tensor_fqn.endswith((".weight_scale_inv", "_scales")): + return False + + # Check if this weight tensor has a corresponding scale tensor + if tensor_fqn not in self._weight_scale_mapping: + return False + + return True + + def _read_quantized_tensor_with_block_alignment( + self, req: ReadItem, safetensor_file: Any + ) -> torch.Tensor: + """ + Read a quantized tensor with block alignment. + + Args: + req: Read request containing tensor info and required slices + safetensor_file: Open safetensors file handle + + Returns: + Dequantized tensor ready for use + """ + tensor_fqn = req.storage_index.fqn + scale_fqn = self._weight_scale_mapping[tensor_fqn] + + try: + group_start = 0 + offset_in_first_group = 0 + if tensor_fqn.endswith("_blocks"): + # Full tensor is a 4D MXFP4 quantized tensor: [..., G, B]. + # Each group G produces B * 2 dequantized values. + # Checkpoint [..., G, B] -> dequantized [..., G*B*2]. + + # The planner gives 3D requests based on the dequantized shape. + # Need to figure out which groups (dimension 2 in checkpoint) to read. + + # Use the quantized checkpoint shape to get the correct B. + *prefix_shape, B = self._tensor_full_shapes[tensor_fqn + "_quantized"] + values_per_group = B * 2 # Each byte has 2 nibbles (4-bit values). + + # Calculate which groups we need based on the requested range in dim 2. + # Ensure the reequest is in 3D. + assert len(req.storage_offsets) == 3 + + # Positions in dequantized space. + dim2_start_deq = req.storage_offsets[2] + dim2_length_deq = req.lengths[2] + dim2_end_deq = dim2_start_deq + dim2_length_deq + + # Convert to group indices. + group_start = dim2_start_deq // values_per_group + group_end = (dim2_end_deq + values_per_group - 1) // values_per_group + + # Read only the necessary groups from checkpoint. + weight_slices_4d = ( + slice( + req.storage_offsets[0], req.storage_offsets[0] + req.lengths[0] + ), + slice( + req.storage_offsets[1], req.storage_offsets[1] + req.lengths[1] + ), + slice(group_start, group_end), + slice(None), # Read all B values for each group. + ) + quantized_tensor = safetensor_file.get_slice(tensor_fqn)[ + weight_slices_4d + ] + + # Also track the offset within the first group + offset_in_first_group = dim2_start_deq - ( + group_start * values_per_group + ) + else: + # 2D quantized tensor, use 2d block partition. + weight_slices = tuple( + slice(offset, offset + length) + for offset, length in zip(req.storage_offsets, req.lengths) + ) + quantized_tensor = safetensor_file.get_slice(tensor_fqn)[weight_slices] + + # Load the corresponding scale inverse tensor (full tensor) + scale_file_name = self._weight_map.get(scale_fqn) + if scale_file_name is None: + raise ValueError(f"Scale tensor {scale_fqn} not found in weight_map") + + # Check if scale tensor is in the same file as the weight tensor + weight_file_name = self._weight_map.get(tensor_fqn) + + if scale_file_name == weight_file_name: + # Scale tensor is in the same file, use current handle + scale_inv = safetensor_file.get_tensor(scale_fqn) + else: + # Scale tensor is in a different file, need to open it + from safetensors import safe_open # type: ignore[import] + + scale_file_path = Path(self.path) / scale_file_name + with safe_open( + scale_file_path, framework="pt", device="cpu" + ) as scale_file: + scale_inv = scale_file.get_tensor(scale_fqn) + + # Get the full tensor shape from our O(1) lookup cache + full_tensor_shape = self._tensor_full_shapes.get(tensor_fqn) + if full_tensor_shape is None: + raise ValueError(f"Could not find full tensor shape for {tensor_fqn}") + + # Determine which dequantization function to use. + if len(full_tensor_shape) == 2: + # 2D block-wise quantization, e.g., used in deepseek v3.1 + slice_info = self._get_slice_to_block_mapping(req) + dequantized_tensor = self._dequantize_tensor( + weight=quantized_tensor, + scale_inv=scale_inv, + full_tensor_shape=full_tensor_shape, + slice_info=slice_info, + ) + elif tensor_fqn.endswith("_blocks"): + # 4D with blocks along dimension 2, used in MXFP4, e.g. gpt-oss + dequantized_tensor = self._dequantize_tensor_mxfp4( + blocks=quantized_tensor, + scales=scale_inv, + req=req, + group_start=group_start, + offset_in_first_group=offset_in_first_group, + ) + else: + raise ValueError("Unsupported quantization types") + + return dequantized_tensor + + except Exception as e: + logger.error("Failed to read the quantized tensor!!") + raise e diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/resharding.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/resharding.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f24b891aa895d3a445908fe6d084e13f9b05da --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/resharding.py @@ -0,0 +1,69 @@ +from torch.distributed.checkpoint.metadata import ChunkStorageMetadata + + +__all__: list[str] = [] + + +def _check_shard_metadata_pair_overlap( + shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata +) -> bool: + """Check if two shards overlap.""" + # For each dim of each shard, check if one shard resides on the other + # end of second shard with respect to that dim. As an example for a 2D + # shard, we would check if one shard is above or on the left of the + # other shard. + ndims = len(shard1.offsets) + for i in range(ndims): + if shard1.offsets[i] >= shard2.offsets[i] + shard2.sizes[i]: + return False + if shard2.offsets[i] >= shard1.offsets[i] + shard1.sizes[i]: + return False + + return True + + +def _shards_get_overlap_region_wrt_saved_tensor( + saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata +) -> list[tuple[int, int, int, int]]: + """ + Return the overlapping region between saved_shard and current_shard. + + There returned list has the same number of elements as the tensor's dimension. + For each element, we produce a tuple with the following contents: + (dimension, `saved_shard` offset, `current_shard` offset, length) + + Offsets are relative to each shard. + """ + narrows = [] + for dim, ( + saved_shard_offset, + current_shard_offset, + saved_shard_size, + current_shard_size, + ) in enumerate( + zip( + saved_shard.offsets, + current_shard.offsets, + saved_shard.sizes, + current_shard.sizes, + ) + ): + min_range_end = min( + saved_shard_offset + saved_shard_size, + current_shard_offset + current_shard_size, + ) + + length = min_range_end - max(current_shard_offset, saved_shard_offset) + + if saved_shard_offset > current_shard_offset: + offset_for_saved_tensor = 0 + offset_for_current_tensor = saved_shard_offset - current_shard_offset + else: + offset_for_saved_tensor = current_shard_offset - saved_shard_offset + offset_for_current_tensor = 0 + + narrows.append( + (dim, offset_for_saved_tensor, offset_for_current_tensor, length) + ) + + return narrows diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/staging.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/staging.py new file mode 100644 index 0000000000000000000000000000000000000000..4bbacc66aaaffe038bced44b9ca7b466a2246f90 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/staging.py @@ -0,0 +1,474 @@ +import os +import tempfile +from concurrent.futures import Future, ThreadPoolExecutor +from contextlib import nullcontext +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, cast, Optional, Union +from typing_extensions import deprecated, Protocol, runtime_checkable + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict +from torch.distributed.checkpoint._pg_transport import PGTransport +from torch.distributed.checkpoint._state_dict_stager import StateDictStager +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE + + +__all__ = ["AsyncStager", "BlockingAsyncStager", "DefaultStager", "StagingOptions"] + +""" +Experimental staging module for PyTorch Distributed Checkpointing. +This module provides advanced staging capabilities for checkpoints including: +- Asynchronous staging using ThreadPoolExecutor +- Pinned memory allocation for faster CPU-GPU transfers +- Shared memory support for multi-process scenarios +- Non-blocking CUDA operations with stream synchronization +- Caching of frequently used storages for efficient memory management +- Automatic resource cleanup and memory management +Classes: + AsyncStager: Protocol defining the staging interface + StagingOptions: Configuration dataclass for staging behavior + DefaultStager: Default implementation with comprehensive staging features + BlockingAsyncStager: Implementation of AsyncStager which stages the state_dict + on CPU RAM and blocks until the copy is complete. Please use DefaultStager instead. +""" + + +@runtime_checkable +class AsyncStager(Protocol): + """ + This protocol is meant to provide customization and extensibility for dcp.async_save, allowing users + to customize how data is staged previous to executing the usual dcp.save path in parallel. + The expected order of operations (concretely defined in `torch.distributed.state_dict_saver.async_save`) + is the following: + + 1. AsyncStager.stage_data(state_dict): + This call gives the AsyncStager the opportunity to 'stage' + the state_dict. The expectation and purpose of staging in this context is to create a "training-safe" + representation of the state dict, meaning that any updates to module data after staging is complete + should not be reflected in the state dict returned from this method. For example, in the default + case a copy of the entire state dict is created on CPU RAM and returned here, allowing users + to continue training without risking changes to data which is being serialized. + + 2. dcp.save is called on the state_dict returned from stage in parallel. This call is responsible + for serializing the state_dict and writing it to storage. + + 3. If AsyncStager.should_synchronize_after_execute is True, this method will be called immediately after + the serialization thread starts and before returning from dcp.async_save. If this is set to False, + the assumption is the user has defined a custom synchronization point for the purpose of further + optimizing save latency in the training loop (for example, by overlapping staging with the + forward/backward pass), and it is the respondsibility of the user to call `AsyncStager.synchronize_staging` + at the appropriate time. + + """ + + # default to True since the common case is to stage synchronously + _synchronize_after_execute: bool = True + + @property + def should_synchronize_after_execute(self) -> bool: + """ + Whether to synchronize after executing the stage. + """ + return self._synchronize_after_execute + + def stage( + self, state_dict: STATE_DICT_TYPE + ) -> Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE]: + """ + Returns a "staged" copy of `state_dict`. The expectation of the staged copy is that it is + inoculated from any updates incurred after the stage call is complete. + """ + raise NotImplementedError( + f"{self.__class__.__name__} must implement stage method" + ) + + @deprecated( + "`synchronize_staging` is deprecated and will be removed in future versions." + "Please use staging_future from AsyncSaveResponse instead.", + category=FutureWarning, + ) + def synchronize_staging(self) -> None: + """ + In the case `stage` is async in some way, this method should be called to ensure staging + is complete and it is safe to begin modifying the original `state_dict` + """ + + def close(self) -> None: + """ + Clean up all resources used by the stager. + """ + + +@dataclass +class StagingOptions: + """ + Configuration options for checkpoint staging behavior. + + Attributes: + use_pinned_memory (bool): Enable pinned memory allocation for faster + CPU-GPU transfers. Requires CUDA to be available. Default: True + use_shared_memory (bool): Enable shared memory for multi-process + scenarios. Useful when multiple processes need access to the + same staged data. Default: True + use_async_staging (bool): Enable asynchronous staging using a + background thread pool. Allows overlapping computation with + staging operations. Requires CUDA. Default: True + use_non_blocking_copy (bool): Use non-blocking device memory + copies with stream synchronization. Improves performance by + allowing CPU work to continue during GPU transfers. Default: True + + Note: + CUDA-dependent features will raise exception if CUDA is not available. + """ + + use_pinned_memory: bool = True + use_shared_memory: bool = True + use_async_staging: bool = True + use_non_blocking_copy: bool = True + + +class DefaultStager(AsyncStager): + """ + DefaultStager provides a full-featured staging implementation that combines + multiple optimization techniques for efficient checkpoint preparation. + + The staging process works as follows: + 1. State dictionary is submitted for staging (sync or async) + 2. Tensors are copied from GPU to optimized CPU storage + 3. CUDA operations are synchronized if non-blocking copies are used + 4. Staged state dictionary is returned or made available via Future + + Usage Patterns: + # Synchronous staging + stager = DefaultStager(StagingOptions(use_async_staging=False)) + staged_dict = stager.stage(state_dict) + stager.close() + + # Asynchronous staging + stager = DefaultStager(StagingOptions(use_async_staging=True)) + future = stager.stage(state_dict) + # ... do other work ... + staged_dict = future.result() + stager.close() + + # Context manager pattern (recommended) + stager = DefaultStager(config) + with stager: + result = stager.stage(state_dict) + + Performance Considerations: + - Async staging provides best performance when model computation + can overlap with staging operations + - Pinned memory improves CPU-GPU transfer speeds but uses more memory + - Shared memory allows efficient IPC to checkpoint process + - Non-blocking copies reduce GPU idle time during memory transfers + + Thread Safety: + DefaultStager is not thread-safe. Each thread should use its own + instance, or external synchronization should be provided. + """ + + def __init__( + self, + config: StagingOptions = StagingOptions(), + ): + self._config = config + self._state_dict_stager = StateDictStager( + pin_memory=config.use_pinned_memory, share_memory=config.use_shared_memory + ) + self._staging_executor = None + self._staging_stream = None + if self._config.use_async_staging: + # pyrefly: ignore [bad-assignment] + self._staging_executor = ThreadPoolExecutor(max_workers=1) + if torch.accelerator.is_available(): + # Note: stream needs to be initialized on the main thread after default cuda + # stream is setup/used to avoid the risk of accidentally reusing the main + # compute stream or in other cases kernels actually launching from the + # main thread. + # pyrefly: ignore [bad-assignment] + self._staging_stream = torch.Stream() + + if self._config.use_non_blocking_copy: + if not torch.accelerator.is_available(): + raise AssertionError( + "Non-blocking copy requires that the current accelerator is available." + ) + + self._staging_future: Optional[Future[STATE_DICT_TYPE]] = None + + def stage( + self, + state_dict: STATE_DICT_TYPE, + **kwargs: Any, + ) -> Union[STATE_DICT_TYPE, Future[STATE_DICT_TYPE]]: + """ + This function is responsible for staging staging the state_dict. + See class docstring for more details on staging. + If use_async_staging is True, it will return a Future object that will be + fulfilled when staging is complete. + If use_async_staging is False, it will return the fully staged state_dict. + + Args: + state_dict (STATE_DICT_TYPE): The state_dict to be staged. + """ + if self._config.use_async_staging: + if self._staging_executor is None: + raise AssertionError( + "staging_executor should not be None for async staging" + ) + self._staging_future = self._staging_executor.submit( + self._stage, + state_dict, + **kwargs, + ) + return self._staging_future + else: + return self._stage(state_dict, **kwargs) + + def _stage(self, state_dict: STATE_DICT_TYPE, **kwargs: Any) -> STATE_DICT_TYPE: + if self._config.use_non_blocking_copy: + if not (self._staging_stream or not self._config.use_async_staging): + raise AssertionError( + "Non-blocking copy in a background thread for async staging needs staging_stream to be initialized." + ) + with ( + self._staging_stream + if self._staging_stream is not None + else nullcontext() + ): + state_dict = self._state_dict_stager.stage( + state_dict, non_blocking=self._config.use_non_blocking_copy + ) + # waits for the enqued copy operations to finish. + self._staging_stream.synchronize() if self._staging_stream else torch.accelerator.synchronize() + else: + state_dict = self._state_dict_stager.stage(state_dict, non_blocking=False) + return state_dict + + def close(self) -> None: + """ + Clean up all resources used by the DefaultStager. Shuts down the ThreadPoolExecutor + used for async staging operations and cleans up the underlying StateDictStager's + cached storages. Should be called when the stager is no longer needed to prevent + resource leaks, especially in long-running applications. After calling close(), + the stager should not be used for further staging operations. + + Example Usage: + stager = DefaultStager(StagingOptions(use_async_staging=True)) + future = stager.stage(state_dict) + result = future.result() + stager.close() # Clean up all resources + """ + if self._staging_executor: + self._staging_executor.shutdown(wait=True) + + def synchronize_staging(self) -> None: + """ + When use_async_staging is True, this method will wait until staging is complete. + If use_async_staging is False, this method is a no-op. + """ + if self._staging_future is not None: + self._staging_future.result() + + +class BlockingAsyncStager(AsyncStager): + """ + An implementation of AsyncStager which stages the state_dict on CPU RAM and blocks until the copy is complete. + This implementation also provides an option to optimize stage latency using pinned memory. + + N.B. synchronize_staging is a no-op in this case. + + + """ + + # default to True since the common case is to stage synchronously + _synchronize_after_execute: bool = False + + def __init__( + self, + cache_staged_state_dict: bool = False, + type_check: bool = False, + ): + """ + Initializes the BlockingAsyncStager. + + Args: + cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency + at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation + that the stager is maintained and reused for multiple dcp.async_save calls. Default to False. + type_check: Whether to perform a type check during cpu_offload. Defaults to False. + + """ + self.cache_staged_state_dict = cache_staged_state_dict + self.type_check = type_check + self.state_dict_cache: Optional[STATE_DICT_TYPE] = None + + def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + """ + Returns a copy of `state_dict` on the CPU. + """ + + if not self.cache_staged_state_dict: + staged_state_dict = _create_cpu_state_dict(state_dict) + _copy_state_dict(state_dict, staged_state_dict, type_check=self.type_check) + return staged_state_dict + + if self.state_dict_cache is None: + self.state_dict_cache = _create_cpu_state_dict(state_dict, pin_memory=True) + return _copy_state_dict(state_dict, self.state_dict_cache) + + def synchronize_staging(self) -> None: + """ + No-op function, since staging is blocking. + """ + + def close(self) -> None: + pass + + +class _ReplicationStager(AsyncStager): + """ + An AsyncStager implementation that replicates state_dict across training ranks + using PGTransport. + + Args: + pg: ProcessGroup for distributed communication + timeout: Timeout for communication operations + device: Device to use for tensor operations + storage_dir: Directory to store persisted state_dicts + + Warning: This is experimental and subject to change. + """ + + _synchronize_after_execute: bool = False + + def __init__( + self, + pg: ProcessGroup, + timeout: timedelta = timedelta(minutes=30), + device: torch.device = torch.device("cpu"), + storage_dir: Optional[str] = None, + ): + self._pg = pg + self._timeout = timeout + # pyrefly: ignore [read-only] + self._device = device + self._transport = PGTransport(pg, timeout, device, None) + + # Set up storage directory for persisting exchanged state_dicts + if storage_dir is None: + self._storage_dir = tempfile.mkdtemp(prefix="replication_stager_") + else: + self._storage_dir = storage_dir + os.makedirs(self._storage_dir, exist_ok=True) + + def stage( + self, state_dict: STATE_DICT_TYPE + ) -> Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE]: + """ + Stage the state_dict by replicating it across ranks. Returns a state_dict representing + the received replica. + + Perform the actual replication logic. Creates bidirectional pairs where each rank exchanges + state_dict with its partner at (rank + world_size//2) % world_size. + Uses simple rank-based ordering to prevent deadlocks. + + Assumes world_size is always even. + """ + if not dist.is_initialized(): + return state_dict + + world_size = dist.get_world_size() + + current_rank = dist.get_rank() + + # Calculate partner rank using half-world offset + # creates bidirectional pairs for replication. + offset = world_size // 2 + partner_rank = (current_rank + offset) % world_size + + # Use simple rank-based ordering to prevent deadlocks. + # Lower-numbered rank sends first, higher-numbered rank receives first. + if current_rank < partner_rank: + # Send first, then receive + self._transport.send_checkpoint([partner_rank], state_dict) + received_state_dict = self._transport.recv_checkpoint(partner_rank) + else: + # Receive first, then send + received_state_dict = self._transport.recv_checkpoint(partner_rank) + self._transport.send_checkpoint([partner_rank], state_dict) + + # Persist the received state_dict for future discoverability + received_state_dict = cast(STATE_DICT_TYPE, received_state_dict) + self._persist_state_dict(received_state_dict, current_rank, partner_rank) + + return received_state_dict + + def _persist_state_dict( + self, state_dict: STATE_DICT_TYPE, current_rank: int, partner_rank: int + ) -> None: + """ + Persist the received state_dict to disk for future discoverability. + Only keeps one replica per rank, overwriting any previous replica. + Uses atomic write pattern (temp file + rename). + + Args: + state_dict: The state_dict received from partner rank + current_rank: Current rank that received the state_dict + partner_rank: Rank that sent the state_dict + """ + final_path = self._get_persisted_path(current_rank, partner_rank) + temp_path = final_path + ".tmp" + + try: + # Ensure parent directory exists and is writable + os.makedirs(os.path.dirname(final_path), exist_ok=True) + + # Write to temporary file with explicit flushing + with open(temp_path, "wb") as f: + torch.save(state_dict, f) + # Flush application buffers to OS buffers + f.flush() + # Force OS buffers to disk for durability + os.fsync(f.fileno()) + + # Atomic rename to final location + os.rename(temp_path, final_path) + except Exception as e: + # Clean up temp file if it exists + try: + if os.path.exists(temp_path): + os.remove(temp_path) + except Exception: + pass # Ignore cleanup errors + # Re-raise the original exception with more context + raise RuntimeError( + f"Failed to persist state_dict from rank {partner_rank} to rank {current_rank}: {e}" + ) from e + + def _get_persisted_path(self, current_rank: int, partner_rank: int) -> str: + """ + Get the file path where a state_dict would be persisted. + + Args: + current_rank: Current rank + + Returns: + File path for the persisted state_dict + """ + filename = f"rank_{current_rank}_replica_partner_{partner_rank}.pt" + return os.path.join(self._storage_dir, filename) + + def synchronize_staging(self) -> None: + """ + No-op function, since staging is blocking. + """ + + def close(self) -> None: + """ + Clean up resources. Persisted files are intentionally left for future discovery. + """ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..6a31144348acb669e0a2f8e17805d5650d5d61d1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict.py @@ -0,0 +1,1634 @@ +# mypy: allow-untyped-defs +import contextlib +import functools +import gc +import warnings +from collections.abc import Callable, Generator, Iterable +from dataclasses import asdict, dataclass, field +from itertools import chain +from typing import Any, cast, no_type_check, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed._state_dict_utils import ( + _broadcast_state_dict, + _distribute_state_dict, + _flatten_state_dict, + _gather_state_dict, + _offload_state_dict_to_cpu, + _unflatten_state_dict, +) +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + _CHECKPOINT_PREFIX, +) +from torch.distributed.fsdp import ( + FullOptimStateDictConfig, + FullStateDictConfig, + FullyShardedDataParallel as FSDP, + OptimStateDictConfig, + ShardedOptimStateDictConfig, + ShardedStateDictConfig, + StateDictConfig, + StateDictType, +) +from torch.distributed.fsdp._common_utils import ( + _get_module_fsdp_state_if_fully_sharded_module, + FSDP_WRAPPED_MODULE, +) +from torch.distributed.tensor import DTensor +from torch.nn.modules.module import _IncompatibleKeys +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils._pytree import tree_map_only + + +__all__ = [ + "FQNS_T", + "PrimitiveType", + "ValueType", + "DictValueType", + "ListDictValueType", + "OptimizerStateType", + "StateDictOptions", + "get_model_state_dict", + "get_optimizer_state_dict", + "get_state_dict", + "set_model_state_dict", + "set_optimizer_state_dict", + "set_state_dict", +] + + +_FLAT_PARAM = "_flat_param" +_PG = "param_groups" +_PARAMS = "params" +_STATE = "state" + +FQNS_T = set[str] +PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str] +ValueType = Union[ + PrimitiveType, list[PrimitiveType], tuple[PrimitiveType], dict[str, "ValueType"] +] +DictValueType = dict[str, ValueType] +ListDictValueType = list[DictValueType] +OptimizerStateType = dict[str, Union[DictValueType, ListDictValueType]] + + +_patched_state_dict: set[Callable] = set() + + +@contextlib.contextmanager +def _gc_context(): + is_enabled = gc.isenabled() + gc.disable() + try: + yield + finally: + if is_enabled: + gc.enable() + + +@dataclass +class StateDictOptions: + """ + This dataclass specifies how get_state_dict/set_state_dict will work. + + - ``full_state_dict``: if this is set to True, all the tensors in the + returned state_dict will be gathered. No ShardedTensor and DTensor + will be in the returned state_dict. + + - ``cpu_offload``: offload all the tensors to cpu. To prevent CPU OOM, if + ``full_state_dict`` is also true, then only the rank0 will get the + state_dict and all other ranks will get empty state_dict. + + - ``ignore_frozen_params``: if the value is True, the returned state_dict + won't contain any frozen parameters -- the ``requires_grad`` is False. + The default value is False. + + - ``keep_submodule_prefixes`` (deprecated): when ``submodules`` is not None, this option + indicates whether to keep the submodule prefixes from the state_dict keys. + or example, if the submodule is ``module.pretrain`` and the full FQN of + the parameter is ``pretrain.layer1.weight`` of the param. When this option + is True, the parameter's key in the returned state_dict will be + ``pretrain.layer1.weight``. If the options is False, the key will be + ``layer1.weight``. + Note that if ``keep_submodule_prefixes`` is False, there may be conflicted + FQNs, hence there should be only one submodule in ``submodules``. + + - ``strict``: the ``strict`` option when ``set_state_dict`` calls + model.load_state_dict(). + + - ``broadcast_from_rank0``: when the option is True, rank0 should receive a + full state_dict and will broadcast the tensors in the state_dict/ + optim_state_dict one by one to other ranks. Other ranks will receive + the tensors and shard according to the local shards in the model and + optimizer. ``full_state_dict`` must be set to True when using this option. + This option currently only supports DTensor, not the legacy ShardedTensor. + """ + + full_state_dict: bool = False + cpu_offload: bool = False + ignore_frozen_params: bool = False + keep_submodule_prefixes: bool = True + strict: bool = True + broadcast_from_rank0: bool = False + flatten_optimizer_state_dict: bool = False + dsd_fqn_modifiers: str = "_fqn_modifiers" + + +@dataclass +class _StateDictInfo(StateDictOptions): + fqn_param_mapping: dict[ + Union[str, torch.Tensor], + Union[FQNS_T, torch.Tensor], + ] = field(default_factory=dict) + shared_params_mapping: dict[ + Union[str, torch.Tensor], + Union[FQNS_T, torch.Tensor], + ] = field(default_factory=dict) + submodule_prefixes: set[str] = field(default_factory=set) + handle_model: bool = True + handle_optim: bool = True + fsdp_context: Callable = contextlib.nullcontext + fsdp_modules: list[nn.Module] = field(default_factory=list) + + +def _get_fqns( + model: nn.Module, + name: str, + dsd_fqn_modifiers: str = "_fqn_modifiers", + skip_ddp_prefix: bool = True, + skip_compiler_prefix: bool = True, +) -> FQNS_T: + """ + This API is used to convert the name of a parameter to the FQNs. For FSDP + without `use_orig_params`, the name of FlatParameter can be mapped to + multiple original parameters. As a result, the return type of this function + is `set[str]`. + + Args: + module (nn.Module): the root model. + name (str): the name + skip_ddp_prefix (bool): whether to skip DDP's `module` prefix + + Returns: + The canonical FQNs based on the model traversal. + """ + + # Remove the checkpoint prefix, if it exists. + name = name.replace(_CHECKPOINT_PREFIX, "") + if "." not in name: + return {name} + + obj_names = name.split(".") + fqn_obj_names = [] + curr_obj = model + for i, curr_obj_name in enumerate(obj_names): + if isinstance(curr_obj, DDP): + if curr_obj_name != "module": + raise AssertionError(f"Expected 'module', got '{curr_obj_name}'") + curr_obj = curr_obj.module + if not skip_ddp_prefix: + fqn_obj_names.append(curr_obj_name) + elif isinstance(curr_obj, FSDP): + if i < len(obj_names) - 1 and obj_names[i + 1] == _FLAT_PARAM: + prefix = ".".join(fqn_obj_names) + flat_param = getattr(curr_obj, _FLAT_PARAM) + if prefix: + prefix = f"{prefix}." + return {f"{prefix}{fqn}" for fqn in flat_param._fqns} + curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE) + if curr_obj_name != FSDP_WRAPPED_MODULE: + # pyrefly: ignore [bad-argument-type] + fqn_obj_names.append(curr_obj_name) + curr_obj = getattr(curr_obj, curr_obj_name) + elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule): + if curr_obj_name != "_orig_mod": + raise AssertionError(f"Expected '_orig_mod', got '{curr_obj_name}'") + curr_obj = curr_obj._orig_mod + if not skip_compiler_prefix: + fqn_obj_names.append(curr_obj_name) + else: + # In some modules, _fqn_modifiers would not shown in the state_dict keys, + # skip them in the fqn to ensure load stat dict successfully for them. + if hasattr(curr_obj, dsd_fqn_modifiers): + if removed_fqn := getattr(curr_obj, dsd_fqn_modifiers)().get( + curr_obj_name + ): + if hasattr(curr_obj, removed_fqn): + curr_obj = getattr(curr_obj, removed_fqn) + # pyrefly: ignore [bad-argument-type] + fqn_obj_names.append(curr_obj_name) + if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX: + if i != len(obj_names) - 1: + raise RuntimeError("Expect `_extra_state` to be the last obj name") + else: + curr_obj = getattr(curr_obj, curr_obj_name) + + return {".".join(fqn_obj_names).replace(_CHECKPOINT_PREFIX, "")} + + +class _EXTRA_STATE: + pass + + +def _iterate_valid_model_state(model, dsd_fqn_modifiers="_fqn_modifiers"): + visited_modules: set[nn.Module] = set() + + def recurse(module: nn.Module, curr_fqn: str) -> Generator: + visited_modules.add(module) + + curr_fqn = f"{curr_fqn}." if curr_fqn else "" + for name, submodule in module.named_children(): + if submodule in visited_modules: + continue + # if user have state_dict_hooks in their model, they can add the state_dict key changes + # at dsd_fqn_modifiers in input to align with the function of state_dict_hook + if ( + hasattr(module, dsd_fqn_modifiers) + and name in getattr(module, dsd_fqn_modifiers)().values() + ): + # skip _fqn_modifiers here thus remove the last `.` added + new_fqn = curr_fqn[:-1] + else: + new_fqn = f"{curr_fqn}{name}" + yield from recurse(submodule, new_fqn) + + for name, obj in chain( + module.named_buffers(recurse=False), module.named_parameters(recurse=False) + ): + if name in module._non_persistent_buffers_set: + continue + new_fqn = f"{curr_fqn}{name}" + yield new_fqn, obj + + if ( + getattr(module.__class__, "get_extra_state", nn.Module.get_extra_state) + != nn.Module.get_extra_state + ): + new_fqn = f"{curr_fqn}{nn.modules.module._EXTRA_STATE_KEY_SUFFIX}" + yield new_fqn, _EXTRA_STATE() + + yield from recurse(model, "") + + +def _verify_options( + model: nn.Module, + optims: tuple[torch.optim.Optimizer, ...], + optim_only: bool, + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> _StateDictInfo: + """ + Verify the model and options passed by the user and generates _StateDictInfo. + """ + if submodules: + warnings.warn( + "Getting submodules only model/optim state_dict is deprecated and " + "will be removed in 2.5. This feature can be achieved by manually " + "filtering out the state_dict returned from get_state_dict.", + FutureWarning, + stacklevel=2, + ) + if optim_only and not optims: + raise RuntimeError( + "Optimizers are not passed in but optim_only is set to True." + ) + + options = options or StateDictOptions() + + fqn_param_mapping: dict[ + Union[str, torch.Tensor], Union[set[str], torch.Tensor] + ] = {} + shared_params_mapping: dict[ + Union[str, torch.Tensor], Union[set[str], torch.Tensor] + ] = {} + for name, param in _iterate_valid_model_state(model): + if isinstance(param, _EXTRA_STATE): + continue + + fqns = _get_fqns(model, name) + fqn = fqn_param_mapping.get(param) + if fqn is not None: + cast(set[str], fqn_param_mapping[param]).update(fqns) + shared_params_mapping[param] = fqn_param_mapping[param] + else: + # We need to do copy as _get_fqns is lru_cached + fqn_param_mapping[param] = fqns.copy() + for fqn in fqns: + if not isinstance(param, _EXTRA_STATE): + fqn_param_mapping[fqn] = param + + for param_, fqns_ in list(shared_params_mapping.items()): + for fqn in fqns_: + shared_params_mapping[fqn] = cast(torch.Tensor, param_) + + submodule_prefixes: set[str] = set() + if submodules: + submodules = set(submodules) + for name, module in model.named_modules(): + if module not in submodules: + continue + fqns = _get_fqns(model, name) + if len(fqns) != 1: + raise AssertionError("Submodule FQN should only have 1 instance") + submodule_prefixes.update(f"{fqn}." for fqn in fqns) + + if options.broadcast_from_rank0 and not options.full_state_dict: + raise ValueError( + "full_state_dict must be True when broadcast_from_rank0 is True." + ) + fsdp_modules = FSDP.fsdp_modules(model) + state_dict_config: StateDictConfig + optim_state_dict_config: OptimStateDictConfig + fsdp_context: Callable + if fsdp_modules: + # FSDP API only work if at least one FSDP instance exists. + if options.full_state_dict: + state_dict_config = FullStateDictConfig( + offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload + ) + optim_state_dict_config = FullOptimStateDictConfig( + offload_to_cpu=options.cpu_offload, + rank0_only=(options.cpu_offload or options.broadcast_from_rank0), + ) + state_dict_type = StateDictType.FULL_STATE_DICT + else: + state_dict_config = ShardedStateDictConfig( + offload_to_cpu=options.cpu_offload, + ) + optim_state_dict_config = ShardedOptimStateDictConfig( + offload_to_cpu=options.cpu_offload, + ) + state_dict_type = StateDictType.SHARDED_STATE_DICT + + @contextlib.contextmanager + def fsdp_state_dict_type_without_warning( + module, + state_dict_type, + state_dict_config, + optim_state_dict_config, + ): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="FSDP.state_dict_type", category=FutureWarning + ) + with FSDP.state_dict_type( + module=module, + state_dict_type=state_dict_type, + state_dict_config=state_dict_config, + optim_state_dict_config=optim_state_dict_config, + ): + yield + + fsdp_context = functools.partial( + fsdp_state_dict_type_without_warning, + module=model, + state_dict_type=state_dict_type, + state_dict_config=state_dict_config, + optim_state_dict_config=optim_state_dict_config, + ) + else: + fsdp_context = contextlib.nullcontext + + return _StateDictInfo( + **asdict(options), + fqn_param_mapping=fqn_param_mapping, + shared_params_mapping=shared_params_mapping, + submodule_prefixes=submodule_prefixes, + fsdp_context=fsdp_context, + fsdp_modules=cast(list[nn.Module], fsdp_modules), + handle_model=not optim_only, + handle_optim=(len(optims) > 0), + ) + + +def _verify_state_dict( + model_state_dict: dict[str, ValueType], + optim_state_dict: OptimizerStateType, + info: _StateDictInfo, +) -> None: + for module in info.fsdp_modules: + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + if fsdp_state is None: + raise AssertionError("Expected a fsdp_state with a fsdp module.") + + # Verify if the model_state_dict and optim_state_dict are valid. This API + # should give the users an explicit error message to debug or report. + if ( + info.handle_model + and not model_state_dict + and not info.submodule_prefixes + and not info.ignore_frozen_params + and not (info.cpu_offload and info.full_state_dict) + and info.strict + and not info.broadcast_from_rank0 + ): + raise RuntimeError( + "The option indicates that model state_dict is required to save " + "or load, but model state_dict is empty." + f"rank = {dist.get_rank()=}." + ) + + if info.handle_optim: + if ( + not optim_state_dict + and not (info.cpu_offload and info.full_state_dict) + and (not info.broadcast_from_rank0) + ): + raise RuntimeError( + "The option indicates that model state_dict is required to save, " + f"or load but optim state_dict is empty. {optim_state_dict}" + ) + + for key in model_state_dict: + if _FLAT_PARAM in key: + raise RuntimeError( + f"{key} contains {_FLAT_PARAM}. This can happen if the model " + "is not the root module." + ) + + +def _state_dict_fn(obj: Union[nn.Module, torch.optim.Optimizer], api: str) -> Callable: + call = getattr(obj, api) + if call in _patched_state_dict: + call = functools.partial(getattr(obj.__class__, api), self=obj) + return call + + +def _maybe_full_or_cpu_state_dict( + state_dict: dict[str, Any], info: _StateDictInfo +) -> dict[str, Any]: + if info.full_state_dict: + ranks_only = ( + () + if (not info.cpu_offload or not torch.distributed.is_initialized()) + else (0,) + ) + return _gather_state_dict( + state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only + ) + elif info.cpu_offload: + return _offload_state_dict_to_cpu(state_dict) + else: + return state_dict + + +@torch.no_grad() +def _get_model_state_dict( + model: nn.Module, info: _StateDictInfo +) -> dict[str, ValueType]: + if not info.handle_model: + return {} + + with info.fsdp_context(): + state_dict = _state_dict_fn(model, "state_dict")() + + for key in list(state_dict.keys()): + fqns = _get_fqns(model, key) + if len(fqns) != 1: + raise AssertionError( + f"Expected 1 FQN for key '{key}', got {len(fqns)}: {fqns}" + ) + fqn = next(iter(fqns)) + if fqn != key: + # As we only support FSDP, DDP, and TP, the only cases are + # wrapper-based DDP and compiler. Verify if the assumption + # is correct. + def verify(key, fqn) -> bool: + if len(fqn) >= len(key): + return False + fqn_split = fqn.split(".") + key_split = key.split(".") + fqn_idx = 0 + for key_idx, key_name in enumerate(key_split): + if key_name == fqn_split[fqn_idx]: + fqn_idx += 1 + if fqn_idx == len(fqn_split): + return key_idx == len(key_split) - 1 + elif key_name in ("module", "_orig_mod"): + continue + else: + return False + return True + + if not verify(key, fqn): + raise RuntimeError(f"An unexpected key, {key}, exists. FQN is {fqn}") + state_dict[fqn] = state_dict.pop(key) + + if info.submodule_prefixes: + new_state_dict: dict[str, ValueType] = {} + # TODO: make this faster. + for fqn in state_dict: + for prefix in info.submodule_prefixes: + if not fqn.startswith(prefix): + continue + if info.keep_submodule_prefixes: + new_state_dict[fqn] = state_dict[fqn] + else: + new_fqn = fqn[len(prefix) :] + new_state_dict[new_fqn] = state_dict[fqn] + state_dict = new_state_dict + + if info.ignore_frozen_params: + for key, param in model.named_parameters(): + if param.requires_grad: + continue + fqns = _get_fqns(model, key) + for fqn in fqns: + state_dict.pop(fqn) + + return _maybe_full_or_cpu_state_dict(state_dict, info) + + +@torch.no_grad() +def _load_model_state_dict( + model: nn.Module, + state_dict: dict[str, ValueType], + info: _StateDictInfo, +) -> _IncompatibleKeys: + if not info.handle_model or (not state_dict and not info.broadcast_from_rank0): + return _IncompatibleKeys({}, {}) + + local_state_dict = {} + for key, value in _iterate_valid_model_state(model, info.dsd_fqn_modifiers): + fqns = _get_fqns(model, key, info.dsd_fqn_modifiers) + fqns_with_prefix = _get_fqns( + model, + key, + info.dsd_fqn_modifiers, + skip_ddp_prefix=False, + skip_compiler_prefix=False, + ) + + for fqn, fqn_with_prefix in zip(fqns, fqns_with_prefix): + if ( + not info.broadcast_from_rank0 or dist.get_rank() == 0 + ) and fqn != fqn_with_prefix: + load_value = state_dict.pop(fqn, None) + if load_value is None: + if info.strict: + raise RuntimeError(f"Missing key: {fqn}.") + else: + state_dict[fqn_with_prefix] = load_value + local_state_dict[fqn_with_prefix] = value + + assign = False + if info.broadcast_from_rank0 or info.full_state_dict: + devices = set() + for value in local_state_dict.values(): + if torch.is_tensor(value) and value.dim() > 0: + devices.add(value.device) + # In lora state_dict, there could be multiple devices, with meta device inside. + # Take the other device in the broadcast/distribtue, and set assign to True + if torch.device("meta") in devices: + devices.remove(torch.device("meta")) + assign = True + if len(devices) == 0: + devices.add(dist.distributed_c10d._get_pg_default_device()) + elif len(devices) > 1: + raise ValueError("Multiple devices found") + + if info.broadcast_from_rank0: + _broadcast_state_dict( + state_dict, + local_state_dict, + device=devices.pop(), + strict=info.strict, + cpu_offload=info.cpu_offload, + ) + elif info.full_state_dict: + _distribute_state_dict(state_dict, local_state_dict, device=devices.pop()) + state_dict.update(local_state_dict) + + with info.fsdp_context(): + return cast( + _IncompatibleKeys, + _state_dict_fn(model, "load_state_dict")( + state_dict=state_dict, strict=info.strict, assign=assign + ), + ) + + +def _init_optim_state(optim: torch.optim.Optimizer) -> None: + """ + Initialize optim states by calling the step() with zero grads. + """ + if optim.state: + # The optimizer state is initialized. + return + + # There are some stateless optimizers like SGD. These optimizer will + # not return in the above condition. So if gradients exist, we should also + # return. If gradients do not exist, the following initialization should + # not disturb SGD because the gradients and lr are both zero. + for param_group in optim.param_groups: + for param in param_group[_PARAMS]: + if param.grad is not None: + return + + for param_group in optim.param_groups: + for param in param_group[_PARAMS]: + if param.requires_grad: + param.grad = torch.zeros_like(param) + + # Some optimizers will update parameters regardless of grads due to lr, so + # make lr to zero when calling `step()`. + lrs = [] + for param_group in optim.param_groups: + if "lr" in param_group: + lrs.append(param_group["lr"]) + param_group["lr"] = ( + torch.tensor(0.0) + if isinstance(param_group["lr"], torch.Tensor) + else 0.0 + ) + optim.step(closure=None) + # Whether to recover the "lr" should not matter too much as we will + # restore checkpointing later. + for param_group in optim.param_groups: + if "lr" in param_group: + param_group["lr"] = lrs.pop(0) + optim.zero_grad(set_to_none=True) + + +def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> dict[str, ValueType]: + """ + This API flattens the optimizer state_dict to support optimizer resharding for + MPMD, e.g., pipeline parallelism. + + Without the API, the original optimizer state_dict looks like: + { + "state": { + "layer1.weight": { + "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor + }, + "layer2.weight": { + "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor + }, + }, + "param_groups": [ + { + "lr": 0.0, + "betas": (0.9, 0.95), ..., + "params": ["layer1.weight", "layer2.weight"] + } + ] + } + + With this API, the optimizer state_dict looks like: + { + "state.layer1.weight.step": 10, + "state.layer2.weight.step": 10, + "state.layer1.weight.exp_avg": SomeTensor, + "state.layer2.weight.exp_avg": SomeTensor, + "state.layer1.weight.exp_avg_sq": SomeTensor, + "state.layer2.weight.exp_avg_sq": SomeTensor, + "param_groups.layer1.weight.lr": 0.1, + "param_groups.layer2.weight.lr": 0.1, + "param_groups.layer1.weight.betas": (0.9, 0.95), + "param_groups.layer2.weight.betas": (0.9, 0.95), + } + + The "state" section supports arbitrary levels of nesting for optimizers like Shampoo. + """ + + def _flatten_state_nested_dict( + nested_dict: dict[str, Any], prefix: str + ) -> dict[str, ValueType]: + """ + Recursively flatten a nested dictionary with dot-separated keys. + + Args: + nested_dict: The dictionary to flatten + prefix: The prefix to prepend to all keys + + Returns: + Flattened dictionary with dot-separated keys + """ + flattened: dict[str, ValueType] = {} + + for key, value in nested_dict.items(): + # Convert all keys to strings for flattening + str_key = str(key) + full_key = f"{prefix}.{str_key}" if prefix else str_key + + if isinstance(value, dict): + # Recursively flatten nested dictionaries + flattened.update(_flatten_state_nested_dict(value, full_key)) + else: + # Base case: store the value with the flattened key + _raise_if_type_not_supported(value) + flattened[full_key] = value + + return flattened + + def _raise_if_type_not_supported(v): + if not isinstance(v, (torch.Tensor, int, float, dict)): + raise NotImplementedError( + "Flattening optimizer state_dict only supports " + "tensor, int, float, dict states now. " + f"Type is {type(v)}." + ) + + ret: dict[str, ValueType] = {} + + # Handle the "state" section with recursive flattening + for fqn, state in cast(DictValueType, state_dict[_STATE]).items(): + state_prefix = f"{_STATE}.{fqn}" + ret.update( + _flatten_state_nested_dict(cast(dict[str, Any], state), state_prefix) + ) + + # Handle the "param_groups" section with two-level flattening + for param_group in cast(ListDictValueType, state_dict[_PG]): + fqns = param_group.pop(_PARAMS) + for fqn in cast(list[str], fqns): + for k, v in param_group.items(): + ret[f"{_PG}.{fqn}.{k}"] = v + + return ret + + +def _unflatten_optim_state_dict( + optim: torch.optim.Optimizer, + state_dict: dict[str, ValueType], + info: _StateDictInfo, +) -> OptimizerStateType: + """ + This API unflattens the state_dict generated by _flatten_optim_state_dict(). + Supports arbitrary levels of nesting in the state section through recursive reconstruction. + + See the docstring of _flatten_optim_state_dict() for more detail. + """ + + def _reconstruct_nested_dict( + flattened_key: str, flattened_dict: dict[str, ValueType] + ) -> dict[str, ValueType]: + """ + Reconstructs a potentially nested value from flattened keys. + For non-nested values, returns the value directly. + For nested values, reconstructs the nested structure with string keys. + """ + + # Create the prefix to search for nested keys + # e.g., if flattened_key is "state.layer1.weight", prefix becomes "state.layer1.weight." + prefix = f"{flattened_key}." + # Initialize an empty dictionary to build our nested structure + nested_dict: dict[str, Any] = {} + + # Iterate through all keys in the flattened dictionary + for key, value in flattened_dict.items(): + # Check if this key is nested under our target key + # e.g., "state.layer1.weight.exp_avg" starts with "state.layer1.weight." + if not key.startswith(prefix): + # Skip keys that don't belong to this nested structure + continue + + # Remove the prefix to get just the nested part + # e.g., "state.layer1.weight.exp_avg" -> "exp_avg" + remaining_key = key[len(prefix) :] + # Split the remaining key into parts to build the nested structure + # e.g., "step" -> ["step"] or "momentum_buffer" -> ["momentum_buffer"] + parts = remaining_key.split(".") + # Start at the root of our new nested dictionary + current = nested_dict + + # Navigate through or create the nested dictionary structure + # For each part except the last one (which will hold the value) + for part in parts[:-1]: + # Create the nested dictionary if it doesn't exist yet + if part not in current: + current[part] = {} + # Move deeper into the nested structure + assert isinstance(current[part], dict) + current = current[part] + + # Set the value at the final level using the last part as the key + # e.g., current["exp_avg"] = tensor(...) + current[parts[-1]] = value + + # Return the reconstructed nested dictionary (empty dict if no keys matched at all) + return nested_dict + + state: DictValueType = {} + pg_state: ListDictValueType = [] + return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} + + for param_group in optim.param_groups: + pg_state.append({_PARAMS: []}) + for param in param_group[_PARAMS]: + for fqn in info.fqn_param_mapping[param]: + # If a parameter is shared, only one of the FQN will be used. + # So we need to verify which if this fqn is actually used in + # the state_dict. + if fqn in info.shared_params_mapping: + in_params = False + for k in param_group: + if k == _PARAMS: + continue + flatten_key = f"{_PG}.{fqn}.{k}" + if flatten_key in state_dict: + in_params = True + break + else: + in_params = True + + if not in_params: + continue + + params = pg_state[-1][_PARAMS] + if not isinstance(params, list): + raise AssertionError(f"Expected list, got {type(params)}") + params.append(fqn) + + # Only add state if param requires grad + if not param.requires_grad: + continue + + # Reconstruct state for this parameter + state[fqn] = {} + for state_name in optim.state[param]: + flattened_state_key = f"{_STATE}.{fqn}.{state_name}" + + if flattened_state_key not in state_dict: + # Try to reconstruct the value + reconstructed_value = _reconstruct_nested_dict( + flattened_state_key, state_dict + ) + cast(DictValueType, state[fqn])[state_name] = ( + reconstructed_value + ) + else: + # Existing keys mean no nesting, directly use the value. + cast(DictValueType, state[fqn])[state_name] = state_dict[ + flattened_state_key + ] + + first_param_fqn = cast(list[str], pg_state[-1][_PARAMS])[0] + for k in param_group: + if k == _PARAMS: + continue + value = state_dict[f"{_PG}.{first_param_fqn}.{k}"] + if k not in pg_state[-1]: + pg_state[-1][k] = value + elif pg_state[-1][k] != value: + raise RuntimeError( + "All the parameters in the same parameter group should have " + f"the same saved param_group value. But {first_param_fqn}.{k} " + f"is {value} while other(s) is {pg_state[-1][k]}." + ) + + return return_osd + + +@torch.no_grad() +def _get_optim_state_dict( + model: nn.Module, + optimizers: tuple[torch.optim.Optimizer, ...], + info: _StateDictInfo, +) -> OptimizerStateType: + if not info.handle_optim: + return {} + + optim_state_dict: OptimizerStateType = {_STATE: {}, _PG: []} + for optim in optimizers: + _init_optim_state(optim) + osd = _state_dict_fn(optim, "state_dict")() + if info.fsdp_modules: + with info.fsdp_context(): + osd = FSDP.optim_state_dict(model, optim, osd) + + # We need to specially handle FlatParameter FSDP as + # FlatParameter FSDP converts the FQNs. + # There are no easy ways to do this conversion systematically. + # We can only use a string replacement without correctness check. + if not osd: + continue + for k in list(osd[_STATE].keys()): + if "_orig_mod" in k: + osd[_STATE][k.replace("_orig_mod.", "")] = osd[_STATE].pop(k) + for g in osd[_PG]: + params = [k.replace("_orig_mod.", "") for k in g[_PARAMS]] + g[_PARAMS] = params + else: + params = list(chain.from_iterable(g[_PARAMS] for g in optim.param_groups)) + param_pid_mapping = dict(zip(params, range(len(params)))) + fqn_pid_mapping = {} + for key, param in model.named_parameters(): + fqns = _get_fqns(model, key) + if len(fqns) != 1: + raise AssertionError( + f"Expected 1 FQN for key '{key}', got {len(fqns)}" + ) + fqn = next(iter(fqns)) + if param not in param_pid_mapping: + continue + pid = param_pid_mapping[param] + fqn_pid_mapping[fqn] = pid + fqn_pid_mapping[pid] = fqn + + # Only convert top-level parameter IDs to FQNs, preserve nested key types + for key in list(osd[_STATE].keys()): + fqn = fqn_pid_mapping[key] + # Move the entire state dict value (which may contain nested integer keys) + # without modifying its internal structure + osd[_STATE][fqn] = osd[_STATE].pop(key) + + for group in osd[_PG]: + group[_PARAMS] = [fqn_pid_mapping[pid] for pid in group[_PARAMS]] + + if not osd: + continue + + cast(DictValueType, optim_state_dict[_STATE]).update(osd[_STATE]) + cast(ListDictValueType, optim_state_dict[_PG]).extend(osd[_PG]) + + if info.flatten_optimizer_state_dict: + optim_state_dict = cast( + OptimizerStateType, _flatten_optim_state_dict(optim_state_dict) + ) + + return _maybe_full_or_cpu_state_dict(optim_state_dict, info) + + +def _split_optim_state_dict( + model: nn.Module, + optim: torch.optim.Optimizer, + optim_state_dict: OptimizerStateType, + info: _StateDictInfo, +) -> OptimizerStateType: + """ + Extract the corresponding optim state_dict from ``optim_state_dict`` for + ``optim`` and return the result optim state_dict. + + Args: + model (nn.Module): the root model. + optim (torch.optim.Optimizer): the optimizer. + optim_state_dict (Dict[str, ValueType]): the superset optim state_dict that + contains the optim state_dict of ``optim``. + info (_StateDictInfo): state dict information. + + Returns: + The optim state_dict of ``optim``. + """ + + state: DictValueType = {} + pg_state: ListDictValueType = [] + return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} + pg_mapping: dict[int, int] = {} + + if all(isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE])): + return optim_state_dict + + for param_group in optim.param_groups: + pg_state.append({_PARAMS: []}) + for param in param_group[_PARAMS]: + for fqn in info.fqn_param_mapping[param]: + if fqn in info.shared_params_mapping: + in_params = False + for loaded_param_group in cast( + ListDictValueType, optim_state_dict[_PG] + ): + if fqn in cast(list[str], loaded_param_group[_PARAMS]): + in_params = True + break + else: + in_params = True + if not in_params: + continue + + params = pg_state[-1][_PARAMS] + if not isinstance(params, list): + raise AssertionError(f"Expected list, got {type(params)}") + params.append(fqn) + if param.requires_grad: + if fqn in cast(DictValueType, optim_state_dict[_STATE]): + state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn] + elif info.strict: + raise RuntimeError( + f"Missing optimizer state for parameter '{fqn}' in checkpoint. " + "The parameter requires gradients but has no saved optimizer state. " + "To load anyway, use StateDictOptions(strict=False)." + ) + for loaded_param_group in cast( + ListDictValueType, optim_state_dict[_PG] + ): + if fqn in cast(list[str], loaded_param_group[_PARAMS]): + pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1 + + if len(param_group[_PARAMS]) == 0: + # Param_group with empty params. + ret = [] + for loaded_param_group in cast(ListDictValueType, optim_state_dict[_PG]): + if len(cast(list[str], loaded_param_group[_PARAMS])) == 0: + ret.append(loaded_param_group) + if len(ret) != 1: + raise ValueError( + "There are param groups that have zero parameters. " + "In such a case, DSD only support exactly one param group " + "with zero parameters." + "But the loaded state_dict has zero or more than one param groups " + "that have zero parameters." + ) + if len(optim_state_dict[_PG]) != len(optim.param_groups): + raise ValueError( + "When there is a parameter group that has zero parameters, " + "multiple optimizers are not supported." + ) + pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1 + + for param_group in cast(ListDictValueType, optim_state_dict[_PG]): + pg_idx = pg_mapping.get(id(param_group), -1) + if pg_idx == -1: + continue + + for key, value in param_group.items(): + if key == _PARAMS: + continue + # TODO: check if value is the same if exists. + pg_state[pg_idx][key] = value + + return return_osd + + +@torch.no_grad() +def _load_optim_state_dict( + model: nn.Module, + optimizers: tuple[torch.optim.Optimizer, ...], + state_dict: OptimizerStateType, + info: _StateDictInfo, +) -> None: + if not info.handle_optim: + return + + for optim in optimizers: + _init_optim_state(optim) + if state_dict: + if _STATE in state_dict: + optim_state_dict = _split_optim_state_dict( + model, optim, state_dict, info + ) + else: + optim_state_dict = _unflatten_optim_state_dict( + optim, cast(dict[str, ValueType], state_dict), info + ) + else: + optim_state_dict = {} + if info.fsdp_modules: + # We need to specially handle FlatParameter FSDP as + # FlatParameter FSDP converts the FQNs. + for original_fqn, _ in model.named_parameters(): + fqns = _get_fqns(model, original_fqn) + fqns_with_compiler = _get_fqns( + model, original_fqn, skip_compiler_prefix=False + ) + if fqns == fqns_with_compiler: + continue + + if len(fqns) != 1: + raise AssertionError( + f"Expected 1 FQN for '{original_fqn}', got {len(fqns)}" + ) + fqn = fqns.pop() + fqn_with_compiler = fqns_with_compiler.pop() + for g in optim_state_dict[_PG]: + val = cast(dict[str, Any], g) + params = [ + key.replace(fqn, fqn_with_compiler) for key in val[_PARAMS] + ] + val[_PARAMS] = params + osd_state = cast(DictValueType, optim_state_dict[_STATE]) + for k in list(osd_state.keys()): + if fqn in k: + osd_state[k.replace(fqn, fqn_with_compiler)] = osd_state.pop(k) + + with info.fsdp_context(): + optim_state_dict = FSDP.optim_state_dict_to_load( + model, optim, optim_state_dict + ) + elif info.full_state_dict: + info.full_state_dict = False + local_state_dict = _get_optim_state_dict(model, (optim,), info) + info.full_state_dict = True + device = None + + def _device(t): + if t.dim() > 0: + nonlocal device + if device is None: + device = t.device + elif device != t.device: + raise ValueError("Device mismatch") + return t + + _ = tree_map_only(torch.Tensor, _device, local_state_dict) + if device is None: + raise AssertionError("Expected device to be set") + flatten_osd, osd_mapping = _flatten_state_dict(optim_state_dict) + flatten_local_osd, local_osd_mapping = _flatten_state_dict(local_state_dict) + if info.broadcast_from_rank0: + _broadcast_state_dict(flatten_osd, flatten_local_osd, device=device) + else: + _distribute_state_dict(flatten_osd, flatten_local_osd, device=device) + # The modifications listed seek to address the problem where optim might possess + # dissimilar parameters in comparison to optim_state_dict. This is achieved by + # incorporating differential parameters within local, which may result in optim + # having additional parameters ultimately. + for optim_key in flatten_osd: + if optim_key not in flatten_local_osd: + if optim_key not in osd_mapping: + raise AssertionError( + f"Expected key '{optim_key}' in osd_mapping" + ) + flatten_local_osd[optim_key] = flatten_osd[optim_key] + local_osd_mapping[optim_key] = osd_mapping[optim_key] + optim_state_dict = _unflatten_state_dict( + flatten_local_osd, local_osd_mapping + ) + for pg in optim_state_dict[_PG]: + if _PARAMS not in pg: + cast(dict[str, ValueType], pg)[_PARAMS] = [] + + # Note that we do not have to convert the FQN back to param id here if + # order in optim.param_groups[idx][_PARAMS] is the same as the one in + # optim_state_dict[_PG][idx][_PARAMS]. + _state_dict_fn(optim, "load_state_dict")(state_dict=optim_state_dict) + + +def get_model_state_dict( + model: nn.Module, + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> dict[str, ValueType]: + """ + Return the model state_dict of ``model``. + + See ``get_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters + that belong to the submodules. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be returned. See + `StateDictOptions` for the details. + + Returns: + The state_dict for ``model``. + + :rtype: typing.Dict[str, ValueType] + """ + with _gc_context(): + info = _verify_options( + model, + (), + optim_only=False, + submodules=submodules, + options=options, + ) + model_state_dict = _get_model_state_dict(model, info) + _verify_state_dict(model_state_dict, {}, info) + return model_state_dict + + +def get_optimizer_state_dict( + model: nn.Module, + optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> OptimizerStateType: + """ + Return the combined state_dict for optimizers. + + See ``get_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[None, Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters + that belong to the submodules. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be returned. See + `StateDictOptions` for the details. + + Returns: + The state_dict for ``optimizers``. + + :rtype: OptimizerStateType + """ + with _gc_context(): + optimizers = ( + (optimizers,) + if isinstance(optimizers, torch.optim.Optimizer) + else tuple(optimizers) + ) + info = _verify_options( + model, + optimizers, + optim_only=True, + submodules=submodules, + options=options, + ) + optim_state_dict = _get_optim_state_dict(model, optimizers, info) + _verify_state_dict({}, optim_state_dict, info) + return optim_state_dict + + +def get_state_dict( + model: nn.Module, + optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> tuple[dict[str, ValueType], OptimizerStateType]: + """ + Return the model state_dict and optimizers state_dict. + + ``get_state_dict`` can process any module that is parallelized by PyTorch + FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any + combination of these parallelisms. The main functions of ``get_state_dict`` + are: 1.) returning a model and optimizer state_dict that can be resharded + with a different number of trainers and/or different parallelisms. + 2.) hiding the parallelism-specific state_dict APIs. Users don't have to call + these APIs. + 3.) sanity checking the result state_dict. + + The keys of the result state dictionary are the canonical FQNs (Fully + Qualified Names). A canonical FQN refers to the FQN based on a parameter's + position in an nn.Module hierarchy. More specifically, a canonical FQN to a + parameter is the FQN returned by ``module.named_parameters()`` or + ``module.named_buffers()`` when the module is not distributed by any + parallelisms. Since the optimizer internally uses parameter IDs to represent + a parameter, there will be a conversion from the parameter IDs to the + canonical FQNs when calling this API. + + ``get_state_dict`` can also process a module that is not parallelized. In + such a case, ``get_state_dict`` only performs one function -- converting the + optimizer parameter IDs to the canonical FQNs. + + Example: + >>> # xdoctest: +SKIP + >>> import torch + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> from torch.nn.parallel import DistributedDataParallel as DDP + >>> from torch.distributed.checkpoint.state_dict import get_state_dict + + >>> fsdp_model = FSDP(copy.deepcopy(model)) + >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) + >>> ddp_model = DDP(copy.deepcopy(model)) + >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) + + + >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) + >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict( + ... fsdp_model, fsdp_optim + ... ) + + >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), + >>> # the asserts will fail. + >>> assert ddp_state_dict == fsdp_state_dict + >>> assert ddp_optim_state == fsdp_optim_state_dict + + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[None, Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters + that belong to the submodules. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be returned. See + `StateDictOptions` for the details. + + Returns: + ``Tuple`` that contain model state_dict and optimizer state_dict. + + :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType] + """ + + with _gc_context(): + optimizers = ( + (optimizers,) + if isinstance(optimizers, torch.optim.Optimizer) + else tuple(optimizers) + ) + info = _verify_options( + model, + optimizers, + optim_only=False, + submodules=submodules, + options=options, + ) + model_state_dict = _get_model_state_dict(model, info) + optim_state_dict = _get_optim_state_dict(model, optimizers, info) + _verify_state_dict(model_state_dict, optim_state_dict, info) + return model_state_dict, optim_state_dict + + +def _unflatten_model_state_dict( + model: nn.Module, + state_dict: Union[dict[nn.Module, dict[str, ValueType]], dict[str, ValueType]], +) -> dict[str, ValueType]: + if not state_dict: + return {} + + if isinstance(next(iter(state_dict.keys())), nn.Module): + warnings.warn( + "Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``" + "is deprecated and will be removed in 2.5. If you need this " + "feature, please preprocessing the model_state_dict to achieve the " + "same functionality.", + FutureWarning, + stacklevel=2, + ) + cast_state_dict = cast(dict[nn.Module, dict[str, ValueType]], state_dict) + new_state_dict: dict[str, ValueType] = {} + for submodule, sub_state_dict in cast_state_dict.items(): + for name, m in model.named_modules(): + if m != submodule: + continue + + fqns = _get_fqns(model, name) + if len(fqns) != 1: + raise AssertionError( + "FQNs for a submodule should only have 1 element" + ) + prefix = f"{next(iter(fqns))}." + new_state_dict.update( + {prefix + subfqn: value for subfqn, value in sub_state_dict.items()} + ) + return new_state_dict + else: + return cast(dict[str, ValueType], state_dict) + + +def set_model_state_dict( + model: nn.Module, + model_state_dict: dict[str, ValueType], + *, + options: Optional[StateDictOptions] = None, +) -> _IncompatibleKeys: + """Load the model state_dict. + + The counterpart of ``get_model_state_dict`` to set the state_dict to the + model. See ``set_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + model_state_dict: (Dict[str, ValueType]): + the model state_dict to load. If the key of the ``model_state_dict`` + is nn.Module, the key is a submodule of ``model`` and the value should + be the state_dict of the submodule. When loading the state_dict, + the prefix of the submodule will be append to the state_dict. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + :type model_state_dict: typing.Dict[str, ValueType] + """ + model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict( + model, model_state_dict + ) + with _gc_context(): + info = _verify_options(model, (), optim_only=False, options=options) + + _verify_state_dict(model_state_dict, {}, info) + return _load_model_state_dict(model, model_state_dict, info) + + +def set_optimizer_state_dict( + model: nn.Module, + optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], + optim_state_dict: OptimizerStateType, + *, + options: Optional[StateDictOptions] = None, +) -> None: + """Load the optimizers state_dict. + + The counterpart of ``get_optimizer_state_dict`` to set the state_dict to the + optimizers. See ``set_state_dict`` for the detail usage. + + WARN: ``set_optimizer_state_dict`` can only be called before ``backward()`` or after + ``step()`` is called on the optimizers. Otherwise, the optimizer states won't be + initialized correctly. + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + optim_state_dict: OptimizerStateType: + the optimizer state_dict to load. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + None + + :type optim_state_dict: typing.OptimizerStateType + """ + with _gc_context(): + optimizers = ( + (optimizers,) + if isinstance(optimizers, torch.optim.Optimizer) + else tuple(optimizers) + ) + info = _verify_options(model, optimizers, optim_only=True, options=options) + + _verify_state_dict({}, optim_state_dict, info) + _load_optim_state_dict(model, optimizers, optim_state_dict, info) + + +def set_state_dict( + model: nn.Module, + optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], + *, + model_state_dict: dict[str, ValueType], + optim_state_dict: OptimizerStateType, + options: Optional[StateDictOptions] = None, +) -> _IncompatibleKeys: + """Load the model state_dict and optimizers state_dict. + + The counterpart of ``get_state_dict`` to set the state_dict to the model and + optimizers. The given ``model_state_dict`` and ``optim_state_dict`` do not + have to be returned by ``get_state_dict`` but must meet the following + requirements: 1) all FQNs are canonical FQNs as defined in ``get_state_dict``, + 2) if a tensor is sharded, it must be either a ShardedTensor or DTensor, + 3) optimizer state_dict cannot contain the parameter IDs; the keys should be + the canonical FQNs. + + WARN: ``set_state_dict`` can only be called before ``backward()`` or after ``step()`` + is called on the optimizers. Otherwise, the optimizer states won't be initialized + correctly. + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + model_state_dict: (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): + the model state_dict to load. If the key of the ``model_state_dict`` + is nn.Module, the key is a submodule of ``model`` and the value should + be the state_dict of the submodule. When loading the state_dict, + the prefix of the submodule will be append to the state_dict. + optim_state_dict: OptimizerStateType: + the optimizer state_dict to load. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys of the model state_dict. + * **unexpected_keys** is a list of str containing the unexpected keys of the model state_dict. + + :type model_state_dict: typing.Dict[str, ValueType] + :type optim_state_dict: typing.OptimizerStateType + """ + + model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict( + model, model_state_dict + ) + with _gc_context(): + optimizers = ( + (optimizers,) + if isinstance(optimizers, torch.optim.Optimizer) + else tuple(optimizers) + ) + info = _verify_options( + model, optimizers, optim_only=not model_state_dict, options=options + ) + + _verify_state_dict(model_state_dict, optim_state_dict, info) + _load_optim_state_dict(model, optimizers, optim_state_dict, info) + return _load_model_state_dict(model, model_state_dict, info) + + +# TODO: correct the state_dict function signature. +# TODO: this API is not yet fully tested. Make it private +@no_type_check +def _patch_model_state_dict( + model: nn.Module, + *, + options: Optional[StateDictOptions] = None, +) -> None: + """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model``. + + Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model`` to + be a partial function to call ``get_state_dict`` and ``set_state_dict``. + + Example: + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.checkpoint.state_dict import patch_model_state_dict + + model = fsdp(model) + patch_model_state_dict(model) + + Args: + model (nn.Module): the nn.Module to the model. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + Returns: + None + """ + + _state_dict_call = functools.partial( + get_model_state_dict, + model=model, + options=options, + ) + + def state_dict_call(): + return _state_dict_call() + + model.state_dict = state_dict_call + + _load_state_dict_call = functools.partial( + set_model_state_dict, + model=model, + options=options, + ) + + def load_state_dict_call(state_dict: dict[str, Any]): + _load_state_dict_call(model_state_dict=state_dict) + + model.load_state_dict = load_state_dict_call + + _patched_state_dict.add(state_dict_call) + _patched_state_dict.add(load_state_dict_call) + + +# TODO: correct the load_state_dict function signature. +# TODO: this API is not yet fully tested. Make it private +@no_type_check +def _patch_optimizer_state_dict( + model: nn.Module, + *, + optimizers: tuple[torch.optim.Optimizer, ...], + options: Optional[StateDictOptions] = None, +) -> None: + """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers``. + + Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers`` to + be a partial function to call ``get_state_dict`` and ``set_state_dict``. + + Note that if there are multiple optimizers, all of the optimizers will be patched. + So users only need to call one of the state_dict() to get the full result. + + Example: + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.checkpoint.state_dict import patch_model_state_dict + + model = fsdp(model) + patch_model_state_dict(model) + + Args: + model (nn.Module): the nn.Module to the model. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + Returns: + None + """ + + _state_dict_call = functools.partial( + get_optimizer_state_dict, + model=model, + optimizers=optimizers, + options=options, + ) + + def state_dict_call(): + return _state_dict_call() + + _load_state_dict_call = functools.partial( + set_optimizer_state_dict, + model=model, + optimizers=optimizers, + options=options, + ) + + def load_state_dict_call(state_dict: dict[str, Any]): + _load_state_dict_call(optim_state_dict=state_dict) + + _patched_state_dict.add(state_dict_call) + _patched_state_dict.add(load_state_dict_call) + optimizers = ( + (optimizers,) + if isinstance(optimizers, torch.optim.Optimizer) + else tuple(optimizers) + ) + for optim in optimizers: + optim.state_dict = state_dict_call + optim.load_state_dict = load_state_dict_call diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_loader.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..178e190e937fb5fab1aa582464e20f1cff8d7abf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_loader.py @@ -0,0 +1,389 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import inspect +import logging +import os +import warnings +from typing import Any, cast, Optional, TYPE_CHECKING, Union +from typing_extensions import deprecated + +import torch +import torch.distributed as dist +from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner +from torch.distributed.checkpoint.logger import _dcp_method_logger +from torch.distributed.checkpoint.stateful import Stateful + +from ._storage_utils import _storage_setup +from .default_planner import DefaultLoadPlanner +from .planner import LoadPlan, LoadPlanner +from .storage import StorageReader +from .utils import _api_bc_check, _DistWrapper, _profile + + +if TYPE_CHECKING: + from torch.distributed.checkpoint.metadata import Metadata + +__all__ = ["load_state_dict", "load"] + +logger = logging.getLogger() + + +@deprecated( + "`load_state_dict` is deprecated and will be removed in future versions. " + "Please use `load` instead.", + category=FutureWarning, +) +def load_state_dict( + state_dict: dict[str, Any], + storage_reader: StorageReader, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[LoadPlanner] = None, +) -> None: + """This method is deprecated. Please switch to 'load'.""" + storage_reader.reset() + with _profile(): + # TODO: test returning `load` here instead. + return _load_state_dict( + state_dict, + storage_reader, + process_group, + coordinator_rank, + no_dist, + planner, + ) + + +@_dcp_method_logger(log_exceptions=True) +@_api_bc_check +def load( + state_dict: dict[str, Any], + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_reader: Optional[StorageReader] = None, + planner: Optional[LoadPlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + no_dist: bool = False, +) -> None: + """ + Load a checkpoint into a distributed state dict in SPMD style. + + Each rank must have the same keys in their ``state_dict`` provided to this + API. Mismatched keys may result in hangs or errors. If unsure, you can use + the ``utils._assert_same_keys`` API to check (but may incur communication + costs). + + Each rank will try to read the least amount of data necessary + to fulfill the requested `state_dict`. When loading :class:`ShardedTensor` + or :class:`DTensor` instances, each rank only reads data for their local shards. + + For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``), + load will first call ``state_dict`` before attempting deserialization, followed by + ``load_state_dict`` once the deserialization is complete. + For each non-``Stateful`` object, load will deserialize the object, and then replace + it in the ``state_dict`` with the deserialized object. + + .. warning:: + All tensors in ``state_dict`` must be allocated on their + destination device *prior to* calling this function. + + All non-tensor data is loaded using `torch.load()` and modified in place + on state_dict. + + .. warning:: + Users must call `load_state_dict` on the root module to ensure load + pos-processing and non-tensor data properly propagates. + + .. note: + If no process group is initialized, this function will assume the intent + is to load a checkpoint into the local process. This can be useful in the + case of local inference, and when using regular Tensors (as opposed to DTensor + or ShardedTensor) + + .. note: + Rank 0 is assumed to be the coordinator rank. + + Args: + state_dict (Dict[str, Any]): The state_dict to load the checkpoint into. + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + storage_reader (Optional[StorageReader]): + Instance of StorageWriter used to perform reads. If this is not + specified, DCP will automatically infer the reader based on the + checkpoint_id. If checkpoint_id is also None, an exception will + be raised. (Default: ``None``) + planner (Optional[LoadPlanner]): + Instance of LoadPlanner. If this is not specified, the default + planner will be used. (Default: ``None``) + process_group (Optional[ProcessGroup]): + ProcessGroup to be used for cross-rank synchronization. + (Default: ``None``) + no_dist (bool): If ``True``, this function will assume the intent is to load + a checkpoint without using cross-rank synchronization. (Default: ``False``) + Returns: + None. + + Examples + >>> # xdoctest: +SKIP + >>> my_model = MyModule() + >>> optimizer = Adagrad(my_model.parameters()) + >>> model_state_dict = my_model.state_dict() + >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader( + ... "/checkpoint/1" + ... ) + + >>> torch.distributed.checkpoint.load_state_dict( + >>> state_dict=model_state_dict, + >>> storage_reader=fs_storage_reader, + >>> ) + + >>> # module.load_state_dict() function might have customized steps + >>> # to flush the state_dict, must call it to + >>> # ensure correct behavior. + >>> my_model.load_state_dict(model_state_dict) + + .. note:: + load_state_dict uses collectives to coordinate reads across ranks. + For NCCL-based process groups, internal tensor representations of + objects must be moved to the GPU device before communication takes place. + In this case, the device used is given by ``torch.cuda.current_device()`` + and it is the user's responsibility to ensure that this is set so that each + rank has an individual GPU, via ``torch.cuda.set_device()``. + """ + + no_dist = no_dist or (not dist.is_available()) or (not dist.is_initialized()) + if no_dist: + warnings.warn( + "torch.distributed is disabled, unavailable or uninitialized, assuming the intent is to load in a single process.", + stacklevel=2, + ) + + with _profile(): + storage_reader = cast( + StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True) + ) + + # All ranks must have the same keys in their `state_dict` provided to + # this API. See documentation for more details. + # Here we simply sort the keys to ensure that all ranks load values in + # the same order. + keys = sorted(state_dict.keys()) + + statetful_sd = {} + for key in keys: + if key not in state_dict: + continue + elem = state_dict[key] + statetful_sd[key] = ( + elem.state_dict() if isinstance(elem, Stateful) else elem + ) + + _load_state_dict( + state_dict=statetful_sd, + storage_reader=storage_reader, + process_group=process_group, + no_dist=no_dist, + planner=planner, + ) + for key in keys: + if key not in state_dict: + continue + elem = state_dict[key] + if isinstance(elem, Stateful): + # If the state_dict is a Stateful object, + # DCP does an in-place load in the original state dict. + elem.load_state_dict(statetful_sd[key]) + else: + # Otherwise, replace the state_dict with the loaded state_dict. + state_dict[key] = statetful_sd[key] + + +def _load_state_dict( + state_dict: dict[str, Any], + storage_reader: StorageReader, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[LoadPlanner] = None, +) -> None: + torch._C._log_api_usage_once("torch.distributed.checkpoint.load_state_dict") + + distW = _DistWrapper(process_group, not no_dist, coordinator_rank) + if planner is None: + planner = DefaultLoadPlanner() + + ckpt_kwargs = {} + if (ckpt_id := getattr(storage_reader, "checkpoint_id", None)) is not None: + ckpt_kwargs["checkpoint_id"] = ckpt_id + ckpt_kwargs["process_group"] = distW.group + + use_collectives = True + metadata: Optional[Metadata] = None + + @_dcp_method_logger(**ckpt_kwargs) + def local_step(): + nonlocal use_collectives + nonlocal metadata + + # Use global metadata if available, otherwise fallback to rank local metadata + try: + metadata = storage_reader.read_metadata() + except Exception: + logger.info( + "Global metadata is not found. Falling back to rank local metadata." + ) + + if ( + not metadata + and "kwargs" in inspect.signature(storage_reader.read_metadata).parameters + ): + try: + metadata = storage_reader.read_metadata(rank=distW.rank) # noqa: F841 + use_collectives = False + except Exception: + logger.info("Rank local metadata is not found.") + + if planner is None: + raise AssertionError("planner is None") + if metadata is None: + raise AssertionError("metadata is None") + planner.set_up_planner(state_dict, metadata, distW.is_coordinator) + + if ( + "kwargs" + in inspect.signature(storage_reader.set_up_storage_reader).parameters + ): + storage_reader.set_up_storage_reader( + metadata, + distW.is_coordinator, + rank=distW.rank, + use_collectives=use_collectives, + ) + else: + storage_reader.set_up_storage_reader(metadata, distW.is_coordinator) + + local_plan = planner.create_local_plan() + local_plan = storage_reader.prepare_local_plan(local_plan) + return local_plan + + @_dcp_method_logger(**ckpt_kwargs) + def global_step(all_local_plans): + if planner is None: + raise AssertionError("planner is None") + all_local_plans = planner.create_global_plan(all_local_plans) + all_local_plans = storage_reader.prepare_global_plan(all_local_plans) + return all_local_plans + + central_plan: Optional[LoadPlan] = None + if use_collectives: + central_plan = distW.reduce_scatter("plan", local_step, global_step) + else: + local_plan: LoadPlan = local_step() + global_plan: list[LoadPlan] = global_step([local_plan]) + central_plan = global_plan[0] + + @_dcp_method_logger(**ckpt_kwargs) + def read_data(): + if planner is None: + raise AssertionError("planner is None") + if central_plan is None: + raise AssertionError("central_plan is None") + final_local_plan = planner.finish_plan(central_plan) + all_reads = storage_reader.read_data(final_local_plan, planner) + + all_reads.wait() + return None + + if use_collectives: + _ = distW.all_gather("read", read_data) + else: + read_data() + distW.barrier() + + +def _load_state_dict_from_keys( + keys: Optional[Union[set[str], str]] = None, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_reader: Optional[StorageReader] = None, + process_group: Optional[dist.ProcessGroup] = None, +) -> dict[str, Any]: + """ + Load only the specified keys from the checkpoint, if no keys are specified, the entire + checkpoint will be loaded. Note, this method completely loads the checkpoint into the + current process and is not distributed. + + .. warning:: + + + .. warning:: + + All non-tensor data is loaded using `torch.load()` + + .. note: + As opposed to the usual pattern, this function does not take a state dict as input + and does not load inplace. Instead, a new state dict is directly initialized and read + from file. + + .. note: + If no process group is initialized, this function will assume the intent + is to load a checkpoint into the local process. This can be useful in the + case of local inference, and when using regular Tensors (as opposed to DTensor + or ShardedTensor) + + .. note: + Rank 0 is assumed to be the coordinator rank. + + Args: + keys (Optional[Union[set[str], str]]): + Loads any key specified in this set. If no keys are specified, the entire checkpoint + is loaded. + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + storage_reader (Optional[StorageReader]): + Instance of StorageWriter used to perform reads. If this is not + specified, DCP will automatically infer the reader based on the + checkpoint_id. If checkpoint_id is also None, an exception will + be raised. (Default: ``None``) + process_group (Optional[ProcessGroup]): + ProcessGroup to be used for cross-rank synchronization. + (Default: ``None``) + + Returns: + State dict from specified keys + """ + torch._C._log_api_usage_once( + "torch.distributed.checkpoint._load_state_dict_from_keys" + ) + + no_dist = not (dist.is_available() and dist.is_initialized()) + if no_dist: + warnings.warn( + "torch.distributed is unavailable or uninitialized, assuming the intent is to load in a single process.", + stacklevel=2, + ) + + storage_reader = cast( + StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True) + ) + + if isinstance(keys, str): + keys = {keys} + + sd: dict[str, Any] = {} + _load_state_dict( + state_dict=sd, + storage_reader=storage_reader, + process_group=process_group, + no_dist=no_dist, + planner=_EmptyStateDictLoadPlanner(keys=keys), + ) + + return sd diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_saver.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_saver.py new file mode 100644 index 0000000000000000000000000000000000000000..370f97cd1cd013246563f021749b6537a327b235 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_saver.py @@ -0,0 +1,496 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import inspect +import os +import warnings +from concurrent.futures import Future +from dataclasses import dataclass +from enum import Enum +from typing import cast, Optional, TYPE_CHECKING, Union +from typing_extensions import deprecated + +import torch +import torch.distributed as dist +from torch.distributed._state_dict_utils import STATE_DICT_TYPE +from torch.distributed.checkpoint._async_process_executor import ( + _ProcessBasedAsyncCheckpointExecutor, +) +from torch.distributed.checkpoint._async_thread_executor import ( + _ThreadBasedAsyncCheckpointExecutor, +) +from torch.distributed.checkpoint._storage_utils import _storage_setup +from torch.distributed.checkpoint.default_planner import DefaultSavePlanner +from torch.distributed.checkpoint.logger import _dcp_method_logger +from torch.distributed.checkpoint.metadata import Metadata +from torch.distributed.checkpoint.planner import SavePlan, SavePlanner +from torch.distributed.checkpoint.staging import ( + AsyncStager, + DefaultStager, + StagingOptions, +) +from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.checkpoint.storage import StorageWriter, WriteResult +from torch.distributed.distributed_c10d import _get_default_group + +from .utils import _api_bc_check, _DistWrapper, _profile + + +if TYPE_CHECKING: + from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor + + +__all__ = [ + "save_state_dict", + "save", + "async_save", + "AsyncCheckpointerType", + "AsyncSaveResponse", +] + + +class AsyncCheckpointerType(Enum): + """Enum for async checkpointer type.""" + + THREAD = "thread" + PROCESS = "process" + + +@deprecated( + "`save_state_dict` is deprecated and will be removed in future versions." + "Please use `save` instead.", + category=FutureWarning, +) +def save_state_dict( + state_dict: STATE_DICT_TYPE, + storage_writer: StorageWriter, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[SavePlanner] = None, +) -> Metadata: + """This method is deprecated. Please switch to 'save'.""" + storage_writer.reset() + + # TODO: test returning `save` here instead. + with _profile(): + return _save_state_dict( + state_dict, + storage_writer, + process_group, + coordinator_rank, + no_dist, + planner, + ) + + +@_dcp_method_logger(log_exceptions=True) # type: ignore[arg-type] +@_api_bc_check +def save( + state_dict: STATE_DICT_TYPE, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + no_dist: bool = False, + use_collectives: bool = True, +) -> Metadata: + """ + Save a distributed model in SPMD style. + + This function is different from ``torch.save()`` as it handles + ``ShardedTensor`` , and ``DTensor`` by having each rank only save their local shards. + + For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``), + save will call ``state_dict`` before serialization. + + .. warning:: + There is no guarantees of Backwards Compatibility across PyTorch versions + for saved state_dicts. + + .. warning:: + If using the `process_group` argument, make sure that only its ranks + call `save_state_dict` and that all data in state_dict belong to it. + + .. note:: + When saving checkpoint for FSDP's `ShardingStrategy.HYBRID_SHARD`, only one of + the shard_group should be calling `save_state_dict` and the corresponding process + group needs to be passed in. + + .. note:: + If no process group is available, this function assumes the intention is to save the + state_dict in the local process. + + .. note: + Rank 0 is assumed to be the coordinator rank. + + + Args: + state_dict (Dict[str, Any]): The state_dict to save. + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + storage_writer (Optional[StorageWriter]): + Instance of StorageWriter used to perform writes. If this is not + specified, DCP will automatically infer the writer based on the + checkpoint_id. If checkpoint_id is also None, an exception will + be raised. (Default: ``None``) + planner (Optional[SavePlanner]): + Instance of SavePlanner. If this is not specified, the default + planner will be used. (Default: ``None``) + process_group (Optional[ProcessGroup]): + ProcessGroup to be used for cross-rank synchronization. + (Default: ``None``) + no_dist (bool): + If ``True``, this function will assume the intent is to load + a checkpoint on a single rank/process. + (Default: ``False``) + use_collectives (bool): If ``False``, this function will assume the intent is to save + a checkpoint without using cross-rank synchronization. + (Default: ``True``) + This configuration is experimental and should be used with caution. + It will change the format of the saved checkpoint and may not be backward compatible. + + Returns: + Metadata: Metadata object for the saved checkpoint. + + Example: + >>> # xdoctest: +SKIP + >>> my_model = MyModule() + + >>> state_dict = {"model": my_model} + + >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter( + ... "/checkpoint/1" + ... ) + >>> torch.distributed.checkpoint.save( + >>> state_dict=state_dict, + >>> storage_writer=fs_storage_writer, + >>> ) + + .. note:: + save_state_dict uses collectives to coordinate writes across ranks. + For NCCL-based process groups, internal tensor representations of + objects must be moved to the GPU device before communication takes place. + In this case, the device used is given by ``torch.cuda.current_device()`` + and it is the user's responsibility to ensure that this is set so that + each rank has an individual GPU, via ``torch.cuda.set_device()``. + """ + torch._C._log_api_usage_once("torch.distributed.checkpoint.save") + + no_dist = no_dist or (not dist.is_available()) or (not dist.is_initialized()) + if no_dist: + warnings.warn( + "torch.distributed is disabled, unavailable or uninitialized, assuming the intent is to save in a single process.", + stacklevel=2, + ) + + with _profile(): + storage_writer = cast( + StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False) + ) + + return _save_state_dict( + state_dict=_stateful_to_state_dict(state_dict), + storage_writer=storage_writer, + process_group=process_group, + no_dist=no_dist, + planner=planner, + use_collectives=use_collectives, + ) + + +@dataclass +class AsyncSaveResponse: + """This class contains futures for staging and upload completion. + It is returned by async_save(). + staging_completion is a future that indicates when local copy + of state_dict is complete. + upload_completion is a future that indicates when a checkpoint + completed saving. + """ + + staging_completion: Future[None] + upload_completion: Future[None] + + +@_dcp_method_logger(log_exceptions=True) +def async_save( + state_dict: STATE_DICT_TYPE, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + async_checkpointer_type: AsyncCheckpointerType = AsyncCheckpointerType.THREAD, + async_stager: Optional[AsyncStager] = None, + no_dist: bool = False, + use_collectives: bool = True, +) -> Union[Future, AsyncSaveResponse]: + """Asynchronous version of ``save``. This code first de-stages the state_dict on to the + staging storage (defaults to CPU memory), and then calls the `save` in a separate thread. + + .. warning:: + This feature is experimental and subject to change. + MUST CALL CLOSE AFTER LAST CHECKPOINT IS SAVED + + Args: + state_dict (Dict[str, Any]): The state_dict to save. + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + storage_writer (Optional[StorageWriter]): + Instance of StorageWriter used to perform 'stage' and 'save'. If + this is not specified, DCP will automatically infer the writer based on the + checkpoint_id. If checkpoint_id is also None, an exception will + be raised. (Default: ``None``) + planner (Optional[SavePlanner]): + Instance of SavePlanner. If this is not specified, the default + planner will be used. (Default: ``None``) + process_group (Optional[ProcessGroup]): + ProcessGroup to be used for cross-rank synchronization. + (Default: ``None``) + async_checkpointer_type (AsyncCheckpointerType): + whether to do checkpoint in separate thread or process + (Default: ``AsyncCheckpointerType.THREAD``) + async_stager (AsyncStager): + provides staging implementation. If storage_writer implements AsyncStager + and async_stager is provided, async_stager will be used for staging + no_dist (bool): + If ``True``, this function will assume the intent is to save + a checkpoint on a single rank/process. + (Default: ``False``) + use_collectives: If False, Save the checkpoint without rank coordination. (Default: ``True``) + This configuration is experimental and should be used with caution. + It will change the format of the saved checkpoint and may not be backward compatible. + + Returns: + Future: A future holding the resultant Metadata object from `save`. + + Example: + >>> # xdoctest: +SKIP + >>> my_model = MyModule() + + >>> state_dict = {"model": my_model} + + >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter( + ... "/checkpoint/1" + ... ) + >>> checkpoint_future = torch.distributed.checkpoint.async_save( + >>> state_dict=state_dict, + >>> storage_writer=fs_storage_writer, + >>> ) + >>> + >>> # ... do some work ... + >>> + >>> checkpoint_future.result() + + """ + torch._C._log_api_usage_once("torch.distributed.checkpoint.async_save") + + if dist.is_available() and dist.is_initialized(): + pg = process_group or _get_default_group() + if torch.device("cpu") not in pg._device_types: + raise AssertionError( + "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'" + ) + + if async_stager is None: + if storage_writer is not None and isinstance(storage_writer, AsyncStager): + # bwc with old storage_writers + async_stager = storage_writer + else: + async_stager = DefaultStager( + StagingOptions( + False, + False, + False, + False, + ) + ) + + state_dict = _stateful_to_state_dict(state_dict) + + @_dcp_method_logger(log_exceptions=True) + def stage_state_dict() -> Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE]: + return async_stager.stage(state_dict) + + staging_future_or_state_dict = stage_state_dict() + + upload_executor: _AsyncCheckpointExecutor = ( + _ProcessBasedAsyncCheckpointExecutor() + if async_checkpointer_type == AsyncCheckpointerType.PROCESS + else _ThreadBasedAsyncCheckpointExecutor() + ) + + upload_future: Future = upload_executor.execute_save( + staging_future_or_state_dict, + checkpoint_id=checkpoint_id, + # pyrefly: ignore [bad-argument-type] + storage_writer=storage_writer, + planner=planner, + process_group=process_group, + no_dist=no_dist, + use_collectives=use_collectives, + ) + + if isinstance(staging_future_or_state_dict, Future): + staging_future = staging_future_or_state_dict + return_staging_future: Future[None] = Future() + + def callback( + original_staging_future: Future[STATE_DICT_TYPE], + return_staging_future: Future[None] = return_staging_future, + ): + try: + original_staging_future.result() + return_staging_future.set_result(None) + except Exception as e: + return_staging_future.set_exception(e) + + if not staging_future.done(): + staging_future.add_done_callback(callback) + else: + return_staging_future.set_result(None) + + # return new AsyncSaveResponse for users using new ZOC implementation + return AsyncSaveResponse( + staging_completion=return_staging_future, upload_completion=upload_future + ) + else: + + @_dcp_method_logger(log_exceptions=True) + def maybe_synchronize_staging(): + if async_stager.should_synchronize_after_execute: + async_stager.synchronize_staging() + + maybe_synchronize_staging() + return upload_future + + +@_dcp_method_logger(log_exceptions=True) +def _stateful_to_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + """Creates a shallow copy of `state_dict` where `state_dict` is called for each Stateful object.""" + stateful_state_dict = {} + for key, elem in state_dict.items(): + # Apply _dcp_method_logger to each state_dict() call + def _elem_to_state_dict(elem): + return elem.state_dict() if isinstance(elem, Stateful) else elem + + _elem_to_state_dict.__name__ = f"_stateful_to_state_dict.{key}" + + stateful_state_dict[key] = _dcp_method_logger(log_exceptions=True)( + _elem_to_state_dict + )(elem) + return stateful_state_dict + + +def _save_state_dict( + state_dict: STATE_DICT_TYPE, + storage_writer: StorageWriter, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[SavePlanner] = None, + use_collectives: bool = True, +) -> Metadata: + torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict") + + distW = _DistWrapper(process_group, not no_dist, coordinator_rank) + if planner is None: + planner = DefaultSavePlanner() + if planner is None: + raise AssertionError("planner is None") + + global_metadata = None + + ckpt_kwargs = {} + if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None: + ckpt_kwargs["checkpoint_id"] = ckpt_id + ckpt_kwargs["process_group"] = distW.group + + @_dcp_method_logger(**ckpt_kwargs) + def local_step(): + if planner is None: + raise AssertionError("planner is None") + storage_meta = storage_writer.storage_meta() + if "storage_meta" not in inspect.signature(planner.set_up_planner).parameters: + warnings.warn( + "The function definition for SavePlanner.set_up_planner has been updated" + " to include the storage_meta argument. Please update your implementation" + " to include this parameter.", + stacklevel=2, + ) + planner.set_up_planner(state_dict, distW.is_coordinator) # type: ignore[call-arg, arg-type] + else: + planner.set_up_planner( + state_dict=state_dict, + storage_meta=storage_meta, + is_coordinator=distW.is_coordinator, + ) + + if ( + "kwargs" + in inspect.signature(storage_writer.set_up_storage_writer).parameters + ): + storage_writer.set_up_storage_writer( + distW.is_coordinator, + rank=distW.rank, + use_collectives=use_collectives, + ) + else: + storage_writer.set_up_storage_writer(distW.is_coordinator) + + local_plan = planner.create_local_plan() + local_plan = storage_writer.prepare_local_plan(local_plan) + return local_plan + + @_dcp_method_logger(**ckpt_kwargs) + def global_step(all_local_plans): + nonlocal global_metadata + + if planner is None: + raise AssertionError("planner is None") + all_local_plans, global_metadata = planner.create_global_plan(all_local_plans) + all_local_plans = storage_writer.prepare_global_plan(all_local_plans) + return all_local_plans + + central_plan: Optional[SavePlan] = None + if use_collectives: + central_plan = distW.reduce_scatter("plan", local_step, global_step) + else: + local_plan: SavePlan = local_step() + global_plan: list[SavePlan] = global_step([local_plan]) + central_plan = global_plan[0] + + @_dcp_method_logger(**ckpt_kwargs) + def write_data(): + if planner is None: + raise AssertionError("planner is None") + if central_plan is None: + raise AssertionError("central_plan is None") + final_local_plan = planner.finish_plan(central_plan) + all_writes = storage_writer.write_data(final_local_plan, planner) + + all_writes.wait() + return all_writes.value() + + @_dcp_method_logger(**ckpt_kwargs) + def finish_checkpoint(all_results): + if global_metadata is None: + raise AssertionError("global_metadata is None") + storage_writer.finish(metadata=global_metadata, results=all_results) + return global_metadata + + if use_collectives: + metadata = distW.all_reduce("write", write_data, finish_checkpoint) + else: + write_results: list[WriteResult] = write_data() + metadata = finish_checkpoint([write_results]) + distW.barrier() + + return metadata diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/stateful.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/stateful.py new file mode 100644 index 0000000000000000000000000000000000000000..15e227d92fb5d29631b0316b3971c435120ad15b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/stateful.py @@ -0,0 +1,42 @@ +from typing import Any, TypeVar +from typing_extensions import Protocol, runtime_checkable + + +__all__ = ["Stateful", "StatefulT"] + + +@runtime_checkable +class Stateful(Protocol): + """ + Stateful protocol for objects that can be checkpointed and restored. + """ + + def state_dict(self) -> dict[str, Any]: + """ + Objects should return their state_dict representation as a dictionary. + The output of this function will be checkpointed, and later restored in + `load_state_dict()`. + + .. warning:: + Because of the inplace nature of restoring a checkpoint, this function + is also called during `torch.distributed.checkpoint.load`. + + + Returns: + Dict: The objects state dict + """ + + ... + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """ + Restore the object's state from the provided state_dict. + + Args: + state_dict: The state dict to restore from + """ + + ... + + +StatefulT = TypeVar("StatefulT", bound=Stateful) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/storage.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/storage.py new file mode 100644 index 0000000000000000000000000000000000000000..b184d7b1700528ad22bc10726cb6619975e8d9e8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/storage.py @@ -0,0 +1,288 @@ +import abc +import os +from dataclasses import dataclass +from typing import Any, Optional, Union + +from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta +from torch.distributed.checkpoint.planner import ( + LoadPlan, + LoadPlanner, + SavePlan, + SavePlanner, +) +from torch.futures import Future + + +__all__ = ["WriteResult", "StorageWriter", "StorageReader"] + + +@dataclass(frozen=True) +class WriteResult: + index: MetadataIndex + + size_in_bytes: int + storage_data: Any + + +class StorageWriter(abc.ABC): + """ + Interface used by ``save_state_dict`` to write to storage. + + One StorageWriter instance acts as both the coordinator and the follower + in a distributed checkpoint. As part of initialization, each instance + is told its role. + + A subclass should expect the following sequence of calls. + + 0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id. + 1) (all ranks) set_up_storage_writer() + 2) (all ranks) prepare_local_plan() + 3) (coordinator) prepare_global_plan() + 4) (all ranks) write_data() + 5) (coordinator) finish() + """ + + @abc.abstractmethod + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + """ + Calls to indicates a brand new checkpoint write is going to happen. + A checkpoint_id may be present if users set the checkpoint_id for + this checkpoint write. The meaning of the checkpiont_id is + storage-dependent. It can be a path to a folder/file or a key for + a key-value storage. + + Args: + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + """ + ... + + @abc.abstractmethod + def set_up_storage_writer( + self, is_coordinator: bool, *args: Any, **kwargs: Any + ) -> None: + """ + Initialize this instance. + + Args: + is_coordinator (bool): Whether this instance is responsible for coordinating + the checkpoint. + """ + + @abc.abstractmethod + def prepare_local_plan(self, plan: SavePlan) -> SavePlan: + """ + Perform storage-specific local planning. + + While this method can produce a completely different plan, the recommended + way is to store storage specific data in SavePlan::storage_data. + + Args: + plan (SavePlan): The local plan from the ``SavePlanner`` in use. + + Returns: + A transformed ``SavePlan`` after storage local planning + """ + + @abc.abstractmethod + def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]: + """ + Perform centralized planning of storage. + + This method is only called on the coordinator instance. + + While this method can produce a completely different plan, the preferred + way is to store storage specific data in SavePlan::storage_data. + + Args: + plans: A list of ``SavePlan`` instances, one for each rank. + + Returns: + A list of transformed ``SavePlan`` after storage global planning + """ + + @abc.abstractmethod + def write_data( + self, plan: SavePlan, planner: SavePlanner + ) -> Future[list[WriteResult]]: + """ + Write all items from ``plan`` using ``planner`` to resolve the data. + + A subclass should call ``SavePlanner::resolve_data`` on each item + from the plan to get access to the underlying object to write. + + Subclasses should lazily call `resolve_data` as it can allocate memory. + In case of tensors, make following assumptions: + + - They might be on any device, including not matching the one on ``WriteItem::tensor_data`` + - They might be views or not contiguous. Only the projection needs to be saved. + + Args: + plan (SavePlan): The save plan to execute. + planner (SavePlanner): Planner object to be used to resolve items to data. + + Returns: + A future that completes to a list of WriteResult + """ + + @abc.abstractmethod + def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: + """ + Write the metadata and marks the current checkpoint as successful. + + The actual format/schema used for serializing `metadata` is an + implementation detail. The only requirement is that it's recoverable + in to the same object graph. + + Args: + metadata (Metadata): metadata for the new checkpoint + results: A list of WriteResults from all ranks. + + Returns: + None + """ + + @classmethod + @abc.abstractmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + """ + Check if the given checkpoint_id is supported by the storage. This allow + us to enable automatic storage selection. + """ + ... + + def storage_meta(self) -> Optional[StorageMeta]: + """ + Return the storage-specific metadata. This is used to store additional information + in a checkpoint that can be useful for providing request-level observability. StorageMeta + is passed to the ``SavePlanner`` during save calls. Returns None by default. + + TODO: provide an example + """ + return None + + +class StorageReader(abc.ABC): + """ + Interface used by ``load_state_dict`` to read from storage. + + One StorageReader instance acts as both the coordinator and the follower + in a distributed checkpoint. As part of initialization, each instance + is told its role. + + A subclass should expected the following sequence of calls by ``load_state_dict``: + + 0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id. + 1) (all ranks) read_metadata() + 2) (all ranks) set_up_storage_reader() + 3) (all ranks) prepare_local_plan() + 4) (coordinator) prepare_global_plan() + 5) (all ranks) read_data() + """ + + @abc.abstractmethod + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + """ + Calls to indicates a brand new checkpoint read is going to happen. + A checkpoint_id may be present if users set the checkpoint_id for + this checkpoint read. The meaning of the checkpiont_id is + storage-dependent. It can be a path to a folder/file or a key for + a key-value storage. + + Args: + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is more like a key-value store. + (Default: ``None``) + """ + ... + + @abc.abstractmethod + def read_metadata(self, *args: Any, **kwargs: Any) -> Metadata: + """ + Read the checkpoint metadata. + + Returns: + The metadata object associated with the checkpoint being loaded. + + """ + + @abc.abstractmethod + def set_up_storage_reader( + self, metadata: Metadata, is_coordinator: bool, *args: Any, **kwargs: Any + ) -> None: + """ + Initialize this instance. + + Args: + metadata (Metadata): The metadata schema to use. + is_coordinator (bool): Whether this instance is responsible for coordinating + the checkpoint. + """ + + @abc.abstractmethod + def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: + """ + Perform storage-specific local planning. + + While this method can produce a completely different plan, the recommended + way is to store storage specific data in LoadPlan::storage_data. + + Args: + plan (LoadPlan): The local plan from the ``LoadPlan`` in use. + + Returns: + A transformed ``LoadPlan`` after storage local planning + """ + + @abc.abstractmethod + def prepare_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]: + """ + Perform centralized planning of storage loading. + + This method is only called on the coordinator instance. + + While this method can produce a completely different plan, the preferred + way is to store storage specific data in LoadPlan::storage_data. + + Args: + plans: A list of ``LoadPlan`` instances, one for each rank. + + Returns: + A list of transformed ``LoadPlan`` after storage global planning + """ + + @abc.abstractmethod + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + """ + Read all items from ``plan`` using ``planner`` to resolve the data. + + A subclass should call ``LoadPlanner::load_bytes`` to deserialize a BytesIO + object into the right place. + + A subclass should call ``LoadPlanner::resolve_tensor`` to get access to the + tensors that in should load data into. + + It's the StorageLayer responsibility to properly schedule any cross device copies + required. + + Args: + plan (LoadPlan): The local plan to execute on + planner (LoadPlanner): The planner object to use to resolve items. + + Returns: + A future that completes once all reads are finished. + """ + + @classmethod + @abc.abstractmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + """ + Check if the given checkpoint_id is supported by the storage. This allow + us to enable automatic storage selection. + """ + ... diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..073649c5f124d1817af12d161d8a80b76ae3ceda --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/checkpoint/utils.py @@ -0,0 +1,485 @@ +# mypy: allow-untyped-defs +import cProfile +import inspect +import io +import itertools +import os +import warnings +from collections.abc import Callable, Sequence +from contextlib import contextmanager +from functools import wraps +from pstats import Stats +from typing import Any, cast, Optional, TypeVar, Union + +import torch +import torch.distributed as dist +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed._shard.sharded_tensor.shard import Shard + +from .api import ( + _is_wrapped_exception, + _wrap_exception, + CheckpointException, + WRAPPED_EXCEPTION, +) +from .metadata import MetadataIndex, STATE_DICT_TYPE + + +__all__ = ["find_tensor_shard", "find_state_dict_object"] + +T = TypeVar("T") +R = TypeVar("R") + + +def _get_failure_dict( + results: list[Union[T, WRAPPED_EXCEPTION]], +) -> dict[int, WRAPPED_EXCEPTION]: + return cast( + dict[int, WRAPPED_EXCEPTION], + {i: err for i, err in enumerate(results) if _is_wrapped_exception(err)}, + ) + + +def _all_gather_keys( + local_dict: dict[str, Any], group: Optional[dist.ProcessGroup] = None +) -> set[str]: + """Gathers all keys, and returns them sorted.""" + keys = list(local_dict.keys()) + gathered_keys: list[list[str]] = [None] * dist.get_world_size(group) # type: ignore[list-item] + + dist.all_gather_object(gathered_keys, keys, group=group) + return set(itertools.chain.from_iterable(gathered_keys)) + + +def _assert_same_keys( + state_dict: dict[str, Any], process_group: Optional[dist.ProcessGroup] = None +) -> None: + """ + Asserts that all ranks have the same keys in their state dict. + This is a collective call which requires all ranks in ``process_group`` to + join. It will also induce cross-rank communication and block CPU. + """ + + if dist.get_world_size(process_group) == 1: + return + + all_keys = _all_gather_keys(state_dict, process_group) + my_keys = set(state_dict.keys()) + diff = all_keys - my_keys + if len(diff) > 0: + raise AssertionError( + f"Key(s) present in other ranks but not this one, difference: {diff}" + ) + + +class _DistWrapper: + """ + This is a wrapper around PG that provides a series of features around object collectives. + + It works without distributed initialized, where most collectives turns into nops. + + All variants that take functions are exception robust, meaning that if one or more + ranks raise errors, all ranks will observe those. + """ + + def __init__( + self, + group: Optional[dist.ProcessGroup], + use_dist: bool, + coordinator_rank: int, + ): + self.group = group + self.use_dist = use_dist + self.coordinator_rank = coordinator_rank + if self.use_dist: + self.global_coordinator_rank = ( + dist.get_global_rank(group, coordinator_rank) + if group is not None + else coordinator_rank + ) + self.rank = dist.get_rank(group) + self.is_coordinator = self.rank == coordinator_rank + else: + self.global_coordinator_rank = 0 + self.rank = 0 + self.is_coordinator = True + + def get_rank(self) -> int: + return self.rank + + def get_world_size(self) -> int: + if self.use_dist: + return dist.get_world_size(self.group) + return 1 + + def broadcast_object(self, object: Optional[T]) -> T: + """Implement functionality similar to c10d::broadcast_object_list but without distributed enabled.""" + object_list = [object] + if self.use_dist: + dist.broadcast_object_list( + object_list=object_list, + group=self.group, + src=self.global_coordinator_rank, + ) + return cast(T, object_list[0]) + + def gather_object(self, object: T) -> Optional[list[T]]: + """Implement functionality similar to c10d::gather_object but without distributed enabled.""" + if self.use_dist: + gather_objs = ( + cast(list[T], [None] * dist.get_world_size(self.group)) + if self.is_coordinator + else None + ) + + dist.gather_object( + obj=object, + object_gather_list=gather_objs if self.is_coordinator else None, + dst=self.global_coordinator_rank, + group=self.group, + ) + result = gather_objs + else: + result = [object] + return result + + def all_gather_object(self, object: T) -> list[T]: + """Implement functionality similar to c10d::all_gather_object but without distributed enabled.""" + if self.use_dist: + gather_objs = cast(list[T], [None] * dist.get_world_size(self.group)) + + dist.all_gather_object( + object_list=gather_objs, obj=object, group=self.group + ) + else: + gather_objs = [object] + return gather_objs + + def scatter_object(self, object_list: Optional[list[T]]) -> T: + """Implement functionality similar to c10d::scatter_object but without distributed enabled.""" + if self.use_dist: + gather_result = cast(list[T], [None]) + dist.scatter_object_list( + scatter_object_output_list=gather_result, + scatter_object_input_list=object_list if self.is_coordinator else None, + src=self.global_coordinator_rank, + group=self.group, + ) + + local_reply = gather_result[0] + else: + if object_list is None: + raise AssertionError("object_list is None") + local_reply = object_list[0] + return local_reply + + def reduce_scatter( + self, + step: str, + map_fun: Callable[[], T], + reduce_fun: Callable[[list[T]], list[R]], + ) -> R: + """ + Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter. + + This method operates in the following way: + Run ``map_fun`` on all ranks + Gather results on rank 0 + Call ``reduce_fun`` on all those values + Scatter to each rank part of the result. + """ + local_data: Union[WRAPPED_EXCEPTION, T] + try: + local_data = map_fun() + except BaseException as e: # noqa: B036 + local_data = _wrap_exception(e) + + all_data = self.gather_object(local_data) + all_results: Optional[list[Union[R, CheckpointException]]] = None + if self.is_coordinator: + if all_data is None: + raise AssertionError("all_data is None") + node_failures = _get_failure_dict(all_data) + + if len(node_failures) == 0: + try: + # N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]? + all_results = cast( + list[Union[R, CheckpointException]], + reduce_fun(cast(list[T], all_data)), + ) + except BaseException as e: # noqa: B036 + node_failures[self.rank] = _wrap_exception(e) + + if len(node_failures) > 0: + all_results = [ + CheckpointException(step, node_failures) + ] * self.get_world_size() + + result = self.scatter_object(all_results) + if isinstance(result, CheckpointException): + raise result + return result + + def all_reduce( + self, + step: str, + map_fun: Callable[[], T], + reduce_fun: Callable[[list[T]], R], + ) -> R: + """ + Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast. + + This method operates in the following way: + Run ``map_fun`` on all ranks + Gather results on rank 0 + Call ``reduce_fun`` on all those values + Broadcast the reduced value to all ranks. + """ + local_data: Union[T, WRAPPED_EXCEPTION] + try: + local_data = map_fun() + except BaseException as e: # noqa: B036 + local_data = _wrap_exception(e) + + all_data = self.gather_object(local_data) + result: Optional[Union[R, CheckpointException]] = None + if self.is_coordinator: + if all_data is None: + raise AssertionError("all_data is None") + node_failures = _get_failure_dict(all_data) + if len(node_failures) == 0: + try: + result = reduce_fun(cast(list[T], all_data)) + except BaseException as e: # noqa: B036 + node_failures[self.rank] = _wrap_exception(e) + + if len(node_failures) > 0: + result = CheckpointException(step, node_failures) + + # pyrefly: ignore [bad-argument-type] + final_result = self.broadcast_object(result) + if isinstance(final_result, CheckpointException): + raise final_result + return cast(R, final_result) + + def all_gather( + self, + step: str, + map_fun: Callable[[], T], + ) -> list[T]: + """ + Compute a value on each rank, then all_gather them. + + This method operates in the following way: + Run ``map_cp`` on all ranks + all_gather the values to all ranks + """ + result: Union[T, WRAPPED_EXCEPTION] + try: + result = map_fun() + except BaseException as e: # noqa: B036 + result = _wrap_exception(e) + + all_results = self.all_gather_object(result) + + node_failures = _get_failure_dict(all_results) + if len(node_failures) > 0: + raise CheckpointException(step, node_failures) + return cast(list[T], all_results) + + def broadcast( + self, + step: str, + map_fun: Callable[[], T], + ) -> T: + """ + Compute a value on rank 0 and broadcast it. + + This method operates in the following way: + Run ``map_cp`` on rank 0 + broadcast the value + """ + result: Optional[Union[T, CheckpointException]] = None + if self.is_coordinator: + try: + result = map_fun() + except BaseException as e: # noqa: B036 + result = CheckpointException(step, {self.rank: _wrap_exception(e)}) + # pyrefly: ignore [bad-argument-type] + final_result = self.broadcast_object(result) + if isinstance(final_result, CheckpointException): + raise final_result + return cast(T, final_result) + + def barrier(self) -> None: + """ + Add a synchronization point across all processes when using distributed. + If torch.distributed is initialized, this function will invoke a barrier across the global process group. + If torch.distributed is not initialized, this function is a no-op. + """ + if not self.use_dist: + return + dist.barrier(group=self.group) + + +def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard: + if index.offset is None: + raise ValueError( + f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided" + ) + + shards = tensor.local_shards() + # index fast path + if index.index is not None: + if ( + len(shards) > index.index + and torch.Size(shards[index.index].metadata.shard_offsets) == index.offset + ): + return shards[index.index] + + for shard in shards: + if torch.Size(shard.metadata.shard_offsets) == index.offset: + return shard + raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'") + + +def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor: + if hasattr(tensor, "__get_tensor_shard__"): + # DTensor implements _Checkpointable + return tensor.__get_tensor_shard__(index) # type: ignore[attr-defined] + if isinstance(tensor, ShardedTensor): + return _find_shard(tensor, index).tensor + if index.offset is not None: + # special case looking up a tensor by origin + if index.offset == torch.Size([0] * len(tensor.size())): + return tensor + raise ValueError( + f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'" + ) + return tensor + + +def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any: + if index.fqn not in state_dict: + raise ValueError(f"Could not find FQN: '{index.fqn}'") + obj = state_dict[index.fqn] + + if isinstance(obj, torch.Tensor): + return find_tensor_shard(obj, index) + elif index.offset is not None: + raise ValueError( + f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'" + ) + return obj + + +def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> list[int]: + return [i_a + i_b for i_a, i_b in zip(a, b)] + + +def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> list[int]: + return [i_a - i_b for i_a, i_b in zip(a, b)] + + +class _ReaderView(io.IOBase): + def __init__(self, base_stream: io.IOBase, offset: int, len: int): + super().__init__() + self.offset = offset + self.len = len + self.base_stream = base_stream + self.seek(0) + + def seek(self, offset: int, whence: int = os.SEEK_SET, /) -> int: + if whence == os.SEEK_SET: + offset = self.offset + offset + elif whence == os.SEEK_END: + whence = os.SEEK_SET + offset = (self.offset + self.len) - offset + return self.base_stream.seek(offset, whence) + + def tell(self) -> int: + return self.base_stream.tell() - self.offset + + def readable(self) -> bool: + return self.base_stream.readable() + + def seekable(self) -> bool: + return self.base_stream.seekable() + + def readinto(self, b): + max_size = self.len - self.tell() + if max_size == 0: + return 0 + if len(b) > max_size: + b = memoryview(b)[:max_size] + return self.base_stream.readinto(b) # type: ignore[attr-defined] + + def read(self, size=-1): + max_size = self.len - self.tell() + if size == -1 or size > max_size: + size = max_size + return self.base_stream.read(size) + + +def _create_file_view(file: io.IOBase, offset: int, length: int) -> io.IOBase: + # FIXME (kumpera) torch.load fails if we wrap with io.BufferedReader + return _ReaderView(file, offset, length) + + +def _normalize_device_info(device_type: str, device_id: int) -> str: + """Device info normalization.""" + if device_type == "cpu": + return "cpu" + return f"{device_type}:{device_id}" + + +# TODO: integrate with distributed logging flag +ENABLE_PROFILE = False + + +@contextmanager +def _profile(): + # Only log the profiling when it is enable and is on rank0 or dist is not + # available. + if ENABLE_PROFILE and (not dist.is_available() or dist.get_rank() == 0): + profiler = cProfile.Profile() + profiler.enable() + try: + yield + finally: + profiler.disable() + stats = Stats(profiler) + stats.sort_stats("time").print_stats(10) + else: + yield + + +def _api_bc_check(func): + @wraps(func) + def inner_func(*args, **kwargs) -> Any: + if len(args) == 2: + warnings.warn( + f"The argument order of {func.__name__} has been changed. " + "Please check the document to avoid future breakages.", + stacklevel=2, + ) + sig = inspect.signature(func) + kwonlyargs = [ + p.name for p in sig.parameters.values() if p.kind == p.KEYWORD_ONLY + ] + if "storage_writer" in kwonlyargs: + if "storage_writer" in kwargs: + raise AssertionError(f"storage_writer in kwargs: {(args, kwargs)}") + kwargs["storage_writer"] = args[1] + elif "storage_reader" in kwonlyargs: + if "storage_reader" in kwargs: + raise AssertionError(f"storage_reader in kwargs: {(args, kwargs)}") + kwargs["storage_reader"] = args[1] + else: + raise RuntimeError(f"Unexpected kwonlyargs = {kwonlyargs}") + return func(args[0], **kwargs) + else: + return func(*args, **kwargs) + + return inner_func diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..93295802ae847cb939954e8c8918dfd2ce49cf4f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/__init__.py @@ -0,0 +1,88 @@ +import logging +import multiprocessing +import socket + +# import for registration side effect +import torch.distributed.debug._handlers # noqa: F401 +from torch._C._distributed_c10d import _WorkerServer +from torch.distributed.debug._store import get_rank, tcpstore_client + + +__all__ = [ + "start_debug_server", + "stop_debug_server", +] + +logger: logging.Logger = logging.getLogger(__name__) + +_WORKER_SERVER: _WorkerServer | None = None +_DEBUG_SERVER_PROC: multiprocessing.Process | None = None + + +def start_debug_server(port: int = 25999, worker_port: int = 0) -> None: + """ + Start the debug server stack on all workers. The frontend debug server is + only started on rank0 while the per rank worker servers are started on all + ranks. + + This server provides an HTTP frontend that allows for debugging slow and + deadlocked distributed jobs across all ranks simultaneously. This collects + data such as stack traces, FlightRecorder events, and performance profiles. + + This depends on dependencies which are not installed by default. + + Dependencies: + - Jinja2 + - aiohttp + + WARNING: This is intended to only be used in trusted network environments. + The debug server is not designed to be secure and should not be exposed to + the public internet. See SECURITY.md for more details. + + WARNING: This is an experimental feature and may change at any time. + + Args: + port (int): The port to start the frontend debug server on. + worker_port (int): The port to start the worker server on. Defaults to 0, which + will cause the worker server to bind to an ephemeral port. + """ + global _WORKER_SERVER, _DEBUG_SERVER_PROC + + assert _WORKER_SERVER is None, "debug server already started" + assert _DEBUG_SERVER_PROC is None, "debug server already started" + + logger.info("Starting debug server on port %d", port) + + store = tcpstore_client() + + _WORKER_SERVER = _WorkerServer("::", worker_port) + + RANK = get_rank() + store.set(f"rank{RANK}", f"http://{socket.gethostname()}:{_WORKER_SERVER.port}") + + from torch.distributed.debug._frontend import main + + if RANK == 0: + _DEBUG_SERVER_PROC = multiprocessing.Process( + target=main, args=(port,), daemon=True + ) + _DEBUG_SERVER_PROC.start() + + +def stop_debug_server() -> None: + """ + Shutdown the debug server and stop the frontend debug server process. + """ + global _WORKER_SERVER, _DEBUG_SERVER_PROC + + assert _DEBUG_SERVER_PROC is not None + assert _WORKER_SERVER is not None + + logger.info("Stopping debug server") + + _DEBUG_SERVER_PROC.terminate() + _WORKER_SERVER.shutdown() + _DEBUG_SERVER_PROC.join() + + _WORKER_SERVER = None + _DEBUG_SERVER_PROC = None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e48339a757075ded6523816c48893880ebd4d769 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/_frontend.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/_frontend.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa71e86edfdf8242e1f914f305df95c19337c565 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/_frontend.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/_handlers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/_handlers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..413299dc72c87f218fb2f98104a90965c1fa383c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/_handlers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/_store.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/_store.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..283102a0dc4d6351266cda1e1401984f1a59fe4a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/_store.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/_frontend.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/_frontend.py new file mode 100644 index 0000000000000000000000000000000000000000..16cccb88632f0372bac132a01cc8b97f60223852 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/_frontend.py @@ -0,0 +1,553 @@ +import asyncio +import json +import logging +import socket +import threading +from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from urllib.parse import parse_qs, urlparse + +from jinja2 import DictLoader, Environment +from tabulate import tabulate + +from torch.distributed.debug._store import get_world_size, tcpstore_client +from torch.distributed.flight_recorder.components.builder import build_db +from torch.distributed.flight_recorder.components.config_manager import JobConfig +from torch.distributed.flight_recorder.components.types import ( + Collective, + Group, + Membership, + NCCLCall, +) + + +logger: logging.Logger = logging.getLogger(__name__) + + +@dataclass(slots=True) +class Response: + status_code: int + text: str + + def raise_for_status(self): + if self.status_code != 200: + raise RuntimeError(f"HTTP {self.status_code}: {self.text}") + + def json(self): + return json.loads(self.text) + + +def fetch_thread_pool(urls: list[str]) -> Iterable[Response]: + # late import for optional dependency + import requests + + max_workers = 20 + + def get(url: str) -> Response: + resp = requests.post(url) + return Response(resp.status_code, resp.text) + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + resps = executor.map(get, urls) + + return resps + + +def fetch_aiohttp(urls: list[str]) -> Iterable[Response]: + # late import for optional dependency + import aiohttp + + async def fetch(session: aiohttp.ClientSession, url: str) -> Response: + async with session.post(url) as resp: + text = await resp.text() + return Response(resp.status, text) + + async def gather(urls: list[str]) -> Iterable[Response]: + async with aiohttp.ClientSession() as session: + return await asyncio.gather(*[fetch(session, url) for url in urls]) + + return asyncio.run(gather(urls)) + + +def fetch_all(endpoint: str, args: str = "") -> tuple[list[str], Iterable[Response]]: + store = tcpstore_client() + keys = [f"rank{r}" for r in range(get_world_size())] + addrs = store.multi_get(keys) + addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs] + + try: + resps = fetch_aiohttp(addrs) + except ImportError: + resps = fetch_thread_pool(addrs) + + return addrs, resps + + +def format_json(blob: str): + parsed = json.loads(blob) + return json.dumps(parsed, indent=2) + + +templates = { + "base.html": """ + + + {% block title %}{% endblock %} - PyTorch Distributed + + + + + + + +
+ {% block header %}{% endblock %} + {% block content %}{% endblock %} +
+ """, + "index.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}Index{% endblock %}

+{% endblock %} +{% block content %} +Hi +{% endblock %} + """, + "raw_resp.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}{{title}}{% endblock %}

+{% endblock %} +{% block content %} + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} +
{{ resp.text }}
+ {% endif %} + {% endfor %} +{% endblock %} + """, + "json_resp.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}{{ title }}{% endblock %}

+{% endblock %} +{% block content %} + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} +
{{ format_json(resp.text) }}
+ {% endif %} + {% endfor %} +{% endblock %} + """, + "profile.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}torch.profiler{% endblock %}

+{% endblock %} + +{% block content %} +
+ + + +
+ + + + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} + + + + {% endif %} + {% endfor %} +{% endblock %} + """, + "tcpstore.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}TCPStore Keys{% endblock %}

+{% endblock %} +{% block content %} +
+    {% for k, v in zip(keys, values) -%}
+{{ k }}: {{ v | truncate(100) }}
+    {% endfor %}
+    
+{% endblock %} + """, + "fr_trace.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}{{ title }}{% endblock %}

+{% endblock %} +{% block content %} +

Groups

+ {{ groups | safe }} +

Memberships

+ {{ memberships | safe }} +

Collectives

+ {{ collectives | safe }} +

NCCL Calls

+ {{ ncclcalls | safe }} +{% endblock %} + """, + "pyspy_dump.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}py-spy Stack Traces{% endblock %}

+{% endblock %} +{% block content %} +
+ + + + + +
+ + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} +
{{ resp.text }}
+ {% endif %} + {% endfor %} +{% endblock %} + """, +} + + +class _IPv6HTTPServer(ThreadingHTTPServer): + address_family: socket.AddressFamily = socket.AF_INET6 # pyre-ignore + request_queue_size: int = 1024 + + +class HTTPRequestHandler(BaseHTTPRequestHandler): + frontend: "FrontendServer" + + def log_message(self, format, *args): + logger.info( + "%s %s", + self.client_address[0], + format % args, + ) + + def do_GET(self): + self.frontend._handle_request(self) + + def get_path(self) -> str: + return urlparse(self.path).path + + def get_query(self) -> dict[str, list[str]]: + return parse_qs(self.get_raw_query()) + + def get_raw_query(self) -> str: + return urlparse(self.path).query + + def get_query_arg( + self, name: str, default: object = None, type: type = str + ) -> object: + query = self.get_query() + if name not in query: + return default + return type(query[name][0]) + + +class FrontendServer: + def __init__(self, port: int): + # Setup templates + loader = DictLoader(templates) + self._jinja_env = Environment(loader=loader, enable_async=True) + self._jinja_env.globals.update( + zip=zip, + format_json=format_json, + enumerate=enumerate, + ) + + # Create routes + self._routes = { + "/": self._handle_index, + "/stacks": self._handle_stacks, + "/pyspy_dump": self._handle_pyspy_dump, + "/fr_trace": self._handle_fr_trace, + "/fr_trace_json": self._handle_fr_trace_json, + "/fr_trace_nccl": self._handle_fr_trace_nccl, + "/fr_trace_nccl_json": self._handle_fr_trace_nccl_json, + "/profile": self._handle_profiler, + "/wait_counters": self._handle_wait_counters, + "/tcpstore": self._handle_tcpstore, + } + + # Create HTTP server + RequestHandlerClass = type( + "HTTPRequestHandler", + (HTTPRequestHandler,), + {"frontend": self}, + ) + + server_address = ("", port) + self._server = _IPv6HTTPServer(server_address, RequestHandlerClass) + + self._thread = threading.Thread( + target=self._serve, + args=(), + daemon=True, + name="distributed.debug.FrontendServer", + ) + self._thread.start() + + def _serve(self) -> None: + try: + self._server.serve_forever() + except Exception: + logger.exception("got exception in frontend server") + + def join(self) -> None: + self._thread.join() + + def _handle_request(self, req: HTTPRequestHandler) -> None: + path = req.get_path() + if path not in self._routes: + req.send_error(404, f"Handler not found: {path}") + return + + handler = self._routes[path] + try: + resp = handler(req) + # Catch SystemExit to not crash when FlightRecorder errors. + except (Exception, SystemExit) as e: + logger.exception( + "Exception in frontend server when handling %s", + path, + ) + req.send_error(500, f"Exception: {repr(e)}") + return + + req.send_response(200) + req.send_header("Content-type", "text/html") + req.end_headers() + req.wfile.write(resp) + + def _render_template(self, template: str, **kwargs: object) -> bytes: + return self._jinja_env.get_template(template).render(**kwargs).encode() + + def _handle_index(self, req: HTTPRequestHandler) -> bytes: + return self._render_template("index.html") + + def _handle_stacks(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("dump_traceback") + return self._render_template( + "raw_resp.html", title="Stacks", addrs=addrs, resps=resps + ) + + def _handle_pyspy_dump(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("pyspy_dump", req.get_raw_query()) + return self._render_template( + "pyspy_dump.html", + addrs=addrs, + resps=resps, + ) + + def _render_fr_trace(self, addrs: list[str], resps: list[Response]) -> bytes: + config = JobConfig() + # pyrefly: ignore [bad-assignment] + args = config.parse_args(args=[]) + args.allow_incomplete_ranks = True + args.verbose = True + + details = {} + for rank, resp in enumerate(resps): + resp.raise_for_status() + dump = { + "rank": rank, + "host_name": addrs[rank], + **resp.json(), + } + if "entries" not in dump: + dump["entries"] = [] + details[f"rank{rank}.json"] = dump + + version = next(iter(details.values()))["version"] + + db = build_db(details, args, version) + + return self._render_template( + "fr_trace.html", + title="FlightRecorder", + groups=tabulate(db.groups, headers=Group._fields, tablefmt="html"), + memberships=tabulate( + db.memberships, headers=Membership._fields, tablefmt="html" + ), + collectives=tabulate( + db.collectives, headers=Collective._fields, tablefmt="html" + ), + ncclcalls=tabulate(db.ncclcalls, headers=NCCLCall._fields, tablefmt="html"), + ) + + def _handle_fr_trace(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("fr_trace_json") + + return self._render_fr_trace(addrs, list(resps)) + + def _handle_fr_trace_json(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("fr_trace_json") + + return self._render_template( + "json_resp.html", + title="FlightRecorder", + addrs=addrs, + resps=resps, + ) + + def _handle_fr_trace_nccl(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true") + + return self._render_fr_trace(addrs, list(resps)) + + def _handle_fr_trace_nccl_json(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true") + + return self._render_template( + "json_resp.html", + title="FlightRecorder NCCL", + addrs=addrs, + resps=resps, + ) + + def _handle_profiler(self, req: HTTPRequestHandler) -> bytes: + duration = req.get_query_arg("duration", default=1.0, type=float) + + addrs, resps = fetch_all("torch_profile", f"duration={duration}") + + return self._render_template("profile.html", addrs=addrs, resps=resps) + + def _handle_wait_counters(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("wait_counter_values") + return self._render_template( + "json_resp.html", title="Wait Counters", addrs=addrs, resps=resps + ) + + def _handle_tcpstore(self, req: HTTPRequestHandler) -> bytes: + store = tcpstore_client(prefix="") + keys = store.list_keys() + keys.sort() + values = [repr(v) for v in store.multi_get(keys)] + return self._render_template("tcpstore.html", keys=keys, values=values) + + +def main(port: int) -> None: + logger.setLevel(logging.INFO) + + server = FrontendServer(port=port) + logger.info("Frontend server started on port %d", server._server.server_port) + server.join() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/_handlers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..b8095c5b34bea5d2408ed87b21b541bf8966f4ad --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/_handlers.py @@ -0,0 +1,23 @@ +import pathlib +import tempfile +import time + +from torch._C._distributed_c10d import _register_handler, _Request, _Response +from torch.profiler import _ExperimentalConfig, profile + + +def _torch_profile(req: _Request, resp: _Response) -> None: + experimental_config = _ExperimentalConfig( + profile_all_threads=True, + ) + duration = float(req.get_param("duration")) + with profile(record_shapes=True, experimental_config=experimental_config) as prof: + time.sleep(duration) + + with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f: + prof.export_chrome_trace(f.name) + resp.set_content(pathlib.Path(f.name).read_bytes(), "application/json") + resp.set_status(200) + + +_register_handler("torch_profile", _torch_profile) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/_store.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/_store.py new file mode 100644 index 0000000000000000000000000000000000000000..487dd30abd6aff96d676ee3cf10d98490613e2a1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/debug/_store.py @@ -0,0 +1,25 @@ +import os + +import torch.distributed as dist + + +def get_rank() -> int: + return int(os.environ["RANK"]) + + +def get_world_size() -> int: + return int(os.environ["WORLD_SIZE"]) + + +def tcpstore_client(prefix: str = "debug_server") -> dist.Store: + MASTER_ADDR = os.environ["MASTER_ADDR"] + MASTER_PORT = int(os.environ["MASTER_PORT"]) + + store = dist.TCPStore( + host_name=MASTER_ADDR, + port=MASTER_PORT, + is_master=False, + ) + if prefix: + store = dist.PrefixStore(prefix, store) + return store diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a7c9b29a750593a812907ce2cf4c800d7d1435bb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/__init__.py @@ -0,0 +1,77 @@ +#!/usr/bin/env/python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" + +Torchelastic agent and user worker failover contract: + +**TL;DR;**: + +* TE(torchelastic) expects user workers to finish with the 5 minutes drift +* It is better to design DDP app to fail for all workers, rather than a single one. +* TE does not synchronize number of restarts between agents +* TE re-rendezvous does not trigger restart decrease +* When a single agent finishes its job(successfully or not), it will close rendezvous. + If other agents still have workers in progress, they will be terminated. +* Based on above, scale down does not work if at least single agent finishes the job. +* When Scale up is detected by agents, it will not decrease ``max_restarts`` + + +In general TE(torchelastic) can launch arbitrary user code, but there is some +clarifications need to be done around what failover mechanism torchelastic +provides and what failover mechanism it expects from user workers. + +Torchelastic currently supports DDP style applications. That means that +TE expects *ALL* workers finish approximately at the same time. In practice, +it is nearly to impossible to guarantee that all workers in arbitrary +DDP application finish at the time, so TE provides a finalization barrier +that waits for TIMEOUT(5 minutes) for worker finalization. + +**Worker Failure** + +When worker fails, TE will check the number of restarts +available, if there is more than 0 restarts, TE will start a new rendezvous +round and restart the worker process. New rendezvous round will other +TE agents to terminate their workers. + +.. note:: The TE agent does not synchronize restarts between themselves. + When a single agent performs restart, it will trigger a local ``max_restarts`` + decrease, other agent will not decrease their ``max_restarts``. + the user to run the distributed application locally on a dev host. + +A single worker failure can cause the whole cluster to fail: +If a single worker is constantly failing, it will cause the TE agent +``max_restarts`` to go to zero. This will cause an agent to finish its +work and close rendezvous. If there are any other workers on different +agents, they will be terminated. + + +**Re-Rendezvous** + +Re-rendezvous occurs when TE agents detect a new node +trying to joint a cluster. TE will not decrease ``max_restarts``. TE agents +will terminate its workers and start a new rendezvous round. + +Note about DynamicRendezvous(etcd-v2, c10d-experimental): If the rendezvous +has already max_nodes, the new node won't be added to the wait list right +away since there is no need to tear down a rendezvous that is already fully +utilized. The new node will wait until its timeout (600 secs by default) +and periodically check the number of participants. If the number becomes +less than max_nodes, it will be added to the wait list; otherwise, it will time out after 600 secs. + +*Scale up event*. When scale up event happens, torchelastic rendezvous +will detect that there are new nodes trying to join. Torchelastic agent +will stop all workers and perform re-rendezvous. Note: when scale up event +happens, *``max_restarts``* will *not* decrease. + +*Scale down event*. When scale down event happens, rendezvous will not +notify the torchelastic agent about it. If TE agent launched with ``max_restarts=0`` , +it relies on the underlying scheduler to handle job restart. If the ``max_restarts>0`` , +TE agent will terminate workers and start a new rdzv round, which is a *Scale up event*. + +""" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cb5d2933493bf948bb222304c0641993f8abc56 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/__pycache__/control_plane.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/__pycache__/control_plane.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b30e1eafcb31723395e0006875a8db575012826 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/__pycache__/control_plane.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b138f0e753df501d8aeb2d672de27f8c495e8ef8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c0d76131fe40d70945ffa8ff97431954151d50e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__init__.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +The elastic agent is the control plane of torchelastic. + +It is a process that launches and manages underlying worker processes. +The agent is responsible for: + +1. Working with distributed torch: the workers are started with all the + necessary information to successfully and trivially call + ``torch.distributed.init_process_group()``. + +2. Fault tolerance: monitors workers and upon detecting worker failures + or unhealthiness, tears down all workers and restarts everyone. + +3. Elasticity: Reacts to membership changes and restarts workers with the new + members. + +The simplest agents are deployed per node and works with local processes. +A more advanced agent can launch and manage workers remotely. Agents can +be completely decentralized, making decisions based on the workers it manages. +Or can be coordinated, communicating to other agents (that manage workers +in the same job) to make a collective decision. +""" + +from .api import ( # noqa: F401 + ElasticAgent, + RunResult, + SimpleElasticAgent, + Worker, + WorkerGroup, + WorkerSpec, + WorkerState, +) +from .local_elastic_agent import TORCHELASTIC_ENABLE_FILE_TIMER, TORCHELASTIC_TIMER_FILE diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f72fe2022ad63a0349bd37fccecece55381952b6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/api.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e3b746f6b6fa76289222f4ed959d118b7797cc7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/api.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/health_check_server.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/health_check_server.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fd7f573914c7429905453e89a57970407fd1cb6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/health_check_server.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/local_elastic_agent.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/local_elastic_agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eec7f4052c863ec5d69aecfefe1224029dbf57b1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/local_elastic_agent.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py new file mode 100644 index 0000000000000000000000000000000000000000..2575aa137a58128213173dde681c313bb24fc5a2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py @@ -0,0 +1,995 @@ +# mypy: ignore-errors + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import abc +import json +import os +import signal +import socket +import time +import traceback +import warnings +from collections import defaultdict +from collections.abc import Callable +from contextlib import contextmanager +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +import torch.distributed.elastic.rendezvous as rdzv +import torch.distributed.elastic.utils.store as store_util +from torch.distributed.elastic.events import Event, EventSource, record +from torch.distributed.elastic.metrics import prof, put_metric +from torch.distributed.elastic.multiprocessing import ProcessFailure, SignalException +from torch.distributed.elastic.rendezvous import RendezvousGracefulExitError +from torch.distributed.elastic.utils.logging import get_logger +from torch.numa.binding import NumaOptions + + +__all__ = [ + "WorkerSpec", + "Worker", + "WorkerState", + "WorkerGroup", + "RunResult", + "ElasticAgent", + "SimpleElasticAgent", +] +_TERMINAL_STATE_SYNC_ID = "torchelastic/agent/terminal_state" + +DEFAULT_ROLE = "default" +logger = get_logger(__name__) + + +@dataclass +class WorkerSpec: + """ + Blueprint information about a particular type of worker. + + For a given role, there must only exist a single worker spec. + Worker spec is expected to be homogeneous across all nodes (machine), + that is each node runs the same number of workers for a particular spec. + + Args: + role: user-defined role for the workers with this spec + local_world_size: number local workers to run + fn: (deprecated use entrypoint instead) + entrypoint: worker function or command + args: arguments to pass to ``entrypoint`` + rdzv_handler: handles rdzv for this set of workers + max_restarts: number of max retries for the workers + monitor_interval: monitor status of workers every ``n`` seconds + master_port: fixed port to run the c10d store on rank 0 + if not specified then will chose a random free port + master_addr: fixed master_addr to run the c10d store on rank 0 + if not specified then will chose hostname on agent rank 0 + redirects: redirect std streams to a file, + selectively redirect for a particular + local rank by passing a map + tee: tees the specified std stream(s) to console + file, + selectively tee for a particular local rank by passing a map, + takes precedence over ``redirects`` settings. + event_log_handler: name of the event logging handler as registered in + `elastic/events/handlers.py `_. + duplicate_stdout_filters: If non-empty, duplicates stdout to a file containing only lines + that match _any_ of the filter strings. + duplicate_stderr_filters: If non-empty, duplicates stderr to a file containing only lines + that match _any_ of the filter strings. + virtual_local_rank: Enable virtual local rank mode for workers (defaults to False). + When enabled, LOCAL_RANK is set to 0 for all workers and + CUDA_VISIBLE_DEVICES is adjusted so each worker accesses its + assigned GPU at device index 0. + """ + + role: str + local_world_size: int + rdzv_handler: rdzv.RendezvousHandler + fn: Callable | None = None + # TODO @kiuk - make entrypoint a required field + entrypoint: Callable | str | None = None + args: tuple = () + max_restarts: int = 3 + monitor_interval: float = 0.1 + master_port: int | None = None + master_addr: str | None = None + local_addr: str | None = None + event_log_handler: str = "null" + numa_options: NumaOptions | None = None + duplicate_stdout_filters: list[str] | None = None + duplicate_stderr_filters: list[str] | None = None + virtual_local_rank: bool = False + + def __post_init__(self): + assert self.local_world_size > 0 + assert self.monitor_interval > 0 + + if self.fn: + warnings.warn( + "WorkerSpec.fn will be deprecated," + " please use WorkerSpec.entrypoint instead", + stacklevel=2, + category=DeprecationWarning, + ) + self.entrypoint = self.fn + assert self.entrypoint + + def get_entrypoint_name(self): + """Get the entry point name. + + If the entrypoint is a function (e.g. ``Callable``) returns its ``__qualname__`` + else if the entrypoint is a binary (e.g. ``str``), returns the binary name. + """ + if isinstance(self.entrypoint, str): + return os.path.basename(self.entrypoint) + else: + assert self.entrypoint is not None + return self.entrypoint.__qualname__ + + +class Worker: + """A worker instance. + + Contrast this with ``WorkerSpec`` that represents the specifications of a + worker. A ``Worker`` is created from a ``WorkerSpec``. A ``Worker`` is to + a ``WorkerSpec`` as an object is to a class. + + The ``id`` of the worker is interpreted + by the specific implementation of ``ElasticAgent``. For a local + agent, it could be the ``pid (int)`` of the worker, for a remote + agent it could be encoded as ``host:port (string)``. + + Args: + id (Any): uniquely identifies a worker (interpreted by the agent) + local_rank (int): local rank of the worker + global_rank (int): global rank of the worker + role_rank (int): rank of the worker across all workers that have the same role + world_size (int): number of workers (globally) + role_world_size (int): number of workers that have the same role + """ + + __slots__ = [ + "id", + "local_rank", + "global_rank", + "role_rank", + "world_size", + "role_world_size", + ] + + def __init__( + self, + local_rank: int, + global_rank: int = -1, + role_rank: int = -1, + world_size: int = -1, + role_world_size: int = -1, + ): + # unique identifier for this worker + self.id: Any = None + + # rank of the worker among workers with the same role being monitored + # by the same ``agent`` instance. + self.local_rank: int = local_rank + + # rank of the worker among all the workers across all roles + # across all ``agent`` instances. + # Global rank is not stable between re-rendezvous. + self.global_rank: int = global_rank + + # rank of the worker among all the workers with the same role + # across all ``agent`` instances. + # Role rank is not stable between re-rendezvous. + self.role_rank: int = role_rank + + # total number of workers (globally). Due to elasticity + # the world size may change between re-rendezvous. + self.world_size: int = world_size + + # total number of workers that share the same role. Due to elasticity + # the role world size may change between re-rendezvous. + self.role_world_size: int = role_world_size + + def __str__(self): + return ( + f"local_rank={self.local_rank},global_rank={self.global_rank}" + f",role_rank={self.role_rank},world_size={self.world_size}" + f",role_world_size={self.role_world_size}" + ) + + def __repr__(self): + return str(self) + + +class WorkerState(str, Enum): + """A state of the ``WorkerGroup``. + + Workers in a worker group change state as a unit. If a single worker + in a worker group fails the entire set is considered failed:: + + UNKNOWN - agent lost track of worker group state, unrecoverable + INIT - worker group object created not yet started + HEALTHY - workers running and healthy + UNHEALTHY - workers running and unhealthy + STOPPED - workers stopped (interrupted) by the agent + SUCCEEDED - workers finished running (exit 0) + FAILED - workers failed to successfully finish (exit !0) + + + A worker group starts from an initial ``INIT`` state, + then progresses to ``HEALTHY`` or ``UNHEALTHY`` states, + and finally reaches a terminal ``SUCCEEDED`` or ``FAILED`` state. + + Worker groups can be interrupted and temporarily put into ``STOPPED`` state + by the agent. Workers in ``STOPPED`` state are scheduled to be restarted + in the near future by the agent. Some examples of workers being put into + ``STOPPED`` state are: + + 1. Worker group failure|unhealthy observed + 2. Membership change detected + + When actions (start, stop, rdzv, retry, etc) on worker group fails + and results in the action being partially applied to the worker group + the state will be ``UNKNOWN``. Typically this happens on uncaught/unhandled + exceptions during state change events on the agent. The agent is not + expected to recover worker groups in ``UNKNOWN`` state and is better off + self terminating and allowing the job manager to retry the node. + """ + + UNKNOWN = "UNKNOWN" + INIT = "INIT" + HEALTHY = "HEALTHY" + UNHEALTHY = "UNHEALTHY" + STOPPED = "STOPPED" + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + + @staticmethod + def is_running(state: "WorkerState") -> bool: + """Return the state of the Worker. + + Returns: + True if the worker state represents workers still running + (e.g. that the process exists but not necessarily healthy). + """ + return state in {WorkerState.HEALTHY, WorkerState.UNHEALTHY} + + +class WorkerGroup: + """A set of ``Worker`` instances. + + The class defines a set of ``Worker`` instances for the given ``WorkerSpec`` managed by ``ElasticAgent``. Whether the worker + group contains cross instance workers or not depends on the implementation of the agent. + """ + + __slots__ = [ + "spec", + "workers", + "store", + "group_rank", + "group_world_size", + "state", + "master_addr", + "master_port", + ] + + def __init__(self, spec: WorkerSpec): + self.spec = spec + self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)] + + # assigned after rdzv + self.store = None + self.group_rank = None + self.group_world_size = None + self.master_addr = None + self.master_port = None + + self.state = WorkerState.INIT + + +class _RoleInstanceInfo: + """The class is used by the agent to exchange the information with other agents. + + The information is used to determine the rank of the workers that agent + manages in heterogeneous environments, where different agents can have + different number of workers. + """ + + __slots__ = ["role", "rank", "local_world_size"] + + def __init__(self, role: str, rank: int, local_world_size: int): + r"""Initialize the agent class instance. + + Args: + role (str): user-defined role for the workers with this spec + rank (int): the rank of the agent + local_world_size (int): number of local workers to run + """ + self.role = role + self.rank = rank + self.local_world_size = local_world_size + + def serialize(self) -> bytes: + dict_data = { + "role": self.role, + "rank": self.rank, + "local_world_size": self.local_world_size, + } + return json.dumps(dict_data).encode(encoding="UTF-8") + + @staticmethod + def deserialize(data: bytes): + dict_data = json.loads(data.decode(encoding="UTF-8")) + return _RoleInstanceInfo( + dict_data["role"], dict_data["rank"], dict_data["local_world_size"] + ) + + @staticmethod + def compare(obj1, obj2) -> int: + if obj1.role == obj2.role: + return obj1.rank - obj2.rank + elif obj1.role > obj2.role: + return 1 + else: + return -1 + + @staticmethod + def find_role_boundaries(roles_infos: list, role: str) -> tuple[int, int]: + start_idx, end_idx = -1, -1 + for idx, role_info in enumerate(roles_infos): + if role_info.role == role: + if start_idx == -1: + start_idx = idx + end_idx = idx + return (start_idx, end_idx) + + +@dataclass +class RunResult: + """Return results of the worker executions. + + Run results follow an "all-or-nothing" policy where the run is successful if and + only if ALL local workers managed by this agent complete successfully. + + If the result is successful (e.g. ``is_failed() = False``) then the ``return_values`` + field contains the outputs (return values) of the workers managed by THIS agent mapped + by their GLOBAL ranks. That is ``result.return_values[0]`` is the return value of + global rank 0. + + .. note:: ``return_values`` are only meaningful for when the worker entrypoint + is a function. Workers specified as a binary entrypoint do not canonically + have a return value and the ``return_values`` field is meaningless and + may be empty. + + If ``is_failed()`` returns ``True`` then the ``failures`` field contains the + failure information, again, mapped by the GLOBAL rank of the worker that failed. + + The keys in ``return_values`` and ``failures`` are mutually exclusive, that is, + a worker's final state can only be one of: succeeded, failed. Workers intentionally + terminated by the agent according to the agent's restart policy, are not represented + in either ``return_values`` nor ``failures``. + """ + + state: WorkerState + return_values: dict[int, Any] = field(default_factory=dict) + failures: dict[int, ProcessFailure] = field(default_factory=dict) + + def is_failed(self) -> bool: + return self.state == WorkerState.FAILED + + +def _get_fq_hostname() -> str: + return socket.getfqdn(socket.gethostname()) + + +class ElasticAgent(abc.ABC): + """An agent process responsible for managing one or more worker processes. + + The worker processes are assumed to be regular distributed PyTorch scripts. + When the worker process is created by the agent, the agent provides the + necessary information for the worker processes to properly initialize + a torch process group. + + The exact deployment topology and ratio of agent-to-worker is dependent + on the specific implementation of the agent and the user's job placement + preferences. For instance, to run a distributed training job on GPU with + 8 trainers (one per GPU) one can: + + 1. Use 8 x single GPU instances, place an agent per instance, managing + 1 worker per agent. + 2. Use 4 x double GPU instances, place an agent per instance, managing + 2 workers per agent. + 3. Use 2 x quad GPU instances, place an agent per instance, managing + 4 workers per agent. + 4. Use 1 x 8 GPU instance, place an agent per instance, managing + 8 workers per agent. + + Usage + :: + + group_result = agent.run() + if group_result.is_failed(): + # workers failed + failure = group_result.failures[0] + logger.exception("worker 0 failed with exit code : %s", failure.exit_code) + else: + return group_result.return_values[0] # return rank 0's results + + """ + + @abc.abstractmethod + def run(self, role: str = DEFAULT_ROLE) -> RunResult: + """Run the agent. + + Supports retrying the worker group on failures up to ``max_restarts``. + + Returns: + The result of the execution, containing the return values or + failure details for each worker mapped by the worker's global rank. + + Raises: + Exception - any other failures NOT related to worker process + """ + raise NotImplementedError + + @abc.abstractmethod + def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup: + """Return the ``WorkerGroup`` for the given ``role``. + + Note that the worker group is a mutable object and hence in a + multi-threaded/process environment it may change state. + Implementers are encouraged (but not required) to return + a defensive read-only copy. + """ + raise NotImplementedError + + +class SimpleElasticAgent(ElasticAgent): + """An ``ElasticAgent`` that manages one particular type of worker role. + + An ``ElasticAgent`` that manages workers (``WorkerGroup``) for a single ``WorkerSpec`` + such as one particular type of worker role. + """ + + def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300): + self._worker_group = WorkerGroup(spec) + self._remaining_restarts = self._worker_group.spec.max_restarts + self._store = None + self._exit_barrier_timeout = exit_barrier_timeout + self._total_execution_time = 0 + + def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup: + return self._worker_group + + @abc.abstractmethod + def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]: + r"""Start ``worker_group.spec.local_world_size`` number of workers. + + This is according to worker spec for the worker group . + Returns a map of ``local_rank`` to worker ``id``. + """ + raise NotImplementedError + + @abc.abstractmethod + def _stop_workers(self, worker_group: WorkerGroup) -> None: + r"""Stop all workers in the given worker group. + + Implementers must deal with workers in all states defined by + ``WorkerState``. That is, it must gracefully handle stopping + non-existent workers, unhealthy (stuck) workers, etc. + """ + raise NotImplementedError + + @abc.abstractmethod + def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: + r"""Check on the workers for the ``worker_group``. + + This function also returns the new state of the worker group. + """ + raise NotImplementedError + + @abc.abstractmethod + def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None: + """Clean up any resources that were allocated during the agent's work. + + Args: + death_sig: Signal to send to the child process, SIGTERM is default + """ + raise NotImplementedError + + @prof + def _rendezvous(self, worker_group: WorkerGroup) -> None: + r"""Run rendezvous for the workers specified by the worker spec. + + Assigns workers a new global rank and world size. + Updates the rendezvous store for the worker group. + """ + spec = worker_group.spec + + with self.record_duration("RENDEZVOUS"): + rdzv_info = spec.rdzv_handler.next_rendezvous() + store = rdzv_info.store + group_rank = rdzv_info.rank + group_world_size = rdzv_info.world_size + + # master_addr/master_port could be explicitly overridden + # TODO: BC - specific to static rdzv and can be simplified further + master_addr = spec.master_addr or rdzv_info.bootstrap_store_info.master_addr + master_port = spec.master_port or rdzv_info.bootstrap_store_info.master_port + + self._store = store + + with self.record_duration("ASSIGN_WORKER_RANKS"): + workers = self._assign_worker_ranks( + store, group_rank, group_world_size, spec + ) + worker_group.workers = workers + worker_group.store = store + worker_group.group_rank = group_rank + worker_group.group_world_size = group_world_size + worker_group.master_addr = master_addr + worker_group.master_port = master_port + + restart_count = spec.max_restarts - self._remaining_restarts + + logger.info( + "[%(role)s] Rendezvous complete for workers. Result:\n" + " restart_count=%(restart_count)s\n" + " master_addr=%(master_addr)s\n" + " master_port=%(master_port)s\n" + " group_rank=%(group_rank)s\n" + " group_world_size=%(group_world_size)s\n" + " local_ranks=%(local_ranks)s\n" + " role_ranks=%(role_ranks)s\n" + " global_ranks=%(global_ranks)s\n" + " role_world_sizes=%(role_world_sizes)s\n" + " global_world_sizes=%(global_world_sizes)s\n" + " event_log_handler=%(event_log_handler)s\n", + { + "role": spec.role, + "restart_count": restart_count, + "master_addr": master_addr, + "master_port": master_port, + "group_rank": group_rank, + "group_world_size": group_world_size, + "local_ranks": [worker.local_rank for worker in workers], + "role_ranks": [worker.role_rank for worker in workers], + "global_ranks": [worker.global_rank for worker in workers], + "role_world_sizes": [worker.role_world_size for worker in workers], + "global_world_sizes": [worker.world_size for worker in workers], + "event_log_handler": spec.event_log_handler, + }, + ) + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def _assign_worker_ranks( + self, store, group_rank: int, group_world_size: int, spec: WorkerSpec + ) -> list[Worker]: + """Determine proper ranks for worker processes. + + Fast Path: when all workers have the same role and world size. We calculate + the global rank to be group_rank * group_world_size + local_rank. And the + `role_world_size` is the same as `global_world_size`. No TCP store is used in + this case. This is only enabled when users set the environment variable + `TORCH_ELASTIC_WORKER_IDENTICAL` to 1. + + Time complexity: each worker O(1), overall O(1) + + Slow Path: when workers have different roles and world sizes. We use the + the following algorithm: + + 1. Each agent writes its configuration(group_rank, group_world_size + , num_workers) to the common store. + 2. The rank 0 agent reads all the role_info from the store and + determines each agents worker ranks. + 3. Determine the global rank: the global rank of the workers is computed + by cumulative sum of the local_world_size for all workers in front of it. + For efficiency reasons each worker is assigned a base global rank + such that it's workers are in the range [base_global_rank, + base_global_rank + local_world_size). + 4. Determine the role rank: The role rank is determined using the algorithms + in the point 3 with the exception that the ranks are calculated with + respect to the role name. + 5. The rank 0 agent writes the assigned ranks to the store. + 6. Each agent reads the assigned ranks from the store. + + Time complexity: each worker O(1), rank0 O(n), overall O(n) + """ + + if os.environ.get("TORCH_ELASTIC_WORKER_IDENTICAL", "0") == "1": + global_world_size = group_world_size * spec.local_world_size + base_global_rank = group_rank * spec.local_world_size + base_role_rank = base_global_rank + role_world_size = global_world_size + else: + ROLE_INFO_PREFIX = "torchelastic/role_info/" + ASSIGNED_RANKS_PREFIX = "torchelastic/assigned_ranks/" + + agent_role_info = _RoleInstanceInfo( + spec.role, group_rank, spec.local_world_size + ) + store.set(f"{ROLE_INFO_PREFIX}{group_rank}", agent_role_info.serialize()) + + # tcp store is collocated with rank 0 so we can use it to do extra compute to reduce overall # of operations. + if group_rank == 0: + role_infos_bytes = store.multi_get( + [f"torchelastic/role_info/{i}" for i in range(group_world_size)] + ) + role_infos = [ + _RoleInstanceInfo.deserialize(info_bytes) + for info_bytes in role_infos_bytes + ] + + role_sizes = defaultdict(lambda: 0) + global_size = 0 + for role_info in role_infos: + role_sizes[role_info.role] += role_info.local_world_size + global_size += role_info.local_world_size + + base_global_rank = 0 + role_ranks = defaultdict(lambda: 0) + + keys = [] + values = [] + for i, role_info in enumerate(role_infos): + keys.append(f"{ASSIGNED_RANKS_PREFIX}{i}") + values.append( + json.dumps( + [ + base_global_rank, + global_size, + role_ranks[role_info.role], + role_sizes[role_info.role], + ] + ) + ) + + base_global_rank += role_info.local_world_size + role_ranks[role_info.role] += role_info.local_world_size + + store.multi_set(keys, values) + + # get will block until the data is available in the store. + ( + base_global_rank, + global_world_size, + base_role_rank, + role_world_size, + ) = json.loads(store.get(f"{ASSIGNED_RANKS_PREFIX}{group_rank}")) + + workers = [] + for local_rank in range(spec.local_world_size): + worker = Worker( + local_rank=local_rank, + global_rank=base_global_rank + local_rank, + role_rank=base_role_rank + local_rank, + world_size=global_world_size, + role_world_size=role_world_size, + ) + workers.append(worker) + return workers + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def _initialize_workers(self, worker_group: WorkerGroup) -> None: + r"""Start a fresh set of workers for the worker_group. + + Essentially, a rendezvous followed by a ``start_workers``. + The caller should first call ``_stop_workers()`` to stop running workers + prior to calling this method. + + Optimistically sets the state of the worker group that + just started as ``HEALTHY`` and delegates the actual monitoring + of state to ``_monitor_workers()`` method + """ + role = worker_group.spec.role + logger.info("[%s] Rendezvous'ing worker group", role) + + # TODO after stopping workers, wait at least monitor_interval*2 for + # workers on different nodes to fail on a collective op before waiting + # on the rdzv barrier, this way we ensure that nodes enter rdzv + # at around the same time and reduce false positive rdzv timeout errors + self._rendezvous(worker_group) + + logger.info("[%s] Starting worker group", role) + worker_ids = self._start_workers(worker_group) + for local_rank, w_id in worker_ids.items(): + worker = worker_group.workers[local_rank] + worker.id = w_id + record( + self._construct_event("START", EventSource.WORKER, worker), + worker_group.spec.event_log_handler, + ) + + worker_group.state = WorkerState.HEALTHY + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def _restart_workers(self, worker_group: WorkerGroup) -> None: + """Restart (stops, rendezvous, starts) all local workers in the group.""" + role = worker_group.spec.role + logger.info("[%s] Stopping worker group", role) + self._stop_workers(worker_group) + worker_group.state = WorkerState.STOPPED + self._initialize_workers(worker_group) + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def run(self, role: str = DEFAULT_ROLE) -> RunResult: + start_time = time.monotonic() + shutdown_called: bool = False + try: + result = self._invoke_run(role) + self._total_execution_time = int(time.monotonic() - start_time) + self._record_metrics(result) + self._record_worker_events(result) + return result + except RendezvousGracefulExitError as e: + logger.info("Rendezvous gracefully exited: %s", e) # noqa: G200 + except SignalException as e: + logger.warning("Received %s death signal, shutting down workers", e.sigval) + self._shutdown(e.sigval) + shutdown_called = True + raise + finally: + if not shutdown_called: + self._shutdown() + # record the execution time in case there were any exceptions during run. + self._total_execution_time = int(time.monotonic() - start_time) + + def get_event_failed(self) -> Event: + return self._construct_event( + state="FAILED", + source=EventSource.AGENT, + raw_error=traceback.format_exc(), + ) + + def get_event_succeeded(self) -> Event: + return self._construct_event( + state="SUCCEEDED", + source=EventSource.AGENT, + ) + + def _record_worker_events(self, result: RunResult) -> None: + for worker in self._worker_group.workers: + failure = result.failures.get(worker.global_rank) + state: str = self._get_worker_state(worker, result) + raw_error = json.dumps(failure.error_file_data) if failure else None + exit_code = failure.exitcode if failure else None + worker_pid = failure.pid if failure else None + record( + self._construct_event( + state=state, + source=EventSource.WORKER, + worker=worker, + raw_error=raw_error, + exit_code=exit_code, + worker_pid=worker_pid, + ), + self._worker_group.spec.event_log_handler, + ) + + def _get_worker_state(self, worker: Worker, result: RunResult) -> str: + failure = result.failures.get(worker.global_rank) + if result.state in {WorkerState.UNHEALTHY, WorkerState.FAILED} and not failure: + # The worker got terminated by the torchelastic agent via SIGTERM signal + return "TERMINATED" + elif failure or worker.global_rank in result.return_values: + return result.state.value + else: + raise ValueError(f"Unknown worker: {worker.global_rank}") + + @contextmanager + def record_duration(self, state: str): + start_time = time.perf_counter() + try: + yield + finally: + end_time = time.perf_counter() + duration_ms = (end_time - start_time) * 1000 + record( + self._construct_event( + state=state, source=EventSource.AGENT, duration_ms=duration_ms + ), + self._worker_group.spec.event_log_handler, + ) + + def _construct_event( + self, + state: str, + source: EventSource, + worker: Worker | None = None, + raw_error: str | None = None, + duration_ms: float | None = None, + exit_code: int | None = None, + worker_pid: int | None = None, + ) -> Event: + wg = self._worker_group + spec = wg.spec + md = { + "group_world_size": wg.group_world_size, + "entry_point": spec.get_entrypoint_name(), + } + if worker: + md["local_rank"] = (worker.local_rank,) + md["role_rank"] = (worker.role_rank,) + md["role_world_size"] = (worker.role_world_size,) + md["exit_code"] = (exit_code,) + md["worker_pid"] = (worker_pid,) + global_rank = worker.global_rank + worker_id = str(worker.id) + else: + global_rank = None + worker_id = None + md_str = json.dumps(md) + metadata = { + "run_id": spec.rdzv_handler.get_run_id(), + "global_rank": global_rank, + "group_rank": wg.group_rank, + "worker_id": worker_id, + "role": spec.role, + "hostname": _get_fq_hostname(), + "state": state, + "total_run_time": self._total_execution_time, + "rdzv_backend": spec.rdzv_handler.get_backend(), + "raw_error": raw_error, + "metadata": md_str, + "agent_restarts": spec.max_restarts - self._remaining_restarts, + "duration_ms": duration_ms, + } + + return Event( + f"torchelastic.worker.status.{state}", source=source, metadata=metadata + ) + + def _record_metrics(self, group_results: RunResult): + is_failed = group_results.is_failed() + self._record_flakiness_metric(is_failed) + spec = self._worker_group.spec + restarts_happened = self._remaining_restarts != spec.max_restarts + put_metric(f"workers.{spec.role}.run_total", 1) + self._record_metric_with_condition( + "run_success_with_retries", not is_failed and restarts_happened + ) + self._record_metric_with_condition( + "run_success_no_retries", not is_failed and not restarts_happened + ) + self._record_metric_with_condition( + "run_failed_with_retries", is_failed and restarts_happened + ) + self._record_metric_with_condition( + "run_failed_no_retries", is_failed and not restarts_happened + ) + + def _record_metric_with_condition(self, metric_name, condition): + spec = self._worker_group.spec + if condition: + put_metric(f"workers.{spec.role}.{metric_name}", 1) + else: + put_metric(f"workers.{spec.role}.{metric_name}", 0) + + def _record_flakiness_metric(self, is_failed: bool = False): + if is_failed: + flakiness = 100.0 + else: + spec = self._worker_group.spec + flakiness = 100.0 - 100.0 * (self._remaining_restarts + 1) / ( + spec.max_restarts + 1 + ) + spec = self._worker_group.spec + + put_metric(f"workers.{spec.role}.flakiness", int(flakiness)) + + def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: + # NOTE: currently only works for a single role + + spec = self._worker_group.spec + role = spec.role + + logger.info( + "[%s] starting workers for entrypoint: %s", role, spec.get_entrypoint_name() + ) + + self._initialize_workers(self._worker_group) + monitor_interval = spec.monitor_interval + rdzv_handler = spec.rdzv_handler + + while True: + assert self._worker_group.state != WorkerState.INIT + time.sleep(monitor_interval) + run_result = self._monitor_workers(self._worker_group) + state = run_result.state + self._worker_group.state = state + + put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts) + put_metric(f"workers.{role}.{state.name.lower()}", 1) + + if state == WorkerState.SUCCEEDED: + logger.info( + "[%s] worker group successfully finished." + " Waiting %s seconds for other agents to finish.", + role, + self._exit_barrier_timeout, + ) + self._exit_barrier() + return run_result + elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}: + if self._remaining_restarts > 0: + logger.info( + "[%s] Worker group %s. " + "%s/%s attempts left;" + " will restart worker group", + role, + state.name, + self._remaining_restarts, + spec.max_restarts, + ) + self._remaining_restarts -= 1 + self._restart_workers(self._worker_group) + else: + self._stop_workers(self._worker_group) + self._worker_group.state = WorkerState.FAILED + return run_result + elif state == WorkerState.HEALTHY: + # membership changes do not count as retries + num_nodes_waiting = rdzv_handler.num_nodes_waiting() + group_rank = self._worker_group.group_rank + if num_nodes_waiting > 0: + logger.info( + "[%s] Detected %s " + "new nodes from group_rank=%s; " + "will restart worker group", + role, + num_nodes_waiting, + group_rank, + ) + self._restart_workers(self._worker_group) + else: + raise Exception( # noqa: TRY002 + f"[{role}] Worker group in {state.name} state" + ) + + def _exit_barrier(self): + """ + Define a barrier that keeps the agent process alive until all workers finish. + + Wait for ``exit_barrier_timeout`` seconds for all agents to finish + executing their local workers (either successfully or not). This + acts as a safety guard against user scripts that terminate at different + times. + """ + logger.info( + "Local worker group finished (%s). " + "Waiting %s seconds for other agents to finish", + self._worker_group.state, + self._exit_barrier_timeout, + ) + start = time.time() + try: + store_util.barrier( + store=self._store, + world_size=self._worker_group.group_world_size, + key_prefix=_TERMINAL_STATE_SYNC_ID, + barrier_timeout=self._exit_barrier_timeout, + ) + logger.info( + "Done waiting for other agents. Elapsed: %s seconds", + time.time() - start, + ) + except SignalException as e: + logger.warning("Got termination signal: %s", e.sigval) + raise + except Exception: + logger.exception( + "Error waiting on exit barrier. Elapsed: %s seconds", + time.time() - start, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/health_check_server.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/health_check_server.py new file mode 100644 index 0000000000000000000000000000000000000000..4815d86aa289c531a01bfcc8277b7ae9ffb2930e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/health_check_server.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Callable + +from torch.distributed.elastic.utils.logging import get_logger + + +log = get_logger(__name__) + +__all__ = ["HealthCheckServer", "create_healthcheck_server"] + + +class HealthCheckServer: + """ + Interface for health check monitoring server, which can be extended + by starting tcp/http server on the specified port. + + Args: + + alive_callback: Callable[[], int], callback to last progress time of agent + + port: int, port number to start tcp/http server + + timeout: int, timeout seconds to decide agent is alive/dead + """ + + _alive_callback: Callable[[], int] + _port: int + _timeout: int + + def __init__( + self, alive_callback: Callable[[], int], port: int, timeout: int + ) -> None: + self._alive_callback = alive_callback + self._port = port + self._timeout = timeout + + def start(self) -> None: + """ + Unsupported functionality for Pytorch, doesn't start any health check server + """ + log.warning("No health check server started") + + def stop(self) -> None: + """ + Function to stop health check server + """ + log.info("Stopping noop health check server.") + + +def create_healthcheck_server( + alive_callback: Callable[[], int], + port: int, + timeout: int, +) -> HealthCheckServer: + """ + creates health check server object + """ + return HealthCheckServer(alive_callback, port, timeout) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..ef281b6c58c318a06e2c97832ab43171313e56df --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py @@ -0,0 +1,456 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import json +import os +import signal +import socket +import time +import uuid +from string import Template +from typing import Any, TYPE_CHECKING + +import torch.distributed.elastic.timer as timer +from torch.distributed.elastic import events +from torch.distributed.elastic.agent.server.api import ( + RunResult, + SimpleElasticAgent, + WorkerGroup, + WorkerSpec, + WorkerState, +) +from torch.distributed.elastic.agent.server.health_check_server import ( + create_healthcheck_server, + HealthCheckServer, +) +from torch.distributed.elastic.metrics.api import prof +from torch.distributed.elastic.multiprocessing import ( + LogsSpecs, + PContext, + start_processes, +) +from torch.distributed.elastic.utils import macros +from torch.distributed.elastic.utils.logging import get_logger + + +if TYPE_CHECKING: + from torch.distributed.elastic.events.api import EventMetadataValue + +logger = get_logger(__name__) + +__all__ = [ + "LocalElasticAgent", + "TORCHELASTIC_ENABLE_FILE_TIMER", + "TORCHELASTIC_TIMER_FILE", + "TORCHELASTIC_HEALTH_CHECK_PORT", +] + +TORCHELASTIC_ENABLE_FILE_TIMER = "TORCHELASTIC_ENABLE_FILE_TIMER" +TORCHELASTIC_HEALTH_CHECK_PORT = "TORCHELASTIC_HEALTH_CHECK_PORT" +TORCHELASTIC_TIMER_FILE = "TORCHELASTIC_TIMER_FILE" + + +class LocalElasticAgent(SimpleElasticAgent): + """An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers. + + This agent is deployed per host and is configured to spawn ``n`` workers. + When using GPUs, ``n`` maps to the number of GPUs available on the host. + + The local agent does not communicate to other local agents deployed on + other hosts, even if the workers may communicate inter-host. The worker id + is interpreted to be a local process. The agent starts and stops all worker + processes as a single unit. + + + The worker function and argument passed to the worker function must be + python multiprocessing compatible. To pass multiprocessing data structures + to the workers you may create the data structure in the same multiprocessing + context as the specified ``start_method`` and pass it as a function argument. + + The ``exit_barrier_timeout`` specifies the amount of time (in seconds) to wait + for other agents to finish. This acts as a safety net to handle cases where + workers finish at different times, to prevent agents from viewing workers + that finished early as a scale-down event. It is strongly advised that the + user code deal with ensuring that workers are terminated in a synchronous + manner rather than relying on the exit_barrier_timeout. + + A named pipe based watchdog can be enabled in ```LocalElasticAgent``` if an + environment variable ``TORCHELASTIC_ENABLE_FILE_TIMER`` with value 1 has + been defined in the ```LocalElasticAgent``` process. + Optionally, another environment variable ```TORCHELASTIC_TIMER_FILE``` + can be set with a unique file name for the named pipe. If the environment + variable ```TORCHELASTIC_TIMER_FILE``` is not set, ```LocalElasticAgent``` + will internally create a unique file name and set it to the environment + variable ```TORCHELASTIC_TIMER_FILE```, and this environment variable will + be propagated to the worker processes to allow them to connect to the same + named pipe that ```LocalElasticAgent``` uses. + + Logs are written to the specified log directory. Each log line will be by default + prefixed by ``[${role_name}${local_rank}]:`` (e.g. ``[trainer0]: foobar``). + Log prefixes can be customized by passing a `template string + `_ as the + ``log_line_prefix_template`` argument. + The following macros (identifiers) are substituted at runtime: + ``${role_name}, ${local_rank}, ${rank}``. For example, to prefix each log line with + global rank instead of the local rank, set ``log_line_prefix_template = "[${rank}]:``. + + + Example launching function + + :: + + def trainer(args) -> str: + return "do train" + + def main(): + start_method="spawn" + shared_queue= multiprocessing.get_context(start_method).Queue() + spec = WorkerSpec( + role="trainer", + local_world_size=nproc_per_process, + entrypoint=trainer, + args=("foobar",), + ...) + agent = LocalElasticAgent(spec, start_method) + results = agent.run() + + if results.is_failed(): + print("trainer failed") + else: + print(f"rank 0 return value: {results.return_values[0]}") + # prints -> rank 0 return value: do train + + Example launching binary + + :: + + def main(): + spec = WorkerSpec( + role="trainer", + local_world_size=nproc_per_process, + entrypoint="/usr/local/bin/trainer", + args=("--trainer-args", "foobar"), + ...) + agent = LocalElasticAgent(spec) + results = agent.run() + + if not results.is_failed(): + print("binary launches do not have return values") + + """ + + def __init__( + self, + spec: WorkerSpec, + logs_specs: LogsSpecs, + start_method="spawn", + exit_barrier_timeout: float = 300, + log_line_prefix_template: str | None = None, + ): + super().__init__(spec, exit_barrier_timeout) + self._start_method = start_method + self._pcontext: PContext | None = None + self._rdzv_handler = spec.rdzv_handler + self._log_line_prefix_template = log_line_prefix_template + self._worker_watchdog: timer.FileTimerServer | None = None + self._logs_specs = logs_specs + self._health_check_server: HealthCheckServer | None = None + + def _setup_local_watchdog(self, envs: dict[int, dict[str, str]]) -> None: + enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER + watchdog_enabled = os.getenv(enable_watchdog_env_name) + watchdog_file_env_name = TORCHELASTIC_TIMER_FILE + watchdog_file_path = os.getenv(watchdog_file_env_name) + if watchdog_enabled is not None and str(watchdog_enabled) == "1": + if watchdog_file_path is None: + watchdog_file_path = "/tmp/watchdog_timer_" + str(uuid.uuid4()) + logger.info("Starting a FileTimerServer with %s ...", watchdog_file_path) + if not envs: + logger.warning( + "Empty envs variables, using empty run_id for FileTimerServer" + ) + run_id = "" + else: + run_id = envs[0]["TORCHELASTIC_RUN_ID"] + self._worker_watchdog = timer.FileTimerServer( + file_path=watchdog_file_path, + run_id=run_id, + max_interval=0.1, + daemon=True, + log_event=self._log_watchdog_event, + ) + self._worker_watchdog.start() + logger.info("FileTimerServer started") + else: + logger.info( + "Environment variable '%s' not found. Do not start FileTimerServer.", + enable_watchdog_env_name, + ) + # Propagate the watchdog file env to worker processes + if watchdog_file_path is not None: + for worker_env in envs.values(): + worker_env[watchdog_file_env_name] = watchdog_file_path + + @staticmethod + def _get_current_time_secs() -> int: + return int(time.time()) + + def _setup_healthcheck(self) -> None: + healthcheck_port_env_name = TORCHELASTIC_HEALTH_CHECK_PORT + healthcheck_port = os.getenv(healthcheck_port_env_name) + if healthcheck_port is not None: + logger.info( + "Found healthcheck port %s: %s", + healthcheck_port_env_name, + healthcheck_port, + ) + if self._worker_watchdog is None: + logger.info( + "FileTimerServer doesn't exist, using current time as dummy callback" + ) + alive_callback = LocalElasticAgent._get_current_time_secs + else: + alive_callback = self._worker_watchdog.get_last_progress_time + + try: + healthcheck_port_as_int = int(healthcheck_port) + self._health_check_server = create_healthcheck_server( + alive_callback=alive_callback, + port=healthcheck_port_as_int, + timeout=60, + ) + self._health_check_server.start() + except ValueError: + logger.info( + "Invalid healthcheck port value: '%s', expecting integer. Not starting healthcheck server.", + healthcheck_port, + ) + else: + logger.info( + "Environment variable '%s' not found. Do not start health check.", + healthcheck_port_env_name, + ) + + def _get_fq_hostname(self) -> str: + return socket.getfqdn(socket.gethostname()) + + def _log_watchdog_event( + self, + name: str, + request: timer.FileTimerRequest | None, + ) -> None: + wg = self._worker_group + spec = wg.spec + md = {"watchdog_event": name} + if request is not None: + md["worker_pid"] = str(request.worker_pid) + md["scope_id"] = request.scope_id + md["expiration_time"] = str(request.expiration_time) + md["signal"] = str(request.signal) + md_str = json.dumps(md) + state = "RUNNING" + metadata: dict[str, EventMetadataValue] = { + "run_id": spec.rdzv_handler.get_run_id(), + "global_rank": None, + "group_rank": wg.group_rank, + "worker_id": None, + "role": spec.role, + "hostname": self._get_fq_hostname(), + "state": state, + "total_run_time": self._total_execution_time, + "rdzv_backend": spec.rdzv_handler.get_backend(), + "raw_error": None, + "metadata": md_str, + "agent_restarts": spec.max_restarts - self._remaining_restarts, + } + # Note: The 'metadata' field of the Event is converted to a TorchelasticStatusLogEntry later. + # The 'name' field of the Event is NOT used in the TorchelasticStatusLogEntry. + event = events.Event( + name=name, source=events.EventSource.AGENT, metadata=metadata + ) + events.record(event, self._worker_group.spec.event_log_handler) + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def _stop_workers(self, worker_group: WorkerGroup) -> None: + self._shutdown() + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]: + spec = worker_group.spec + store = worker_group.store + assert store is not None + restart_count = spec.max_restarts - self._remaining_restarts + + use_agent_store: bool = spec.rdzv_handler.use_agent_store + logger.info("use_agent_store: %s", use_agent_store) + + args: dict[int, tuple] = {} + envs: dict[int, dict[str, str]] = {} + log_line_prefixes: dict[int, str] | None = ( + {} if self._log_line_prefix_template else None + ) + for worker in worker_group.workers: + local_rank = worker.local_rank + worker_env = { + "RANK": str(worker.global_rank), + "GROUP_RANK": str(worker_group.group_rank), + "ROLE_RANK": str(worker.role_rank), + "ROLE_NAME": spec.role, + "LOCAL_WORLD_SIZE": str(spec.local_world_size), + "WORLD_SIZE": str(worker.world_size), + "GROUP_WORLD_SIZE": str(worker_group.group_world_size), + "ROLE_WORLD_SIZE": str(worker.role_world_size), + "MASTER_ADDR": worker_group.master_addr, + "MASTER_PORT": str(worker_group.master_port), + "TORCHELASTIC_RESTART_COUNT": str(restart_count), + "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts), + "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(), + "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store), + "TORCH_NCCL_ASYNC_ERROR_HANDLING": os.getenv( + "TORCH_NCCL_ASYNC_ERROR_HANDLING", str(1) + ), + } + self._set_local_rank_env(worker_env, local_rank, spec) + if "OMP_NUM_THREADS" in os.environ: + worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"] + + if self._log_line_prefix_template: + log_line_prefix = Template( + self._log_line_prefix_template + ).safe_substitute( + role_name=spec.role, + rank=worker.global_rank, + local_rank=local_rank, + ) + # pyrefly: ignore [unsupported-operation] + log_line_prefixes[local_rank] = log_line_prefix + + # pyrefly: ignore [unsupported-operation] + envs[local_rank] = worker_env + worker_args = list(spec.args) + worker_args = macros.substitute(worker_args, str(local_rank)) + args[local_rank] = tuple(worker_args) + + self._setup_local_watchdog(envs=envs) + self._setup_healthcheck() + + assert spec.entrypoint is not None + assert self._logs_specs is not None + self._pcontext = start_processes( + name=spec.role, + entrypoint=spec.entrypoint, + args=args, + envs=envs, + logs_specs=self._logs_specs, + log_line_prefixes=log_line_prefixes, + start_method=self._start_method, + numa_options=spec.numa_options, + duplicate_stdout_filters=spec.duplicate_stdout_filters, + duplicate_stderr_filters=spec.duplicate_stderr_filters, + ) + + return self._pcontext.pids() + + def _set_local_rank_env( + self, worker_env: dict[str, str | None], local_rank: int, spec: WorkerSpec + ) -> None: + # Set CUDA_VISIBLE_DEVICES and LOCAL_RANK based on virtual_local_rank mode. + # Virtual mode: Each worker sees only its assigned GPU as device 0, LOCAL_RANK=0 + # Traditional mode: Workers see all GPUs, LOCAL_RANK matches actual local rank + + if spec.virtual_local_rank: + # Set LOCAL_RANK=0 and use CUDA_VISIBLE_DEVICES to control the actual GPU access. + + worker_env["LOCAL_RANK"] = "0" + + # Map local_rank through existing CUDA_VISIBLE_DEVICES + # HIP uses CUDA_VISIBLE_DEVICES as a compatibility hack: + # https://rocm.docs.amd.com/en/latest/conceptual/gpu-isolation.html#cuda-visible-devices + parent_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") + if parent_visible_devices is not None: + # Parse comma-separated list of GPU IDs + available_gpus = parent_visible_devices.split(",") + if local_rank >= len(available_gpus): + raise ValueError( + f"local_rank {local_rank} exceeds available GPUs in " + f"CUDA_VISIBLE_DEVICES={parent_visible_devices}" + ) + + visible_gpu = available_gpus[local_rank].strip() + else: + # No restriction, use local_rank directly + visible_gpu = str(local_rank) + + worker_env["CUDA_VISIBLE_DEVICES"] = visible_gpu + return + + # In traditional mode, don't override CUDA_VISIBLE_DEVICES + # (inherit from parent environment) + worker_env["LOCAL_RANK"] = str(local_rank) + + if "CUDA_VISIBLE_DEVICES" in os.environ: + worker_env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"] + + def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None: + if self._worker_watchdog is not None: + self._worker_watchdog.stop() + self._worker_watchdog = None + if self._health_check_server is not None: + self._health_check_server.stop() + self._health_check_server = None + if self._pcontext: + self._pcontext.close(death_sig) + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: + role = worker_group.spec.role + worker_pids = {w.id for w in worker_group.workers} + assert self._pcontext is not None + pc_pids = set(self._pcontext.pids().values()) + if worker_pids != pc_pids: + logger.error( + "[%s] worker pids do not match process_context pids." + " Expected: %s, actual: %s", + role, + worker_pids, + pc_pids, + ) + return RunResult(state=WorkerState.UNKNOWN) + + result = self._pcontext.wait(0) + if result: + if result.is_failed(): + # map local rank failure to global rank + worker_failures = {} + for local_rank, failure in result.failures.items(): + worker = worker_group.workers[local_rank] + worker_failures[worker.global_rank] = failure + return RunResult( + state=WorkerState.FAILED, + failures=worker_failures, + ) + else: + # copy ret_val_queue into a map with a global ranks + workers_ret_vals = {} + for local_rank, ret_val in result.return_values.items(): + worker = worker_group.workers[local_rank] + workers_ret_vals[worker.global_rank] = ret_val + return RunResult( + state=WorkerState.SUCCEEDED, + return_values=workers_ret_vals, + ) + else: + return RunResult(state=WorkerState.HEALTHY) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/control_plane.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/control_plane.py new file mode 100644 index 0000000000000000000000000000000000000000..817255edd23dcee2deea8554ada3637d30f9885f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/control_plane.py @@ -0,0 +1,53 @@ +import os +from collections.abc import Generator +from contextlib import contextmanager, ExitStack + +from torch.distributed.elastic.multiprocessing.errors import record + + +__all__ = [ + "worker_main", +] + +TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET" + + +@contextmanager +def _worker_server(socket_path: str) -> Generator[None, None, None]: + from torch._C._distributed_c10d import _WorkerServer + + server = _WorkerServer(socket_path) + try: + yield + finally: + server.shutdown() + + +@record +@contextmanager +def worker_main() -> Generator[None, None, None]: + """ + This is a context manager that wraps your main entry function. This combines + the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that + exposes handlers via a unix socket specified by + ``Torch_WORKER_SERVER_SOCKET``. + + Example + + :: + + @worker_main() + def main(): + pass + + + if __name__ == "__main__": + main() + + """ + with ExitStack() as stack: + socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET) + if socket_path is not None: + stack.enter_context(_worker_server(socket_path)) + + yield diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/events/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/events/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..deea40f3899aee490a899cfa1dd6d3019512cb9e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/events/__init__.py @@ -0,0 +1,173 @@ +#!/usr/bin/env/python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Module contains events processing mechanisms that are integrated with the standard python logging. + +Example of usage: + +:: + + from torch.distributed.elastic import events + + event = events.Event( + name="test_event", source=events.EventSource.WORKER, metadata={...} + ) + events.get_logging_handler(destination="console").info(event) + +""" + +import inspect +import logging +import os +import socket +import traceback +from typing import Optional + +from torch.distributed.elastic.events.handlers import get_logging_handler + +from .api import ( # noqa: F401 + Event, + EventMetadataValue, + EventSource, + NodeState, + RdzvEvent, +) + + +_events_loggers: dict[str, logging.Logger] = {} + + +def _get_or_create_logger(destination: str = "null") -> logging.Logger: + """ + Construct python logger based on the destination type or extends if provided. + + Available destination could be found in ``handlers.py`` file. + The constructed logger does not propagate messages to the upper level loggers, + e.g. root logger. This makes sure that a single event can be processed once. + + Args: + destination: The string representation of the event handler. + Available handlers found in ``handlers`` module + """ + global _events_loggers + + if destination not in _events_loggers: + _events_logger = logging.getLogger(f"torchelastic-events-{destination}") + _events_logger.setLevel(os.environ.get("LOGLEVEL", "INFO")) + # Do not propagate message to the root logger + _events_logger.propagate = False + + logging_handler = get_logging_handler(destination) + _events_logger.addHandler(logging_handler) + + # Add the logger to the global dictionary + _events_loggers[destination] = _events_logger + + return _events_loggers[destination] + + +def record(event: Event, destination: str = "null") -> None: + _get_or_create_logger(destination).info(event.serialize()) + + +def record_rdzv_event(event: RdzvEvent) -> None: + _get_or_create_logger("dynamic_rendezvous").info(event.serialize()) + + +def construct_and_record_rdzv_event( + run_id: str, + message: str, + node_state: NodeState, + name: str = "", + hostname: str = "", + pid: int | None = None, + master_endpoint: str = "", + local_id: int | None = None, + rank: int | None = None, +) -> None: + """ + Initialize rendezvous event object and record its operations. + + Args: + run_id (str): The run id of the rendezvous. + message (str): The message describing the event. + node_state (NodeState): The state of the node (INIT, RUNNING, SUCCEEDED, FAILED). + name (str): Event name. (E.g. Current action being performed). + hostname (str): Hostname of the node. + pid (Optional[int]): The process id of the node. + master_endpoint (str): The master endpoint for the rendezvous store, if known. + local_id (Optional[int]): The local_id of the node, if defined in dynamic_rendezvous.py + rank (Optional[int]): The rank of the node, if known. + Returns: + None + Example: + >>> # See DynamicRendezvousHandler class + >>> def _record( + ... self, + ... message: str, + ... node_state: NodeState = NodeState.RUNNING, + ... rank: Optional[int] = None, + ... ) -> None: + ... construct_and_record_rdzv_event( + ... name=f"{self.__class__.__name__}.{get_method_name()}", + ... run_id=self._settings.run_id, + ... message=message, + ... node_state=node_state, + ... hostname=self._this_node.addr, + ... pid=self._this_node.pid, + ... local_id=self._this_node.local_id, + ... rank=rank, + ... ) + """ + # We don't want to perform an extra computation if not needed. + if isinstance(get_logging_handler("dynamic_rendezvous"), logging.NullHandler): + return + + # Set up parameters. + if not hostname: + hostname = socket.getfqdn() + if not pid: + pid = os.getpid() + + # Determines which file called this function. + callstack = inspect.stack() + filename = "no_file" + if len(callstack) > 1: + stack_depth_1 = callstack[1] + filename = os.path.basename(stack_depth_1.filename) + if not name: + name = stack_depth_1.function + + # Delete the callstack variable. If kept, this can mess with python's + # garbage collector as we are holding on to stack frame information in + # the inspect module. + del callstack + + # Set up error trace if this is an exception + if node_state == NodeState.FAILED: + error_trace = traceback.format_exc() + else: + error_trace = "" + + # Initialize event object + event = RdzvEvent( + name=f"{filename}:{name}", + run_id=run_id, + message=message, + hostname=hostname, + pid=pid, + node_state=node_state, + master_endpoint=master_endpoint, + rank=rank, + local_id=local_id, + error_trace=error_trace, + ) + + # Finally, record the event. + record_rdzv_event(event) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/events/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/events/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c09fd1a11fe2030cb57cbc3f06358cb7b474288 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/events/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/events/__pycache__/api.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/events/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecade519782adde92f158ae60bb5c9e872f7076e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/events/__pycache__/api.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/events/__pycache__/handlers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/events/__pycache__/handlers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63acaefc4ba760cf1a4ca2803b4d025cdd33f971 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/events/__pycache__/handlers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/events/api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/events/api.py new file mode 100644 index 0000000000000000000000000000000000000000..31afe29ff5f597b27b453e9993e1257e3f1f8d2a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/events/api.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Union + + +__all__ = ["EventSource", "Event", "NodeState", "RdzvEvent"] + +EventMetadataValue = Union[str, int, float, bool, None] + + +class EventSource(str, Enum): + """Known identifiers of the event producers.""" + + AGENT = "AGENT" + WORKER = "WORKER" + + +@dataclass +class Event: + """ + The class represents the generic event that occurs during the torchelastic job execution. + + The event can be any kind of meaningful action. + + Args: + name: event name. + source: the event producer, e.g. agent or worker + timestamp: timestamp in milliseconds when event occurred. + metadata: additional data that is associated with the event. + """ + + name: str + source: EventSource + timestamp: int = 0 + metadata: dict[str, EventMetadataValue] = field(default_factory=dict) + + def __str__(self): + return self.serialize() + + @staticmethod + def deserialize(data: Union[str, "Event"]) -> "Event": + if isinstance(data, Event): + return data + if isinstance(data, str): + data_dict = json.loads(data) + data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined] + # pyrefly: ignore [unbound-name] + return Event(**data_dict) + + def serialize(self) -> str: + return json.dumps(asdict(self)) + + +class NodeState(str, Enum): + """The states that a node can be in rendezvous.""" + + INIT = "INIT" + RUNNING = "RUNNING" + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + + +@dataclass +class RdzvEvent: + """ + Dataclass to represent any rendezvous event. + + Args: + name: Event name. (E.g. Current action being performed) + run_id: The run id of the rendezvous + message: The message describing the event + hostname: Hostname of the node + pid: The process id of the node + node_state: The state of the node (INIT, RUNNING, SUCCEEDED, FAILED) + master_endpoint: The master endpoint for the rendezvous store, if known + rank: The rank of the node, if known + local_id: The local_id of the node, if defined in dynamic_rendezvous.py + error_trace: Error stack trace, if this is an error event. + """ + + name: str + run_id: str + message: str + hostname: str + pid: int + node_state: NodeState + master_endpoint: str = "" + rank: int | None = None + local_id: int | None = None + error_trace: str = "" + + def __str__(self): + return self.serialize() + + @staticmethod + def deserialize(data: Union[str, "RdzvEvent"]) -> "RdzvEvent": + if isinstance(data, RdzvEvent): + return data + if isinstance(data, str): + data_dict = json.loads(data) + data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined] + # pyrefly: ignore [unbound-name] + return RdzvEvent(**data_dict) + + def serialize(self) -> str: + return json.dumps(asdict(self)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/events/handlers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/events/handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..30d925353253d5bab4c4780f298e7fa68a4409e5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/events/handlers.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + + +_log_handlers: dict[str, logging.Handler] = { + "console": logging.StreamHandler(), + "dynamic_rendezvous": logging.NullHandler(), + "null": logging.NullHandler(), +} + + +def get_logging_handler(destination: str = "null") -> logging.Handler: + global _log_handlers + return _log_handlers[destination] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/metrics/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b2c2330924879ddbe35629a82d94a9b0c4c9c339 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/metrics/__init__.py @@ -0,0 +1,168 @@ +#!/usr/bin/env/python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Metrics API. + +**Overview**: + +The metrics API in torchelastic is used to publish telemetry metrics. +It is designed to be used by torchelastic's internal modules to +publish metrics for the end user with the goal of increasing visibility +and helping with debugging. However you may use the same API in your +jobs to publish metrics to the same metrics ``sink``. + +A ``metric`` can be thought of as timeseries data +and is uniquely identified by the string-valued tuple +``(metric_group, metric_name)``. + +torchelastic makes no assumptions about what a ``metric_group`` is +and what relationship it has with ``metric_name``. It is totally up +to the user to use these two fields to uniquely identify a metric. + +.. note:: The metric group ``torchelastic`` is reserved by torchelastic for + platform level metrics that it produces. + For instance torchelastic may output the latency (in milliseconds) + of a re-rendezvous operation from the agent as + ``(torchelastic, agent.rendezvous.duration.ms)`` + +A sensible way to use metric groups is to map them to a stage or module +in your job. You may also encode certain high level properties +the job such as the region or stage (dev vs prod). + +**Publish Metrics**: + +Using torchelastic's metrics API is similar to using python's logging +framework. You first have to configure a metrics handler before +trying to add metric data. + +The example below measures the latency for the ``calculate()`` function. + +:: + + import time + import torch.distributed.elastic.metrics as metrics + + # makes all metrics other than the one from "my_module" to go /dev/null + metrics.configure(metrics.NullMetricsHandler()) + metrics.configure(metrics.ConsoleMetricsHandler(), "my_module") + + + def my_method(): + start = time.time() + calculate() + end = time.time() + metrics.put_metric("calculate_latency", int(end - start), "my_module") + +You may also use the torch.distributed.elastic.metrics.prof` decorator +to conveniently and succinctly profile functions + +:: + + # -- in module examples.foobar -- + + import torch.distributed.elastic.metrics as metrics + + metrics.configure(metrics.ConsoleMetricsHandler(), "foobar") + metrics.configure(metrics.ConsoleMetricsHandler(), "Bar") + + + @metrics.prof + def foo(): + pass + + + class Bar: + @metrics.prof + def baz(): + pass + +``@metrics.prof`` will publish the following metrics +:: + + .success - 1 if the function finished successfully + .failure - 1 if the function threw an exception + .duration.ms - function duration in milliseconds + +**Configuring Metrics Handler**: + +`torch.distributed.elastic.metrics.MetricHandler` is responsible for emitting +the added metric values to a particular destination. Metric groups can be +configured with different metric handlers. + +By default torchelastic emits all metrics to ``/dev/null``. +By adding the following configuration metrics, +``torchelastic`` and ``my_app`` metric groups will be printed out to +console. + +:: + + import torch.distributed.elastic.metrics as metrics + + metrics.configure(metrics.ConsoleMetricHandler(), group="torchelastic") + metrics.configure(metrics.ConsoleMetricHandler(), group="my_app") + +**Writing a Custom Metric Handler**: + +If you want your metrics to be emitted to a custom location, implement +the `torch.distributed.elastic.metrics.MetricHandler` interface +and configure your job to use your custom metric handler. + +Below is a toy example that prints the metrics to ``stdout`` + +:: + + import torch.distributed.elastic.metrics as metrics + + + class StdoutMetricHandler(metrics.MetricHandler): + def emit(self, metric_data): + ts = metric_data.timestamp + group = metric_data.group_name + name = metric_data.name + value = metric_data.value + print(f"[{ts}][{group}]: {name}={value}") + + + metrics.configure(StdoutMetricHandler(), group="my_app") + +Now all metrics in the group ``my_app`` will be printed to stdout as: + +:: + + [1574213883.4182858][my_app]: my_metric= + [1574213940.5237644][my_app]: my_metric= + +""" + +from typing import Optional + +from .api import ( # noqa: F401 + configure, + ConsoleMetricHandler, + get_elapsed_time_ms, + getStream, + MetricData, + MetricHandler, + MetricsConfig, + NullMetricHandler, + prof, + profile, + publish_metric, + put_metric, +) + + +def initialize_metrics(cfg: MetricsConfig | None = None): + pass + + +try: + from torch.distributed.elastic.metrics.static_init import * # type: ignore[import] # noqa: F401 F403 +except ModuleNotFoundError: + pass diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/metrics/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/metrics/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..215f1b4a4b4919ea7340c24975315a8742082af9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/metrics/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/metrics/__pycache__/api.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/metrics/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f68f313f6f9a550770557368fe84cdb99826e1f4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/metrics/__pycache__/api.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/metrics/api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/metrics/api.py new file mode 100644 index 0000000000000000000000000000000000000000..102049481538d15a7fe995a8602ba45d6842303e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/metrics/api.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import abc +import time +from collections import namedtuple +from functools import wraps +from typing_extensions import deprecated + + +__all__ = [ + "MetricsConfig", + "MetricHandler", + "ConsoleMetricHandler", + "NullMetricHandler", + "MetricStream", + "configure", + "getStream", + "prof", + "profile", + "put_metric", + "publish_metric", + "get_elapsed_time_ms", + "MetricData", +] + +MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"]) + + +class MetricsConfig: + __slots__ = ["params"] + + def __init__(self, params: dict[str, str] | None = None): + self.params = params + if self.params is None: + self.params = {} + + +class MetricHandler(abc.ABC): + @abc.abstractmethod + def emit(self, metric_data: MetricData): + pass + + +class ConsoleMetricHandler(MetricHandler): + def emit(self, metric_data: MetricData): + print( + f"[{metric_data.timestamp}][{metric_data.group_name}]: {metric_data.name}={metric_data.value}" + ) + + +class NullMetricHandler(MetricHandler): + def emit(self, metric_data: MetricData): + pass + + +class MetricStream: + def __init__(self, group_name: str, handler: MetricHandler): + self.group_name = group_name + self.handler = handler + + def add_value(self, metric_name: str, metric_value: int): + self.handler.emit( + MetricData(time.time(), self.group_name, metric_name, metric_value) + ) + + +_metrics_map: dict[str, MetricHandler] = {} +_default_metrics_handler: MetricHandler = NullMetricHandler() + + +# pyre-fixme[9]: group has type `str`; used as `None`. +def configure(handler: MetricHandler, group: str | None = None): + if group is None: + global _default_metrics_handler + # pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used + # as `MetricHandler`. + _default_metrics_handler = handler + else: + _metrics_map[group] = handler + + +def getStream(group: str): + handler = _metrics_map.get(group, _default_metrics_handler) + return MetricStream(group, handler) + + +def _get_metric_name(fn): + qualname = fn.__qualname__ + split = qualname.split(".") + if len(split) == 1: + module = fn.__module__ + if module: + return module.split(".")[-1] + "." + split[0] + else: + return split[0] + else: + return qualname + + +def prof(fn=None, group: str = "torchelastic"): + r""" + @profile decorator publishes duration.ms, count, success, failure metrics for the function that it decorates. + + The metric name defaults to the qualified name (``class_name.def_name``) of the function. + If the function does not belong to a class, it uses the leaf module name instead. + + Usage + + :: + + @metrics.prof + def x(): + pass + + + @metrics.prof(group="agent") + def y(): + pass + """ + + def wrap(f): + @wraps(f) + def wrapper(*args, **kwargs): + key = _get_metric_name(f) + try: + start = time.time() + result = f(*args, **kwargs) + put_metric(f"{key}.success", 1, group) + except Exception: + put_metric(f"{key}.failure", 1, group) + raise + finally: + put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group) # type: ignore[possibly-undefined] + return result + + return wrapper + + if fn: + return wrap(fn) + else: + return wrap + + +@deprecated("Deprecated, use `@prof` instead", category=FutureWarning) +def profile(group=None): + """ + @profile decorator adds latency and success/failure metrics to any given function. + + Usage + + :: + + @metrics.profile("my_metric_group") + def some_function(): + """ + + def wrap(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + start_time = time.time() + result = func(*args, **kwargs) + # pyrefly: ignore [bad-argument-type] + publish_metric(group, f"{func.__name__}.success", 1) + except Exception: + # pyrefly: ignore [bad-argument-type] + publish_metric(group, f"{func.__name__}.failure", 1) + raise + finally: + publish_metric( + # pyrefly: ignore [bad-argument-type] + group, + f"{func.__name__}.duration.ms", + get_elapsed_time_ms(start_time), # type: ignore[possibly-undefined] + ) + return result + + return wrapper + + return wrap + + +def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchelastic"): + """ + Publish a metric data point. + + Usage + + :: + + put_metric("metric_name", 1) + put_metric("metric_name", 1, "metric_group_name") + """ + getStream(metric_group).add_value(metric_name, metric_value) + + +@deprecated( + "Deprecated, use `put_metric(metric_group)(metric_name, metric_value)` instead", + category=FutureWarning, +) +def publish_metric(metric_group: str, metric_name: str, metric_value: int): + metric_stream = getStream(metric_group) + metric_stream.add_value(metric_name, metric_value) + + +def get_elapsed_time_ms(start_time_in_seconds: float): + """Return the elapsed time in millis from the given start time.""" + end_time = time.time() + return int((end_time - start_time_in_seconds) * 1000) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..60b7cd32fd2531a3e3b04416b75a29767ba835fa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__init__.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Library that launches and manages ``n`` copies of worker subprocesses either specified by a function or a binary. + +For functions, it uses ``torch.multiprocessing`` (and therefore python +``multiprocessing``) to spawn/fork worker processes. For binaries it uses python +``subprocessing.Popen`` to create worker processes. + + +Usage 1: Launching two trainers as a function + +:: + + from torch.distributed.elastic.multiprocessing import Std, start_processes + + + def trainer(a, b, c): + pass # train + + + # runs two trainers + # LOCAL_RANK=0 trainer(1,2,3) + # LOCAL_RANK=1 trainer(4,5,6) + ctx = start_processes( + name="trainer", + entrypoint=trainer, + args={0: (1, 2, 3), 1: (4, 5, 6)}, + envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}}, + log_dir="/tmp/foobar", + redirects=Std.ALL, # write all worker stdout/stderr to a log file + tee={0: Std.ERR}, # tee only local rank 0's stderr to console + ) + + # waits for all copies of trainer to finish + ctx.wait() + +Usage 2: Launching 2 echo workers as a binary + +:: + + # same as invoking + # echo hello + # echo world > stdout.log + ctx = start_processes( + name="echo" + entrypoint="echo", + log_dir="/tmp/foobar", + args={0: "hello", 1: "world"}, + redirects={1: Std.OUT}, + ) + +Just like ``torch.multiprocessing``, the return value of the function +:func:`start_processes` is a process context (:class:`api.PContext`). If a function +was launched, a :class:`api.MultiprocessContext` is returned and if a binary +was launched a :class:`api.SubprocessContext` is returned. Both are specific +implementations of the parent :class:`api.PContext` class. +""" + +from collections.abc import Callable +from typing import Optional, Union + +from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401 + _validate_full_rank, + DefaultLogsSpecs, + LogsDest, + LogsSpecs, + MultiprocessContext, + PContext, + ProcessFailure, + RunProcsResult, + SignalException, + Std, + SubprocessContext, + to_map, +) +from torch.distributed.elastic.utils.logging import get_logger +from torch.numa.binding import NumaOptions + + +__all__ = [ + "start_processes", + "MultiprocessContext", + "PContext", + "ProcessFailure", + "RunProcsResult", + "SignalException", + "Std", + "LogsDest", + "LogsSpecs", + "DefaultLogsSpecs", + "SubprocessContext", + "to_map", +] + + +def start_processes( + name: str, + entrypoint: Callable | str, + args: dict[int, tuple], + envs: dict[int, dict[str, str]], + logs_specs: LogsSpecs, + log_line_prefixes: dict[int, str] | None = None, + start_method: str = "spawn", + numa_options: NumaOptions | None = None, + duplicate_stdout_filters: list[str] | None = None, + duplicate_stderr_filters: list[str] | None = None, +) -> PContext: + """ + Start ``n`` copies of ``entrypoint`` processes with the provided options. + + ``entrypoint`` is either a ``Callable`` (function) or a ``str`` (binary). + The number of copies is determined by the number of entries for ``args`` and + ``envs`` arguments, which need to have the same key set. + + ``args`` and ``env`` parameters are the arguments and environment variables + to pass down to the entrypoint mapped by the replica index (local rank). + All local ranks must be accounted for. + That is, the keyset should be ``{0,1,...,(nprocs-1)}``. + + .. note:: When the ``entrypoint`` is a binary (``str``), ``args`` can only be strings. + If any other type is given, then it is casted to a string representation + (e.g. ``str(arg1)``). Furthermore, a binary failure will only write + an ``error.json`` error file if the main function is annotated with + ``torch.distributed.elastic.multiprocessing.errors.record``. For function launches, + this is done by default and there is no need to manually annotate + with the ``@record`` annotation. + + Inside ``logs_specs``, ``redirects`` and ``tee`` are bitmasks specifying which std + stream(s) to redirect to a log file in the ``log_dir``. Valid mask values are defined + in ``Std``. To redirect/tee only certain local ranks, pass ``redirects`` as a map + with the key as the local rank to specify the redirect behavior for. + Any missing local ranks will default to ``Std.NONE``. + + ``duplicate_stdout_filters`` and ``duplicate_stderr_filters``, if non-empty, + duplicate stdouts and stderrs respectively specified in ``logs_specs``'s ``tee`` + to a file containing only lines that match _any_ of the filter strings. The log + file is aggregated across all ranks selected by ``tee``. + + ``tee`` acts like the unix "tee" command in that it redirects + prints to console. + To avoid worker stdout/stderr from printing to console, use the ``redirects`` parameter. + + For each process, the ``log_dir`` will contain: + + #. ``{local_rank}/error.json``: if the process failed, a file with the error info + #. ``{local_rank}/stdout.log``: if ``redirect & STDOUT == STDOUT`` + #. ``{local_rank}/stderr.log``: if ``redirect & STDERR == STDERR`` + #. ``filtered_stdout.log``: if ``duplicate_stdout_filters`` is non-empty + #. ``filtered_stderr.log``: if ``duplicate_stderr_filters`` is non-empty + + .. note:: It is expected that the ``log_dir`` exists, is empty, and is a directory. + + Example: + :: + + log_dir = "/tmp/test" + + # ok; two copies of foo: foo("bar0"), foo("bar1") + start_processes( + name="trainer", + entrypoint=foo, + args:{0:("bar0",), 1:("bar1",), + envs:{0:{}, 1:{}}, + log_dir=log_dir + ) + + # invalid; envs missing for local rank 1 + start_processes( + name="trainer", + entrypoint=foo, + args:{0:("bar0",), 1:("bar1",), + envs:{0:{}}, + log_dir=log_dir + ) + + # ok; two copies of /usr/bin/touch: touch file1, touch file2 + start_processes( + name="trainer", + entrypoint="/usr/bin/touch", + args:{0:("file1",), 1:("file2",), + envs:{0:{}, 1:{}}, + log_dir=log_dir + ) + + # caution; arguments casted to string, runs: + # echo "1" "2" "3" and echo "[1, 2, 3]" + start_processes( + name="trainer", + entrypoint="/usr/bin/echo", + args:{0:(1,2,3), 1:([1,2,3],), + envs:{0:{}, 1:{}}, + log_dir=log_dir + ) + + Args: + name: a human readable short name that describes what the processes are + (used as header when tee'ing stdout/stderr outputs) + entrypoint: either a ``Callable`` (function) or ``cmd`` (binary) + args: arguments to each replica + envs: env vars to each replica + log_dir: directory used to write log files + start_method: multiprocessing start method (spawn, fork, forkserver) + ignored for binaries + logs_specs: defines ``log_dir``, ``redirects``, and ``tee``. + inside ``logs_specs``: + - redirects: which std streams to redirect to a log file + - tee: which std streams to redirect + print to console + local_ranks_filter: which ranks' logs to print to console + duplicate_stdout_filters: filters for the duplicated stdout logs + duplicate_stderr_filters: filters for the duplicated stderr logs + + """ + + nprocs = len(args) + _validate_full_rank(args, nprocs, "args") + _validate_full_rank(envs, nprocs, "envs") + + context: PContext + if isinstance(entrypoint, str): + context = SubprocessContext( + name=name, + entrypoint=entrypoint, + args=args, + envs=envs, + duplicate_stdout_filters=duplicate_stdout_filters, + duplicate_stderr_filters=duplicate_stderr_filters, + logs_specs=logs_specs, + log_line_prefixes=log_line_prefixes, + numa_options=numa_options, + ) + else: + context = MultiprocessContext( + name=name, + entrypoint=entrypoint, + args=args, + envs=envs, + duplicate_stdout_filters=duplicate_stdout_filters, + duplicate_stderr_filters=duplicate_stderr_filters, + log_line_prefixes=log_line_prefixes, + start_method=start_method, + logs_specs=logs_specs, + numa_options=numa_options, + ) + + try: + context.start() + return context + except Exception: + context.close() + raise diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5b5f297c65e4f0b77f2492ae965244220be8fbd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/api.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f265be83e23d9f375c557db02ae4b2d8616fa275 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/api.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ec77d7b0e7cb9ec3ffd7b4d3903b5fa42b8a23c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/tail_log.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/tail_log.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e89854e7a3a746983f35cc5f3ed14878fd2cb76 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/tail_log.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py new file mode 100644 index 0000000000000000000000000000000000000000..45351c380ca0db821149edd174cb588192619be6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py @@ -0,0 +1,1036 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import abc +import logging +import os +import re +import shutil +import signal +import subprocess +import sys +import tempfile +import threading +import time +from abc import ABC, abstractmethod +from collections.abc import Callable +from contextlib import nullcontext +from dataclasses import dataclass, field +from enum import IntFlag +from multiprocessing import synchronize +from types import FrameType +from typing import Any, TextIO, Union + +import torch.multiprocessing as mp +from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record +from torch.distributed.elastic.multiprocessing.redirects import ( + redirect_stderr, + redirect_stdout, +) +from torch.distributed.elastic.multiprocessing.subprocess_handler import ( + get_subprocess_handler, + SubprocessHandler, +) +from torch.distributed.elastic.multiprocessing.tail_log import TailLog +from torch.numa.binding import maybe_wrap_with_numa_binding, NumaOptions + + +IS_WINDOWS = sys.platform == "win32" +IS_MACOS = sys.platform == "darwin" + + +logger = logging.getLogger(__name__) + +__all__ = [ + "DefaultLogsSpecs", + "SignalException", + "Std", + "to_map", + "RunProcsResult", + "PContext", + "get_std_cm", + "MultiprocessContext", + "SubprocessContext", + "LogsDest", + "LogsSpecs", +] + + +class SignalException(Exception): + """ + Exception is raised inside the torchelastic agent process by the termination handler + if the death signal got received by the process. + """ + + def __init__(self, msg: str, sigval: signal.Signals) -> None: + super().__init__(msg) + self.sigval = sigval + + +def _terminate_process_handler(signum: int, frame: FrameType | None) -> None: + """Termination handler that raises exceptions on the main process. + + When the process receives death signal(SIGTERM, SIGINT), this termination handler will + be invoked. It raises the ``SignalException`` exception that should be processed by the + user code. Python does not terminate process after the termination handler is finished, + so the exception should not be silently ignored, otherwise the process will never + be terminated. + """ + sigval = signal.Signals(signum) + raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval) + + +def _get_kill_signal() -> signal.Signals: + """Get the kill signal. SIGKILL for unix, CTRL_C_EVENT for windows.""" + if IS_WINDOWS: + return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 + else: + return signal.SIGKILL + + +def _get_default_signal() -> signal.Signals: + """Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows.""" + if IS_WINDOWS: + return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 + else: + return signal.SIGTERM + + +def _validate_full_rank(d: dict[int, Any], nprocs: int, what: str): + actual_keys = set(d.keys()) + expected_keys = set(range(nprocs)) + + if actual_keys != expected_keys: + raise RuntimeError( + f"{what}, local rank mapping mismatch," + f" expected: {expected_keys}, actual: {actual_keys}" + ) + + +_MAPPING_REGEX = r"^(\d:[0123],)*(\d:[0123])$" +_VALUE_REGEX = r"^[0123]$" + + +class Std(IntFlag): + NONE = 0 + OUT = 1 + ERR = 2 + ALL = OUT | ERR + + @classmethod + def from_str(cls, vm: str) -> Union["Std", dict[int, "Std"]]: + """ + Example: + :: + + from_str("0") -> Std.NONE + from_str("1") -> Std.OUT + from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR} + + Any other input raises an exception + """ + + def to_std(v: str) -> Std: # type: ignore[return] + s = Std(int(v)) + if s in Std: + return s + # return None -> should NEVER reach here since we regex check input + + if re.match(_VALUE_REGEX, vm): # vm is a number (e.g. 0) + return to_std(vm) + elif re.match(_MAPPING_REGEX, vm): # vm is a mapping (e.g. 0:1,1:2) + d: dict[int, Std] = {} + for m in vm.split(","): + i, v = m.split(":") + d[int(i)] = to_std(v) + return d + else: + raise ValueError( + f"{vm} does not match: <{_VALUE_REGEX}> or <{_MAPPING_REGEX}>" + ) + + +def to_map(val_or_map: Std | dict[int, Std], local_world_size: int) -> dict[int, Std]: + """ + Certain APIs take redirect settings either as a single value (e.g. apply to all + local ranks) or as an explicit user-provided mapping. This method is a convenience + method that converts a value or mapping into a mapping. + + Example: + :: + + to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT} + to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT} + to_map( + {0: Std.OUT, 1: Std.OUT}, local_world_size=2 + ) # returns: {0: Std.OUT, 1: Std.OUT} + """ + if isinstance(val_or_map, Std): + return dict.fromkeys(range(local_world_size), val_or_map) + else: + map = {} + for i in range(local_world_size): + map[i] = val_or_map.get(i, Std.NONE) + return map + + +@dataclass +class LogsDest: + """ + For each log type, holds mapping of local rank ids to file paths. + """ + + stdouts: dict[int, str] = field(default_factory=dict) + stderrs: dict[int, str] = field(default_factory=dict) + tee_stdouts: dict[int, str] = field(default_factory=dict) + tee_stderrs: dict[int, str] = field(default_factory=dict) + error_files: dict[int, str] = field(default_factory=dict) + filtered_stdout: str = field(default_factory=str) + filtered_stderr: str = field(default_factory=str) + + +class LogsSpecs(ABC): + """ + Defines logs processing and redirection for each worker process. + + Args: + log_dir: + Base directory where logs will be written. + redirects: + Streams to redirect to files. Pass a single ``Std`` + enum to redirect for all workers, or a mapping keyed + by local_rank to selectively redirect. + tee: + Streams to duplicate to stdout/stderr. + Pass a single ``Std`` enum to duplicate streams for all workers, + or a mapping keyed by local_rank to selectively duplicate. + """ + + def __init__( + self, + log_dir: str | None = None, + redirects: Std | dict[int, Std] = Std.NONE, + tee: Std | dict[int, Std] = Std.NONE, + local_ranks_filter: set[int] | None = None, + ) -> None: + self._root_log_dir = log_dir + self._redirects = redirects + self._tee = tee + self._local_ranks_filter = local_ranks_filter + + @abstractmethod + def reify( + self, + envs: dict[int, dict[str, str]], + ) -> LogsDest: + """ + Given the environment variables, builds destination of log files for each of the local ranks. + + Envs parameter contains env variables dict for each of the local ranks, where entries are defined in: + :func:`~torchelastic.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent._start_workers`. + """ + + @property + @abstractmethod + def root_log_dir(self) -> str: + pass + + +class DefaultLogsSpecs(LogsSpecs): + """ + Default LogsSpecs implementation: + + - `log_dir` will be created if it doesn't exist + - Generates nested folders for each attempt and rank. + """ + + def __init__( + self, + log_dir: str | None = None, + redirects: Std | dict[int, Std] = Std.NONE, + tee: Std | dict[int, Std] = Std.NONE, + local_ranks_filter: set[int] | None = None, + ) -> None: + if log_dir != os.devnull: + if not log_dir: + log_dir = tempfile.mkdtemp(prefix="torchelastic_") + elif not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + else: + if os.path.isfile(log_dir): + raise NotADirectoryError(f"log_dir: {log_dir} is a file") + super().__init__(log_dir, redirects, tee, local_ranks_filter) + # initialized only once + self._run_log_dir = None + + @property + def root_log_dir(self) -> str: + return str(self._root_log_dir) + + def _make_log_dir(self, log_dir: str | None, rdzv_run_id: str): + base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_") + os.makedirs(base_log_dir, exist_ok=True) + dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_", dir=base_log_dir) + logger.info("log directory set to: %s", dir) + return dir + + def reify( + self, + envs: dict[int, dict[str, str]], + ) -> LogsDest: + """ + Uses following scheme to build log destination paths: + + - `//attempt_//stdout.log` + - `//attempt_//stderr.log` + - `//attempt_//error.json` + - `//attempt_/filtered_stdout.log` + - `//attempt_/filtered_stderr.log` + """ + nprocs = len(envs) + global_env = {} # use only to query properties that are not dependent on a rank + if nprocs > 0: + global_env = envs[0] + else: + logger.warning( + "Empty envs map provided when defining logging destinations." + ) + # Keys are always defined, but values can be missing in unit tests + run_id = global_env.get("TORCHELASTIC_RUN_ID", "test_run_id") + restart_count = global_env.get("TORCHELASTIC_RESTART_COUNT", "0") + + attempt_log_dir: str = "" + if self._root_log_dir != os.devnull: + if not self._run_log_dir: + self._run_log_dir = self._make_log_dir(self._root_log_dir, run_id) + + attempt_log_dir = os.path.join( + self._run_log_dir, f"attempt_{restart_count}" + ) # type: ignore[call-overload] + shutil.rmtree(attempt_log_dir, ignore_errors=True) + os.makedirs(attempt_log_dir) + + if self._root_log_dir == os.devnull: + attempt_log_dir = os.devnull + + # create subdirs for each local rank in the logs_dir + # logs_dir + # |- 0 + # |- error.json + # |- stdout.log + # |- stderr.log + # |- ... + # |- (nprocs-1) + redirs = to_map(self._redirects, nprocs) + ts = to_map(self._tee, nprocs) + + # to tee stdout/stderr we first redirect into a file + # then tail -f stdout.log/stderr.log so add tee settings to redirects + for local_rank, tee_std in ts.items(): + redirect_std = redirs[local_rank] + redirs[local_rank] = redirect_std | tee_std + + SYS_STREAM = "" # special case to indicate to output to console + stdouts = dict.fromkeys(range(nprocs), SYS_STREAM) + stderrs = dict.fromkeys(range(nprocs), SYS_STREAM) + tee_stdouts: dict[int, str] = {} + tee_stderrs: dict[int, str] = {} + error_files = {} + + for local_rank in range(nprocs): + if attempt_log_dir == os.devnull: + tee_stdouts[local_rank] = os.devnull + tee_stderrs[local_rank] = os.devnull + error_files[local_rank] = os.devnull + envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = "" + else: + clogdir = os.path.join(attempt_log_dir, str(local_rank)) + os.mkdir(clogdir) + + rd = redirs[local_rank] + if (rd & Std.OUT) == Std.OUT: + stdouts[local_rank] = os.path.join(clogdir, "stdout.log") + if (rd & Std.ERR) == Std.ERR: + stderrs[local_rank] = os.path.join(clogdir, "stderr.log") + + t = ts[local_rank] + if t & Std.OUT == Std.OUT: + tee_stdouts[local_rank] = stdouts[local_rank] + if t & Std.ERR == Std.ERR: + tee_stderrs[local_rank] = stderrs[local_rank] + + if ( + self._local_ranks_filter + and local_rank not in self._local_ranks_filter + ): + # If stream is tee'd, only write to file, but don't tail + if local_rank in tee_stdouts: + tee_stdouts.pop(local_rank, None) + if local_rank in tee_stderrs: + tee_stderrs.pop(local_rank, None) + + # If stream is not redirected, don't print + if stdouts[local_rank] == SYS_STREAM: + stdouts[local_rank] = os.devnull + if stderrs[local_rank] == SYS_STREAM: + stderrs[local_rank] = os.devnull + + error_file = os.path.join(clogdir, "error.json") + error_files[local_rank] = error_file + logger.info( + "Setting worker%s reply file to: %s", local_rank, error_file + ) + envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file + + return LogsDest( + stdouts, + stderrs, + tee_stdouts, + tee_stderrs, + error_files, + os.path.join(attempt_log_dir, "filtered_stdout.log"), + os.path.join(attempt_log_dir, "filtered_stderr.log"), + ) + + def __repr__(self) -> str: + return ( + f"DefaultLogsSpecs(root_log_dir={self._root_log_dir}, redirects={self._redirects}, " + f"tee={self._tee}, local_ranks_filter={self._local_ranks_filter})" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DefaultLogsSpecs): + return False + + return ( + self._root_log_dir == other._root_log_dir + and self._redirects == other._redirects + and self._tee == other._tee + and self._local_ranks_filter == other._local_ranks_filter + ) + + +@dataclass +class RunProcsResult: + """ + Results of a completed run of processes started with ``start_processes()``. Returned by ``PContext``. + + Note the following: + + 1. All fields are mapped by local rank + 2. ``return_values`` - only populated for functions (not the binaries). + 3. ``stdouts`` - path to stdout.log (empty string if no redirect) + 4. ``stderrs`` - path to stderr.log (empty string if no redirect) + + """ + + return_values: dict[int, Any] = field(default_factory=dict) + failures: dict[int, ProcessFailure] = field(default_factory=dict) + stdouts: dict[int, str] = field(default_factory=dict) + stderrs: dict[int, str] = field(default_factory=dict) + + def is_failed(self) -> bool: + return len(self.failures) > 0 + + +class PContext(abc.ABC): + """ + The base class that standardizes operations over a set of processes that are launched via different mechanisms. + + The name ``PContext`` is intentional to disambiguate with ``torch.multiprocessing.ProcessContext``. + + .. warning:: stdouts and stderrs should ALWAYS be a superset of + tee_stdouts and tee_stderrs (respectively) this is b/c + tee is implemented as a redirect + tail -f + + Args: + duplicate_stdout_filters: + If non-empty, duplicates stdouts specified in ``logs_specs``'s ``tee`` + to a file containing only lines that match _any_ of the filter strings. + The log file is aggregated across all ranks selected by ``tee``. + duplicate_stderr_filters: + If non-empty, duplicates stderrs specified in ``logs_specs``'s ``tee`` + to a file containing only lines that match _any_ of the filter strings. + The log file is aggregated across all ranks selected by ``tee``. + """ + + def __init__( + self, + name: str, + entrypoint: Callable | str, + args: dict[int, tuple], + envs: dict[int, dict[str, str]], + logs_specs: LogsSpecs, + log_line_prefixes: dict[int, str] | None = None, + duplicate_stdout_filters: list[str] | None = None, + duplicate_stderr_filters: list[str] | None = None, + ): + self.name = name + # validate that all mappings have the same number of keys and + # all local ranks are accounted for + nprocs = len(args) + + # TODO log_line_prefixes can be expanded too + logs_dest = logs_specs.reify(envs) + + _validate_full_rank(logs_dest.stdouts, nprocs, "stdouts") + _validate_full_rank(logs_dest.stderrs, nprocs, "stderrs") + + self.entrypoint = entrypoint + self.args = args + self.envs = envs + self.stdouts = logs_dest.stdouts + self.stderrs = logs_dest.stderrs + self.error_files = logs_dest.error_files + self.nprocs = nprocs + self.filtered_stdout: TextIO | None = None + self.filtered_stderr: TextIO | None = None + + self._tail_logs = [ + TailLog(name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes), + TailLog(name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes), + ] + + if duplicate_stdout_filters: + self.filtered_stdout = open( # noqa: SIM115 + logs_dest.filtered_stdout, mode="w", errors="replace", buffering=1 + ) + self._tail_logs.append( + TailLog( + name, + logs_dest.tee_stdouts, + self.filtered_stdout, + log_line_prefixes, + log_line_filter=lambda line: any( + needle in line for needle in duplicate_stdout_filters + ), + ) + ) + + if duplicate_stderr_filters: + self.filtered_stderr = open( # noqa: SIM115 + logs_dest.filtered_stderr, mode="w", errors="replace", buffering=1 + ) + self._tail_logs.append( + TailLog( + name, + logs_dest.tee_stderrs, + self.filtered_stderr, + log_line_prefixes, + log_line_filter=lambda line: any( + needle in line for needle in duplicate_stderr_filters + ), + ) + ) + + def start(self) -> None: + """Start processes using parameters defined in the constructor.""" + if threading.current_thread() is threading.main_thread(): + # Register signal handlers for the signals specified in the environment variable + signals_to_handle = os.environ.get( + "TORCHELASTIC_SIGNALS_TO_HANDLE", "SIGTERM,SIGINT,SIGHUP,SIGQUIT" + ) + signal_list = signals_to_handle.split(",") + + for sig_name in signal_list: + try: + sig = getattr(signal, sig_name.strip()) + signal.signal(sig, _terminate_process_handler) + logger.info("Registered signal handler for %s", sig_name) + except (AttributeError, ValueError): + logger.warning( + "Failed to register signal handler for %s", + sig_name, + exc_info=True, + ) + except RuntimeError: + if IS_WINDOWS and sig_name.strip() in [ + "SIGHUP", + "SIGQUIT", + "SIGUSR1", + "SIGUSR2", + ]: + logger.info( + "Signal %s is not supported on Windows, skipping", sig_name + ) + else: + logger.warning( + "Failed to register signal handler for %s", + sig_name, + exc_info=True, + ) + else: + logger.warning( + "Failed to register signal handlers since torchelastic is running on a child thread. " + "This could lead to orphaned worker processes if the torchrun is terminated." + ) + self._start() + for tail_log in self._tail_logs: + tail_log.start() + + @abc.abstractmethod + def _start(self) -> None: + """Start processes using strategy defined in a particular context.""" + raise NotImplementedError + + @abc.abstractmethod + def _poll(self) -> RunProcsResult | None: + """ + Poll the run status of the processes running under this context. + This method follows an "all-or-nothing" policy and returns + a ``RunProcessResults`` object if either all processes complete + successfully or any process fails. Returns ``None`` if + all processes are still running. + """ + raise NotImplementedError + + def wait(self, timeout: float = -1, period: float = 1) -> RunProcsResult | None: + """ + Wait for the specified ``timeout`` seconds, polling every ``period`` seconds + for the processes to be done. Returns ``None`` if the processes are still running + on timeout expiry. Negative timeout values are interpreted as "wait-forever". + A timeout value of zero simply queries the status of the processes (e.g. equivalent + to a poll). + + .. note:: + Multiprocessing library registers SIGTERM and SIGINT signal handlers that raise + ``SignalException`` when the signals received. It is up to the consumer of the code + to properly handle the exception. It is important not to swallow the exception otherwise + the process would not terminate. Example of the typical workflow can be: + + .. code-block:: python + pc = start_processes(...) + try: + pc.wait(1) + .. do some other work + except SignalException as e: + pc.shutdown(e.sigval, timeout=30) + + If SIGTERM or SIGINT occurs, the code above will try to shutdown child processes by propagating + received signal. If child processes will not terminate in the timeout time, the process will send + the SIGKILL. + """ + if timeout == 0: + return self._poll() + + if timeout < 0: + timeout = sys.maxsize + + expiry = time.time() + timeout + while time.time() < expiry: + pr = self._poll() + if pr: + return pr + time.sleep(period) + + return None + + @abc.abstractmethod + def pids(self) -> dict[int, int]: + """Return pids of processes mapped by their respective local_ranks.""" + raise NotImplementedError + + @abc.abstractmethod + def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: + r""" + Terminates all processes managed by this context and cleans up any + meta resources (e.g. redirect, error_file files). + """ + raise NotImplementedError + + def close(self, death_sig: signal.Signals | None = None, timeout: int = 30) -> None: + r""" + Terminates all processes managed by this context and cleans up any + meta resources (e.g. redirect, error_file files). + + Args: + death_sig: Death signal to terminate processes. + timeout: Time to wait for processes to finish, if process is + still alive after this time, it will be terminated via SIGKILL. + """ + if not death_sig: + death_sig = _get_default_signal() + self._close(death_sig=death_sig, timeout=timeout) + for tail_log in self._tail_logs: + tail_log.stop() + if self.filtered_stdout: + self.filtered_stdout.close() + if self.filtered_stderr: + self.filtered_stderr.close() + + +def get_std_cm(std_rd: str, redirect_fn): + if IS_WINDOWS or IS_MACOS or not std_rd: + return nullcontext() + else: + return redirect_fn(std_rd) + + +def _wrap( + local_rank: int, + fn: Callable, + args: dict[int, tuple], + envs: dict[int, dict[str, str]], + stdout_redirects: dict[int, str], # redirect file for stdout (to console if None) + stderr_redirects: dict[int, str], # redirect file for stderr (to console if None) + ret_vals: dict[int, mp.SimpleQueue], + queue_finished_reading_event: synchronize.Event, + numa_options: NumaOptions | None, +) -> None: + # get the per-rank params up front so we fail fast if no mapping is found + args_ = args[local_rank] + env_ = envs[local_rank] + ret_val_ = ret_vals[local_rank] + + stdout_rd = stdout_redirects[local_rank] + stderr_rd = stderr_redirects[local_rank] + + stdout_cm = get_std_cm(stdout_rd, redirect_stdout) + stderr_cm = get_std_cm(stderr_rd, redirect_stderr) + + for k, v in env_.items(): + os.environ[k] = v + + with stdout_cm, stderr_cm: + fn = maybe_wrap_with_numa_binding( + fn, gpu_index=local_rank, numa_options=numa_options + ) + ret = record(fn)(*args_) + ret_val_.put(ret) + queue_finished_reading_event.wait() + + +class MultiprocessContext(PContext): + """``PContext`` holding worker processes invoked as a function.""" + + def __init__( + self, + name: str, + entrypoint: Callable, + args: dict[int, tuple], + envs: dict[int, dict[str, str]], + start_method: str, + logs_specs: LogsSpecs, + log_line_prefixes: dict[int, str] | None = None, + numa_options: NumaOptions | None = None, + duplicate_stdout_filters: list[str] | None = None, + duplicate_stderr_filters: list[str] | None = None, + ): + super().__init__( + name, + entrypoint, + args, + envs, + logs_specs, + log_line_prefixes, + duplicate_stdout_filters, + duplicate_stderr_filters, + ) + + self.start_method = start_method + # each ret_val queue will always contain a single element. + self._ret_vals = { + local_rank: mp.get_context(self.start_method).SimpleQueue() + for local_rank in range(self.nprocs) + } + + # see comments in ``join()`` for what this is + self._return_values: dict[int, Any] = {} + self._pc: mp.ProcessContext | None = None + # Note: set method should ONLY be invoked for the use case when all processes finished + # successfully. If any process died on event.wait() calling set() method will deadlock. + self._worker_finished_event = mp.get_context(self.start_method).Event() + + self._numa_options: NumaOptions | None = numa_options + + def _start(self): + if self._pc: + raise ValueError( + "The process context already initialized." + " Most likely the start method got called twice." + ) + self._pc = mp.start_processes( + fn=_wrap, + args=( + self.entrypoint, + self.args, + self.envs, + self.stdouts, + self.stderrs, + self._ret_vals, + self._worker_finished_event, + self._numa_options, + ), + nprocs=self.nprocs, + join=False, + daemon=False, + start_method=self.start_method, + ) + + def _is_done(self) -> bool: + return len(self._return_values) == self.nprocs + + def _poll(self) -> RunProcsResult | None: + assert self._pc is not None # assertion for mypy type checker + + try: + # torch.mp.ProcessContext Throws an Exception if some/all of + # worker processes failed + # timeout < 0 checks worker status and return immediately + # Join will never return success since we use synchronize.Event to wait + # for all processes to finish. + self._pc.join(-1) + + # IMPORTANT: we use multiprocessing.Queue to carry worker return values + # back to the parent, the worker process will wait before terminating + # until all the buffered items are fed by the feeder thread to the underlying + # pipe. Hence to prevent deadlocks on large return values, + # we opportunistically try queue.get on each join call + # See: https://docs.python.org/2/library/multiprocessing.html#all-platforms + for local_rank in range(self.nprocs): + return_queue = self._ret_vals[local_rank] + if not return_queue.empty(): + # save the return values temporarily into a member var + self._return_values[local_rank] = return_queue.get() + + if self._is_done(): + # we should ALWAYS have ALL the return values when all the processes are done + self._worker_finished_event.set() + + # At this point workers finished running the user function + # But the child process might still have not exited. Wait for them. + # pc.join() blocks [forever] until "a" proc exits. Loop until all of them exits. + while not self._pc.join(): + logger.debug( + "entrypoint fn finished, waiting for all child procs to exit..." + ) + + _validate_full_rank( + self._return_values, self.nprocs, "return_value queue" + ) + self.close() + return RunProcsResult( + return_values=self._return_values, + stdouts=self.stdouts, + stderrs=self.stderrs, + ) + else: + return None + except (mp.ProcessRaisedException, mp.ProcessExitedException) as e: + failed_local_rank = e.error_index + + # entrypoint for MultiprocessContext will always be a Callable + fn_name = self.entrypoint.__qualname__ # type: ignore[union-attr] + failed_proc = self._pc.processes[failed_local_rank] + error_filepath = self.error_files[failed_local_rank] + + logger.exception( + "failed (exitcode: %s)" + " local_rank: %s (pid: %s)" + " of fn: %s (start_method: %s)", + failed_proc.exitcode, + failed_local_rank, + e.error_pid, + fn_name, + self.start_method, + ) + + self.close() + return RunProcsResult( + failures={ + failed_local_rank: ProcessFailure( + local_rank=failed_local_rank, + pid=e.error_pid, + exitcode=failed_proc.exitcode, + error_file=error_filepath, + ) + }, + stdouts=self.stdouts, + stderrs=self.stderrs, + ) + + def pids(self) -> dict[int, int]: + assert self._pc is not None # assertion for mypy type checking + return dict(enumerate(self._pc.pids())) + + def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: + if not self._pc: + return + for proc in self._pc.processes: + if proc.is_alive(): + logger.warning( + "Closing process %s via signal %s", proc.pid, death_sig.name + ) + try: + os.kill(proc.pid, death_sig) + except ProcessLookupError: + # If the process exited because of some reason, + # `ProcessLookupError` will be raised, it is safe to ignore it. + pass + end = time.monotonic() + timeout + for proc in self._pc.processes: + time_to_wait = end - time.monotonic() + if time_to_wait <= 0: + break + proc.join(time_to_wait) + for proc in self._pc.processes: + if proc.is_alive(): + logger.warning( + "Unable to shutdown process %s via %s, forcefully exiting via %s", + proc.pid, + death_sig, + _get_kill_signal(), + ) + try: + os.kill(proc.pid, _get_kill_signal()) + except ProcessLookupError: + # If the process exited because of some reason, + # `ProcessLookupError` will be raised, it is safe to ignore it. + pass + proc.join() + + +class SubprocessContext(PContext): + """``PContext`` holding worker processes invoked as a binary.""" + + def __init__( + self, + name: str, + entrypoint: str, + args: dict[int, tuple], + envs: dict[int, dict[str, str]], + logs_specs: LogsSpecs, + log_line_prefixes: dict[int, str] | None = None, + numa_options: NumaOptions | None = None, + duplicate_stdout_filters: list[str] | None = None, + duplicate_stderr_filters: list[str] | None = None, + ): + super().__init__( + name, + entrypoint, + args, + envs, + logs_specs, + log_line_prefixes, + duplicate_stdout_filters, + duplicate_stderr_filters, + ) + + # state vector; _vdone[local_rank] -> is local_rank finished or not + self._running_local_ranks: set[int] = set(range(self.nprocs)) + self._failures: dict[int, ProcessFailure] = {} + self.subprocess_handlers: dict[int, SubprocessHandler] = {} + self._numa_options: NumaOptions | None = numa_options + + def _start(self): + if self.subprocess_handlers: + raise ValueError( + "The subprocess handlers already initialized. Most likely the start method got called twice." + ) + self.subprocess_handlers = { + local_rank: get_subprocess_handler( + entrypoint=self.entrypoint, # type: ignore[arg-type] # entrypoint is always a str + args=self.args[local_rank], + env=self.envs[local_rank], + stdout=self.stdouts[local_rank], + stderr=self.stderrs[local_rank], + local_rank_id=local_rank, + numa_options=self._numa_options, + ) + for local_rank in range(self.nprocs) + } + + def _capture_process_failures(self, done_local_ranks: set[int]): + for local_rank in self._running_local_ranks: + handler = self.subprocess_handlers[local_rank] + exitcode = handler.proc.poll() + if exitcode is not None: + done_local_ranks.add(local_rank) + if exitcode != 0: # failed or signaled + self._failures[local_rank] = ProcessFailure( + local_rank=local_rank, + pid=handler.proc.pid, + exitcode=exitcode, + error_file=self.error_files[local_rank], + ) + # else: --> succeeded; nothing to do + + def _poll(self) -> RunProcsResult | None: + done_local_ranks: set[int] = set() + self._capture_process_failures(done_local_ranks) + + self._running_local_ranks.difference_update(done_local_ranks) + + # if ALL procs are finished or ANY have failed + if not self._running_local_ranks or self._failures: + self.close() # terminate all running procs + self._capture_process_failures( + done_local_ranks + ) # log sigterms and sigkill exit codes in the self._failures for bookkeeping purposes + + result = RunProcsResult( + failures=self._failures, + stdouts=self.stdouts, + stderrs=self.stderrs, + ) + if result.is_failed(): + first_failure = min(result.failures.values(), key=lambda f: f.timestamp) + logger.error( + "failed (exitcode: %s) local_rank: %s (pid: %s) of binary: %s", + first_failure.exitcode, + first_failure.local_rank, + first_failure.pid, + self.entrypoint, + ) + else: + # Populate return with dummy values. This provides consistency with MultiprocessingHandler + result.return_values = dict.fromkeys(range(self.nprocs)) + + return result + else: # there are no failures and procs still running + return None + + def pids(self) -> dict[int, int]: + return { + local_rank: sh.proc.pid + for local_rank, sh in self.subprocess_handlers.items() + } + + def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: + if not self.subprocess_handlers: + return + for handler in self.subprocess_handlers.values(): + if handler.proc.poll() is None: + logger.warning( + "Sending process %s closing signal %s", + handler.proc.pid, + death_sig.name, + ) + handler.close(death_sig=death_sig) + end = time.monotonic() + timeout + for handler in self.subprocess_handlers.values(): + time_to_wait = end - time.monotonic() + if time_to_wait <= 0: + break + try: + handler.proc.wait(time_to_wait) + except subprocess.TimeoutExpired: + # Ignore the timeout expired exception, since + # the child process will be forcefully terminated via SIGKILL + pass + for handler in self.subprocess_handlers.values(): + if handler.proc.poll() is None: + logger.warning( + "Unable to shutdown process %s via %s, forcefully exiting via %s", + handler.proc.pid, + death_sig, + _get_kill_signal(), + ) + handler.close(death_sig=_get_kill_signal()) + handler.proc.wait() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f61c99dc5c7779a5d839ca5b0364616b55079286 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Each host in a distributed PyTorch job runs with a single TorchElastic agent, +and multiple workers (as children processes of the TorchElastic agent). +Since the workers are user-provided (your PyTorch script/job), TorchElastic +has a way to propagate errors on the trainers through the agent and up to the +scheduler, which ultimately informs the end-user about the state of the job +and applies any retry policies. + +TorchElastic categorizes errors into 3 categories: + ++----------------+----------------+--------------------------------------------------------------+ +| Category | Sub-Category | Description | ++================+================+==============================================================+ +| User Error | Input Error | invalid inputs to TorchElastic APIs (e.g. min > max nodes) | +| +----------------+--------------------------------------------------------------+ +| | Worker Failure | any failures on the worker child process | ++----------------+----------------+--------------------------------------------------------------+ +| Platform Error | n/a | failures caused by the agent | ++----------------+----------------+--------------------------------------------------------------+ +| Infra Error | n/a | failures outside the domain of the agent and workers | +| | | (e.g. host failures) | ++----------------+----------------+--------------------------------------------------------------+ + +All errors other than "Worker Failure" are either raised canonically from the +agent process or implicitly or explicitly crash the agent process. So the +standard language (python) provided exception handling strategies apply. + +Worker Failures are special because the exception/failure originates on a different +process from the agent so the error needs to be propagated inter-process +(e.g. the agent cannot simply ``try-catch`` an exception raised on the worker process). + +TorchElastic agents use :func:`torch.distributed.elastic.multiprocessing.start_processes` +to launch the workers which has a simple file based inter-process error propagation +built-in. + +Any function or binary entrypoint decorated with :func:`record` +will write uncaught exceptions (with the trace information) to a file specified by the +environment variable ``TORCHELASTIC_ERROR_FILE``. The parent process (e.g. agent) +sets this env var on each child it launches, then aggregates the error files for all +children, and propagates the one with the **smallest** timestamp (e.g. the **first** error). +""" + +import json +import os +import signal +import socket +import time +from collections.abc import Callable +from dataclasses import dataclass, field +from datetime import datetime +from functools import wraps +from string import Template +from typing import Any, Optional, TypeVar, Union +from typing_extensions import ParamSpec + +from torch.distributed.elastic.utils.logging import get_logger + +from .error_handler import ErrorHandler # noqa: F401 +from .handlers import get_error_handler # noqa: F401 + + +__all__ = [ + "ProcessFailure", + "ChildFailedError", + "record", + "ErrorHandler", + "get_error_handler", +] + +logger = get_logger(__name__) + + +JSON = dict[str, Any] + +_EMPTY_ERROR_DATA: dict[str, Any] = {"message": ""} +_NOT_AVAILABLE = "" + +_R = TypeVar("_R") +_P = ParamSpec("_P") + + +@dataclass +class ProcessFailure: + """ + Represent the failed process result. When the worker process fails, it may record failure root cause into the file. + + Tries to read the failure timestamp from the provided ``error_file``, + if the ``error_file`` does not exist, the timestamp is the current + timestamp (seconds since epoch). + + The ``message`` field is a concise explanation of the failure. If + the error file exists then the message is obtained from the error file. + Otherwise one is generated based on the failure signature. + + .. note:: It is assumed that the ``error_file`` is written by + ``torch.distributed.elastic.multiprocessing.errors.error_handler.ErrorHandler``. + Otherwise the behavior is undefined. + + """ + + local_rank: int + pid: int + exitcode: int + error_file: str + error_file_data: JSON = field(init=False) + message: str = field(init=False) + timestamp: int = field(init=False) + + def __post_init__(self): + self.error_file_data = _EMPTY_ERROR_DATA + if os.path.isfile(self.error_file): + try: + with open(self.error_file) as fp: + self.error_file_data = json.load(fp) + logger.debug( + "User process failed with error data: %s", + json.dumps(self.error_file_data, indent=2), + ) + self.message, self.timestamp = self._get_error_data( + self.error_file_data + ) + except Exception: + logger.exception("Failed to parse reply file: %s", self.error_file) + raise + else: + self._set_no_reply_file() + + # make up an informative message if not already present + if not self.message: + # signals typically do not generate an error file message + if self.exitcode < 0: + self.message = ( + f"Signal {-self.exitcode} ({self.signal_name()})" + f" received by PID {self.pid}" + ) + else: + self.error_file_data["errorTraits"] = { + "category": "system_terminated_error", + "retryability": "False", + } + self.message = "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html" + + def _get_error_data(self, error_file_data: dict[str, Any]) -> tuple[str, int]: + message = error_file_data["message"] + if isinstance(message, str): + timestamp = int(error_file_data.get("timestamp", 0)) + else: + timestamp = int(message["extraInfo"]["timestamp"]) + return (message, timestamp) + + def _set_no_reply_file(self): + self.error_file = _NOT_AVAILABLE + self.error_file_data = _EMPTY_ERROR_DATA + self.message = "" + self.timestamp = int(time.time()) + + def signal_name(self) -> str: + if self.exitcode < 0: + # We don't want to kill the parent process trying to find the signal name. + # if the signal doesn't map to a known name, use not available. + try: + return signal.Signals(-self.exitcode).name + except Exception: + return _NOT_AVAILABLE + else: + return _NOT_AVAILABLE + + def timestamp_isoformat(self): + """Return timestamp in ISO format (YYYY-MM-DD_HH:MM:SS).""" + return datetime.fromtimestamp(self.timestamp).isoformat(sep="_") + + +GlobalRank = int + +_FAILURE_FORMAT_TEMPLATE = """[${idx}]: + time : ${time} + host : ${hostname} + rank : ${rank} (local_rank: ${local_rank}) + exitcode : ${exitcode} (pid: ${pid}) + error_file: ${error_file} + traceback : ${message}""" + +# extra new lines before and after are intentional +_MSG_FORMAT_TEMPLATE = """ +${boarder} +${title} +${section} +Failures: +${other_failures} +${section} +Root Cause (first observed failure): +${root_failure} +${boarder}""" + + +class ChildFailedError(Exception): + """ + Special exception type that can be raised from a function annotated with the + ``@record`` decorator to have the child process' (root exception) propagate + up the stack as-is (e.g. without being wrapped in the parent's traceback). + + Useful in cases where the parent is a simple nanny process + and the child (worker) processes are actually doing meaningful compute. + In this case, errors typically occur on the child process as the parent + is not doing anything non-trivial, and child errors should be propagated + to the scheduler for accurate root cause diagnostics. + + .. note:: The propagation relies on error files rather than exception handling to + support both function and binary launches. + + Example: + :: + + # process tree on a host (container) + 0: scheduler-init-process: + |- 1: torchelastic_agent: + |- 2: trainer_0 (ok) + |- 3: trainer_1 (fail) -> error.json + |- ... + |- n+2: trainer_n (ok) + |- n+3: other processes + |- ... + + In the example above, trainer 1's failure (written into error.json) is + the root cause and should be reported to the scheduler's init process. + The torchelastic agent raises a ``ChildFailedError("trainer", {1: "trainer_1/error.json"})`` + upon detecting trainer 1's failure which would propagate the contents + of trainer 1's error file to the scheduler's init process. + """ + + def __init__(self, name: str, failures: dict[GlobalRank, ProcessFailure]): + self.name = name + self.failures = failures + assert ( + self.failures + ) # does not make sense to create a ChildFaileError with no failures + super().__init__(self.format_msg()) + + def get_first_failure(self) -> tuple[GlobalRank, ProcessFailure]: + rank = min(self.failures.keys(), key=lambda r: self.failures[r].timestamp) + return rank, self.failures[rank] + + def format_msg(self, boarder_delim="=", section_delim="-"): + title = f"{self.name} FAILED" + root_rank, _root_failure = self.get_first_failure() + + root_failure_fmt: str = "" + other_failures_fmt: list[str] = [] + width = len(title) + for idx, (rank, failure) in enumerate(self.failures.items()): + fmt, w = self._format_failure(idx, rank, failure) + width = max(width, w) + if rank == root_rank: + root_failure_fmt = fmt + else: + other_failures_fmt.append(fmt) + + # upper boundary on width + width = min(width, 60) + + return Template(_MSG_FORMAT_TEMPLATE).substitute( + boarder=boarder_delim * width, + title=title, + section=section_delim * width, + root_failure=root_failure_fmt, + other_failures="\n".join(other_failures_fmt or [" "]), + ) + + def _format_failure( + self, idx: int, rank: int, failure: ProcessFailure + ) -> tuple[str, int]: + # failure.message is either a str (when the failure does not generate a traceback - e.g. signals) + # or a dict (json) of the form + # {"message": $ERROR_MSG, "extraInfo": {"py_callstack": $TRACEBACK, timestamp: $TS}} + # so the display logic is: + # 1. if failure.message is not a dict (it is a str) just show it as is + # 2. else try to get the traceback (py_callstack) + # 3. if the traceback is not there, use the message + # 4. if the message is not there show + msg = failure.message + if isinstance(failure.message, dict): + msg = ( + failure.message.get("extraInfo", {}) + .get("py_callstack", failure.message.get("message", "")) + .replace("\n", "\n ") # to properly indent the traceback + ) + + fmt = Template(_FAILURE_FORMAT_TEMPLATE).substitute( + idx=idx, + time=failure.timestamp_isoformat(), + hostname=socket.getfqdn(), + rank=rank, + local_rank=failure.local_rank, + exitcode=failure.exitcode, + pid=failure.pid, + error_file=failure.error_file, + message=msg, + ) + width = 0 + for line in fmt.split("\n"): + width = max(width, len(line)) + return fmt, width + + +def record( + fn: Callable[_P, _R], error_handler: ErrorHandler | None = None +) -> Callable[_P, _R | None]: + """ + Syntactic sugar to record errors/exceptions that happened in the decorated + function using the provided ``error_handler``. + + Using this decorator is equivalent to: + + :: + + error_handler = get_error_handler() + error_handler.initialize() + try: + foobar() + except ChildFailedError as e: + _, failure = e.get_first_failure() + error_handler.dump_error_file(failure.error_file, failure.exitcode) + raise + except Exception as e: + error_handler.record_exception(e) + raise + + .. important:: use this decorator once per process at the top level method, + typically this is the main method. + + Example + + :: + + @record + def main(): + pass + + + if __name__ == "__main__": + main() + + """ + if not error_handler: + error_handler = get_error_handler() + + def wrap(f: Callable[_P, _R]) -> Callable[_P, _R | None]: + @wraps(f) + def wrapper(*args: _P.args, **kwargs: _P.kwargs): + assert error_handler is not None # assertion for mypy type checker + error_handler.initialize() + try: + return f(*args, **kwargs) + except SystemExit as se: + # For run_path based entrypoints, SystemExit with code = 0 will never exit. + # Handling it here by returning a value: + if se.code == 0: + return None + else: + raise + except ChildFailedError as e: + rank, failure = e.get_first_failure() + if failure.error_file != _NOT_AVAILABLE: + error_handler.dump_error_file(failure.error_file, failure.exitcode) + else: + logger.info( + ( + "local_rank %s FAILED with no error file." + " Decorate your entrypoint fn with @record for traceback info." + " See: https://pytorch.org/docs/stable/elastic/errors.html", + rank, + ) + ) + raise + except Exception as e: + error_handler.record_exception(e) + raise + + return wrapper + + return wrap(fn) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19583df34ad1af77c2ea9f69bcc5fea8559c7359 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/error_handler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/error_handler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7efc947f0b399c6c64fdd2eb2573f123afe39ef6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/error_handler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/handlers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/handlers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f87812a55ccee96c879df3dcff31fcb6057237c4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/handlers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/error_handler.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/error_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ab6613e54dee10edbec54abcc5bc689b01676358 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/error_handler.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import faulthandler +import json +import logging +import os +import time +import traceback +import warnings +from typing import Any + + +__all__ = ["ErrorHandler"] + +logger = logging.getLogger(__name__) + + +class ErrorHandler: + """ + Write the provided exception object along with some other metadata about + the error in a structured way in JSON format to an error file specified by the + environment variable: ``TORCHELASTIC_ERROR_FILE``. If this environment + variable is not set, then simply logs the contents of what would have been + written to the error file. + + This handler may be subclassed to customize the handling of the error. + Subclasses should override ``initialize()`` and ``record_exception()``. + """ + + def _get_error_file_path(self) -> str | None: + """ + Return the error file path. + + May return ``None`` to have the structured error be logged only. + """ + return os.environ.get("TORCHELASTIC_ERROR_FILE", None) + + def initialize(self) -> None: + """ + Call prior to running code that we wish to capture errors/exceptions. + + Typically registers signal/fault handlers. Users can override this + function to add custom initialization/registrations that aid in + propagation/information of errors/signals/exceptions/faults. + """ + try: + faulthandler.enable(all_threads=True) + except Exception as e: + warnings.warn( + f"Unable to enable fault handler. {type(e).__name__}: {e}", stacklevel=2 + ) + + def _write_error_file(self, file_path: str, error_msg: str) -> None: + """Write error message to the file.""" + try: + with open(file_path, "w") as fp: + fp.write(error_msg) + except Exception as e: + warnings.warn( + f"Unable to write error to file. {type(e).__name__}: {e}", stacklevel=2 + ) + + def record_exception(self, e: BaseException) -> None: + """ + Write a structured information about the exception into an error file in JSON format. + + If the error file cannot be determined, then logs the content + that would have been written to the error file. + """ + file = self._get_error_file_path() + if file: + data = { + "message": { + "message": f"{type(e).__name__}: {e}", + "extraInfo": { + "py_callstack": traceback.format_exc(), + "timestamp": str(int(time.time())), + }, + } + } + with open(file, "w") as fp: + json.dump(data, fp) + + def override_error_code_in_rootcause_data( + self, + rootcause_error_file: str, + rootcause_error: dict[str, Any], + error_code: int = 0, + ): + """Modify the rootcause_error read from the file, to correctly set the exit code.""" + if "message" not in rootcause_error: + logger.warning( + "child error file (%s) does not have field `message`. \n" + "cannot override error code: %s", + rootcause_error_file, + error_code, + ) + elif isinstance(rootcause_error["message"], str): + logger.warning( + "child error file (%s) has a new message format. \n" + "skipping error code override", + rootcause_error_file, + ) + else: + rootcause_error["message"]["errorCode"] = error_code + + def dump_error_file(self, rootcause_error_file: str, error_code: int = 0): + """Dump parent error file from child process's root cause error and error code.""" + with open(rootcause_error_file) as fp: + rootcause_error = json.load(fp) + # Override error code since the child process cannot capture the error code if it + # is terminated by signals like SIGSEGV. + if error_code: + self.override_error_code_in_rootcause_data( + rootcause_error_file, rootcause_error, error_code + ) + logger.debug( + "child error file (%s) contents:\n%s", + rootcause_error_file, + json.dumps(rootcause_error, indent=2), + ) + + my_error_file = self._get_error_file_path() + if my_error_file: + # Guard against existing error files + # This can happen when the child is created using multiprocessing + # and the same env var (TORCHELASTIC_ERROR_FILE) is used on the + # parent and child to specify the error files (respectively) + # because the env vars on the child is set in the wrapper function + # and by default the child inherits the parent's env vars, if the child + # process receives a signal before the wrapper function kicks in + # and the signal handler writes to the error file, then the child + # will write to the parent's error file. In this case just log the + # original error file contents and overwrite the error file. + self._rm(my_error_file) + self._write_error_file(my_error_file, json.dumps(rootcause_error)) + logger.info("dumped error file to parent's %s", my_error_file) + else: + logger.error( + "no error file defined for parent, to copy child error file (%s)", + rootcause_error_file, + ) + + def _rm(self, my_error_file): + if os.path.isfile(my_error_file): + # Log the contents of the original file. + with open(my_error_file) as fp: + try: + original = json.dumps(json.load(fp), indent=2) + logger.warning( + "%s already exists" + " and will be overwritten." + " Original contents:\n%s", + my_error_file, + original, + ) + except json.decoder.JSONDecodeError: + logger.warning( + "%s already exists" + " and will be overwritten." + " Unable to load original contents:\n", + my_error_file, + ) + os.remove(my_error_file) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/handlers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..6721217a41190c2bdd6bf2293540a33c893c145d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/handlers.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# Multiprocessing error-reporting module + + +from torch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler + + +__all__ = ["get_error_handler"] + + +def get_error_handler() -> ErrorHandler: + return ErrorHandler() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/redirects.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/redirects.py new file mode 100644 index 0000000000000000000000000000000000000000..057013fbb9e5b8a2aeca69b41d7679cbe75c0e28 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/redirects.py @@ -0,0 +1,104 @@ +# mypy: allow-untyped-defs +# !/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Taken and modified from original source: +# https://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/ +import ctypes +import logging +import os +import sys +from contextlib import contextmanager +from functools import partial + + +IS_WINDOWS = sys.platform == "win32" +IS_MACOS = sys.platform == "darwin" + + +logger = logging.getLogger(__name__) + + +def get_libc(): + if IS_WINDOWS or IS_MACOS: + logger.warning( + "NOTE: Redirects are currently not supported in Windows or MacOs." + ) + return None + else: + return ctypes.CDLL("libc.so.6") + + +libc = get_libc() + + +def _c_std(stream: str): + return ctypes.c_void_p.in_dll(libc, stream) + + +def _python_std(stream: str): + return {"stdout": sys.stdout, "stderr": sys.stderr}[stream] + + +_VALID_STD = {"stdout", "stderr"} + + +@contextmanager +def redirect(std: str, to_file: str): + """ + Redirect ``std`` (one of ``"stdout"`` or ``"stderr"``) to a file in the path specified by ``to_file``. + + This method redirects the underlying std file descriptor (not just python's ``sys.stdout|stderr``). + See usage for details. + + Directory of ``dst_filename`` is assumed to exist and the destination file + is overwritten if it already exists. + + .. note:: Due to buffering cross source writes are not guaranteed to + appear in wall-clock order. For instance in the example below + it is possible for the C-outputs to appear before the python + outputs in the log file. + + Usage: + + :: + + # syntactic-sugar for redirect("stdout", "tmp/stdout.log") + with redirect_stdout("/tmp/stdout.log"): + print("python stdouts are redirected") + libc = ctypes.CDLL("libc.so.6") + libc.printf(b"c stdouts are also redirected" + os.system("echo system stdouts are also redirected") + + print("stdout restored") + + """ + if std not in _VALID_STD: + raise ValueError( + f"unknown standard stream <{std}>, must be one of {_VALID_STD}" + ) + + c_std = _c_std(std) + python_std = _python_std(std) + std_fd = python_std.fileno() + + def _redirect(dst): + libc.fflush(c_std) + python_std.flush() + os.dup2(dst.fileno(), std_fd) + + with os.fdopen(os.dup(std_fd)) as orig_std, open(to_file, mode="w+b") as dst: + _redirect(dst) + try: + yield + finally: + _redirect(orig_std) + + +redirect_stdout = partial(redirect, "stdout") +redirect_stderr = partial(redirect, "stderr") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f56d423ce080fd7c331dc9b43eda58e5370678fc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from torch.distributed.elastic.multiprocessing.subprocess_handler.handlers import ( + get_subprocess_handler, +) +from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import ( + SubprocessHandler, +) + + +__all__ = ["SubprocessHandler", "get_subprocess_handler"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b00f08cc63eda1caedd4c46fbcad7e58fc0fd5c7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3007e360dd135bb9fe64b911327a69524c00fae7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e66af441290a827e76da61e1f1f21933f0909362 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1742626e285838485c19911704792510d13fb4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py @@ -0,0 +1,33 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import ( + SubprocessHandler, +) +from torch.numa.binding import NumaOptions + + +__all__ = ["get_subprocess_handler"] + + +def get_subprocess_handler( + entrypoint: str, + args: tuple, + env: dict[str, str], + stdout: str, + stderr: str, + local_rank_id: int, + numa_options: NumaOptions | None = None, +) -> SubprocessHandler: + return SubprocessHandler( + entrypoint=entrypoint, + args=args, + env=env, + stdout=stdout, + stderr=stderr, + local_rank_id=local_rank_id, + numa_options=numa_options, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..268817108d8cd20f6ba0130818286d297da78c4e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +import signal +import sys +from subprocess import Popen +from typing import Any + +from torch.numa.binding import maybe_wrap_command_args_with_numa_binding, NumaOptions + + +__all__ = ["SubprocessHandler"] + +IS_WINDOWS = sys.platform == "win32" + + +def _get_default_signal() -> signal.Signals: + """Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows.""" + if IS_WINDOWS: + return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 + else: + return signal.SIGTERM + + +class SubprocessHandler: + """ + Convenience wrapper around python's ``subprocess.Popen``. Keeps track of + meta-objects associated to the process (e.g. stdout and stderr redirect fds). + """ + + def __init__( + self, + entrypoint: str, + args: tuple, + env: dict[str, str], + stdout: str | None, + stderr: str | None, + local_rank_id: int, + numa_options: NumaOptions | None, + ): + self._stdout = open(stdout, "w") if stdout else None # noqa: SIM115 + self._stderr = open(stderr, "w") if stderr else None # noqa: SIM115 + # inherit parent environment vars + env_vars = os.environ.copy() + env_vars.update(env) + + args_str = (entrypoint, *[str(e) for e in args]) + args_str = maybe_wrap_command_args_with_numa_binding( + args_str, + gpu_index=local_rank_id, + numa_options=numa_options, + ) + + self.local_rank_id = local_rank_id + + self.proc: Popen = self._popen(args_str, env_vars) + + def _popen(self, args: tuple, env: dict[str, str]) -> Popen: + kwargs: dict[str, Any] = {} + if not IS_WINDOWS: + kwargs["start_new_session"] = True + + return Popen( + # pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes], + # _PathLike[str], bytes, str]], bytes, str]` for 1st param but got + # `Tuple[str, *Tuple[Any, ...]]`. + args=args, + env=env, + stdout=self._stdout, + stderr=self._stderr, + **kwargs, + ) + + def close(self, death_sig: signal.Signals | None = None) -> None: + if not death_sig: + death_sig = _get_default_signal() + if IS_WINDOWS: + self.proc.send_signal(death_sig) + else: + os.killpg(self.proc.pid, death_sig) + if self._stdout: + self._stdout.close() + if self._stderr: + self._stderr.close() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/tail_log.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/tail_log.py new file mode 100644 index 0000000000000000000000000000000000000000..77d410cce55c09b0acd79ebf4583028f5a7bb759 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/tail_log.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import time +from collections.abc import Callable +from concurrent.futures.thread import ThreadPoolExecutor +from threading import Event +from typing import TextIO, TYPE_CHECKING + + +if TYPE_CHECKING: + from concurrent.futures._base import Future + +__all__ = ["tail_logfile", "TailLog"] + +logger = logging.getLogger(__name__) + + +def tail_logfile( + header: str, + file: str, + dst: TextIO, + finished: Event, + interval_sec: float, + log_line_filter: Callable[[str], bool] | None = None, +): + while not os.path.exists(file): + if finished.is_set(): + return + time.sleep(interval_sec) + + with open(file, errors="replace") as fp: + while True: + line = fp.readline() + + if line: + if log_line_filter and log_line_filter(line): + dst.write(f"{header}{line}") + else: # reached EOF + if finished.is_set(): + # log line producer is finished + break + else: + # log line producer is still going + # wait for a bit before looping again + time.sleep(interval_sec) + + +class TailLog: + """ + Tail the given log files. + + The log files do not have to exist when the ``start()`` method is called. The tail-er will gracefully wait until + the log files are created by the producer and will tail the contents of the + log files until the ``stop()`` method is called. + + .. warning:: ``TailLog`` will wait indefinitely for the log file to be created! + + Each log file's line will be suffixed with a header of the form: ``[{name}{idx}]:``, + where the ``name`` is user-provided and ``idx`` is the index of the log file + in the ``log_files`` mapping. ``log_line_prefixes`` can be used to override the + header for each log file. + + Usage: + + :: + + log_files = {0: "/tmp/0_stdout.log", 1: "/tmp/1_stdout.log"} + tailer = TailLog("trainer", log_files, sys.stdout).start() + # actually run the trainers to produce 0_stdout.log and 1_stdout.log + run_trainers() + tailer.stop() + + # once run_trainers() start writing the ##_stdout.log files + # the tailer will print to sys.stdout: + # >>> [trainer0]:log_line1 + # >>> [trainer1]:log_line1 + # >>> [trainer0]:log_line2 + # >>> [trainer0]:log_line3 + # >>> [trainer1]:log_line2 + + .. note:: Due to buffering log lines between files may not necessarily + be printed out in order. You should configure your application's + logger to suffix each log line with a proper timestamp. + + """ + + def __init__( + self, + name: str, + log_files: dict[int, str], + dst: TextIO, + log_line_prefixes: dict[int, str] | None = None, + interval_sec: float = 0.1, + log_line_filter: Callable[[str], bool] = (lambda _: True), + ): + n = len(log_files) + self._threadpool = None + if n > 0: + # pyrefly: ignore [bad-assignment] + self._threadpool = ThreadPoolExecutor( + max_workers=n, + thread_name_prefix=f"{self.__class__.__qualname__}_{name}", + ) + + self._name = name + self._dst = dst + self._log_files = log_files + self._log_line_prefixes = log_line_prefixes + self._log_line_filter = log_line_filter + self._finished_events: dict[int, Event] = { + local_rank: Event() for local_rank in log_files + } + self._futs: list[Future] = [] + self._interval_sec = interval_sec + self._stopped = False + + def start(self) -> "TailLog": + if not self._threadpool or not self._dst: + return self + + for local_rank, file in self._log_files.items(): + header = f"[{self._name}{local_rank}]:" + if self._log_line_prefixes and local_rank in self._log_line_prefixes: + header = self._log_line_prefixes[local_rank] + self._futs.append( + self._threadpool.submit( + tail_logfile, + header=header, + file=file, + dst=self._dst, + finished=self._finished_events[local_rank], + interval_sec=self._interval_sec, + log_line_filter=self._log_line_filter, + ) + ) + return self + + def stop(self) -> None: + for finished in self._finished_events.values(): + finished.set() + + for local_rank, f in enumerate(self._futs): + try: + f.result() + except Exception as e: + logger.exception( + "error in log tailor for %s%s. %s", + self._name, + local_rank, + e.__class__.__qualname__, + ) + + if self._threadpool: + self._threadpool.shutdown(wait=True) + + self._stopped = True + + def stopped(self) -> bool: + return self._stopped diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c387a3ec2833ac643c571afa7a194a1dc0d3fbea --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__init__.py @@ -0,0 +1,163 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +In the context of Torch Distributed Elastic we use the term *rendezvous* to +refer to a particular functionality that combines a **distributed +synchronization** primitive with **peer discovery**. + +It is used by Torch Distributed Elastic to gather participants of a training +job (i.e. nodes) such that they all agree on the same list of participants and +everyone's roles, as well as make a consistent collective decision on when +training can begin/resume. + +Torch Distributed Elastic rendezvous provides the following critical +functionalities: + +**Barrier**: + +Nodes performing rendezvous will all block until the rendezvous is considered +complete - this happens when at least ``min`` total number of nodes have joined +the rendezvous barrier (for the same job). This also implies the barrier is not +necessarily of fixed size. + +There's an additional small waiting time after reaching ``min`` number of +nodes - this is used to ensure the rendezvous is not completed "too quickly" +(which could potentially exclude additional nodes attempting to join at +approximately the same time). + +If ``max`` number of nodes is gathered at the barrier, the rendezvous is +completed immediately. + +There's also an overall timeout which causes the rendezvous to fail if ``min`` +number of nodes is never reached - this is meant to be a simple fail-safe to +help release partially allocated job resources, in case there's a problem with +the resource manager, and is meant to be interpreted as non-retryable. + +**Exclusivity**: + +A simple distributed barrier would not be sufficient, as we also need to ensure +that only one group of nodes exists at any given time (for a given job). In +other words, new nodes (i.e. joining late) should not be able to form a parallel +independent group of workers for the same job. + +Torch Distributed Elastic rendezvous ensures that if a group of nodes has +already completed a rendezvous (and hence might already be training), then +additional "late" nodes attempting to rendezvous will only announce themselves +as waiting, and will have to wait until the (previously completed) existing +rendezvous is destroyed first. + +**Consistency**: + +When a rendezvous is completed, all its members will agree on the job membership +and everyone's role in it. This role is represented using an integer, called +rank, that is between between 0 and world size. + +Note that ranks are *not stable*, in the sense that the same node can be +assigned a different rank in the next (re-)rendezvous. + +**Fault-tolerance**: + +Torch Distributed Elastic rendezvous is designed to tolerate node failures +during the rendezvous process. Should a process crash (or lose network +connectivity, etc), between joining the rendezvous and it being completed, then +a re-rendezvous with remaining healthy nodes will happen automatically. + +A node can also fail *after* it has completed (or *has been observed* by other +nodes to have completed) the rendezvous - this scenario will be handled by the +Torch Distributed Elastic ``train_loop`` instead (where it will also trigger a +re-rendezvous). + +**Shared key-value store**: + +When the rendezvous is completed, a shared key-value store is created and +returned. This store implements a ``torch.distributed.Store`` API (see +`distributed communication docs +`__). + +This store is only shared by the members of the completed rendezvous. It +is intended to be used by Torch Distributed Elastic to exchange information +necessary to initialize job control and data-planes. + +**Waiting workers and rendezvous closing**: + +Torch Distributed Elastic rendezvous handler object provides additional +functionalities, which are technically not part of the rendezvous process: + +1. Querying how many workers arrived late at the barrier, who can participate in + *next* rendezvous. + +2. Setting the rendezvous *closed* to signal all nodes not to participate in + next rendezvous. + +**DynamicRendezvousHandler**: + +Torch Distributed Elastic comes with the :py:class:`.DynamicRendezvousHandler` +class that implements the rendezvous mechanism described above. It is a backend- +agnostic type that expects a particular :py:class:`.RendezvousBackend` instance +to be specified during construction. + +Torch distributed users can either implement their own backend type or use one +of the following implementations that come with PyTorch: + +- :py:class:`.C10dRendezvousBackend`: Uses a C10d store (by default + ``TCPStore``) as the rendezvous backend. The main advantage of using a C10d + store is that it requires no 3rd-party dependency (such as etcd) to establish + a rendezvous. +- :py:class:`.EtcdRendezvousBackend`: Supersedes the legacy + :py:class:`.EtcdRendezvousHandler` class. Passing an + :py:class:`.EtcdRendezvousBackend` instance to + :py:class:`.DynamicRendezvousHandler` is functionally equivalent to + instantiating an :py:class:`.EtcdRendezvousHandler`. + + :: + + store = TCPStore("localhost") + + backend = C10dRendezvousBackend(store, "my_run_id") + + rdzv_handler = DynamicRendezvousHandler.from_backend( + run_id="my_run_id", store=store, backend=backend, min_nodes=2, max_nodes=4 + ) +""" + +from .api import ( + rendezvous_handler_registry, + RendezvousClosedError, + RendezvousConnectionError, + RendezvousError, + RendezvousGracefulExitError, + RendezvousHandler, + RendezvousHandlerCreator, + RendezvousHandlerRegistry, + RendezvousInfo, + RendezvousParameters, + RendezvousStateError, + RendezvousStoreInfo, + RendezvousTimeoutError, +) +from .registry import _register_default_handlers, _register_out_of_tree_handlers + + +_register_default_handlers() +_register_out_of_tree_handlers() + + +__all__ = [ + "RendezvousClosedError", + "RendezvousConnectionError", + "RendezvousError", + "RendezvousGracefulExitError", + "RendezvousHandler", + "RendezvousHandlerCreator", + "RendezvousHandlerRegistry", + "RendezvousInfo", + "RendezvousParameters", + "RendezvousStateError", + "RendezvousStoreInfo", + "RendezvousTimeoutError", + "rendezvous_handler_registry", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e9e92d40c44b4a0f2f678c2f35035eaa1300d74 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/_etcd_stub.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/_etcd_stub.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b5919726ad53673252c6258f010a06ff90b6445 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/_etcd_stub.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/api.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef477d5938e66a2408ea60bd3ae38b26272f79d9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/api.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/c10d_rendezvous_backend.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/c10d_rendezvous_backend.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85a9f10997332a8de8466c8a55641af77fd33c41 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/c10d_rendezvous_backend.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/dynamic_rendezvous.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/dynamic_rendezvous.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a49077ec650a7c781f4acaec0737f39ec20e6448 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/dynamic_rendezvous.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b56fa973159c4e3e5320eae15eb078c29a743d28 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous_backend.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous_backend.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0a813f92ba006a6c3cbccd2fa5fa8b05c04c26f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous_backend.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_server.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_server.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3dd846ae46e10374c88c07d1c9d02959833ac73 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_server.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_store.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_store.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f774e5ca897fad3f6d9cb4b5633ff40fc00d9c0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_store.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/registry.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/registry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3f7fc7df67104e6bf8e92ca86cbf033dbeb117a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/registry.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/static_tcp_rendezvous.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/static_tcp_rendezvous.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5283352e141002db2e4b10d884ca32d15fc8835d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/static_tcp_rendezvous.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19e72846fc88c78e8149fb457d5cb9a23ff9b37d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/_etcd_stub.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/_etcd_stub.py new file mode 100644 index 0000000000000000000000000000000000000000..5890a97c672a61b5678e66b006ba173fe7668286 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/_etcd_stub.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any + + +""" +This file is not meant to be used directly. It serves as a stub to allow +other files to be safely imported without requiring the installation of +the 'etcd' library. The classes and methods here raise exceptions to +indicate that the real 'etcd' module is needed. +""" + + +class EtcdStubError(ImportError): + """Custom exception to indicate that the real etcd module is required.""" + + def __init__(self) -> None: + super().__init__("The 'etcd' module is required but not installed.") + + +class EtcdAlreadyExist(Exception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + +class EtcdCompareFailed(Exception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + +class EtcdKeyNotFound(Exception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + +class EtcdWatchTimedOut(Exception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + +class EtcdEventIndexCleared(Exception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + +class EtcdException(Exception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + +class EtcdResult: + def __init__(self) -> None: + raise EtcdStubError + + +class Client: + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + def read(self, key: str) -> None: + raise EtcdStubError + + def write( + self, key: str, value: Any, ttl: int | None = None, **kwargs: Any + ) -> None: + raise EtcdStubError + + def test_and_set( + self, key: str, value: Any, prev_value: Any, ttl: int | None = None + ) -> None: + raise EtcdStubError diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/api.py new file mode 100644 index 0000000000000000000000000000000000000000..2b3fa8183dfb81da2f0b675a5e1a5d1f6fee935f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/api.py @@ -0,0 +1,391 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import socket +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, ClassVar + +from torch.distributed import Store +from torch.distributed.elastic.utils.distributed import get_free_port + + +__all__ = [ + "RendezvousClosedError", + "RendezvousConnectionError", + "RendezvousError", + "RendezvousGracefulExitError", + "RendezvousHandler", + "RendezvousHandlerCreator", + "RendezvousHandlerRegistry", + "RendezvousInfo", + "RendezvousParameters", + "RendezvousStateError", + "RendezvousStoreInfo", + "RendezvousTimeoutError", + "rendezvous_handler_registry", +] + + +class RendezvousError(Exception): + """Represents the base type for rendezvous errors.""" + + +class RendezvousClosedError(RendezvousError): + """Raised when a rendezvous is closed.""" + + +class RendezvousTimeoutError(RendezvousError): + """Raised when a rendezvous did not complete on time.""" + + +class RendezvousConnectionError(RendezvousError): + """Raised when the connection to a rendezvous backend has failed.""" + + +class RendezvousStateError(RendezvousError): + """Raised when the state of a rendezvous is corrupt.""" + + +class RendezvousGracefulExitError(RendezvousError): + """Raised when node wasn't not included in rendezvous and gracefully exits. + + Exception is a mechanism to exit the stack, however does not mean a failure. + """ + + +@dataclass +class RendezvousStoreInfo: + """Store address and port that can be used to bootstrap trainer distributed comms""" + + MASTER_ADDR_KEY: ClassVar[str] = "MASTER_ADDR" + MASTER_PORT_KEY: ClassVar[str] = "MASTER_PORT" + master_addr: str + master_port: int + + @staticmethod + def build( + rank: int, + store: Store, + local_addr: str | None, + server_port: int | None = None, + ) -> "RendezvousStoreInfo": + """Factory method, finds unused new port on rank0 host and addr/port info with all ranks. + + If master_addr/master_port is knowns (useful when sharing existing tcp store server) use the constructor. + + Args: + rank: rank of the current node + store: store to use for rendezvous + local_addr: address of the current node, if not provided will be resolved from hostname + server_port: port of the TCPStore server, when the TCPStore is shared. + """ + # TODO swap to collectives comms API + if rank == 0: + addr = local_addr or socket.getfqdn() + # When TCPStore is not shared, we fallback to get_free_port. + port = server_port or get_free_port() + store.set( + RendezvousStoreInfo.MASTER_ADDR_KEY, + addr.encode(encoding="UTF-8"), # type: ignore[arg-type] + ) + store.set( + RendezvousStoreInfo.MASTER_PORT_KEY, + str(port).encode(encoding="UTF-8"), # type: ignore[arg-type] + ) + + addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8") + port = int( + store.get(RendezvousStoreInfo.MASTER_PORT_KEY).decode(encoding="UTF-8") + ) + return RendezvousStoreInfo(master_addr=addr, master_port=port) + + +class RendezvousInfo: + """Holds the information about the rendezvous.""" + + def __init__( + self, + store: Store, + rank: int, + world_size: int, + bootstrap_store_info: RendezvousStoreInfo, + ): + self._store = store + self._rank = rank + self._world_size = world_size + self._bootstrap_store_info = bootstrap_store_info + + @property + def store(self) -> Store: + """Store used by torchelastic control plane""" + return self._store + + @property + def rank(self) -> int: + """Rank within a group""" + return self._rank + + @property + def world_size(self) -> int: + """Global group size""" + return self._world_size + + @property + def bootstrap_store_info(self) -> RendezvousStoreInfo | None: + """Store information that can used by trainer code to bootstrap distributed comms.""" + return self._bootstrap_store_info + + +class RendezvousHandler(ABC): + """Main rendezvous interface. + + Note: + Distributed Torch users normally **do not** need to implement their own + ``RendezvousHandler``. An implementation based on C10d Store is already + provided, and is recommended for most users. + """ + + @abstractmethod + def get_backend(self) -> str: + """Return the name of the rendezvous backend.""" + + @property + def use_agent_store(self) -> bool: + """Indicates that store reference returned by :py:meth:`next_rendezvous` can be shared with user + applications and will be available during application lifecycle. + + Rendezvous handler impl will share store details as instance of :py:class:`RendezvousStoreInfo`. + Applications as a convention use `MASTER_ADDR`/`MASTER_PORT` env variables to lookup the store. + """ + return False + + @abstractmethod + def next_rendezvous(self) -> RendezvousInfo: + """Main entry-point into the rendezvous barrier. + + Blocks until the rendezvous is complete and the current process is + included in the formed worker group, or a timeout occurs, or the + rendezvous was marked closed. + + Returns: + Instance of :py:class:`RendezvousInfo`. + + Raises: + RendezvousClosedError: + The rendezvous is closed. + RendezvousConnectionError: + The connection to the rendezvous backend has failed. + RendezvousStateError: + The rendezvous state is corrupt. + RendezvousTimeoutError: + The rendezvous did not complete on time. + """ + + @abstractmethod + def is_closed(self) -> bool: + """Check whether the rendezvous has been closed. + + A closed rendezvous means all future attempts to re-rendezvous within + same job will fail. + + ``is_closed()`` and :py:meth:`set_closed` have semantics of eventual + propagation and should not be used for synchronization. The intention is + that if at least one node decides the job is finished, it will close the + rendezvous, and other nodes will soon observe this and stop running as + well. + """ + + @abstractmethod + def set_closed(self): + """Mark the rendezvous as closed.""" + + @abstractmethod + def num_nodes_waiting(self) -> int: + """Return the number of nodes who arrived late at the rendezvous + barrier, hence were not included in the current worker group. + + Callers should periodically call this method to check whether new + nodes are waiting to join the job and if so admit them by calling + :py:meth:`next_rendezvous()` (re-rendezvous). + """ + + @abstractmethod + def get_run_id(self) -> str: + """Return the run id of the rendezvous. + + The run id is a user-defined id that uniquely identifies an instance of + a distributed application. It typically maps to a job id and is used to + allow nodes to join the correct distributed application. + """ + + @abstractmethod + def shutdown(self) -> bool: + """Close all resources that were open for the rendezvous. + + Example:: + + rdzv_handler = ... + try: + store, rank, world_size = rdzv_handler.next_rendezvous() + finally: + rdzv_handler.shutdown() + """ + + +class RendezvousParameters: + """Hold the parameters to construct a :py:class:`RendezvousHandler`. + + Args: + backend: + The name of the backend to use to handle the rendezvous. + endpoint: + The endpoint of the rendezvous, usually in form [:]. + run_id: + The id of the rendezvous. + min_nodes: + The minimum number of nodes to admit to the rendezvous. + max_nodes: + The maximum number of nodes to admit to the rendezvous. + local_addr: + The address of the local node. + **kwargs: + Additional parameters for the specified backend. + """ + + def __init__( + self, + backend: str, + endpoint: str, + run_id: str, + min_nodes: int, + max_nodes: int, + local_addr: str | None = None, + **kwargs, + ): + if not backend: + raise ValueError("The rendezvous backend name must be a non-empty string.") + + if min_nodes < 1: + raise ValueError( + f"The minimum number of rendezvous nodes ({min_nodes}) must be greater than zero." + ) + if max_nodes < min_nodes: + raise ValueError( + f"The maximum number of rendezvous nodes ({max_nodes}) must be greater than or " + f"equal to the minimum number of rendezvous nodes ({min_nodes})." + ) + + self.backend = backend + self.endpoint = endpoint + self.run_id = run_id + self.min_nodes = min_nodes + self.max_nodes = max_nodes + self.config = kwargs + self.local_addr = local_addr + + def get(self, key: str, default: Any = None) -> Any: + """Return the value for ``key`` if ``key`` exists, else ``default``.""" + return self.config.get(key, default) + + def get_as_bool(self, key: str, default: bool | None = None) -> bool | None: + """Return the value for ``key`` as a ``bool``.""" + value = self.get(key, default) + if value is None or isinstance(value, bool): + return value + if isinstance(value, int): + if value == 1: + return True + if value == 0: + return False + elif isinstance(value, str): + if value.lower() in ["1", "true", "t", "yes", "y"]: + return True + if value.lower() in ["0", "false", "f", "no", "n"]: + return False + raise ValueError( + f"The rendezvous configuration option '{key}' does not represent a valid boolean value." + ) + + def get_as_int(self, key: str, default: int | None = None) -> int | None: + """Return the value for ``key`` as an ``int``.""" + value = self.get(key, default) + if value is None: + return value + try: + return int(value) + except ValueError as e: + raise ValueError( + f"The rendezvous configuration option '{key}' does not represent a valid integer " + "value." + ) from e + + +RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler] + + +class RendezvousHandlerRegistry: + """Represent a registry of :py:class:`RendezvousHandler` backends.""" + + _registry: dict[str, RendezvousHandlerCreator] + + def __init__(self) -> None: + self._registry = {} + + def register(self, backend: str, creator: RendezvousHandlerCreator) -> None: + """Register a new rendezvous backend. + + Args: + backend: + The name of the backend. + creator: + The callback to invoke to construct the + :py:class:`RendezvousHandler`. + """ + if not backend: + raise ValueError("The rendezvous backend name must be a non-empty string.") + + current_creator: RendezvousHandlerCreator | None + try: + current_creator = self._registry[backend] + except KeyError: + current_creator = None + + if current_creator is not None and current_creator != creator: + raise ValueError( + f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it " + f"is already registered with '{current_creator}'." + ) + + self._registry[backend] = creator + + def create_handler(self, params: RendezvousParameters) -> RendezvousHandler: + """Create a new :py:class:`RendezvousHandler`.""" + try: + creator = self._registry[params.backend] + except KeyError as e: + raise ValueError( + f"The rendezvous backend '{params.backend}' is not registered. Did you forget " + f"to call `{self.register.__name__}`?" + ) from e + + handler = creator(params) + + # Do some sanity check. + if handler.get_backend() != params.backend: + raise RuntimeError( + f"The rendezvous backend '{handler.get_backend()}' does not match the requested " + f"backend '{params.backend}'." + ) + + return handler + + +# The default global registry instance used by launcher scripts to instantiate +# rendezvous handlers. +rendezvous_handler_registry = RendezvousHandlerRegistry() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..0296c4d45ddc13dadc9ee1d91f07a3950c277892 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py @@ -0,0 +1,270 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import binascii +import logging +import os +import tempfile +from base64 import b64decode, b64encode +from datetime import timedelta +from typing import Any, cast + +from torch.distributed import FileStore, Store, TCPStore +from torch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState + +from .api import ( + RendezvousConnectionError, + RendezvousError, + RendezvousParameters, + RendezvousStateError, +) +from .dynamic_rendezvous import RendezvousBackend, Token +from .utils import _matches_machine_hostname, parse_rendezvous_endpoint + + +logger = logging.getLogger(__name__) + +# default port for the TCP store +DEFAULT_PORT = 29400 + + +class C10dRendezvousBackend(RendezvousBackend): + """Represents a C10d-backed rendezvous backend. + + Args: + store: + The :py:class:`torch.distributed.Store` instance to use to + communicate with the C10d store. + run_id: + The run id of the rendezvous. + """ + + # See the explanation in the __init__ method. + _NULL_SENTINEL = "Y2FuaW1hZGFt" + + _store: Store + _key: str + + def __init__(self, store: Store, run_id: str) -> None: + if not run_id: + raise ValueError("The run id must be a non-empty string.") + + self._store = store + + self._key = "torch.rendezvous." + run_id + + # The read operation of a store blocks the caller until the specified + # key becomes available. This behavior makes it tricky to use a store + # as a regular key-value dictionary. + # + # As a workaround we initially set a sentinel value as the rendezvous + # state. Whenever this value gets returned we treat it as a None. + self._call_store("compare_set", self._key, "", self._NULL_SENTINEL) + + @property + def name(self) -> str: + """See base class.""" + return "c10d" + + def get_state(self) -> tuple[bytes, Token] | None: + """See base class.""" + base64_state: bytes = self._call_store("get", self._key) + + return self._decode_state(base64_state) + + def set_state( + self, state: bytes, token: Token | None = None + ) -> tuple[bytes, Token, bool] | None: + """See base class.""" + base64_state_str: str = b64encode(state).decode() + + if token: + # Shortcut if we know for sure that the token is not valid. + if not isinstance(token, bytes): + result = self.get_state() + if result is not None: + return *result, False + return None + + token = token.decode() + else: + token = self._NULL_SENTINEL + + base64_state: bytes = self._call_store( + "compare_set", self._key, token, base64_state_str + ) + + state_token_pair = self._decode_state(base64_state) + if state_token_pair is None: + return None + + new_state, new_token = state_token_pair + + # C10d Store's compare_set method does not offer an easy way to find out + # whether our write attempt was successful. As a brute-force solution we + # perform a bitwise comparison of our local state and the remote state. + return new_state, new_token, new_state == state + + def _call_store(self, store_op: str, *args, **kwargs) -> Any: + try: + return getattr(self._store, store_op)(*args, **kwargs) + except (ValueError, RuntimeError, TimeoutError) as exc: + raise RendezvousConnectionError( + "The connection to the C10d store has failed. See inner exception for details." + ) from exc + + def _decode_state(self, base64_state: bytes) -> tuple[bytes, Token] | None: + if base64_state == self._NULL_SENTINEL.encode(): + return None + + try: + state = b64decode(base64_state) + except binascii.Error as exc: + raise RendezvousStateError( + "The state object is corrupt. See inner exception for details." + ) from exc + + return state, base64_state + + +def _create_tcp_store(params: RendezvousParameters) -> TCPStore: + host, port = parse_rendezvous_endpoint(params.endpoint, default_port=DEFAULT_PORT) + + cfg_is_host = params.get_as_bool("is_host") + # If the user has explicitly specified whether our process should host the + # the store, respect it. + if cfg_is_host is not None: + is_host = cfg_is_host + # Otherwise try to determine whether we are the host based on our hostname + # and IP address. + else: + is_host = _matches_machine_hostname(host) + + # The timeout + read_timeout = cast(int, params.get_as_int("read_timeout", 60)) + if read_timeout <= 0: + raise ValueError("The read timeout must be a positive integer.") + + # In specific cases we attempt to instantiate the store twice. For details + # see the explanation in the except clause below. + for is_server in [is_host, False]: + try: + store = TCPStore( + host, + port, + is_master=is_server, + multi_tenant=True, + timeout=timedelta(seconds=read_timeout), + ) + + if is_server: + msg = f"Process {os.getpid()} hosts the TCP store for the C10d rendezvous backend." + construct_and_record_rdzv_event( + run_id=params.run_id, message=msg, node_state=NodeState.INIT + ) + logger.info(msg) + + break + except (ValueError, RuntimeError, TimeoutError) as exc: + # If we heuristically inferred the value of is_host as True and our + # first attempt to instantiate the TCP store has failed, try it one + # more time with is_host set to False. As an edge case there can be + # more than one process that is part of the same rendezvous on this + # machine and only one of them will eventually host the store. + + if not is_server or cfg_is_host is not None: + raise RendezvousConnectionError( + "The connection to the C10d store has failed. See inner exception for details." + ) from exc + + return store # type: ignore[possibly-undefined] + + +def _create_file_store(params: RendezvousParameters) -> FileStore: + # If a user specifies an endpoint, we treat it as a path to a file. + if params.endpoint: + path = params.endpoint + else: + try: + # The temporary file is readable and writable only by the user of + # this process. + _, path = tempfile.mkstemp() + except OSError as exc: + raise RendezvousError( + "The file creation for C10d store has failed. See inner exception for details." + ) from exc + + try: + store = FileStore(path) + except (ValueError, RuntimeError) as exc: + raise RendezvousConnectionError( + "The connection to the C10d store has failed. See inner exception for details." + ) from exc + + return store + + +def create_backend(params: RendezvousParameters) -> tuple[C10dRendezvousBackend, Store]: + """Create a new :py:class:`C10dRendezvousBackend` from the specified parameters. + + +--------------+-----------------------------------------------------------+ + | Parameter | Description | + +==============+===========================================================+ + | store_type | The type of the C10d store. The currently supported types | + | | are "tcp" and "file" which correspond to | + | | :py:class:`torch.distributed.TCPStore` and | + | | :py:class:`torch.distributed.FileStore`, respectively. | + | | Defaults to "tcp". | + +--------------+-----------------------------------------------------------+ + | read_timeout | The read timeout, in seconds, for store operations. | + | | Defaults to 60 seconds. | + | | | + | | Note this only applies to | + | | :py:class:`torch.distributed.TCPStore`. It is not relevant| + | | to :py:class:`torch.distributed.FileStore` which does not | + | | take in timeout as a parameter. | + +--------------+-----------------------------------------------------------+ + | is_host | A boolean value indicating whether this backend instance | + | | will host the C10d store. If not specified it will be | + | | inferred heuristically by matching the hostname or the IP | + | | address of this machine against the specified rendezvous | + | | endpoint. Defaults to ``None``. | + | | | + | | Note that this configuration option only applies to | + | | :py:class:`torch.distributed.TCPStore`. In normal | + | | circumstances you can safely skip it; the only time when | + | | it is needed is if its value cannot be correctly | + | | determined (e.g. the rendezvous endpoint has a CNAME as | + | | the hostname or does not match the FQDN of the machine). | + +--------------+-----------------------------------------------------------+ + """ + # As of today we only support TCPStore and FileStore. Other store types do + # not have the required functionality (e.g. compare_set) yet. + store_type = params.get("store_type", "tcp").strip().lower() + store: Store + + try: + if store_type == "file": + store = _create_file_store(params) + elif store_type == "tcp": + store = _create_tcp_store(params) + else: + raise ValueError( + "Invalid store type given. Currently only supports file and tcp." + ) + + backend = C10dRendezvousBackend(store, params.run_id) + + except Exception as e: + construct_and_record_rdzv_event( + message=f"{type(e).__name__}: {str(e)}", + run_id=params.run_id, + node_state=NodeState.FAILED, + ) + raise + + return backend, store diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py new file mode 100644 index 0000000000000000000000000000000000000000..84adeea95573121e69f11c6faa52fe6601f271c7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py @@ -0,0 +1,1453 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +import logging +import os +import pickle +import socket +import threading +import time +import weakref +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from enum import Enum +from typing import Any + +import torch.distributed as dist +from torch.distributed import Store +from torch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState + +from .api import ( + RendezvousClosedError, + RendezvousError, + RendezvousGracefulExitError, + RendezvousHandler, + RendezvousInfo, + RendezvousParameters, + RendezvousStateError, + RendezvousStoreInfo, + RendezvousTimeoutError, +) +from .utils import _delay, _PeriodicTimer + + +__all__ = [ + "RendezvousBackend", + "RendezvousTimeout", + "RendezvousSettings", + "DynamicRendezvousHandler", + "create_handler", +] + +logger = logging.getLogger(__name__) + + +def get_method_name(depth=2): + if len(inspect.stack()) > depth: + return inspect.stack()[depth].function + return "no_method_name" + + +Token = Any +"""Represent an opaque fencing token used by the rendezvous backend.""" + + +class RendezvousBackend(ABC): + """Represent a backend that holds the rendezvous state.""" + + @property + @abstractmethod + def name(self) -> str: + """Get the name of the backend.""" + + @abstractmethod + def get_state(self) -> tuple[bytes, Token] | None: + """Get the rendezvous state. + + Returns: + A tuple of the encoded rendezvous state and its fencing token or + ``None`` if no state is found in the backend. + + Raises: + RendezvousConnectionError: + The connection to the backend has failed. + RendezvousStateError: + The rendezvous state is corrupt. + """ + + @abstractmethod + def set_state( + self, state: bytes, token: Token | None = None + ) -> tuple[bytes, Token, bool] | None: + """Set the rendezvous state. + + The new rendezvous state is set conditionally: + + - If the specified ``token`` matches the fencing token stored in the + backend, the state will be updated. The new state will be returned + to the caller along with its fencing token. + - If the specified ``token`` does not match the fencing token stored + in the backend, the state won't be updated; instead the existing + state along with its fencing token will be returned to the caller. + - If the specified ``token`` is ``None``, the new state will be set + only if there is no existing state in the backend. Either the new + state or the existing state along with its fencing token will be + returned to the caller. + + Args: + state: + The encoded rendezvous state. + token: + An optional fencing token that was retrieved by a previous call + to :py:meth:`get_state` or ``set_state()``. + + Returns: + A tuple of the serialized rendezvous state, its fencing token, and + a boolean value indicating whether our set attempt succeeded. + + Raises: + RendezvousConnectionError: + The connection to the backend has failed. + RendezvousStateError: + The rendezvous state is corrupt. + """ + + +class RendezvousTimeout: + """Hold the timeout configuration of a rendezvous. + + Args: + join: + The time within which the rendezvous is expected to complete. + last_call: + An additional wait amount before completing the rendezvous once the + rendezvous has the minimum number of required participants. + close: + The time within which the rendezvous is expected to close after a + call to :py:meth:`RendezvousHandler.set_closed` or + :py:meth:`RendezvousHandler.shutdown`. + heartbeat: + The time within which a keep-alive heartbeat is expected to + complete. + """ + + _ZERO = timedelta(0) + + _DEFAULT_TIMEOUTS = { + "join": timedelta(seconds=600), + "last_call": timedelta(seconds=30), + "close": timedelta(seconds=30), + "heartbeat": timedelta(seconds=5), + } + + _join: timedelta + _last_call: timedelta + _close: timedelta + _heartbeat: timedelta + + def __init__( + self, + join: timedelta | None = None, + last_call: timedelta | None = None, + close: timedelta | None = None, + heartbeat: timedelta | None = None, + ) -> None: + self._set_timeouts( + join=join, last_call=last_call, close=close, heartbeat=heartbeat + ) + + @property + def join(self) -> timedelta: + """Get the join timeout.""" + return self._join + + @property + def last_call(self) -> timedelta: + """Get the last call timeout.""" + return self._last_call + + @property + def close(self) -> timedelta: + """Get the close timeout.""" + return self._close + + @property + def heartbeat(self) -> timedelta: + """Get the keep-alive heartbeat timeout.""" + return self._heartbeat + + def _set_timeouts(self, **timeouts: timedelta | None): + for name, timeout in timeouts.items(): + if timeout is None: + timeout = self._DEFAULT_TIMEOUTS[name] + if timeout <= self._ZERO: + raise ValueError(f"The {name} timeout ({timeout}) must be positive.") + setattr(self, "_" + name, timeout) + + +@dataclass(repr=False, eq=False, frozen=True) +class RendezvousSettings: + """Hold the settings of the rendezvous. + + Attributes: + run_id: + The run id of the rendezvous. + min_nodes: + The minimum number of nodes to admit to the rendezvous. + max_nodes: + The maximum number of nodes to admit to the rendezvous. + timeout: + The timeout configuration of the rendezvous. + keep_alive_interval: + The amount of time a node waits before sending a heartbeat to keep + it alive in the rendezvous. + keep_alive_max_attempt: + The maximum number of failed heartbeat attempts after which a node + is considered dead. + """ + + run_id: str + min_nodes: int + max_nodes: int + timeout: RendezvousTimeout + keep_alive_interval: timedelta + keep_alive_max_attempt: int + + +@dataclass(eq=True, order=True, frozen=True) +class _NodeDesc: + """Describe a node in the rendezvous. + + Attributes: + addr: + The FQDN of the node or user specified local node address. + pid: + The id of the process in which the rendezvous handler runs. + local_id: + A process-wide unique id. + """ + + addr: str + pid: int + local_id: int + + def __repr__(self) -> str: + return f"{self.addr}_{self.pid}_{self.local_id}" + + +class _NodeDescGenerator: + """Generate node descriptors. + + A node descriptor is a combination of an FQDN, a process id, and an auto- + incremented integer that uniquely identifies a node in the rendezvous. + """ + + _lock: threading.Lock + _local_id: int + + def __init__(self) -> None: + self._lock = threading.Lock() + + # An integer that is incremented with each call to generate(). + self._local_id = 0 + + def generate(self, local_addr: str | None = None) -> _NodeDesc: + # This method can be called by multiple threads concurrently; therefore, + # we must increment the integer atomically. + with self._lock: + local_id = self._local_id + + self._local_id += 1 + + return _NodeDesc(local_addr or socket.getfqdn(), os.getpid(), local_id) + + +class _RendezvousState: + """Hold the state of a rendezvous. + + Attributes: + round: + The current round of the rendezvous. + complete: + A boolean value indicating whether the current round of the + rendezvous is complete. + deadline: + The time at which the current round of the rendezvous will be + considered complete if it is still waiting for nodes to join. + closed: + A boolean value indicating whether the rendezvous is closed. + participants: + A dictionary of the participants and their corresponding ranks. + wait_list: + A set of nodes that are waiting to participate in the next round of + the rendezvous. + redundancy_list: + A set of nodes that are redundant in the current round and can join + the next rendezvous without triggering re-rendezvous. + last_heartbeats: + A dictionary containing each node's last heartbeat time. + """ + + round: int + complete: bool + deadline: datetime | None + closed: bool + participants: dict[_NodeDesc, int] + wait_list: set[_NodeDesc] + redundancy_list: set[_NodeDesc] + last_heartbeats: dict[_NodeDesc, datetime] + + def __init__(self) -> None: + self.round = 0 + self.complete = False + self.deadline = None + self.closed = False + self.participants = {} + self.wait_list = set() + self.redundancy_list = set() + self.last_heartbeats = {} + + +def _remove_participant_epilogue( + state: _RendezvousState, settings: RendezvousSettings +) -> None: + if state.complete: + # If we do not have any participants left, move to the next round. + if not state.participants: + msg = "No participants left in the rendezvous, marking rendezvous as incomplete" + logger.debug(msg) + state.complete = False + + state.round += 1 + else: + if len(state.participants) < settings.min_nodes: + msg = ( + f"Number of participants {len(state.participants)}) less than" + f"min_nodes {settings.min_nodes}, clearning deadline in state" + ) + logger.debug(msg) + state.deadline = None + + +class _RendezvousStateHolder(ABC): + """Hold the shared rendezvous state synced with other nodes.""" + + @property + @abstractmethod + def state(self) -> _RendezvousState: + """Get the local state.""" + + @abstractmethod + def sync(self) -> bool | None: + """Read or writes the latest state. + + Returns: + A boolean value indicating whether the local state, in case marked + as dirty, was successfully synced with other nodes. + """ + + @abstractmethod + def mark_dirty(self) -> None: + """Mark the local state as dirty.""" + + +class _BackendRendezvousStateHolder(_RendezvousStateHolder): + """Hold the rendezvous state synced with other nodes via a backend. + + Args: + backend: + The rendezvous backend to use. + settings: + The rendezvous settings. + cache_duration: + The amount of time, in seconds, to cache the last rendezvous state + before requesting it from the backend again. + """ + + _backend: RendezvousBackend + _state: _RendezvousState + _settings: RendezvousSettings + _cache_duration: int + _token: Token + _dirty: bool + _last_sync_time: float + _dead_nodes: list[_NodeDesc] + + def __init__( + self, + backend: RendezvousBackend, + settings: RendezvousSettings, + cache_duration: int = 1, + ) -> None: + self._backend = backend + self._state = _RendezvousState() + self._settings = settings + self._cache_duration = cache_duration + self._token = None + self._dirty = False + self._last_sync_time = -1 + self._dead_nodes = [] + + def _record(self, message: str, node_state: NodeState = NodeState.RUNNING): + construct_and_record_rdzv_event( + name=f"{self.__class__.__name__}.{get_method_name()}", + run_id=self._settings.run_id, + message=message, + node_state=node_state, + ) + + @property + def state(self) -> _RendezvousState: + """See base class.""" + return self._state + + def sync(self) -> bool | None: + """See base class.""" + state_bits: bytes | None = None + + token = None + + has_set: bool | None + + if self._dirty: + has_set = False + + state_bits = pickle.dumps(self._state) + + set_response = self._backend.set_state(state_bits, self._token) + if set_response is not None: + state_bits, token, has_set = set_response + else: + has_set = None + + if self._cache_duration > 0: + # Avoid overloading the backend if we are asked to retrieve the + # state repeatedly. Try to serve the cached state. + if self._last_sync_time >= max( + time.monotonic() - self._cache_duration, 0 + ): + return None + + get_response = self._backend.get_state() + if get_response is not None: + state_bits, token = get_response + + if state_bits is not None: + try: + self._state = pickle.loads(state_bits) + except pickle.PickleError as exc: + raise RendezvousStateError( + "The rendezvous state is corrupt. See inner exception for details." + ) from exc + else: + self._state = _RendezvousState() + + if has_set and self._dead_nodes and logger.isEnabledFor(logging.DEBUG): + node_list = ", ".join(f"'{dead_node}'" for dead_node in self._dead_nodes) + + msg = ( + f"As part of the sync operation the node(s) {node_list} have been removed from the " + f"rendezvous '{self._settings.run_id}' since they had no heartbeat." + ) + self._record(message=msg) + logger.debug(msg) + + self._token = token + + self._dirty = False + + self._last_sync_time = time.monotonic() + + self._sanitize() + + return has_set + + def _sanitize(self) -> None: + state = self._state + + expire_time = datetime.now(timezone.utc) - ( + self._settings.keep_alive_interval * self._settings.keep_alive_max_attempt + ) + + # Filter out the dead nodes. + self._dead_nodes = [ + node + for node, last_heartbeat in state.last_heartbeats.items() + if last_heartbeat < expire_time + ] + + participant_removed = False + + for dead_node in self._dead_nodes: + msg = f"Detected dead node '{dead_node}', removing it from the rendezvous" + logger.debug(msg) + del state.last_heartbeats[dead_node] + + try: + del state.participants[dead_node] + + participant_removed = True + except KeyError: + pass + + try: + state.wait_list.remove(dead_node) + except KeyError: + pass + + try: + state.redundancy_list.remove(dead_node) + except KeyError: + pass + + if participant_removed: + # Common epilogue shared with the _remove_from_participants() + # function of _DistributedRendezvousOpExecutor. + _remove_participant_epilogue(state, self._settings) + + def mark_dirty(self) -> None: + """See base class. + + If the local rendezvous state is dirty, the next sync call will try to + write the changes back to the backend. However this attempt might fail + if another node, which had the same state, also made changes and wrote + them before us. + """ + self._dirty = True + + +class _Action(Enum): + """Specifies the possible actions based on the state of the rendezvous.""" + + KEEP_ALIVE = 1 + ADD_TO_PARTICIPANTS = 2 + ADD_TO_WAIT_LIST = 3 + ADD_TO_REDUNDANCY_LIST = 4 + REMOVE_FROM_PARTICIPANTS = 5 + REMOVE_FROM_WAIT_LIST = 6 + REMOVE_FROM_REDUNDANCY_LIST = 7 + MARK_RENDEZVOUS_COMPLETE = 8 + MARK_RENDEZVOUS_CLOSED = 9 + SYNC = 10 + ERROR_CLOSED = 11 + ERROR_TIMEOUT = 12 + FINISH = 13 + + +class _RendezvousContext: + """Holds the context of the rendezvous. + + Attributes: + node: + The node descriptor associated with the current rendezvous handler + instance. + state: + The current state of the rendezvous. + settings: + The rendezvous settings. + """ + + node: _NodeDesc + state: _RendezvousState + settings: RendezvousSettings + + def __init__( + self, node: _NodeDesc, state: _RendezvousState, settings: RendezvousSettings + ) -> None: + self.node = node + self.state = state + self.settings = settings + + +class _RendezvousOpExecutor(ABC): + """Execute rendezvous operations.""" + + @abstractmethod + def run( + self, + state_handler: Callable[[_RendezvousContext, float], _Action], + deadline: float, + update_deadline: Callable[[timedelta], float] | None = None, + ) -> None: + """Execute a rendezvous operation. + + An operation is run inside a state machine and is expected to transition + the rendezvous from one state to another. + + Args: + state_handler: + A callable that is expected to return the next state transition + action based on the current state of the rendezvous. + deadline: + The time, in seconds, at which the operation will be considered + timed-out. + update_deadline: + Function to generate a new operation deadline if the current + node may participate in the next rendezvous. + """ + + +class _DistributedRendezvousOpExecutor(_RendezvousOpExecutor): + """Execute rendezvous operations using a shared state. + + Args: + node: + The node descriptor associated with the current rendezvous handler + instance. + state_holder: + The ``RendezvousStateHolder`` to use to sync the rendezvous state + with other nodes. + settings: + The rendezvous settings. + """ + + _node: _NodeDesc + _state: _RendezvousState + _state_holder: _RendezvousStateHolder + _settings: RendezvousSettings + + def __init__( + self, + node: _NodeDesc, + state_holder: _RendezvousStateHolder, + settings: RendezvousSettings, + ) -> None: + self._node = node + self._state_holder = state_holder + self._settings = settings + + def _record(self, message: str, node_state: NodeState = NodeState.RUNNING) -> None: + construct_and_record_rdzv_event( + name=f"{self.__class__.__name__}.{get_method_name()}", + run_id=self._settings.run_id, + message=message, + node_state=node_state, + hostname=self._node.addr, + pid=self._node.pid, + local_id=self._node.local_id, + ) + + def run( + self, + state_handler: Callable[[_RendezvousContext, float], _Action], + deadline: float, + update_deadline: Callable[[timedelta], float] | None = None, + ) -> None: + """See base class.""" + action = None + while action != _Action.FINISH: + # Reads or writes the latest rendezvous state shared by all nodes in + # the rendezvous. Note that our local changes might get overridden + # by another node if that node synced its changes before us. + has_set = self._state_holder.sync() + if has_set is not None: + if has_set: + msg = ( + f"The node '{self._node}' has successfully synced its local changes with " + f"other nodes in the rendezvous '{self._settings.run_id}'." + ) + else: + msg = ( + f"The node '{self._node}' has a stale state and failed to sync its local " + f"changes with other nodes in the rendezvous '{self._settings.run_id}'." + ) + + self._record(message=msg) + logger.debug(msg) + + self._state = self._state_holder.state + + ctx = _RendezvousContext(self._node, self._state, self._settings) + + # Determine the next action to take based on the current state of + # the rendezvous. + action = state_handler(ctx, deadline) + + if action == _Action.FINISH: + continue + + if action == _Action.ERROR_CLOSED: + raise RendezvousClosedError + + if action == _Action.ERROR_TIMEOUT: + raise RendezvousTimeoutError + + if action == _Action.SYNC: + # Delay the execution by one second to avoid overloading the + # backend if we are asked to poll for state changes. + _delay(seconds=1) + else: + if action == _Action.KEEP_ALIVE: + self._keep_alive() + elif action == _Action.ADD_TO_PARTICIPANTS: + self._add_to_participants() + elif action == _Action.ADD_TO_WAIT_LIST: + self._add_to_wait_list() + elif action == _Action.ADD_TO_REDUNDANCY_LIST: + self._add_to_redundancy_list() + elif action == _Action.REMOVE_FROM_PARTICIPANTS: + self._remove_from_participants() + elif action == _Action.REMOVE_FROM_WAIT_LIST: + self._remove_from_wait_list() + elif action == _Action.REMOVE_FROM_REDUNDANCY_LIST: + self._remove_from_redundancy_list() + # update deadline since the node may participate in rendezvous process + if update_deadline: + deadline = update_deadline(self._settings.timeout.join) + elif action == _Action.MARK_RENDEZVOUS_COMPLETE: + self._mark_rendezvous_complete() + elif action == _Action.MARK_RENDEZVOUS_CLOSED: + self._mark_rendezvous_closed() + + # Attempt to sync our changes back to other nodes. + self._state_holder.mark_dirty() + + def _keep_alive(self) -> None: + msg = ( + f"The node '{self._node}' updated its keep-alive heartbeat time for the rendezvous " + f"'{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + self._state.last_heartbeats[self._node] = datetime.now(timezone.utc) + + def _add_to_participants(self) -> None: + msg = ( + f"The node '{self._node}' added itself to the participants of round " + f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + state = self._state + + try: + state.wait_list.remove(self._node) + except KeyError: + pass + + # The ranks of the participants will be set once the rendezvous is + # complete. + state.participants[self._node] = 0 + + self._keep_alive() + + if len(state.participants) == self._settings.min_nodes: + state.deadline = ( + datetime.now(timezone.utc) + self._settings.timeout.last_call + ) + + if len(state.participants) == self._settings.max_nodes: + self._mark_rendezvous_complete() + + def _add_to_wait_list(self) -> None: + msg = ( + f"The node '{self._node}' added itself to the wait list of round " + f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + if self._node in self._state.redundancy_list: + self._state.redundancy_list.remove(self._node) + self._state.wait_list.add(self._node) + + self._keep_alive() + + def _add_to_redundancy_list(self) -> None: + msg = ( + f"The node '{self._node}' added itself to the redundancy list of round " + f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + self._state.redundancy_list.add(self._node) + + self._keep_alive() + + def _remove_from_participants(self) -> None: + msg = ( + f"The node '{self._node}' removed itself from the participants of round " + f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + state = self._state + + del state.participants[self._node] + + del state.last_heartbeats[self._node] + + # Common epilogue shared with the sanitizer() function of + # _BackendRendezvousStateHolder. + _remove_participant_epilogue(state, self._settings) + + def _remove_from_wait_list(self) -> None: + msg = ( + f"The node '{self._node}' removed itself from the wait list of round " + f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + self._state.wait_list.remove(self._node) + + del self._state.last_heartbeats[self._node] + + def _remove_from_redundancy_list(self) -> None: + msg = ( + f"The node '{self._node}' removed itself from the redundant list of round " + f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + self._state.redundancy_list.remove(self._node) + + del self._state.last_heartbeats[self._node] + + def _mark_rendezvous_complete(self) -> None: + msg = ( + f"The node '{self._node}' marked round {self._state.round} of the rendezvous " + f"'{self._settings.run_id}' as complete. Pending sync." + ) + self._record(message=msg, node_state=NodeState.SUCCEEDED) + logger.debug(msg) + + state = self._state + + state.complete = True + state.deadline = None + + # Assign the ranks. + for rank, node in enumerate(sorted(state.participants)): + state.participants[node] = rank + + def _mark_rendezvous_closed(self) -> None: + msg = ( + f"The node '{self._node}' marked the rendezvous '{self._settings.run_id}' as closed. " + "Pending sync." + ) + self._record(message=msg, node_state=NodeState.SUCCEEDED) + logger.debug(msg) + + self._state.closed = True + + +def _should_keep_alive(ctx: _RendezvousContext) -> bool: + """Determine whether a keep-alive heartbeat should be sent.""" + try: + last_heartbeat = ctx.state.last_heartbeats[ctx.node] + except KeyError: + return False + + return ( + last_heartbeat <= datetime.now(timezone.utc) - ctx.settings.keep_alive_interval + ) + + +class _RendezvousExitOp: + """Represent a rendezvous exit operation.""" + + def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: + if ctx.node in ctx.state.participants: + if time.monotonic() > deadline: + return _Action.ERROR_TIMEOUT + return _Action.REMOVE_FROM_PARTICIPANTS + return _Action.FINISH + + +class _RendezvousJoinOp: + """Represent a rendezvous join operation.""" + + def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: + state = ctx.state + + # A closed rendezvous means that it no longer accepts new nodes. + if state.closed: + if ctx.node in state.redundancy_list: + msg = f"The rendezvous '{ctx.settings.run_id}' is closed, terminating pending rendezvous." + raise RendezvousGracefulExitError(msg) + return _Action.ERROR_CLOSED + + if ctx.node in state.redundancy_list: + msg = f"The node {ctx.node} is in redundancy list" + logger.debug(msg) + # don't apply the timeout logic here, since we want to allow the node to rejoin + if len(state.participants) == ctx.settings.max_nodes: + if _should_keep_alive(ctx): + return _Action.KEEP_ALIVE + else: + return _Action.SYNC + else: + # transition to waiting state that will respect timeouts. + msg = f"The node {ctx.node} is removed from redundancy list" + logger.debug(msg) + return _Action.REMOVE_FROM_REDUNDANCY_LIST + + is_participant = ctx.node in state.participants + + # If we are part of the rendezvous and it is already complete there is + # no further action to take. + if state.complete and is_participant: + return _Action.FINISH + + now = time.monotonic() + if now > deadline: + rollback_period = 5 # 5 seconds + + # If we still have time to rollback (a short period on top of the + # operation deadline), try to remove ourself from the rendezvous. + # It is okay if we can't though as our keep-alive will eventually + # expire. + if now <= deadline + rollback_period: + # If we are part of the rendezvous, it means we couldn't find + # enough participants to complete it on time. + if is_participant: + return _Action.REMOVE_FROM_PARTICIPANTS + # If we are in the wait list, it means we couldn't wait till the + # next round of the rendezvous. + if ctx.node in state.wait_list: + return _Action.REMOVE_FROM_WAIT_LIST + return _Action.ERROR_TIMEOUT + + if state.complete: + # If we are here, it means we are not part of the rendezvous. In + # case the rendezvous has capacity for additional participants add + # ourself to the wait list for the next round. + if len(state.participants) < ctx.settings.max_nodes: + if ctx.node not in state.wait_list: + return _Action.ADD_TO_WAIT_LIST + elif len(state.participants) >= ctx.settings.max_nodes: + if ( + ctx.node not in state.redundancy_list + and ctx.node not in state.wait_list + ): + return _Action.ADD_TO_REDUNDANCY_LIST + elif is_participant: + # If the rendezvous has enough number of participants including us, + # check whether we have passed the rendezvous deadline. If yes, + # complete it. + if ( + len(state.participants) >= ctx.settings.min_nodes + and len(state.participants) <= ctx.settings.max_nodes + and state.deadline is not None + ): + if state.deadline < datetime.now(timezone.utc): + msg = ( + f"The node '{ctx.node}' marking the rendezvous complete, " + f"quorum established within deadline" + ) + logger.debug(msg) + return _Action.MARK_RENDEZVOUS_COMPLETE + else: + msg = f"The node '{ctx.node}' can't complete rendezvous: deadline reached" + logger.debug(msg) + else: + msg = f"The node '{ctx.node}' can't complete rendezvous: not enough participants" + logger.debug(msg) + else: + # The rendezvous is not complete yet and we are not part of it. Try + # to join. + return _Action.ADD_TO_PARTICIPANTS + + if _should_keep_alive(ctx): + return _Action.KEEP_ALIVE + + # At this point either the rendezvous is not complete, but we are part + # of it, which means we have to wait for other participants to join; or + # the rendezvous is complete, but we are not part of it, which means we + # have to wait for the next round. + return _Action.SYNC + + +class _RendezvousCloseOp: + """Represent a rendezvous close operation.""" + + def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: + if ctx.state.closed: + return _Action.FINISH + if time.monotonic() > deadline: + return _Action.ERROR_TIMEOUT + return _Action.MARK_RENDEZVOUS_CLOSED + + +class _RendezvousKeepAliveOp: + """Represent a rendezvous keep-alive update operation.""" + + def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: + if _should_keep_alive(ctx): + if time.monotonic() > deadline: + return _Action.ERROR_TIMEOUT + return _Action.KEEP_ALIVE + return _Action.FINISH + + +class DynamicRendezvousHandler(RendezvousHandler): + """Represent a handler that sets up a rendezvous among a set of nodes.""" + + # Static + _node_desc_generator = _NodeDescGenerator() + + _this_node: _NodeDesc + _settings: RendezvousSettings + _backend_name: str + _store: Store + _state_holder: _RendezvousStateHolder + _op_executor: _RendezvousOpExecutor + _heartbeat_lock: threading.Lock + _keep_alive_timer: _PeriodicTimer | None + + @classmethod + def from_backend( + cls, + run_id: str, + store: Store, + backend: RendezvousBackend, + min_nodes: int, + max_nodes: int, + local_addr: str | None = None, + timeout: RendezvousTimeout | None = None, + keep_alive_interval: int = 5, + keep_alive_max_attempt: int = 3, + ): + """Create a new :py:class:`DynamicRendezvousHandler`. + + Args: + run_id: + The run id of the rendezvous. + store: + The C10d store to return as part of the rendezvous. + backend: + The backend to use to hold the rendezvous state. + min_nodes: + The minimum number of nodes to admit to the rendezvous. + max_nodes: + The maximum number of nodes to admit to the rendezvous. + local_addr: + The local node address. + timeout: + The timeout configuration of the rendezvous. + keep_alive_interval: + The amount of time a node waits before sending a heartbeat to keep + it alive in the rendezvous. + keep_alive_max_attempt: + The maximum number of failed heartbeat attempts after which a node + is considered dead. + """ + # We associate each handler instance with a unique node descriptor. + node = cls._node_desc_generator.generate(local_addr) + + settings = RendezvousSettings( + run_id, + min_nodes, + max_nodes, + timeout or RendezvousTimeout(), + keep_alive_interval=timedelta(seconds=keep_alive_interval), + keep_alive_max_attempt=keep_alive_max_attempt, + ) + + state_holder = _BackendRendezvousStateHolder(backend, settings) + + return cls(node, settings, backend.name, store, state_holder) + + def __init__( + self, + node: _NodeDesc, + settings: RendezvousSettings, + backend_name: str, + store: Store, + state_holder: _RendezvousStateHolder, + ) -> None: + if not settings.run_id: + raise ValueError("The run id must be a non-empty string.") + + if settings.min_nodes < 1: + raise ValueError( + f"The minimum number of nodes ({settings.min_nodes}) must be greater than zero." + ) + + if settings.max_nodes < settings.min_nodes: + raise ValueError( + f"The maximum number of nodes ({settings.max_nodes}) must be greater than or equal " + f"to the minimum number of nodes ({settings.min_nodes})." + ) + + self._this_node = node + + self._settings = settings + + self._backend_name = backend_name + + self._store = store + + self._state_holder = state_holder + + self._op_executor = _DistributedRendezvousOpExecutor( + self._this_node, self._state_holder, self._settings + ) + + self._heartbeat_lock = threading.Lock() + + self._keep_alive_timer = None + + # Cached shared store server reference + self._shared_tcp_store_server: dist.Store | None = None + + self._bootstrap_store_info: RendezvousStoreInfo | None = None + + def _record( + self, + message: str, + node_state: NodeState = NodeState.RUNNING, + rank: int | None = None, + ) -> None: + construct_and_record_rdzv_event( + name=f"{self.__class__.__name__}.{get_method_name()}", + run_id=self._settings.run_id, + message=message, + node_state=node_state, + hostname=self._this_node.addr, + pid=self._this_node.pid, + local_id=self._this_node.local_id, + rank=rank, + ) + + def _create_tcp_store_server(self, master_addr, master_port) -> dist.TCPStore: + return dist.TCPStore( + host_name=master_addr, + port=master_port, + is_master=True, + multi_tenant=True, + ) + + @property + def settings(self) -> RendezvousSettings: + """Get the settings of the rendezvous.""" + return self._settings + + def get_backend(self) -> str: + """See base class.""" + return self._backend_name + + @property + def use_agent_store(self) -> bool: + """See base class.""" + return os.getenv("TORCH_DISABLE_SHARE_RDZV_TCP_STORE", "0") != "1" + + def next_rendezvous(self) -> RendezvousInfo: + """See base class.""" + msg = ( + f"The node '{self._this_node}' attempts to join the next round of the rendezvous " + f"'{self._settings.run_id}'." + ) + self._record(message=msg) + logger.info(msg) + + try: + self._stop_heartbeats() + + # Delay the execution for a small random amount of time if this is our + # first run. This will slightly skew the rendezvous attempts across the + # nodes and reduce the load on the backend. + if self._state_holder.state.round == 0: + _delay(seconds=(0, 0.3)) + + exit_op = _RendezvousExitOp() + join_op = _RendezvousJoinOp() + + deadline = self._get_deadline(self._settings.timeout.join) + self._op_executor.run(exit_op, deadline) + self._op_executor.run(join_op, deadline, self._get_deadline) + + self._start_heartbeats() + + rank, world_size = self._get_world() + store = self._get_store() + + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + msg = ( + f"The node '{self._this_node}' has joined round {self._state_holder.state.round} of " + f"the rendezvous '{self._settings.run_id}' as rank {rank} in a world of size " + f"{world_size}." + ) + self._record(message=msg, rank=rank) + logger.info(msg) + + # opt-out option of TCPStore sharing + if os.getenv("TORCH_DISABLE_SHARE_RDZV_TCP_STORE", "0") == "1": + bootstrap_store_info = RendezvousStoreInfo.build( + rank, store, local_addr=self._this_node.addr + ) + return RendezvousInfo( + store, + rank, + world_size, + bootstrap_store_info, + ) + + # This will only be hit when TCPStore sharing is enabled. + if self._bootstrap_store_info is None: + # To avoid race in get_free_port because we release the port after the call, + # we want to create a TCPStore server soon afterwards. + server_port = 0 + if rank == 0: + self._shared_tcp_store_server = self._create_tcp_store_server( + self._this_node.addr, server_port + ) + server_port = self._shared_tcp_store_server.port + self._bootstrap_store_info = RendezvousStoreInfo.build( + rank, + store, + local_addr=self._this_node.addr, + server_port=server_port, # For non-0 rank, this is a no-op + ) + + assert self._bootstrap_store_info is not None + if rank == 0: + assert self._shared_tcp_store_server is not None + + return RendezvousInfo( + store, + rank, + world_size, + self._bootstrap_store_info, # type: ignore[assignment] + ) + + def is_closed(self) -> bool: + """See base class.""" + try: + with self._heartbeat_lock: + self._state_holder.sync() + + return self._state_holder.state.closed + + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + def set_closed(self) -> None: + """See base class.""" + try: + with self._heartbeat_lock: + self._close() + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + def num_nodes_waiting(self) -> int: + """See base class.""" + try: + with self._heartbeat_lock: + self._state_holder.sync() + + return len(self._state_holder.state.wait_list) + + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + def get_run_id(self) -> str: + """See base class.""" + return self._settings.run_id + + def shutdown(self) -> bool: + """See base class.""" + self._stop_heartbeats() + + try: + self._close() + + return True + except RendezvousError as ex: + msg = ( + f"The node '{self._this_node}' has failed to shutdown the rendezvous " + f"'{self._settings.run_id}' due to an error of type {type(ex).__name__}." + ) + self._record(message=msg, node_state=NodeState.FAILED) + logger.warning(msg) + + return False + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + def _close(self) -> None: + op = _RendezvousCloseOp() + + deadline = self._get_deadline(self._settings.timeout.close) + + self._op_executor.run(op, deadline) + + msg = f"The node '{self._this_node}' has closed the rendezvous '{self._settings.run_id}'." + self._record(message=msg, node_state=NodeState.SUCCEEDED) + logger.info(msg) + + @staticmethod + def _keep_alive_weak(weak_self) -> None: + self = weak_self() + if self is not None: + self._keep_alive() + + def _keep_alive(self) -> None: + with self._heartbeat_lock: + op = _RendezvousKeepAliveOp() + + deadline = self._get_deadline(self._settings.timeout.heartbeat) + + try: + self._op_executor.run(op, deadline) + + msg = ( + f"The node '{self._this_node}' has sent a keep-alive heartbeat to the rendezvous " + f"'{self._settings.run_id}'." + ) + self._record(message=msg) + logger.debug(msg) + except RendezvousError as ex: + msg = ( + f"The node '{self._this_node}' has failed to send a keep-alive heartbeat to the " + f"rendezvous '{self._settings.run_id}' due to an error of type {type(ex).__name__}." + ) + self._record(message=msg, node_state=NodeState.FAILED) + logger.warning(msg) + + def _start_heartbeats(self) -> None: + self._keep_alive_timer = _PeriodicTimer( + self._settings.keep_alive_interval, self._keep_alive_weak, weakref.ref(self) + ) + + self._keep_alive_timer.set_name( + f"RendezvousKeepAliveTimer_{self._this_node.local_id}" + ) + + self._keep_alive_timer.start() + + def _stop_heartbeats(self) -> None: + if self._keep_alive_timer is None: + return + + self._keep_alive_timer.cancel() + + def _get_world(self) -> tuple[int, int]: + state = self._state_holder.state + + return state.participants[self._this_node], len(state.participants) + + def _wrap_store(self, store: Store) -> Store: + key_prefix = ( + f"torch.rendezvous.{self._settings.run_id}.{self._state_holder.state.round}" + ) + + return dist.PrefixStore(key_prefix, store) + + def _get_store(self) -> Store: + return self._wrap_store(self._store) + + def _get_deadline(self, timeout: timedelta) -> float: + return time.monotonic() + timeout.total_seconds() + + +def _get_timeout(params: RendezvousParameters, key: str) -> timedelta | None: + timeout = params.get_as_int(key + "_timeout") + if timeout is None: + return None + return timedelta(seconds=timeout) + + +def create_handler( + store: Store, backend: RendezvousBackend, params: RendezvousParameters +) -> DynamicRendezvousHandler: + """Create a new :py:class:`DynamicRendezvousHandler` from the specified parameters. + + Args: + store: + The C10d store to return as part of the rendezvous. + backend: + The backend to use to hold the rendezvous state. + + +-------------------+------------------------------------------------------+ + | Parameter | Description | + +===================+======================================================+ + | join_timeout | The total time, in seconds, within which the | + | | rendezvous is expected to complete. Defaults to 600 | + | | seconds. | + +-------------------+------------------------------------------------------+ + | last_call_timeout | An additional wait amount, in seconds, before | + | | completing the rendezvous once the minimum number of | + | | nodes has been reached. Defaults to 30 seconds. | + +-------------------+------------------------------------------------------+ + | close_timeout | The time, in seconds, within which the rendezvous is | + | | expected to close after a call to | + | | :py:meth:`RendezvousHandler.set_closed` or | + | | :py:meth:`RendezvousHandler.shutdown`. Defaults to | + | | 30 seconds. | + +-------------------+------------------------------------------------------+ + | heartbeat | The time, in seconds, within which a keep-alive | + | | heartbeat is expected to complete | + +-------------------+------------------------------------------------------+ + """ + try: + timeout = RendezvousTimeout( + _get_timeout(params, "join"), + _get_timeout(params, "last_call"), + _get_timeout(params, "close"), + _get_timeout(params, "heartbeat"), + ) + keep_alive_interval = params.get_as_int("keep_alive_interval", 5) + if keep_alive_interval is None: + raise TypeError( + "You passed 'keep_alive_interval=None' as a rendezvous configuration option" + ) + keep_alive_max_attempt = params.get_as_int("keep_alive_max_attempt", 3) + if keep_alive_max_attempt is None: + raise TypeError( + "You passed 'keep_alive_max_attempt=None' as a rendezvous configuration option" + ) + + return DynamicRendezvousHandler.from_backend( + params.run_id, + store, + backend, + params.min_nodes, + params.max_nodes, + params.local_addr, + timeout, + keep_alive_interval=keep_alive_interval, + keep_alive_max_attempt=keep_alive_max_attempt, + ) + except Exception as e: + construct_and_record_rdzv_event( + message=f"{type(e).__name__}: {str(e)}", + run_id=params.run_id, + node_state=NodeState.FAILED, + ) + raise diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous.py new file mode 100644 index 0000000000000000000000000000000000000000..93a7073bed87a33a7f2ba0dfb64c7daa57b9d55f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous.py @@ -0,0 +1,1080 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import logging +import sys +import threading +import time + + +try: + import etcd # type: ignore[import] +except ModuleNotFoundError: + from . import _etcd_stub as etcd + +from torch.distributed.elastic.rendezvous import ( + RendezvousClosedError, + RendezvousError, + RendezvousHandler, + RendezvousInfo, + RendezvousParameters, + RendezvousStoreInfo, + RendezvousTimeoutError, +) + +from .etcd_store import cas_delay, EtcdStore +from .utils import parse_rendezvous_endpoint + + +__all__ = [ + "EtcdRendezvousRetryableFailure", + "EtcdRendezvousRetryImmediately", + "EtcdRendezvousHandler", + "EtcdRendezvous", + "create_rdzv_handler", +] + +_log_fmt = logging.Formatter("%(levelname)s %(asctime)s %(message)s") +_log_handler = logging.StreamHandler(sys.stderr) +_log_handler.setFormatter(_log_fmt) + +logger = logging.getLogger(__name__) +logger.propagate = False +logger.setLevel(logging.INFO) +logger.addHandler(_log_handler) + + +# Retryable failure exception means the we were too late to make +# a desired state transition (e.g. because of a race condition), +# and should now restart from the beginning. +# A small delay is recommended to avoid spamming Etcd. +class EtcdRendezvousRetryableFailure(Exception): + pass + + +# Similar to retryable failure, but the new state we observed suggests we +# can re-try immediately, i.e. without a need for "safety delay". +class EtcdRendezvousRetryImmediately(Exception): + pass + + +# Default timeout for the rendezvous. +_DEFAULT_TIMEOUT: int = 600 # 10 minutes + +# Additional waiting time after reaching the minimum number of nodes +# in case the rendezvous is elastic (min != max). +_DEFAULT_LAST_CALL_TIMEOUT: int = 30 # 30 seconds + +# Various constants used internally in EtcdRendezvous +CONST_ETCD_SETUP_TTL = 5 +CONST_ETCD_FROZEN_TTL = 10 +CONST_ETCD_JOINABLE_EPHEMERAL_TTL = 10 + +# Ephemeral node TTL for worker's keep-alive key: +CONST_WORKER_KEEPALIVE_TTL = 10 + +# TTL for the ephemeral run_id-specific directory. All rendezvous state data +# for a specific run_id (job instance) is contained within directory. +# Its only role is to clean-up rendezvous data from old runs (for the case when +# etcd server is persistent), and has no affect on correctness, but should be +# larger than any timeouts that a worker process is expected to survive: +CONST_RUNID_SUBROOT_TTL = 7200 # 2 hours + + +class EtcdRendezvousHandler(RendezvousHandler): + """ + Implements a + :py:class:`torch.distributed.elastic.rendezvous.RendezvousHandler` interface + backed by + :py:class:`torch.distributed.elastic.rendezvous.etcd_rendezvous.EtcdRendezvous`. + ``EtcdRendezvousHandler`` uses a URL to configure the type of rendezvous to + use and to pass implementation specific configurations to the rendezvous + module. The basic etcd rendezvous configuration URL looks like the following + :: + + etcd://:/?min_workers=&max_workers= # noqa: W605 + + -- example -- + + etcd://localhost:2379/1234?min_workers=1&max_workers=3 + + The URL above is interpreted as follows: + + 1. Use the rendezvous handler that is registered with the ``etcd`` + scheme + 2. The ``etcd`` endpoint to use is ``localhost:2379`` + 3. ``job_id == 1234`` is used as the prefix in etcd (this allows one to + share a common etcd server for multiple jobs so long as the + ``job_ids`` are guaranteed to be unique). Note that the job id can be + any string (e.g. does not need to be a number) as long as it is + unique. + 4. ``min_workers=1`` and ``max_workers=3`` specifies a range for + membership size - Torch Distributed Elastic starts running the job as + long as the cluster size is greater than or equal to ``min_workers`` + and admits up to ``max_workers`` into the cluster. + + Below are a full list of the parameters that can be passed to etcd + rendezvous: + + +--------------------------------------------+--------------------------+ + | Parameter | Description | + +============================================+==========================+ + | min_workers | minimum number of | + | | workers for the | + | | rendezvous to be valid | + +--------------------------------------------+--------------------------+ + | max_workers | maximum number of | + | | workers to admit | + +--------------------------------------------+--------------------------+ + | timeout | total timeout within | + | | which next_rendezvous is | + | | expected to succeed | + | | (default 600s) | + +--------------------------------------------+--------------------------+ + | last_call_timeout | additional wait amount | + | | ("last call") after min | + | | number of workers has | + | | been reached (defaults | + | | to 30s) | + +--------------------------------------------+--------------------------+ + | etcd_prefix | path prefix (from etcd | + | | root), inside which all | + | | etcd nodes will be | + | | created (defaults to | + | | ``/torchelastic/p2p``) | + +--------------------------------------------+--------------------------+ + """ + + def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: str | None): + """ + Args: + rdzv_impl: the implementation of the rendezvous + local_addr: the local address of the current node + """ + + self._rdzv_impl = rdzv_impl + self._local_addr = local_addr + + def __del__(self): + # TODO: look into using weakref here instead. + del self._rdzv_impl + + def get_backend(self) -> str: + return "etcd" + + def next_rendezvous(self): + rdzv_version, rank, world_size = self._rdzv_impl.rendezvous_barrier() + + logger.info("Creating EtcdStore as the c10d::Store implementation") + store = self._rdzv_impl.setup_kv_store(rdzv_version) + + bootstrap_store_info = RendezvousStoreInfo.build( + rank, store, local_addr=self._local_addr + ) + return RendezvousInfo(store, rank, world_size, bootstrap_store_info) + + def is_closed(self): + try: + _, state = self._rdzv_impl.get_rdzv_state() + return state["status"] == "closed" + except etcd.EtcdKeyNotFound: + # No rendezvous state, so it cannot be closed. + return False + + def set_closed(self): + self._rdzv_impl.set_closed() + + def num_nodes_waiting(self): + try: + _, state = self._rdzv_impl.get_rdzv_state() + if state["status"] == "final": + return state["num_workers_waiting"] + except etcd.EtcdKeyNotFound: + pass + return 0 + + def get_run_id(self) -> str: + return self._rdzv_impl._run_id + + def shutdown(self) -> bool: + try: + self.set_closed() + return True + except BaseException: # noqa: B036 + logger.warning("Shutdown failed", exc_info=True) + return False + + +# TODO: we should probably handle a few additional errors, +# like EtcdLeaderElectionInProgress and EtcdWatcherCleared. These are +# only relevant for multi-node Etcd ensemble. A simple retry would work, +# but is verbose to add everywhere. Consider wrapping the client calls +# into auto-retry for these errors? +# +class EtcdRendezvous: + """A rendezvous implementation that uses `etcd `__ as the backend store.""" + + def __init__( + self, + client, + prefix, + run_id, + num_min_workers, + num_max_workers, + timeout, + last_call_timeout, + ): + self.client = client + logger.info("Etcd machines: %s", self.client.machines) + + self._prefix = prefix + self._run_id = run_id + self._num_min_workers = num_min_workers + self._num_max_workers = num_max_workers + self._timeout = timeout + self._last_call_timeout = last_call_timeout + + # For cleaning up TTL refresher threads (for ephemeral keys) + self._lease_run_id_stop = None + self._lease_this_rank_stop = None + + if not self._prefix.endswith("/"): + self._prefix += "/" + + # Setup a permanent prefix dir, if didn't exist + if self._prefix != "/": + self.create_path_if_not_exists(self._prefix) + + # Lease a "sub-root" node specific to this job instance (run_id) + self.create_path_if_not_exists(self.get_path(""), ttl=CONST_RUNID_SUBROOT_TTL) + self._lease_run_id_stop = self.setup_lease_renewal( + self.get_path(""), ttl=CONST_RUNID_SUBROOT_TTL + ) + + # Subdir for all rendezvous work + self.create_path_if_not_exists(self.get_path("/rdzv")) + + # Create a rendezvous version counter, if doesn't exist + try: + self.client.write( + key=self.get_path("/rdzv/version_counter"), value="0", prevExist=False + ) + except etcd.EtcdAlreadyExist: + pass + + def __del__(self): + # TODO: look into using weakref here instead. + if self._lease_run_id_stop is not None: + self._lease_run_id_stop.set() + + if self._lease_this_rank_stop is not None: + self._lease_this_rank_stop.set() + + def rendezvous_barrier(self): + """ + Main entry point for next rendezvous. + + This method is blocking until rendezvous succeeds or a timeout occurs. + + Returns: + ``(rdzv_version, rank, world_size)`` + + Raises: + RendezvousTimeoutError - timeout waiting for rendezvous + RendezvousClosedError - rendezvous is or was closed while waiting + RendezvousError - other persistent errors that + render the rendezvous non-retryable + """ + self._rendezvous_deadline = time.time() + self._timeout + while True: + if time.time() > self._rendezvous_deadline: + raise RendezvousTimeoutError + + logger.info("Attempting to join next rendezvous") + try: + # Dis-own our lease in the previous rendezvous, if exists + if self._lease_this_rank_stop is not None: + self._lease_this_rank_stop.set() + + return self.init_phase() + + except EtcdRendezvousRetryImmediately: + # The type of failure suggests we can retry without delay + pass + + except EtcdRendezvousRetryableFailure: + # In case of retryable failure, wait a small delay + # to avoid spamming etcd + time.sleep(1) + + except RendezvousTimeoutError: + logger.info("Rendezvous timeout occurred in EtcdRendezvousHandler") + raise + + except RendezvousClosedError: + logger.info( + "Rendezvous for run_id=%s was observed to be closed", self._run_id + ) + raise + + except RendezvousError: + raise + + except Exception as e: + # In case of a general exception, wait a small delay + # to avoid spamming etcd + # FIXME: there are a few things that fall under this like + # etcd.EtcdKeyNotFound, etc, which could be handled more explicitly. + logger.info("Rendezvous attempt failed, will retry. Reason: %s", e) # noqa: G200 + time.sleep(1) + + def init_phase(self): + """ + Initially, the rendezvous state is expected to be one of: + + 1. empty (non-existent) - in this case we try to create a new one. + 2. joinable - we try to join it. + 3. final - we announce ourselves as waiting, and go into monitoring mode + + Any other state is considered transitional, and will be retried after + a short delay. + + Returns: + ``(rdzv_version, rank, world_size)`` + + Raises: + RendezvousClosedError - current rendezvous was/is closed + EtcdRendezvousRetryableFailure - observed some intermediate + state, which is best handled by retrying later + """ + try: + active_version = self.try_create_rendezvous() + state = json.loads(active_version.value) + logger.info("New rendezvous state created: %s", state) + except etcd.EtcdAlreadyExist: + active_version, state = self.get_rdzv_state() + # Note: it is possible for above query to fail (etcd.EtcdKeyNotFound), + # but this is ok for us - just means we'll restart from beginning. + logger.info("Observed existing rendezvous state: %s", state) + + if state["status"] == "closed": + raise RendezvousClosedError + + if state["status"] == "joinable": + return self.join_phase(state["version"]) + + if state["status"] == "final": + self.handle_existing_rendezvous(state["version"]) + raise EtcdRendezvousRetryImmediately + + self.try_wait_for_state_change(etcd_index=active_version.etcd_index + 1) + raise EtcdRendezvousRetryableFailure + + def join_phase(self, expected_version): + """ + We observed a rendezvous state in 'joinable' state, and attempt to join this + particular version, and then wait for all other peers to join. + """ + # Failure to join will propagate an exception, causing a re-entry. + active_version, this_rank = self.join_rendezvous(expected_version) + state = json.loads(active_version.value) + logger.info( + "Joined rendezvous version %s as rank %s. Full state: %s", + state["version"], + this_rank, + state, + ) + + # If this worker was first to reach num_min_workers requirement, + # and rendezvous is still joinable (therefore it is elastic), + # then this worker will be responsible for waiting out the "last call" + # timeout and closing (i.e. transitioning to 'frozen') the rendezvous + # afterwards. + # As a safety against a potential failure of this worker (during the + # last call timeout), the rendezvous state is made ephemeral + # when min_num_workers is reached. + + if this_rank == self._num_min_workers - 1 and state["status"] == "joinable": + logger.info("Rank %s is responsible for join last call.", this_rank) + last_call_deadline = time.time() + self._last_call_timeout + self.handle_join_last_call(expected_version, last_call_deadline) + logger.info("Rank %s finished join last call.", this_rank) + + # Wait for rendezvous state to be frozen, which means a fixed set of peers + logger.info("Waiting for remaining peers.") + active_version = self.wait_for_peers(expected_version) + state = json.loads(active_version.value) + + assert state["version"] == expected_version, ( + "Logic error: failed to observe version mismatch" + ) + + return self.confirm_phase(expected_version, this_rank) + + def confirm_phase(self, expected_version, this_rank): + """ + Once the rendezvous state transitions from 'joinable' to 'frozen', + we have every participant confirm their membership and setup per-member + keep-alive TTL keys, and then wait for all other participants to confirm, + which would then successfully conclude this rendezvous. + """ + logger.info("All peers arrived. Confirming membership.") + self.confirm_membership(expected_version, this_rank) + + logger.info("Waiting for confirmations from all peers.") + active_version = self.wait_for_final(expected_version) + state = json.loads(active_version.value) + + logger.info( + "Rendezvous version %s is complete. Final state: %s", + state["version"], + state, + ) + + # Rendezvous version number; our rank in it; world size + return state["version"], this_rank, len(state["participants"]) + + def handle_existing_rendezvous(self, expected_version): + """ + Handle the case when there's an existing (state 'final) rendezvous already + in place, and we have to announce ourselves waiting, and wait until + the next rendezvous opportunity. + """ + # If state is 'final' -> increment num_workers_waiting + # Then, observe state changes: + # 1. if it's no longer final -> bail out and re-try + # 2. if keep alives are missing, destroy it and bail out. + active_state = self.announce_self_waiting(expected_version) + logger.info( + "Added self to waiting list. Rendezvous full state: %s", active_state.value + ) + + self.wait_for_rendezvous_to_free(expected_version) + logger.info( + "Previously existing rendezvous state changed. Will re-try joining." + ) + + def try_create_rendezvous(self): + """ + Create new rendezvous state or raise an exception that indicates an unexpected state (e.g. already exists). + + Raises: + RendezvousError - on unexpected state + """ + # Initially active_version is ephemeral - this is to handle the + # possibility that might fail to complete the setup transaction, + # i.e. the transition "setup" -> "joinable". + active_version = self.client.write( + key=self.get_path("/rdzv/active_version"), + value=json.dumps({"status": "setup"}), + prevExist=False, + ttl=CONST_ETCD_SETUP_TTL, + ) + + try: + version_counter = self.client.get(self.get_path("/rdzv/version_counter")) + version_counter.value = str(int(version_counter.value) + 1) + self.client.update(version_counter) + except (etcd.EtcdKeyNotFound, etcd.EtcdCompareFailed) as e: + raise RendezvousError( + "Unexpected state of EtcdRendezvousHandler, worker needs to die." + ) from e + + # Any failure below results in declaring a retryable rendezvous failure. + # The ephemeral /rdzv/active_version will expire and someone can then + # re-try the setup process. + + # Create directory node for participant data + self.client.write( + key=self.get_path(f"/rdzv/v_{version_counter.value}"), + value=None, + dir=True, + prevExist=False, + ) + + # Publish rendezvous version and signal it is ready-to-be-joined. + # If rendezvous was set closed just before this, a retry will happen, + # where the closed condition will be handled. + return self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps( + { + "status": "joinable", + "version": version_counter.value, + "participants": [], + } + ), + prev_value=active_version.value, + ) + + def join_rendezvous(self, expected_version): + """Helper method for the join phase.""" + # Use compare-and-swap to add self to rendezvous state: + while True: + cas_delay() + active_version, state = self.get_rdzv_state() + + if state["status"] != "joinable": + raise EtcdRendezvousRetryableFailure( + "Rendezvous state became non-joinable before we could join. " + "Must join next one." + ) + + if state["version"] != expected_version: + raise EtcdRendezvousRetryImmediately( + "Rendezvous version changed. Must try join the new one." + ) + + assert len(state["participants"]) < self._num_max_workers, ( + "Logic error: joinable rendezvous should always have space left" + ) + + this_rank = len(state["participants"]) + state["participants"].append(this_rank) + + # When reaching min workers, or changing state to frozen, we'll set + # the active_version node to be ephemeral. + set_ttl: int | None = None + if len(state["participants"]) == self._num_max_workers: + state["status"] = "frozen" + state["keep_alives"] = [] + set_ttl = CONST_ETCD_FROZEN_TTL + elif len(state["participants"]) >= self._num_min_workers: + set_ttl = CONST_ETCD_JOINABLE_EPHEMERAL_TTL + + try: + # Compare-and-swap. + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ttl=set_ttl, + ) + # We succeeded joining. + return active_version, this_rank + + except etcd.EtcdCompareFailed: + logger.info("Join rendezvous CAS unsuccessful, retrying") + + def wait_for_peers(self, expected_version): + """Helper method for the join phase.""" + active_version, state = self.get_rdzv_state() + while True: + if state["status"] == "frozen" and state["version"] == expected_version: + # Success, all peers arrived. + return active_version + + elif state["status"] == "joinable" and state["version"] == expected_version: + # Continue waiting for any interesting events. + active_version, state = self.try_wait_for_state_change( + etcd_index=active_version.etcd_index + 1 + ) + + else: + # No valid transition possible at this point + raise EtcdRendezvousRetryableFailure( + "Rendezvous state transition no longer possible. Must re-enter." + ) + + def confirm_membership(self, expected_version, this_rank): + """Helper method for the confirm phase.""" + # Compare-and-swap loop + while True: + cas_delay() + active_version, state = self.get_rdzv_state() + + if state["status"] != "frozen": + raise EtcdRendezvousRetryImmediately( + "Rendezvous no longer frozen, before we confirmed. " + "Must join next one" + ) + if state["version"] != expected_version: + raise EtcdRendezvousRetryImmediately( + "Rendezvous version changed. Must try join the new one." + ) + + this_lease_key = self.get_path( + f"/rdzv/v_{expected_version}/rank_{this_rank}" + ) + self.client.set(this_lease_key, value=None, ttl=CONST_WORKER_KEEPALIVE_TTL) + + state["keep_alives"].append(this_lease_key) + if len(state["keep_alives"]) == len(state["participants"]): + # Everyone confirmed (this rank is last to do so) + state["status"] = "final" + state["num_workers_waiting"] = 0 + finalize = True + else: + finalize = False + + try: + # Compare-and-swap. If new state is still frozen, keep it ephemeral. + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ttl=None if finalize else CONST_ETCD_FROZEN_TTL, + ) + + self._lease_this_rank_stop = self.setup_lease_renewal( + this_lease_key, ttl=CONST_WORKER_KEEPALIVE_TTL + ) + return active_version + + except etcd.EtcdCompareFailed: + logger.info("Confirm membership CAS unsuccessful, retrying") + + def wait_for_final(self, expected_version): + """Helper method for the confirm phase.""" + active_version, state = self.get_rdzv_state() + while True: + if state["status"] == "final" and state["version"] == expected_version: + # Success. This rendezvous is final, and we accept it. + return active_version + + elif state["status"] == "frozen" and state["version"] == expected_version: + # Continue waiting for any interesting events. + active_version, state = self.try_wait_for_state_change( + etcd_index=active_version.etcd_index + 1 + ) + + else: + # No valid transition possible at this point + raise EtcdRendezvousRetryableFailure( + "Rendezvous state transition no longer possible. Must re-enter." + ) + + def announce_self_waiting(self, expected_version): + """ + Announce this worker is waiting (via num_workers_waiting counter) to join next + rendezvous, but only if state and version match. + """ + while True: + cas_delay() + active_version, state = self.get_rdzv_state() + + if state["status"] != "final" or state["version"] != expected_version: + raise EtcdRendezvousRetryImmediately + + # Increment counter to signal an additional waiting worker. + state["num_workers_waiting"] += 1 + + try: + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ) + return active_version + + except etcd.EtcdCompareFailed: + logger.info("Announce self as waiting CAS unsuccessful, retrying") + + def wait_for_rendezvous_to_free(self, expected_version): + """ + When there's an existing valid rendezvous in state 'final', we have to wait until the next opportunity to join. + + Such opportunity may come from: + + 1. rendezvous state changed by someone else, in which case we unblock and retry. + 2. rendezvous becomes invalid because at least one member failed to renew their + leased keep_alive node. We detect this, and destroy the rendezvous. + """ + active_version, state = self.get_rdzv_state() + while True: + if state["status"] != "final" or state["version"] != expected_version: + return + + # Check if current rendezvous state is valid, in the sense that all + # its members are alive (renewing their lease). + # If not, try destroy this rendezvous, so a new one can be created. + alive_members = self.client.get( + self.get_path(f"/rdzv/v_{expected_version}") + ) + keep_alive_keys = [ch.key for ch in alive_members.children] + + for key in state["keep_alives"]: + if key not in keep_alive_keys: + # This participant didn't renew their lease. We'll declare this + # rendezvous version as dead (but only if it hadn't changed) + logger.info("Keep-alive key %s is not renewed.", key) + logger.info( + "Rendezvous version %s is incomplete. ", expected_version + ) + logger.info("Attempting to destroy it.") + + # Compare-and-delete operation. Throws if compare failed, + # which means rendezvous was already destroyed/re-created/closed, + # and we can try to re-enter the barrier. + self.client.delete( + key=self.get_path("/rdzv/active_version"), + prevValue=active_version.value, + ) + + logger.info( + "Destroyed rendezvous version %s successfully.", + expected_version, + ) + + # We can return (and retry) immediately + return + + # Existing rendezvous seems valid, no reason to destroy it. + # We just have to wait until something changes and re-check. + try: + overall_timeout = ( + max(self._rendezvous_deadline - time.time(), 0.0) + 1.0 + ) + self.client.watch( + key=self.get_path("/rdzv"), + index=active_version.etcd_index + 1, + recursive=True, + timeout=overall_timeout, + ) + except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut): + pass + + if time.time() > self._rendezvous_deadline: + raise RendezvousTimeoutError + active_version, state = self.get_rdzv_state() + + def handle_join_last_call(self, expected_version, deadline): + """ + After we reach min number of workers, one particular worker takes on the + responsibility of waiting an additional timeout before closing the join window. + If the worker responsible for this fails, the rendezvous will be destroyed due + to expiring TTL, and the other participants will re-rendezvous. + + Here we expect to see state + Exit gracefully if either: + + 1. state becomes + 2. timeout happens (reaching deadline), in which case + we try the transition to + + Exit with exception otherwise. + """ + active_version, state = self.get_rdzv_state() + while True: + if state["status"] == "frozen" and state["version"] == expected_version: + # Worker set became frozen before last-call timeout. This is possible + # when num_max_workers is reached before the timeout. + return + + if state["status"] != "joinable" or state["version"] != expected_version: + raise EtcdRendezvousRetryableFailure( + "Rendezvous state transition no longer possible. Must re-enter." + ) + + # If timeout occurred, attempt a state transition (joinable -> frozen) + if time.time() >= deadline: + state["status"] = "frozen" + state["keep_alives"] = [] + try: + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ttl=CONST_ETCD_FROZEN_TTL, + ) + # We successfully made this rendezvous frozen. + return + except etcd.EtcdCompareFailed: + logger.info( + "Join last-call transition CAS unsuccessful. Will retry" + ) + cas_delay() + active_version, state = self.get_rdzv_state() + continue + + # Timeout did not occur, so we must refresh TTL, and wait for + # further changes. Note: we only want TTL to be refreshed if + # state is still joinable, hence we use CAS for that here, + # even though we don't change any of the data. + try: + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=active_version.value, + prev_value=active_version.value, + ttl=CONST_ETCD_JOINABLE_EPHEMERAL_TTL, + ) + + # Minimize "oversleeping": + timeout = min( + CONST_ETCD_JOINABLE_EPHEMERAL_TTL / 2, + deadline - time.time() + 1.0, # Oversleeping by 1s is ok. + ) + active_version, state = self.try_wait_for_state_change( + etcd_index=active_version.etcd_index + 1, timeout=timeout + ) + except etcd.EtcdCompareFailed: + logger.info("Join last-call TTL refresh CAS unsuccessful, will retry") + cas_delay() + active_version, state = self.get_rdzv_state() + + def set_closed(self): + """ + Mark rendezvous 'closed' for current run_id, which is used to signal other + participants to not attempt to perform (re-)rendezvous. This is useful + when one of the workers decides the job is complete. + """ + while True: + active_version, state = self.get_rdzv_state() + + if state["status"] == "closed": + # Already closed by someone else. + return + + state["status"] = "closed" + try: + self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ) + return + + except etcd.EtcdCompareFailed: + logger.info("Set closed CAS unsuccessful, retrying") + cas_delay() + + def get_rdzv_state(self): + active_version = self.client.get(key=self.get_path("/rdzv/active_version")) + return active_version, json.loads(active_version.value) + + def try_wait_for_state_change(self, etcd_index, timeout=None): + # Don't sleep past the overall deadline (at least more than by 1s) + overall_timeout = max(self._rendezvous_deadline - time.time(), 0.0) + 1.0 + timeout = overall_timeout if timeout is None else min(timeout, overall_timeout) + + try: + self.client.watch( + self.get_path("/rdzv/active_version"), index=etcd_index, timeout=timeout + ) + except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut): + pass + + if time.time() > self._rendezvous_deadline: + raise RendezvousTimeoutError + + # Unfortunately, we have to do another fetch in order to get last etcd_index. + return self.get_rdzv_state() + + def get_path(self, path): + if not path.startswith("/"): + path = "/" + path + + return f"{self._prefix}run_{self._run_id}{path}" + + def create_path_if_not_exists(self, full_path, ttl=None): + try: + self.client.write( + key=full_path, value=None, dir=True, prevExist=False, ttl=ttl + ) + except etcd.EtcdAlreadyExist: + pass + + def setup_lease_renewal(self, full_path, ttl): + # NOTE: For ephemeral key TTL renewal (~lease) to work correctly, + # make sure you don't call any long-blocking methods that do not + # release the Python's GIL! An example of this is calling a pybind11 + # extension function that is blocking / long-running, but is not + # doing a scoped release of the GIL. + def lease_worker(client, path, ttl, stop_event): + while True: + try: + client.refresh(path, ttl=ttl) + except etcd.EtcdKeyNotFound: + break + except ConnectionRefusedError: + # This error usually occurs during test when the server already got terminated but the + # python garbage collector have not yet invoked the __del__ method. + break + + if stop_event.wait(timeout=ttl / 2): + break + + lease_stop_event = threading.Event() + lease_thread = threading.Thread( + target=lease_worker, args=(self.client, full_path, ttl, lease_stop_event) + ) + + lease_thread.daemon = True + lease_thread.start() + + return lease_stop_event + + def store_extra_data(self, rdzv_version, key, value): + node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data") + try: + # If first time we are storing anything: + extra_data = self.client.write( + key=node, value=json.dumps({key: value}), prevExist=False + ) + return + except etcd.EtcdAlreadyExist: + pass + + # CAS loop, to make sure we don't lose concurrent stores. + while True: + # We never delete extra_data. Failure here should be fatal, no special handling. + extra_data = self.client.get(node) + + new_extra_data_value = json.loads(extra_data.value) + new_extra_data_value[key] = value + + try: + extra_data = self.client.test_and_set( + key=node, + value=json.dumps(new_extra_data_value), + prev_value=extra_data.value, + ) + return + except etcd.EtcdCompareFailed: + logger.info("Store extra_data CAS unsuccessful, retrying") + time.sleep(0.1) + + def load_extra_data(self, rdzv_version, key, timeout=None): + # 'extra_data' node itself, and the directory it is located in: + node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data") + node_dir = self.get_path(f"/rdzv/v_{rdzv_version}") + + # TODO: implement timeout + # https://github.com/pytorch/elastic/issues/12 + while True: + # Combined wait for the node itself, and the key inside it. + root = self.client.get(node_dir) + + # Find the extra_data node, if it exists + extra_data = [n for n in root.children if n.key == node] + assert len(extra_data) <= 1 + + # Node for extra_data exists, check the desired key inside it. + if len(extra_data) == 1: + extra_data_dict = json.loads(extra_data[0].value) + if key in extra_data_dict: + return extra_data_dict[key] + + # The 'extra_data' node doesn't exist, or they key isn't published yet. + # Wait for interesting events on the extra_data node and retry. + try: + self.client.watch(node, index=root.etcd_index + 1) + except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut): + pass + + def setup_kv_store(self, rdzv_version): + store_path = self.get_path(f"/rdzv/v_{rdzv_version}/kv") + self.create_path_if_not_exists(store_path) + return EtcdStore(etcd_client=self.client, etcd_store_prefix=store_path) + + +def _create_etcd_client(params: RendezvousParameters) -> etcd.Client: + """Create a new ``etcd.Client`` from the specified ``RendezvousParameters``.""" + hostname, port = parse_rendezvous_endpoint(params.endpoint, 2379) + + # The communication protocol + protocol = params.config.get("protocol") + if protocol is None: + protocol = "http" + else: + if protocol != "http" and protocol != "https": + raise ValueError("The etcd protocol must be HTTP or HTTPS.") + + # The SSL client certificate + ssl_cert = params.config.get("cert") + if ssl_cert is not None: + cert_key = params.config.get("key") + if cert_key is not None: + # The etcd client expects the certificate key as the second element + # of the `cert` tuple. + ssl_cert = (ssl_cert, cert_key) + + # The root certificate + ca_cert = params.config.get("cacert") + + return etcd.Client( + hostname, + port, + protocol=protocol, + cert=ssl_cert, + ca_cert=ca_cert, + allow_reconnect=True, + ) + + +# Handler for torch.distributed "static" registration +def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: + """ + Usage: + + :: + + rdzv_params = RendezvousParameters( + backend="etcd", + endpoint="192.168.0.42:2379", + run_id="123", + min_nodes=4, + max_nodes=8, + timeout=300, + last_call_timeout=30, + etcd_prefix="custom_prefix", + protocol="https", + cacert="/etc/kubernetes/certs/ca.crt", + cert="/etc/kubernetes/certs/client.crt", + key="/etc/kubernetes/certs/client.key") + # -- or -- + rdzv_params = RendezvousParameters( + backend="etcd", + endpoint="192.168.0.42:2379", + run_id="123", + min_nodes=4, + max_nodes=8) + + etcd_rdzv_handler = create_etcd_rendezvous_handler(rdzv_params) + + + Where: + run_id - unique id for this training job instance, + min_nodes - min number of workers expected to join the rendezvous, + max_nodes - max number of workers allowed to join the rendezvous, + defaults to min_workers is not specified. + timeout - total timeout within which next_rendezvous is expected to + succeed; a RendezvousTimeoutError is raised otherwise; + Defaults is 600 (10 minutes). + last_call_timeout - additional wait amount ("last call") after + min number of workers has been reached. + Defaults to 30 seconds. + etcd_prefix - path prefix (from etcd root), inside which all + etcd nodes will be created. + Default is "/torchelastic/p2p". + protocol - http (default) or https to access etcd. + cacert - CA cert to access etcd, only makes sense with https. + cert - client cert to access etcd, only makes sense with https. + key - client key to access etcd, only makes sense with https. + """ + client = _create_etcd_client(params) + + etcd_prefix = params.get("etcd_prefix", "/torchelastic/p2p") + + rdzv = EtcdRendezvous( + client=client, + prefix=etcd_prefix, + run_id=params.run_id, + num_min_workers=params.min_nodes, + num_max_workers=params.max_nodes, + timeout=params.get_as_int("timeout", _DEFAULT_TIMEOUT), + last_call_timeout=params.get_as_int( + "last_call_timeout", _DEFAULT_LAST_CALL_TIMEOUT + ), + ) + return EtcdRendezvousHandler( + rdzv_impl=rdzv, + local_addr=params.local_addr, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..4cda28221ff4ec79fbd468a5067c91942b9b7be4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py @@ -0,0 +1,214 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import binascii +from base64 import b64decode, b64encode +from typing import cast + +import urllib3.exceptions # type: ignore[import] + + +try: + import etcd # type: ignore[import] +except ModuleNotFoundError: + from . import _etcd_stub as etcd + +from torch.distributed import Store + +from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError +from .dynamic_rendezvous import RendezvousBackend, Token +from .etcd_store import EtcdStore +from .utils import parse_rendezvous_endpoint + + +class EtcdRendezvousBackend(RendezvousBackend): + """Represents an etcd-based rendezvous backend. + + Args: + client: + The ``etcd.Client`` instance to use to communicate with etcd. + run_id: + The run id of the rendezvous. + key_prefix: + The path under which to store the rendezvous state in etcd. + ttl: + The TTL of the rendezvous state. If not specified, defaults to two hours. + """ + + _DEFAULT_TTL = 7200 # 2 hours + + _client: etcd.Client + _key: str + _ttl: int + + def __init__( + self, + client: etcd.Client, + run_id: str, + key_prefix: str | None = None, + ttl: int | None = None, + ) -> None: + if not run_id: + raise ValueError("The run id must be a non-empty string.") + + self._client = client + + if key_prefix: + self._key = key_prefix + "/" + run_id + else: + self._key = run_id + + if ttl and ttl > 0: + self._ttl = ttl + else: + self._ttl = self._DEFAULT_TTL + + @property + def name(self) -> str: + """See base class.""" + return "etcd-v2" + + def get_state(self) -> tuple[bytes, Token] | None: + """See base class.""" + try: + result = self._client.read(self._key) + except etcd.EtcdKeyNotFound: + return None + except (etcd.EtcdException, urllib3.exceptions.TimeoutError) as exc: + raise RendezvousConnectionError( + "The connection to etcd has failed. See inner exception for details." + ) from exc + + return self._decode_state(result) + + def set_state( + self, state: bytes, token: Token | None = None + ) -> tuple[bytes, Token, bool] | None: + """See base class.""" + base64_state = b64encode(state).decode() + + kwargs = {} + + def get_state(): + result = self.get_state() + if result is not None: + return *result, False + return None + + if token: + try: + token = int(token) + except ValueError: + return get_state() + + if token: + kwargs["prevIndex"] = token + else: + kwargs["prevExist"] = False + + try: + result = self._client.write(self._key, base64_state, self._ttl, **kwargs) + except (etcd.EtcdAlreadyExist, etcd.EtcdCompareFailed): + result = None + except (etcd.EtcdException, urllib3.exceptions.TimeoutError) as exc: + raise RendezvousConnectionError( + "The connection to etcd has failed. See inner exception for details." + ) from exc + + if result is None: + return get_state() + + tmp = *self._decode_state(result), True + return tmp + + def _decode_state(self, result: etcd.EtcdResult) -> tuple[bytes, Token]: + # pyrefly: ignore [missing-attribute] + base64_state = result.value.encode() + + try: + state = b64decode(base64_state) + except binascii.Error as exc: + raise RendezvousStateError( + "The state object is corrupt. See inner exception for details." + ) from exc + + # pyrefly: ignore [missing-attribute] + return state, result.modifiedIndex + + +def _create_etcd_client(params: RendezvousParameters) -> etcd.Client: + host, port = parse_rendezvous_endpoint(params.endpoint, default_port=2379) + + # The timeout + read_timeout = cast(int, params.get_as_int("read_timeout", 60)) + if read_timeout <= 0: + raise ValueError("The read timeout must be a positive integer.") + + # The communication protocol + protocol = params.get("protocol", "http").strip().lower() + if protocol != "http" and protocol != "https": + raise ValueError("The protocol must be HTTP or HTTPS.") + + # The SSL client certificate + ssl_cert = params.get("ssl_cert") + if ssl_cert: + ssl_cert_key = params.get("ssl_cert_key") + if ssl_cert_key: + # The etcd client expects the certificate key as the second element + # of the `cert` tuple. + ssl_cert = (ssl_cert, ssl_cert_key) + + # The root certificate + ca_cert = params.get("ca_cert") + + try: + return etcd.Client( + host, + port, + read_timeout=read_timeout, + protocol=protocol, + cert=ssl_cert, + ca_cert=ca_cert, + allow_reconnect=True, + ) + except (etcd.EtcdException, urllib3.exceptions.TimeoutError) as exc: + raise RendezvousConnectionError( + "The connection to etcd has failed. See inner exception for details." + ) from exc + + +def create_backend(params: RendezvousParameters) -> tuple[EtcdRendezvousBackend, Store]: + """Create a new :py:class:`EtcdRendezvousBackend` from the specified parameters. + + +--------------+-----------------------------------------------------------+ + | Parameter | Description | + +==============+===========================================================+ + | read_timeout | The read timeout, in seconds, for etcd operations. | + | | Defaults to 60 seconds. | + +--------------+-----------------------------------------------------------+ + | protocol | The protocol to use to communicate with etcd. Valid | + | | values are "http" and "https". Defaults to "http". | + +--------------+-----------------------------------------------------------+ + | ssl_cert | The path to the SSL client certificate to use along with | + | | HTTPS. Defaults to ``None``. | + +--------------+-----------------------------------------------------------+ + | ssl_cert_key | The path to the private key of the SSL client certificate | + | | to use along with HTTPS. Defaults to ``None``. | + +--------------+-----------------------------------------------------------+ + | ca_cert | The path to the rool SSL authority certificate. Defaults | + | | to ``None``. | + +--------------+-----------------------------------------------------------+ + """ + client = _create_etcd_client(params) + + backend = EtcdRendezvousBackend( + client, params.run_id, key_prefix="/torch/elastic/rendezvous" + ) + + store = EtcdStore(client, "/torch/elastic/store") + + return backend, store diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_server.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_server.py new file mode 100644 index 0000000000000000000000000000000000000000..347e7339d9a46a78c9edf20917eef6146672ffc8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_server.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import atexit +import logging +import os +import shlex +import shutil +import socket +import subprocess +import tempfile +import time +from typing import TextIO + + +try: + import etcd # type: ignore[import] +except ModuleNotFoundError: + pass + + +logger = logging.getLogger(__name__) + + +def find_free_port(): + """ + Find a free port and binds a temporary socket to it so that the port can be "reserved" until used. + + .. note:: the returned socket must be closed before using the port, + otherwise a ``address already in use`` error will happen. + The socket should be held and closed as close to the + consumer of the port as possible since otherwise, there + is a greater chance of race-condition where a different + process may see the port as being free and take it. + + Returns: a socket binded to the reserved free port + + Usage:: + + sock = find_free_port() + port = sock.getsockname()[1] + sock.close() + use_port(port) + """ + addrs = socket.getaddrinfo( + host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM + ) + + for addr in addrs: + family, type, proto, _, _ = addr + try: + s = socket.socket(family, type, proto) + s.bind(("localhost", 0)) + s.listen(0) + return s + except OSError as e: + s.close() # type: ignore[possibly-undefined] + print(f"Socket creation attempt failed: {e}") + raise RuntimeError("Failed to create a socket") + + +def stop_etcd(subprocess, data_dir: str | None = None): + if subprocess and subprocess.poll() is None: + logger.info("stopping etcd server") + subprocess.terminate() + subprocess.wait() + + if data_dir: + logger.info("deleting etcd data dir: %s", data_dir) + shutil.rmtree(data_dir, ignore_errors=True) + + +class EtcdServer: + """ + .. note:: tested on etcd server v3.4.3. + + Starts and stops a local standalone etcd server on a random free + port. Useful for single node, multi-worker launches or testing, + where a sidecar etcd server is more convenient than having to + separately setup an etcd server. + + This class registers a termination handler to shutdown the etcd + subprocess on exit. This termination handler is NOT a substitute for + calling the ``stop()`` method. + + The following fallback mechanism is used to find the etcd binary: + + 1. Uses env var TORCHELASTIC_ETCD_BINARY_PATH + 2. Uses ``/bin/etcd`` if one exists + 3. Uses ``etcd`` from ``PATH`` + + Usage + :: + + server = EtcdServer("/usr/bin/etcd", 2379, "/tmp/default.etcd") + server.start() + client = server.get_client() + # use client + server.stop() + + Args: + etcd_binary_path: path of etcd server binary (see above for fallback path) + """ + + def __init__(self, data_dir: str | None = None): + self._port = -1 + self._host = "localhost" + + root = os.path.dirname(__file__) + default_etcd_bin = os.path.join(root, "bin/etcd") + self._etcd_binary_path = os.environ.get( + "TORCHELASTIC_ETCD_BINARY_PATH", default_etcd_bin + ) + if not os.path.isfile(self._etcd_binary_path): + self._etcd_binary_path = "etcd" + + self._base_data_dir = ( + data_dir if data_dir else tempfile.mkdtemp(prefix="torchelastic_etcd_data") + ) + self._etcd_cmd = None + self._etcd_proc: subprocess.Popen | None = None + + def _get_etcd_server_process(self) -> subprocess.Popen: + if not self._etcd_proc: + raise RuntimeError( + "No etcd server process started. Call etcd_server.start() first" + ) + else: + return self._etcd_proc + + def get_port(self) -> int: + """Return the port the server is running on.""" + return self._port + + def get_host(self) -> str: + """Return the host the server is running on.""" + return self._host + + def get_endpoint(self) -> str: + """Return the etcd server endpoint (host:port).""" + return f"{self._host}:{self._port}" + + def start( + self, + timeout: int = 60, + num_retries: int = 3, + stderr: int | TextIO | None = None, + ) -> None: + """ + Start the server, and waits for it to be ready. When this function returns the sever is ready to take requests. + + Args: + timeout: time (in seconds) to wait for the server to be ready + before giving up. + num_retries: number of retries to start the server. Each retry + will wait for max ``timeout`` before considering it as failed. + stderr: the standard error file handle. Valid values are + `subprocess.PIPE`, `subprocess.DEVNULL`, an existing file + descriptor (a positive integer), an existing file object, and + `None`. + + Raises: + TimeoutError: if the server is not ready within the specified timeout + """ + curr_retries = 0 + while True: + try: + data_dir = os.path.join(self._base_data_dir, str(curr_retries)) + os.makedirs(data_dir, exist_ok=True) + return self._start(data_dir, timeout, stderr) + except Exception as e: + curr_retries += 1 + stop_etcd(self._etcd_proc) + logger.warning( # noqa: G200 + "Failed to start etcd server, got error: %s, retrying", str(e) + ) + if curr_retries >= num_retries: + shutil.rmtree(self._base_data_dir, ignore_errors=True) + raise + atexit.register(stop_etcd, self._etcd_proc, self._base_data_dir) + + def _start( + self, data_dir: str, timeout: int = 60, stderr: int | TextIO | None = None + ) -> None: + sock = find_free_port() + sock_peer = find_free_port() + self._port = sock.getsockname()[1] + peer_port = sock_peer.getsockname()[1] + + etcd_cmd = shlex.split( + " ".join( + [ + self._etcd_binary_path, + "--enable-v2", + "--data-dir", + data_dir, + "--listen-client-urls", + f"http://{self._host}:{self._port}", + "--advertise-client-urls", + f"http://{self._host}:{self._port}", + "--listen-peer-urls", + f"http://{self._host}:{peer_port}", + ] + ) + ) + + logger.info("Starting etcd server: [%s]", etcd_cmd) + + sock.close() + sock_peer.close() + self._etcd_proc = subprocess.Popen(etcd_cmd, close_fds=True, stderr=stderr) + self._wait_for_ready(timeout) + + def get_client(self): + """Return an etcd client object that can be used to make requests to this server.""" + return etcd.Client( + host=self._host, port=self._port, version_prefix="/v2", read_timeout=10 + ) + + def _wait_for_ready(self, timeout: int = 60) -> None: + client = etcd.Client( + host=f"{self._host}", port=self._port, version_prefix="/v2", read_timeout=5 + ) + max_time = time.time() + timeout + + while time.time() < max_time: + if self._get_etcd_server_process().poll() is not None: + # etcd server process finished + exitcode = self._get_etcd_server_process().returncode + raise RuntimeError( + f"Etcd server process exited with the code: {exitcode}" + ) + try: + logger.info("etcd server ready. version: %s", client.version) + return + except Exception: + time.sleep(1) + raise TimeoutError("Timed out waiting for etcd server to be ready!") + + def stop(self) -> None: + """Stop the server and cleans up auto generated resources (e.g. data dir).""" + logger.info("EtcdServer stop method called") + stop_etcd(self._etcd_proc, self._base_data_dir) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_store.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_store.py new file mode 100644 index 0000000000000000000000000000000000000000..faaf77587bc9d66e42110f8b36c8c17e5aedec87 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_store.py @@ -0,0 +1,215 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import datetime +import random +import time +from base64 import b64decode, b64encode + +# pyre-ignore[21]: Could not find name `Store` in `torch.distributed`. +from torch.distributed import Store + + +try: + import etcd # type: ignore[import] +except ModuleNotFoundError: + from . import _etcd_stub as etcd + + +# Delay (sleep) for a small random amount to reduce CAS failures. +# This does not affect correctness, but will reduce requests to etcd server. +def cas_delay(): + time.sleep(random.uniform(0, 0.1)) + + +# pyre-fixme[11]: Annotation `Store` is not defined as a type. +class EtcdStore(Store): + """ + Implement a c10 Store interface by piggybacking on the rendezvous etcd instance. + + This is the store object returned by ``EtcdRendezvous``. + """ + + def __init__( + self, + etcd_client, + etcd_store_prefix, + # Default timeout same as in c10d/Store.hpp + timeout: datetime.timedelta | None = None, + ): + super().__init__() # required for pybind trampoline. + + self.client = etcd_client + self.prefix = etcd_store_prefix + + if timeout is not None: + self.set_timeout(timeout) + + if not self.prefix.endswith("/"): + self.prefix += "/" + + def set(self, key, value): + """ + Write a key/value pair into ``EtcdStore``. + + Both key and value may be either Python ``str`` or ``bytes``. + """ + self.client.set(key=self.prefix + self._encode(key), value=self._encode(value)) + + def get(self, key) -> bytes: + """ + Get a value by key, possibly doing a blocking wait. + + If key is not immediately present, will do a blocking wait + for at most ``timeout`` duration or until the key is published. + + + Returns: + value ``(bytes)`` + + Raises: + LookupError - If key still not published after timeout + """ + b64_key = self.prefix + self._encode(key) + kvs = self._try_wait_get([b64_key]) + + if kvs is None: + raise LookupError(f"Key {key} not found in EtcdStore") + + return self._decode(kvs[b64_key]) + + def add(self, key, num: int) -> int: + """ + Atomically increment a value by an integer amount. + + The integer is represented as a string using base 10. If key is not present, + a default value of ``0`` will be assumed. + + Returns: + the new (incremented) value + + + """ + b64_key = self._encode(key) + # c10d Store assumes value is an integer represented as a decimal string + try: + # Assume default value "0", if this key didn't yet: + node = self.client.write( + key=self.prefix + b64_key, + value=self._encode(str(num)), # i.e. 0 + num + prevExist=False, + ) + return int(self._decode(node.value)) + except etcd.EtcdAlreadyExist: + pass + + while True: + # Note: c10d Store does not have a method to delete keys, so we + # can be sure it's still there. + node = self.client.get(key=self.prefix + b64_key) + new_value = self._encode(str(int(self._decode(node.value)) + num)) + try: + node = self.client.test_and_set( + key=node.key, value=new_value, prev_value=node.value + ) + return int(self._decode(node.value)) + except etcd.EtcdCompareFailed: + cas_delay() + + def wait(self, keys, override_timeout: datetime.timedelta | None = None): + """ + Wait until all of the keys are published, or until timeout. + + Raises: + LookupError - if timeout occurs + """ + b64_keys = [self.prefix + self._encode(key) for key in keys] + kvs = self._try_wait_get(b64_keys, override_timeout) + if kvs is None: + raise LookupError("Timeout while waiting for keys in EtcdStore") + # No return value on success + + def check(self, keys) -> bool: + """Check if all of the keys are immediately present (without waiting).""" + b64_keys = [self.prefix + self._encode(key) for key in keys] + kvs = self._try_wait_get( + b64_keys, + override_timeout=datetime.timedelta(microseconds=1), # as if no wait + ) + return kvs is not None + + # + # Encode key/value data in base64, so we can store arbitrary binary data + # in EtcdStore. Input can be `str` or `bytes`. + # In case of `str`, utf-8 encoding is assumed. + # + def _encode(self, value) -> str: + if type(value) is bytes: + return b64encode(value).decode() + elif type(value) is str: + return b64encode(value.encode()).decode() + raise ValueError("Value must be of type str or bytes") + + # + # Decode a base64 string (of type `str` or `bytes`). + # Return type is `bytes`, which is more convenient with the Store interface. + # + def _decode(self, value) -> bytes: + if type(value) is bytes: + return b64decode(value) + elif type(value) is str: + return b64decode(value.encode()) + raise ValueError("Value must be of type str or bytes") + + # + # Get all of the (base64-encoded) etcd keys at once, or wait until all the keys + # are published or timeout occurs. + # This is a helper method for the public interface methods. + # + # On success, a dictionary of {etcd key -> etcd value} is returned. + # On timeout, None is returned. + # + def _try_wait_get(self, b64_keys, override_timeout=None): + timeout = self.timeout if override_timeout is None else override_timeout # type: ignore[attr-defined] + deadline = time.time() + timeout.total_seconds() + + while True: + # Read whole directory (of keys), filter only the ones waited for + all_nodes = None + try: + all_nodes = self.client.get(key=self.prefix) + req_nodes = { + node.key: node.value + for node in all_nodes.children + if node.key in b64_keys + } + + if len(req_nodes) == len(b64_keys): + # All keys are available + return req_nodes + except etcd.EtcdKeyNotFound: + pass + + watch_timeout = deadline - time.time() + if watch_timeout <= 0: + return None + + try: + index = all_nodes.etcd_index + 1 if all_nodes else 0 + self.client.watch( + key=self.prefix, + recursive=True, + timeout=watch_timeout, + index=index, + ) + except etcd.EtcdWatchTimedOut: + if time.time() >= deadline: + return None + else: + continue + except etcd.EtcdEventIndexCleared: + continue diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/registry.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..ebada4623a814c6b8a2b802d544e5926426e13fc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/registry.py @@ -0,0 +1,95 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from importlib.metadata import entry_points + +from .api import ( + rendezvous_handler_registry as handler_registry, + RendezvousHandler, + RendezvousParameters, +) +from .dynamic_rendezvous import create_handler + + +log = logging.getLogger(__name__) + +__all__ = ["get_rendezvous_handler"] + + +def _create_static_handler(params: RendezvousParameters) -> RendezvousHandler: + from . import static_tcp_rendezvous + + return static_tcp_rendezvous.create_rdzv_handler(params) + + +def _create_etcd_handler(params: RendezvousParameters) -> RendezvousHandler: + from . import etcd_rendezvous + + return etcd_rendezvous.create_rdzv_handler(params) + + +def _create_etcd_v2_handler(params: RendezvousParameters) -> RendezvousHandler: + from .etcd_rendezvous_backend import create_backend + + backend, store = create_backend(params) + + return create_handler(store, backend, params) + + +def _create_c10d_handler(params: RendezvousParameters) -> RendezvousHandler: + from .c10d_rendezvous_backend import create_backend + + backend, store = create_backend(params) + + return create_handler(store, backend, params) + + +def _register_default_handlers() -> None: + handler_registry.register("etcd", _create_etcd_handler) + handler_registry.register("etcd-v2", _create_etcd_v2_handler) + handler_registry.register("c10d", _create_c10d_handler) + handler_registry.register("static", _create_static_handler) + + +def _register_out_of_tree_handlers() -> None: + discovered_handler_generators = entry_points(group="torchrun.handlers") + + for handler_generator in discovered_handler_generators: + try: + get_handler = discovered_handler_generators[handler_generator.name].load() + handler_registry.register(handler_generator.name, get_handler()) + except Exception: + log.warning( + "Exception while registering out of tree plugin %s: ", + handler_generator.name, + exc_info=True, + ) + + +def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler: + """ + Obtain a reference to a :py:class`RendezvousHandler`. + + Custom rendezvous handlers can be registered by + + :: + + from torch.distributed.elastic.rendezvous import rendezvous_handler_registry + from torch.distributed.elastic.rendezvous.registry import get_rendezvous_handler + + + def create_my_rdzv(params: RendezvousParameters): + return MyCustomRdzv(params) + + + rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv) + + my_rdzv_handler = get_rendezvous_handler( + "my_rdzv_backend_name", RendezvousParameters + ) + """ + return handler_registry.create_handler(params) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py new file mode 100644 index 0000000000000000000000000000000000000000..52b68000530889b6be1a8ec78ea762f6e5817975 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import datetime +import logging +from typing import cast + +from torch.distributed import PrefixStore, Store, TCPStore +from torch.distributed.elastic.rendezvous import ( + RendezvousHandler, + RendezvousInfo, + RendezvousParameters, + RendezvousStoreInfo, +) +from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint + + +__all__ = ["StaticTCPRendezvous", "create_rdzv_handler"] + +logger = logging.getLogger(__name__) + +_default_timeout_seconds = 600 + + +class StaticTCPRendezvous(RendezvousHandler): + """ + Static rendezvous that is a wrapper around the TCPStore. + + Creates TCPStore based on the input parameters with the + listener on the agent with group_rank=0 + """ + + def __init__( + self, + master_addr: str, + master_port: int, + rank: int, + world_size: int, + run_id: str, + timeout: int, + ): + self.master_addr = master_addr + self.master_port = master_port + self.rank = rank + self.world_size = world_size + self.run_id = run_id + self.timeout = datetime.timedelta(seconds=timeout) + self._store: Store | None = None + + def get_backend(self) -> str: + return "static" + + @property + def use_agent_store(self) -> bool: + return True + + def next_rendezvous(self) -> RendezvousInfo: + logger.info("Creating TCPStore as the c10d::Store implementation") + is_master = self.rank == 0 + if not self._store: + self._store = TCPStore( # type: ignore[call-arg] + self.master_addr, + self.master_port, + self.world_size, + is_master, + self.timeout, + multi_tenant=True, + ) + store = PrefixStore(self.run_id, self._store) + # TCPStore server instance is used by trainer code + bootstrap_store_info = RendezvousStoreInfo(self.master_addr, self.master_port) + return RendezvousInfo( + store, + self.rank, + self.world_size, + bootstrap_store_info, + ) + + def is_closed(self): + return False + + def set_closed(self): + pass + + def num_nodes_waiting(self): + return 0 + + def get_run_id(self) -> str: + return self.run_id + + def shutdown(self) -> bool: + return True + + +def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: + if "rank" not in params.config: + raise ValueError( + "rank is absent in RendezvousParameters." + "Try add --node-rank to the cmd request" + ) + endpoint = params.endpoint.strip() + if not endpoint: + raise ValueError( + "endpoint is absent in RendezvousParameters" + "Try add --master-port and --master-addr to the cmd request" + ) + master_addr, master_port = parse_rendezvous_endpoint(endpoint, -1) + if master_port == -1: + raise ValueError( + f"Port is absent in endpoint: {endpoint}. Try launching with --master-port" + ) + world_size = params.max_nodes + rank = cast(int, params.config.get("rank")) + run_id = params.run_id + if "timeout" in params.config: + timeout = int(params.config["timeout"]) + else: + timeout = _default_timeout_seconds + + return StaticTCPRendezvous( + master_addr, master_port, rank, world_size, run_id, timeout + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..05ebbba55913fc4f7d9843420a68b4ae233f3e14 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/utils.py @@ -0,0 +1,285 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import ipaddress +import random +import re +import socket +import time +import weakref +from collections.abc import Callable +from datetime import timedelta +from threading import Event, Thread +from typing import Any + + +__all__ = ["parse_rendezvous_endpoint"] + + +def _parse_rendezvous_config(config_str: str) -> dict[str, str]: + """Extract key-value pairs from a rendezvous configuration string. + + Args: + config_str: + A string in format =,...,=. + """ + config: dict[str, str] = {} + + config_str = config_str.strip() + if not config_str: + return config + + key_values = config_str.split(",") + for kv in key_values: + key, *values = kv.split("=", 1) + + key = key.strip() + if not key: + raise ValueError( + "The rendezvous configuration string must be in format " + "=,...,=." + ) + + value: str | None + if values: + value = values[0].strip() + else: + value = None + if not value: + raise ValueError( + f"The rendezvous configuration option '{key}' must have a value specified." + ) + + config[key] = value + return config + + +def _try_parse_port(port_str: str) -> int | None: + """Try to extract the port number from ``port_str``.""" + if port_str and re.match(r"^[0-9]{1,5}$", port_str): + return int(port_str) + return None + + +def parse_rendezvous_endpoint( + endpoint: str | None, default_port: int +) -> tuple[str, int]: + """Extract the hostname and the port number from a rendezvous endpoint. + + Args: + endpoint: + A string in format [:]. + default_port: + The port number to use if the endpoint does not include one. + + Returns: + A tuple of hostname and port number. + """ + if endpoint is not None: + endpoint = endpoint.strip() + + if not endpoint: + return ("localhost", default_port) + + # An endpoint that starts and ends with brackets represents an IPv6 address. + if endpoint[0] == "[" and endpoint[-1] == "]": + host, *rest = endpoint, *[] + else: + host, *rest = endpoint.rsplit(":", 1) + + # Sanitize the IPv6 address. + if len(host) > 1 and host[0] == "[" and host[-1] == "]": + host = host[1:-1] + + if len(rest) == 1: + port = _try_parse_port(rest[0]) + if port is None or port >= 2**16: + raise ValueError( + f"The port number of the rendezvous endpoint '{endpoint}' must be an integer " + "between 0 and 65536." + ) + else: + port = default_port + + if not re.match(r"^[\w\.:-]+$", host): + raise ValueError( + f"The hostname of the rendezvous endpoint '{endpoint}' must be a dot-separated list of " + "labels, an IPv4 address, or an IPv6 address." + ) + + return host, port + + +def _matches_machine_hostname(host: str) -> bool: + """Indicate whether ``host`` matches the hostname of this machine. + + This function compares ``host`` to the hostname as well as to the IP + addresses of this machine. Note that it may return a false negative if this + machine has CNAME records beyond its FQDN or IP addresses assigned to + secondary NICs. + """ + if host == "localhost": + return True + + try: + addr = ipaddress.ip_address(host) + except ValueError: + addr = None + + if addr and addr.is_loopback: + return True + + try: + host_addr_list = socket.getaddrinfo( + host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME + ) + except (ValueError, socket.gaierror) as _: + host_addr_list = [] + + host_ip_list = [host_addr_info[4][0] for host_addr_info in host_addr_list] + + this_host = socket.gethostname() + if host == this_host: + return True + + addr_list = socket.getaddrinfo( + this_host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME + ) + for addr_info in addr_list: + # If we have an FQDN in the addr_info, compare it to `host`. + if addr_info[3] and addr_info[3] == host: + return True + + # Otherwise if `host` represents an IP address, compare it to our IP + # address. + if addr and addr_info[4][0] == str(addr): + return True + + # If the IP address matches one of the provided host's IP addresses + if addr_info[4][0] in host_ip_list: + return True + + return False + + +def _delay(seconds: float | tuple[float, float]) -> None: + """Suspend the current thread for ``seconds``. + + Args: + seconds: + Either the delay, in seconds, or a tuple of a lower and an upper + bound within which a random delay will be picked. + """ + if isinstance(seconds, tuple): + seconds = random.uniform(*seconds) + # Ignore delay requests that are less than 10 milliseconds. + if seconds >= 0.01: + time.sleep(seconds) + + +class _PeriodicTimer: + """Represent a timer that periodically runs a specified function. + + Args: + interval: + The interval, in seconds, between each run. + function: + The function to run. + """ + + # The state of the timer is hold in a separate context object to avoid a + # reference cycle between the timer and the background thread. + class _Context: + interval: float + function: Callable[..., None] + args: tuple[Any, ...] + kwargs: dict[str, Any] + stop_event: Event + + _name: str | None + _thread: Thread | None + _finalizer: weakref.finalize | None + + # The context that is shared between the timer and the background thread. + _ctx: _Context + + def __init__( + self, + interval: timedelta, + function: Callable[..., None], + *args: Any, + **kwargs: Any, + ) -> None: + self._name = None + + self._ctx = self._Context() + self._ctx.interval = interval.total_seconds() + self._ctx.function = function # type: ignore[assignment] + self._ctx.args = args or () + self._ctx.kwargs = kwargs or {} + self._ctx.stop_event = Event() + + self._thread = None + self._finalizer = None + + @property + def name(self) -> str | None: + """Get the name of the timer.""" + return self._name + + def set_name(self, name: str) -> None: + """Set the name of the timer. + + The specified name will be assigned to the background thread and serves + for debugging and troubleshooting purposes. + """ + if self._thread: + raise RuntimeError("The timer has already started.") + + self._name = name + + def start(self) -> None: + """Start the timer.""" + if self._thread: + raise RuntimeError("The timer has already started.") + + self._thread = Thread( + target=self._run, + name=self._name or "PeriodicTimer", + args=(self._ctx,), + daemon=True, + ) + + # We avoid using a regular finalizer (a.k.a. __del__) for stopping the + # timer as joining a daemon thread during the interpreter shutdown can + # cause deadlocks. The weakref.finalize is a superior alternative that + # provides a consistent behavior regardless of the GC implementation. + self._finalizer = weakref.finalize( + self, self._stop_thread, self._thread, self._ctx.stop_event + ) + + # We do not attempt to stop our background thread during the interpreter + # shutdown. At that point we do not even know whether it still exists. + self._finalizer.atexit = False + + self._thread.start() + + def cancel(self) -> None: + """Stop the timer at the next opportunity.""" + if self._finalizer: + self._finalizer() + + @staticmethod + def _run(ctx) -> None: + while not ctx.stop_event.wait(ctx.interval): + ctx.function(*ctx.args, **ctx.kwargs) + + @staticmethod + def _stop_thread(thread, stop_event): + stop_event.set() + + thread.join() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9c2ea349cc67ff7175d5ef17ec63aecddbf52a7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/__init__.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Expiration timers are set up on the same process as the agent and +used from your script to deal with stuck workers. When you go into +a code-block that has the potential to get stuck you can acquire +an expiration timer, which instructs the timer server to kill the +process if it does not release the timer by the self-imposed expiration +deadline. + +Usage:: + + import torchelastic.timer as timer + import torchelastic.agent.server as agent + + def main(): + start_method = "spawn" + message_queue = mp.get_context(start_method).Queue() + server = timer.LocalTimerServer(message, max_interval=0.01) + server.start() # non-blocking + + spec = WorkerSpec( + fn=trainer_func, + args=(message_queue,), + ...) + agent = agent.LocalElasticAgent(spec, start_method) + agent.run() + + def trainer_func(message_queue): + timer.configure(timer.LocalTimerClient(message_queue)) + with timer.expires(after=60): # 60 second expiry + # do some work + +In the example above if ``trainer_func`` takes more than 60 seconds to +complete, then the worker process is killed and the agent retries the worker group. +""" + +from .api import ( # noqa: F401 + configure, + expires, + TimerClient, + TimerRequest, + TimerServer, +) +from .file_based_local_timer import ( # noqa: F401 + FileTimerClient, + FileTimerRequest, + FileTimerServer, +) +from .local_timer import LocalTimerClient, LocalTimerServer # noqa: F401 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8517ee75690eadf898f8949645a8303a03d00f37 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/api.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7a84e49a17877405dd09b99037d969b975fbfdd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/api.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/debug_info_logging.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/debug_info_logging.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6071d014c2f35547d327b53ceaee771f9675b0a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/debug_info_logging.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/file_based_local_timer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/file_based_local_timer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b20478334723d94d38e7f2c2e007f2891c00c7e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/file_based_local_timer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/local_timer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/local_timer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0008664e4ccea8f14b84ca0f3b222abb553293c8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/local_timer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/api.py new file mode 100644 index 0000000000000000000000000000000000000000..efe942022246e90c3b6b68fae59be012d9c8d56b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/api.py @@ -0,0 +1,281 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import abc +import logging +import threading +import time +from contextlib import contextmanager +from inspect import getframeinfo, stack +from typing import Any + + +__all__ = [ + "TimerRequest", + "TimerClient", + "RequestQueue", + "TimerServer", + "configure", + "expires", +] + +logger = logging.getLogger(__name__) + + +class TimerRequest: + """ + Data object representing a countdown timer acquisition and release + that is used between the ``TimerClient`` and ``TimerServer``. + A negative ``expiration_time`` should be interpreted as a "release" + request. + + .. note:: the type of ``worker_id`` is implementation specific. + It is whatever the TimerServer and TimerClient implementations + have on to uniquely identify a worker. + """ + + __slots__ = ["worker_id", "scope_id", "expiration_time"] + + def __init__(self, worker_id: Any, scope_id: str, expiration_time: float): + self.worker_id = worker_id + self.scope_id = scope_id + self.expiration_time = expiration_time + + def __eq__(self, other): + if isinstance(other, TimerRequest): + return ( + self.worker_id == other.worker_id + and self.scope_id == other.scope_id + and self.expiration_time == other.expiration_time + ) + return False + + +class TimerClient(abc.ABC): + """ + Client library to acquire and release countdown timers by communicating + with the TimerServer. + """ + + @abc.abstractmethod + def acquire(self, scope_id: str, expiration_time: float) -> None: + """ + Acquires a timer for the worker that holds this client object + given the scope_id and expiration_time. Typically registers + the timer with the TimerServer. + """ + + @abc.abstractmethod + def release(self, scope_id: str): + """ + Releases the timer for the ``scope_id`` on the worker this + client represents. After this method is + called, the countdown timer on the scope is no longer in effect. + """ + + +class RequestQueue(abc.ABC): + """ + Consumer queue holding timer acquisition/release requests + """ + + @abc.abstractmethod + def size(self) -> int: + """ + Returns the size of the queue at the time this method is called. + Note that by the time ``get`` is called the size of the queue + may have increased. The size of the queue should not decrease + until the ``get`` method is called. That is, the following assertion + should hold: + + size = q.size() + res = q.get(size, timeout=0) + assert size == len(res) + + -- or -- + + size = q.size() + res = q.get(size * 2, timeout=1) + assert size <= len(res) <= size * 2 + """ + + @abc.abstractmethod + def get(self, size: int, timeout: float) -> list[TimerRequest]: + """ + Gets up to ``size`` number of timer requests in a blocking fashion + (no more than ``timeout`` seconds). + """ + + +class TimerServer(abc.ABC): + """ + Entity that monitors active timers and expires them + in a timely fashion. This server is responsible for + reaping workers that have expired timers. + """ + + def __init__( + self, request_queue: RequestQueue, max_interval: float, daemon: bool = True + ): + """ + :param request_queue: Consumer ``RequestQueue`` + :param max_interval: max time (in seconds) to wait + for an item in the request_queue + :param daemon: whether to run the watchdog thread as a daemon + """ + super().__init__() + self._request_queue = request_queue + self._max_interval = max_interval + self._daemon = daemon + self._watchdog_thread: threading.Thread | None = None + self._stop_signaled = False + + @abc.abstractmethod + def register_timers(self, timer_requests: list[TimerRequest]) -> None: + """ + Processes the incoming timer requests and registers them with the server. + The timer request can either be a acquire-timer or release-timer request. + Timer requests with a negative expiration_time should be interpreted + as a release-timer request. + """ + + @abc.abstractmethod + def clear_timers(self, worker_ids: set[Any]) -> None: + """ + Clears all timers for the given ``worker_ids``. + """ + + @abc.abstractmethod + def get_expired_timers(self, deadline: float) -> dict[str, list[TimerRequest]]: + """ + Returns all expired timers for each worker_id. An expired timer + is a timer for which the expiration_time is less than or equal to + the provided deadline. + """ + + @abc.abstractmethod + def _reap_worker(self, worker_id: Any) -> bool: + """ + Reaps the given worker. Returns True if the worker has been + successfully reaped, False otherwise. If any uncaught exception + is thrown from this method, the worker is considered reaped + and all associated timers will be removed. + """ + + def _reap_worker_no_throw(self, worker_id: Any) -> bool: + """ + Wraps ``_reap_worker(worker_id)``, if an uncaught exception is + thrown, then it considers the worker as reaped. + """ + try: + return self._reap_worker(worker_id) + except Exception: + logger.exception( + "Uncaught exception thrown from _reap_worker(), " + "check that the implementation correctly catches exceptions", + ) + return True + + def _watchdog_loop(self): + while not self._stop_signaled: + try: + self._run_watchdog() + except Exception: + logger.exception("Error running watchdog") + + def _run_watchdog(self): + batch_size = max(1, self._request_queue.size()) + timer_requests = self._request_queue.get(batch_size, self._max_interval) + self.register_timers(timer_requests) + now = time.time() + reaped_worker_ids = set() + for worker_id, expired_timers in self.get_expired_timers(now).items(): + logger.info( + "Reaping worker_id=[%s]. Expired timers: %s", + worker_id, + self._get_scopes(expired_timers), + ) + if self._reap_worker_no_throw(worker_id): + logger.info("Successfully reaped worker=[%s]", worker_id) + reaped_worker_ids.add(worker_id) + else: + logger.error( + "Error reaping worker=[%s]. Will retry on next watchdog.", worker_id + ) + self.clear_timers(reaped_worker_ids) + + def _get_scopes(self, timer_requests): + return [r.scope_id for r in timer_requests] + + def start(self) -> None: + logger.info( + "Starting %s... max_interval=%s, daemon=%s", + type(self).__name__, + self._max_interval, + self._daemon, + ) + self._watchdog_thread = threading.Thread( + target=self._watchdog_loop, daemon=self._daemon + ) + logger.info("Starting watchdog thread...") + self._watchdog_thread.start() + + def stop(self) -> None: + logger.info("Stopping %s", type(self).__name__) + self._stop_signaled = True + if self._watchdog_thread: + logger.info("Stopping watchdog thread...") + self._watchdog_thread.join(self._max_interval) + self._watchdog_thread = None + else: + logger.info("No watchdog thread running, doing nothing") + + +_timer_client: TimerClient | None = None + + +def configure(timer_client: TimerClient): + """ + Configures a timer client. Must be called before using ``expires``. + """ + global _timer_client + _timer_client = timer_client + logger.info("Timer client configured to: %s", type(_timer_client).__name__) + + +@contextmanager +def expires(after: float, scope: str | None = None, client: TimerClient | None = None): + """ + Acquires a countdown timer that expires in ``after`` seconds from now, + unless the code-block that it wraps is finished within the timeframe. + When the timer expires, this worker is eligible to be reaped. The + exact meaning of "reaped" depends on the client implementation. In + most cases, reaping means to terminate the worker process. + Note that the worker is NOT guaranteed to be reaped at exactly + ``time.now() + after``, but rather the worker is "eligible" for being + reaped and the ``TimerServer`` that the client talks to will ultimately + make the decision when and how to reap the workers with expired timers. + + Usage:: + + torch.distributed.elastic.timer.configure(LocalTimerClient()) + with expires(after=10): + torch.distributed.all_reduce(...) + """ + if client is None: + if _timer_client is None: + raise RuntimeError("Configure timer client before using countdown timers.") + client = _timer_client + if scope is None: + # grab the caller file + lineno + caller = getframeinfo(stack()[1][0]) + scope = f"{caller.filename}#{caller.lineno}" + expiration = time.time() + after + client.acquire(scope, expiration) + try: + yield + finally: + client.release(scope) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/debug_info_logging.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/debug_info_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..e385d91283a7b610f00397bfa4bc4800a89761ca --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/debug_info_logging.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from torch.distributed.elastic.utils.logging import get_logger + + +logger = get_logger(__name__) + +__all__ = ["log_debug_info_for_expired_timers"] + + +def log_debug_info_for_expired_timers( + run_id: str, + expired_timers: dict[int, list[str]], +): + if expired_timers: + logger.info("Timers expired for run:[%s] [%s].", run_id, expired_timers) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/file_based_local_timer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/file_based_local_timer.py new file mode 100644 index 0000000000000000000000000000000000000000..5855efefcc85342378c273657fed27b37160a6ba --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/file_based_local_timer.py @@ -0,0 +1,444 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import io +import json +import os +import select +import signal +import sys +import threading +import time +from collections.abc import Callable +from typing import TypeVar +from typing_extensions import ParamSpec + +from torch.distributed.elastic.timer.api import TimerClient, TimerRequest +from torch.distributed.elastic.timer.debug_info_logging import ( + log_debug_info_for_expired_timers, +) +from torch.distributed.elastic.utils.logging import get_logger + + +_P = ParamSpec("_P") +_R = TypeVar("_R") + +__all__ = ["FileTimerClient", "FileTimerRequest", "FileTimerServer"] + +logger = get_logger(__name__) + + +def _retry(max_retries: int, sleep_time: float) -> Callable: + """ + A simple retry wrapper. + + Args: + max_retries: int, the maximum number of retries. + sleep_time: float, the time to sleep between retries. + """ + + def wrapper(func: Callable[_P, _R]) -> Callable[_P, _R]: + def wrapper(*args: _P.args, **kwargs: _P.kwargs): + for i in range(max_retries): + try: + return func(*args, **kwargs) + except Exception: + logger.exception("Error running %s. Retrying...", func.__name__) + if i < max_retries - 1: + time.sleep(sleep_time) + else: + raise + + return wrapper + + return wrapper + + +class FileTimerRequest(TimerRequest): + """ + Data object representing a countdown timer acquisition and release + that is used between the ``FileTimerClient`` and ``FileTimerServer``. + A negative ``expiration_time`` should be interpreted as a "release" + request. + ``signal`` is the signal to reap the worker process from the server + process. + """ + + __slots__ = ["version", "signal"] + + def __init__( + self, worker_pid: int, scope_id: str, expiration_time: float, signal: int = 0 + ) -> None: + super().__init__( + worker_id=worker_pid, scope_id=scope_id, expiration_time=expiration_time + ) + self.version = 1 + self.signal = signal + + @property + def worker_pid(self) -> int: + return self.worker_id + + def __eq__(self, other) -> bool: + if isinstance(other, FileTimerRequest): + return ( + super().__eq__(other) + and self.version == other.version + and self.signal == other.signal + ) + return False + + def to_json(self) -> str: + return json.dumps( + { + "version": self.version, + "pid": self.worker_pid, + "scope_id": self.scope_id, + "expiration_time": self.expiration_time, + "signal": self.signal, + }, + ) + + +class FileTimerClient(TimerClient): + """ + Client side of ``FileTimerServer``. This client is meant to be used + on the same host that the ``FileTimerServer`` is running on and uses + pid to uniquely identify a worker. + This client uses a named_pipe to send timer requests to the + ``FileTimerServer``. This client is a producer while the + ``FileTimerServer`` is a consumer. Multiple clients can work with + the same ``FileTimerServer``. + + Args: + + file_path: str, the path of a FIFO special file. ``FileTimerServer`` + must have created it by calling os.mkfifo(). + + signal: signal, the signal to use to kill the process. Using a + negative or zero signal will not kill the process. + """ + + def __init__( + self, + file_path: str, + signal=(signal.SIGKILL if sys.platform != "win32" else signal.CTRL_C_EVENT), # type: ignore[attr-defined] + ) -> None: + super().__init__() + self._file_path = file_path + self.signal = signal + + @_retry(max_retries=10, sleep_time=0.1) + def _open_non_blocking(self) -> io.TextIOWrapper | None: + # The server may have crashed or may haven't started yet. + # In such case, calling open() in blocking model blocks the client. + # To avoid such issue, open it in non-blocking mode, and an OSError will + # be raised if the server is not there. + fd = os.open(self._file_path, os.O_WRONLY | os.O_NONBLOCK) + return os.fdopen(fd, "wt") + + def _send_request(self, request: FileTimerRequest) -> None: + try: + file = self._open_non_blocking() + except Exception as e: + raise BrokenPipeError( + "Could not send the FileTimerRequest because FileTimerServer is not available." + ) from e + with file: + json_request = request.to_json() + # Write request with no greater than select.PIPE_BUF is guarantee to be atomic. + if len(json_request) > select.PIPE_BUF: + raise RuntimeError( + f"FileTimerRequest larger than {select.PIPE_BUF} bytes " + f"is not supported: {json_request}" + ) + file.write(json_request + "\n") + + def acquire(self, scope_id: str, expiration_time: float) -> None: + self._send_request( + request=FileTimerRequest( + worker_pid=os.getpid(), + scope_id=scope_id, + expiration_time=expiration_time, + signal=self.signal, + ), + ) + + def release(self, scope_id: str) -> None: + self._send_request( + request=FileTimerRequest( + worker_pid=os.getpid(), scope_id=scope_id, expiration_time=-1, signal=0 + ), + ) + + +class FileTimerServer: + """ + Server that works with ``FileTimerClient``. Clients are expected to be + running on the same host as the process that is running this server. + Each host in the job is expected to start its own timer server locally + and each server instance manages timers for local workers (running on + processes on the same host). + + Args: + + file_path: str, the path of a FIFO special file to be created. + + max_interval: float, max interval in seconds for each watchdog loop. + + daemon: bool, running the watchdog thread in daemon mode or not. + A daemon thread will not block a process to stop. + log_event: Callable[[Dict[str, str]], None], an optional callback for + logging the events in JSON format. + """ + + def __init__( + self, + file_path: str, + run_id: str, + max_interval: float = 10, + daemon: bool = True, + log_event: Callable[[str, FileTimerRequest | None], None] | None = None, + ) -> None: + self._file_path = file_path + self._run_id = run_id + self._max_interval = max_interval + self._daemon = daemon + self._timers: dict[tuple[int, str], FileTimerRequest] = {} + self._stop_signaled = False + self._watchdog_thread: threading.Thread | None = None + + self._is_client_started = False + if os.path.exists(self._file_path): + os.remove(self._file_path) + os.mkfifo(self._file_path) + # For test only. Count the number of requests received. + self._request_count = 0 + # For test only. Process all requests and stop the server. + self._run_once = False + self._log_event = ( + log_event if log_event is not None else lambda name, request: None + ) + self._last_progress_time = int(time.time()) + + def start(self) -> None: + logger.info( + "Starting %s... max_interval=%s, daemon=%s, file_path=%s", + type(self).__name__, + self._max_interval, + self._daemon, + self._file_path, + ) + self._watchdog_thread = threading.Thread( + target=self._watchdog_loop, daemon=self._daemon + ) + logger.info("Starting watchdog thread...") + self._watchdog_thread.start() + self._log_event("watchdog started", None) + + def stop(self) -> None: + logger.info("Stopping %s", type(self).__name__) + self._stop_signaled = True + if self._watchdog_thread: + logger.info("Stopping watchdog thread...") + self._watchdog_thread.join(self._max_interval) + self._watchdog_thread = None + else: + logger.info("No watchdog thread running, doing nothing") + if os.path.exists(self._file_path): + os.remove(self._file_path) + self._log_event("watchdog stopped", None) + + def run_once(self) -> None: + self._run_once = True + if self._watchdog_thread: + logger.info("Stopping watchdog thread...") + self._watchdog_thread.join() + self._watchdog_thread = None + else: + logger.info("No watchdog thread running, doing nothing") + if os.path.exists(self._file_path): + os.remove(self._file_path) + + @staticmethod + def is_process_running(pid: int): + """ + function to check process is running or not + """ + try: + # Check if the process exists and we can send signals to it + os.kill(pid, 0) + return True + except OSError: + return False + + def _watchdog_loop(self) -> None: + # Open the pipe in blocking mode blocks the server thread. + # This is fine for the following reasons: + # 1. No client case usually does not happen. + # 2. We are running the watchdog loop in a separate daemon + # thread, which will not block the process to stop. + try: + with open(self._file_path) as fd: + self._is_client_started = True + while not self._stop_signaled: + try: + run_once = self._run_once + self._run_watchdog(fd) + if run_once: + break + self._last_progress_time = int(time.time()) + except Exception: + logger.exception("Error running watchdog") + + except Exception: + logger.exception("Could not open the FileTimerServer pipe") + raise + + def _run_watchdog(self, fd: io.TextIOWrapper) -> None: + timer_requests = self._get_requests(fd, self._max_interval) + self.register_timers(timer_requests) + now = time.time() + reaped_worker_pids = set() + kill_process = False + reap_signal = 0 + + all_expired_timers = self.get_expired_timers(now) + log_debug_info_for_expired_timers( + self._run_id, + { + pid: [expired_timer.to_json() for expired_timer in expired_timers] + for pid, expired_timers in all_expired_timers.items() + }, + ) + + for worker_pid, expired_timers in all_expired_timers.items(): + logger.info( + "Reaping worker_pid=[%s]. Expired timers: %s", + worker_pid, + self._get_scopes(expired_timers), + ) + reaped_worker_pids.add(worker_pid) + # In case we have multiple expired timers, we find the first timer + # with a valid signal (>0) in the expiration time order. + expired_timers.sort(key=lambda timer: timer.expiration_time) + signal = 0 + expired_timer = None + for timer in expired_timers: + self._log_event("timer expired", timer) + if timer.signal > 0: + signal = timer.signal + expired_timer = timer + break + if signal <= 0: + logger.info( + "No signal specified with worker=[%s]. Do not reap it.", worker_pid + ) + continue + if self._reap_worker(worker_pid, signal): + logger.info( + "Successfully reaped worker=[%s] with signal=%s", worker_pid, signal + ) + self._log_event("kill worker process", expired_timer) + kill_process = True + reap_signal = signal + else: + logger.error( + "Error reaping worker=[%s]. Will retry on next watchdog.", + worker_pid, + ) + if kill_process and reap_signal > 0: + logger.info( + "Terminating the server process=[%s] because of expired timers", + os.getpid(), + ) + self._reap_worker(os.getpid(), reap_signal) + + self.clear_timers(reaped_worker_pids) + + def _get_scopes(self, timer_requests: list[FileTimerRequest]) -> list[str]: + return [r.scope_id for r in timer_requests] + + def _get_requests( + self, fd: io.TextIOWrapper, max_interval: float + ) -> list[FileTimerRequest]: + start = time.time() + requests = [] + while not self._stop_signaled or self._run_once: + # For named pipe, readline() is blocking when at least one writer opens. + # It returns only when flush() is called at the writer side. + # Note that flush() is automatically called inside close(). + # After the last writer closes, readline() is not blocking. + # It will return an empty string when it's at end-of-file. + # Since the client side always opens the pipe, writes a message and closes + # the pipe immediately, the readline() call below is not blocking for long. + json_request = fd.readline() + if len(json_request) == 0: + if self._run_once: + break + time.sleep(min(max_interval, 1)) + else: + request = json.loads(json_request) + pid = request["pid"] + scope_id = request["scope_id"] + expiration_time = request["expiration_time"] + signal = request["signal"] + requests.append( + FileTimerRequest( + worker_pid=pid, + scope_id=scope_id, + expiration_time=expiration_time, + signal=signal, + ) + ) + now = time.time() + if now - start > max_interval: + break + return requests + + def register_timers(self, timer_requests: list[FileTimerRequest]) -> None: + for request in timer_requests: + pid = request.worker_pid + scope_id = request.scope_id + expiration_time = request.expiration_time + self._request_count += 1 + + key = (pid, scope_id) + # negative expiration is a proxy for a release call + if expiration_time < 0: + if key in self._timers: + del self._timers[key] + else: + self._timers[key] = request + + def clear_timers(self, worker_pids: set[int]) -> None: + for pid, scope_id in list(self._timers.keys()): + if pid in worker_pids or not FileTimerServer.is_process_running(pid): + del self._timers[(pid, scope_id)] + + def get_expired_timers(self, deadline: float) -> dict[int, list[FileTimerRequest]]: + # pid -> [timer_requests...] + expired_timers: dict[int, list[FileTimerRequest]] = {} + for request in self._timers.values(): + if request.expiration_time <= deadline: + expired_scopes = expired_timers.setdefault(request.worker_pid, []) + expired_scopes.append(request) + return expired_timers + + def _reap_worker(self, worker_pid: int, signal: int) -> bool: + try: + os.kill(worker_pid, signal) + return True + except ProcessLookupError: + logger.info("Process with pid=%s does not exist. Skipping", worker_pid) + return True + except Exception: + logger.exception("Error terminating pid=%s", worker_pid) + return False + + def get_last_progress_time(self) -> int: + return self._last_progress_time if self._is_client_started else int(time.time()) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/local_timer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/local_timer.py new file mode 100644 index 0000000000000000000000000000000000000000..5e66ef3fae34958422c1160bfdc1994b13bf1553 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/timer/local_timer.py @@ -0,0 +1,128 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import logging +import multiprocessing as mp +import os +import signal +import time +from queue import Empty +from typing import Any + +from .api import RequestQueue, TimerClient, TimerRequest, TimerServer + + +__all__ = ["LocalTimerClient", "MultiprocessingRequestQueue", "LocalTimerServer"] + +logger = logging.getLogger(__name__) + + +class LocalTimerClient(TimerClient): + """ + Client side of ``LocalTimerServer``. This client is meant to be used + on the same host that the ``LocalTimerServer`` is running on and uses + pid to uniquely identify a worker. This is particularly useful in situations + where one spawns a subprocess (trainer) per GPU on a host with multiple + GPU devices. + """ + + def __init__(self, mp_queue): + super().__init__() + self._mp_queue = mp_queue + + def acquire(self, scope_id, expiration_time): + pid = os.getpid() + acquire_request = TimerRequest(pid, scope_id, expiration_time) + self._mp_queue.put(acquire_request) + + def release(self, scope_id): + pid = os.getpid() + release_request = TimerRequest(pid, scope_id, -1) + self._mp_queue.put(release_request) + + +class MultiprocessingRequestQueue(RequestQueue): + """ + A ``RequestQueue`` backed by python ``multiprocessing.Queue`` + """ + + def __init__(self, mp_queue: mp.Queue): + super().__init__() + self._mp_queue = mp_queue + + def size(self) -> int: + return self._mp_queue.qsize() + + def get(self, size, timeout: float) -> list[TimerRequest]: + requests = [] + wait = timeout + for _ in range(size): + start = time.time() + + try: + r = self._mp_queue.get(block=True, timeout=wait) + except Empty: + break + + requests.append(r) + wait = wait - (time.time() - start) + if wait <= 0: + break + + return requests + + +class LocalTimerServer(TimerServer): + """ + Server that works with ``LocalTimerClient``. Clients are expected to be + subprocesses to the parent process that is running this server. Each host + in the job is expected to start its own timer server locally and each + server instance manages timers for local workers (running on processes + on the same host). + """ + + def __init__( + self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True + ): + super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon) + self._timers: dict[tuple[Any, str], TimerRequest] = {} + + def register_timers(self, timer_requests: list[TimerRequest]) -> None: + for request in timer_requests: + pid = request.worker_id + scope_id = request.scope_id + expiration_time = request.expiration_time + + # negative expiration is a proxy for a release call + if expiration_time < 0: + self._timers.pop((pid, scope_id), None) + else: + self._timers[(pid, scope_id)] = request + + def clear_timers(self, worker_ids: set[int]) -> None: + for pid, scope_id in list(self._timers.keys()): + if pid in worker_ids: + self._timers.pop((pid, scope_id)) + + def get_expired_timers(self, deadline: float) -> dict[Any, list[TimerRequest]]: + # pid -> [timer_requests...] + expired_timers: dict[Any, list[TimerRequest]] = {} + for request in self._timers.values(): + if request.expiration_time <= deadline: + expired_scopes = expired_timers.setdefault(request.worker_id, []) + expired_scopes.append(request) + return expired_timers + + def _reap_worker(self, worker_id: int) -> bool: + try: + os.kill(worker_id, signal.SIGKILL) + return True + except ProcessLookupError: + logger.info("Process with pid=%s does not exist. Skipping", worker_id) + return True + except Exception: + logger.exception("Error terminating pid=%s", worker_id) + return False diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ce2bbf5bbe2348bb0eaa411a034710dd14f7648e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__init__.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .api import get_env_variable_or_raise, get_socket_with_port, macros # noqa: F401 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6fd20d93696f6634dca6c91ebfd53fcdf6d4b2e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/api.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fb77292217d06431fe63ef0cf47053f3ca2526f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/api.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/distributed.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/distributed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9410ba0b00d99cf4bd7e04c6a45788765a3737c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/distributed.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/log_level.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/log_level.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9576034456c7720c4762c55e849dde17c730d594 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/log_level.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/logging.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/logging.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b507f8852df500a516c21b304c4a5ab0aac745d2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/logging.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/store.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/store.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..144d32334605e2c5c65fe9273e927997e52ccf2c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/store.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/api.py new file mode 100644 index 0000000000000000000000000000000000000000..2b881137047c23789a061a719437a43b1743959f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/api.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import socket +from string import Template +from typing import Any + + +def get_env_variable_or_raise(env_name: str) -> str: + r""" + Tries to retrieve environment variable. Raises ``ValueError`` + if no environment variable found. + + Args: + env_name (str): Name of the env variable + """ + value = os.environ.get(env_name, None) + if value is None: + msg = f"Environment variable {env_name} expected, but not set" + raise ValueError(msg) + return value + + +def get_socket_with_port() -> socket.socket: + addrs = socket.getaddrinfo( + host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM + ) + for addr in addrs: + family, type, proto, _, _ = addr + s = socket.socket(family, type, proto) + try: + s.bind(("localhost", 0)) + s.listen(0) + return s + except OSError: + s.close() + raise RuntimeError("Failed to create a socket") + + +class macros: + """ + Defines simple macros for caffe2.distributed.launch cmd args substitution + """ + + local_rank = "${local_rank}" + + @staticmethod + def substitute(args: list[Any], local_rank: str) -> list[str]: + args_sub = [] + for arg in args: + if isinstance(arg, str): + sub = Template(arg).safe_substitute(local_rank=local_rank) + args_sub.append(sub) + else: + args_sub.append(arg) + return args_sub diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c39bca6f3c8a31f5f2d7115ad12c1fc4925fe1d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__init__.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .cycling_iterator import CyclingIterator # noqa: F401 +from .elastic_distributed_sampler import ElasticDistributedSampler # noqa: F401 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71ba79a6438b1540a7db0ab5bc993d1b863528d9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__pycache__/cycling_iterator.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__pycache__/cycling_iterator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55b97f7a41684fb90974a0085fac4f14ab8e0c87 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__pycache__/cycling_iterator.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__pycache__/elastic_distributed_sampler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__pycache__/elastic_distributed_sampler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dde258358b89eb2d6e38e0ab66a9f19d62fa918d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__pycache__/elastic_distributed_sampler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/cycling_iterator.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/cycling_iterator.py new file mode 100644 index 0000000000000000000000000000000000000000..291a04226db79c77b3bde4cec239e45b31be81b5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/cycling_iterator.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 + +from collections.abc import Callable, Iterator +from typing import TypeVar +from typing_extensions import Self + + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +_T = TypeVar("_T") + +__all__ = ["CyclingIterator"] + + +class CyclingIterator(Iterator[_T]): + """ + An iterator decorator that cycles through the + underlying iterator "n" times. Useful to "unroll" + the dataset across multiple training epochs. + + The generator function is called as ``generator_fn(epoch)`` + to obtain the underlying iterator, where ``epoch`` is a + number less than or equal to ``n`` representing the ``k``th cycle + + For example if ``generator_fn`` always returns ``[1,2,3]`` + then ``CyclingIterator(n=2, generator_fn)`` will iterate through + ``[1,2,3,1,2,3]`` + """ + + def __init__( + self, + n: int, + generator_fn: Callable[[int], Iterator[_T]], + start_epoch: int = 0, + ): + self._n = n + self._epoch = start_epoch + self._generator_fn = generator_fn + self._iter = generator_fn(self._epoch) + + def __iter__(self) -> Self: + return self + + def __next__(self) -> _T: + try: + return next(self._iter) + except StopIteration as eod: # eod == end of data + if self._epoch < self._n - 1: + self._epoch += 1 + self._iter = self._generator_fn(self._epoch) + return self.__next__() + else: + raise eod diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..c824cc2fd018c005a59d0927a53ca449bf99102d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from collections.abc import Iterator, Sized +from typing import cast, TypeVar + +import torch +from torch.utils.data import Dataset +from torch.utils.data.distributed import DistributedSampler + + +T = TypeVar("T") + +__all__ = ["ElasticDistributedSampler"] + + +class ElasticDistributedSampler(DistributedSampler[T]): + """ + Sampler that restricts data loading to a subset of + the dataset for elastic training. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size. + + Args: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + start_index (optional): Which index of the dataset to start sampling from + """ + + def __init__( + self, + dataset: Dataset[T], + num_replicas: int | None = None, + rank: int | None = None, + start_index: int = 0, + ): + super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank) + if not isinstance(dataset, Sized): + raise TypeError("Dataset must be an instance of collections.abc.Sized") + + # Cast to Sized for mypy + # pyrefly: ignore [redundant-cast] + sized_dataset = cast(Sized, dataset) + + if start_index >= len(sized_dataset): + raise ValueError( + f"Start index {start_index} should be less than dataset size {len(sized_dataset)}" + ) + + self.start_index = start_index + sized_dataset = cast(Sized, self.dataset) + self.num_samples = math.ceil( + float(len(sized_dataset) - self.start_index) / self.num_replicas + ) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self) -> Iterator[T]: + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + sized_dataset = cast(Sized, self.dataset) + indices = ( + torch.randperm(len(sized_dataset) - self.start_index, generator=g) + .add(self.start_index) + .tolist() + ) + + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/distributed.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..7b294d222ea7de5f0b7e91ac27ef876768d47eb6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/distributed.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import datetime +import os +import socket +from contextlib import closing + +import torch.distributed as dist +from torch.distributed.elastic.utils.logging import get_logger +from torch.distributed.elastic.utils.store import barrier + + +__all__ = ["create_c10d_store", "get_free_port", "get_socket_with_port"] + +logger = get_logger(__name__) + +_ADDRESS_IN_USE = "Address already in use" +_SOCKET_TIMEOUT = "Socket Timeout" + +_TCP_STORE_INIT = "_tcp_store/num_members" + + +def create_c10d_store( + is_server: bool, + server_addr: str, + server_port: int = -1, + world_size: int = 1, + timeout: float = (60 * 10), # 10 min + wait_for_workers: bool = True, + retries=3, + use_libuv: bool | None = None, +): + if use_libuv is not None: + logger.warning( + "argument use_libuv is deprecated and ignored. Set USE_LIBUV environment " + 'variable to "0" to disable libuv, or "1" to enable it. If the env var ' + "is not set, libuv will be used by default." + ) + + # check os.environ for use_libuv + use_libuv = os.environ.get("USE_LIBUV", "1") == "1" # libuv is the default option + + if server_port == -1 and world_size > 1: + raise ValueError( + f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}" + ) + + if server_port != -1: + logger.info("sever_port: %s, specified, ignoring retries", server_port) + + # only retry when server_port is NOT static + attempt = retries if server_port == -1 else 1 + while True: + if server_port != -1: + port = server_port + else: + port = get_free_port() + + logger.info( + "Creating c10d store on %s:%s\n" + " world_size : %s\n" + " is_server : %s\n" + " timeout(sec): %s\n" + " use_libuv : %s\n", + server_addr, + port, + world_size, + is_server, + timeout, + use_libuv, + ) + + try: + store = dist.TCPStore( + host_name=server_addr, + port=port, + world_size=world_size, + is_master=is_server, + timeout=datetime.timedelta(seconds=timeout), + wait_for_workers=wait_for_workers, + use_libuv=use_libuv, + ) + # skips full rank check when we don't have to wait for all workers + if wait_for_workers: + _check_full_rank(store, world_size, timeout=timeout) + logger.info("Successfully created c10d store") + return store + except RuntimeError as e: + # this is brittle, but the underlying exception type is not properly pybinded + # so we parse the error msg for now, interestingly this is how torch itself + # detects timeouts and port conflicts in their own unittests + # see - caffe2/torch/testing/_internal/common_utils.py + # TODO properly map the exceptions in pybind (c10d/init.cpp) + if str(e) == _ADDRESS_IN_USE: # this will only happen on the server + if attempt < retries: + logger.warning( + "port: %s already in use, attempt: [%s/%s]", + port, + attempt, + retries, + ) + attempt += 1 + else: + raise RuntimeError( + f"on {server_addr}, port: {port} already in use" + ) from e + else: + raise + + +def _check_full_rank(store, world_size, timeout): + try: + barrier(store, world_size, key_prefix=_TCP_STORE_INIT, barrier_timeout=timeout) + except RuntimeError as e: + if str(e) == _SOCKET_TIMEOUT: + raise TimeoutError( + f"timed out waiting for all {world_size} members to join" + ) from e + else: + raise + + +def get_free_port(): + """ + Returns an unused port on localhost. + + This function finds an unused port on localhost by opening to socket to bind + to a port and then closing it. + + Returns: + int: an unused port on localhost + + Example: + >>> # xdoctest: +SKIP("Nondeterministic") + >>> get_free_port() + 63976 + + .. note:: + The port returned by :func:`get_free_port` is not reserved and may be + taken by another process after this function returns. + """ + sock = get_socket_with_port() + with closing(sock): + return sock.getsockname()[1] + + +def get_socket_with_port() -> socket.socket: + """ + Returns a free port on localhost that is "reserved" by binding a temporary + socket on it. Close the socket before passing the port to the entity + that requires it. Usage example + + :: + + sock = _get_socket_with_port() + with closing(sock): + port = sock.getsockname()[1] + sock.close() + # there is still a race-condition that some other process + # may grab this port before func() runs + func(port) + """ + + addrs = socket.getaddrinfo( + host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM + ) + for addr in addrs: + family, type, proto, _, _ = addr + s = socket.socket(family, type, proto) + try: + s.bind(("localhost", 0)) + s.listen(0) + return s + except OSError as e: + s.close() + logger.warning("Socket creation attempt failed.", exc_info=e) + raise RuntimeError("Failed to create a socket") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/log_level.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/log_level.py new file mode 100644 index 0000000000000000000000000000000000000000..87ea0f7d64182488b40fd7fed6965ce57ec475a0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/log_level.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +def get_log_level() -> str: + """ + Return default log level for pytorch. + """ + return "WARNING" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/logging.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..aadf37eb16b8084486a537b18f399098cbcc4fb5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/logging.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +import logging +import os +import warnings + +from torch.distributed.elastic.utils.log_level import get_log_level + + +def get_logger(name: str | None = None) -> logging.Logger: + """ + Util function to set up a simple logger that writes + into stderr. The loglevel is fetched from the LOGLEVEL + env. variable or WARNING as default. The function will use the + module name of the caller if no name is provided. + + Args: + name: Name of the logger. If no name provided, the name will + be derived from the call stack. + """ + + # Derive the name of the caller, if none provided + # Use depth=2 since this function takes up one level in the call stack + return _setup_logger(name or _derive_module_name(depth=2)) + + +def _setup_logger(name: str | None = None) -> logging.Logger: + logger = logging.getLogger(name) + logger.setLevel(os.environ.get("LOGLEVEL", get_log_level())) + return logger + + +def _derive_module_name(depth: int = 1) -> str | None: + """ + Derives the name of the caller module from the stack frames. + + Args: + depth: The position of the frame in the stack. + """ + try: + stack = inspect.stack() + assert depth < len(stack) + # FrameInfo is just a named tuple: (frame, filename, lineno, function, code_context, index) + frame_info = stack[depth] + + module = inspect.getmodule(frame_info[0]) + if module: + module_name = module.__name__ + else: + # inspect.getmodule(frame_info[0]) does NOT work (returns None) in + # binaries built with @mode/opt + # return the filename (minus the .py extension) as modulename + filename = frame_info[1] + module_name = os.path.splitext(os.path.basename(filename))[0] + return module_name + except Exception as e: + warnings.warn( + f"Error deriving logger module name, using . Exception: {e}", + RuntimeWarning, + stacklevel=2, + ) + return None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/store.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/store.py new file mode 100644 index 0000000000000000000000000000000000000000..598899e936aa0c9a1c43dda38ef2479eec03f842 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/elastic/utils/store.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Callable, Iterable +from contextlib import contextmanager +from datetime import timedelta + +import torch + + +DistStoreError = torch._C._DistStoreError + +_NUM_MEMBERS = "/num_members" +_LAST_MEMBER_CHECKIN = "/last_member" +_TRACE = "/TRACE" +_TRACING_GATE = "/TRACING_GATE" +_MAX_TRACE_MISSING_RANKS = 16 + + +__all__ = ["store_timeout", "get_all", "synchronize", "barrier"] + + +@contextmanager +def store_timeout(store, timeout: float): + """ + This sets the timeout and then restores the old timeout when the context + manager exits. + + Args: + store: the store to set the timeout on + timeout: the timeout to set + """ + + old_timeout = store.timeout + store.set_timeout(timedelta(seconds=timeout)) + yield + store.set_timeout(old_timeout) + + +def get_all(store, rank: int, prefix: str, world_size: int): + r""" + Given a store and a prefix, the method goes through the array of keys + of the following format: ``{prefix}{idx}``, where idx is in a range + from 0 to size, and tries to retrieve the data. + + The Rank0 process waits at the end to make sure all other processes + finished the procedure before exiting. + + Usage + + :: + + values = get_all(store, "torchelastic/data", 3) + value1 = values[0] # retrieves the data for key torchelastic/data0 + value2 = values[1] # retrieves the data for key torchelastic/data1 + value3 = values[2] # retrieves the data for key torchelastic/data2 + + """ + data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)]) + + barrier_key = _barrier_nonblocking( + store=store, + world_size=world_size, + key_prefix=f"{prefix}/finished", + ) + if rank == 0: + # Rank0 runs the TCPStore daemon, as a result it needs to exit last. + # Otherwise, the barrier may timeout if rank0 process finished the work + # before other processes finished `get_all` method + store.wait([barrier_key]) + + return data_arr + + +def synchronize( + store, + data: bytes, + rank: int, + world_size: int, + key_prefix: str, + timeout: float = 300, +) -> list[bytes]: + """ + Synchronizes ``world_size`` agents between each other using the underlying c10d store. + The ``data`` will be available on each of the agents. + + Note: The data on the path is not deleted, as a result there can be stale data if + you use the same key_prefix twice. + + Time complexity: O(N) per worker, O(N^2) globally. + """ + with store_timeout(store, timeout): + store.set(f"{key_prefix}{rank}", data) + agent_data = get_all(store, rank, key_prefix, world_size) + return agent_data + + +def _try_detecting_missing_ranks( + store, + world_size: int, + key_prefix: str, + rank: int, + rank_decoder: Callable[[int], str], + trace_timeout: float, +) -> Iterable[str] | None: + store.set(f"{key_prefix}{rank}{_TRACE}", "") + + def _find_missing_ranks(): + missing_rank_info = set() + ranks_missing = 0 + for i in range(1, world_size): + # reduce noise, assuming in general 8 ranks per node + # It is valuable to know that 1 or >1 nodes have timed-out. + if ranks_missing >= _MAX_TRACE_MISSING_RANKS: + break + try: + if ranks_missing == 0: + store.wait( + [f"{key_prefix}{i}{_TRACE}"], timedelta(seconds=trace_timeout) + ) + else: + # use a shortest timeout, some ranks have failed to check-in + store.wait([f"{key_prefix}{i}{_TRACE}"], timedelta(milliseconds=1)) + except DistStoreError: + ranks_missing += 1 + missing_rank_info.add(rank_decoder(i)) + return missing_rank_info + + def _checkin(): + try: + store.wait([f"{key_prefix}{_TRACING_GATE}"]) + return [f"[]"] + except DistStoreError: + # in case rank0 is the source of the timeout, original exception will be raised + return None + + if rank == 0: + missing_rank_info = _find_missing_ranks() + store.set(f"{key_prefix}{_TRACING_GATE}", "") + return missing_rank_info + else: + return _checkin() + + +def _barrier_nonblocking(store, world_size: int, key_prefix: str) -> str: + """ + Does all the non-blocking operations for a barrier and returns the final key + that can be waited on. + """ + num_members_key = key_prefix + _NUM_MEMBERS + last_member_key = key_prefix + _LAST_MEMBER_CHECKIN + + idx = store.add(num_members_key, 1) + if idx == world_size: + store.set(last_member_key, "") + + return last_member_key + + +def barrier( + store, + world_size: int, + key_prefix: str, + barrier_timeout: float = 300, + rank: int | None = None, + rank_tracing_decoder: Callable[[int], str] | None = None, + trace_timeout: float = 10, +) -> None: + """ + A global lock between agents. This will pause all workers until at least + ``world_size`` workers respond. + + This uses a fast incrementing index to assign waiting ranks and a success + flag set by the last worker. + + Time complexity: O(1) per worker, O(N) globally. + + Optionally, passing rank will enable tracing of missing ranks on timeouts. + `rank_tracing_decoder` lambda arg can be used to convert rank data + into a more meaningful information at an app level (e.g. hostname). + + Note: Since the data is not removed from the store, the barrier can be used + once per unique ``key_prefix``. + """ + + if rank is None: + assert rank_tracing_decoder is None, "Tracing requires rank information" + + with store_timeout(store, barrier_timeout): + last_member_key = _barrier_nonblocking( + store=store, world_size=world_size, key_prefix=key_prefix + ) + try: + store.wait([last_member_key]) + except DistStoreError as e: + if rank is None: + raise e + else: + missing_ranks = _try_detecting_missing_ranks( + store, + world_size, + key_prefix, + rank, + rank_tracing_decoder or (lambda x: str(x)), + trace_timeout, + ) + if missing_ranks is not None: + raise DistStoreError( + "Timed out waiting on barrier on " + "rank {}, for key prefix: {} (world_size={}, missing_ranks={}, timeout={})".format( + rank, + key_prefix, + world_size, + f"[{', '.join(missing_ranks)}]", + barrier_timeout, + ) + ) from None + else: + raise e diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..608a329871232668b5e5a6d364dd3fae185b018a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/__pycache__/fr_trace.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/__pycache__/fr_trace.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c816a5aa2ac2c025f748d0ac8ba9b079f245e6d1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/__pycache__/fr_trace.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffef2a5ac91b738f8d67c5af7d4c255eb1be9c6b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/builder.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/builder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a2983037ccf742245848372ece15f1b50411a30 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/builder.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/config_manager.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/config_manager.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1feab2b5555f01da323943ae5ce82d9620af7451 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/config_manager.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/fr_logger.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/fr_logger.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..418c4773f510668cf28e886aad1384d0eec3e6e7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/fr_logger.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/loader.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/loader.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89d10e3c95f2ee76cb917808ca1070a67cb2eb42 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/loader.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/types.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/types.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cead31206a375d410b1f4f13302e7da9b71c8d28 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/types.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e53b0a9b094c6025928e46eb31c775c437451987 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/builder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..56736450e3f2a8decdc6dfc11c929d8a1bdfb16f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/builder.py @@ -0,0 +1,457 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import ast +import copy +import os +import sys +from typing import Any # type: ignore[attr-defined] + +from torch.distributed.flight_recorder.components.fr_logger import FlightRecorderLogger +from torch.distributed.flight_recorder.components.types import ( + Collective, + Database, + EntryState, + Group, + MatchStateRecord, + Membership, + NCCLCall, + Op, + Traceback, +) +from torch.distributed.flight_recorder.components.utils import ( + add_stack_id_in_entries, + align_trace_from_beginning, + check_current_entry_match, + check_no_missing_dump_files, + check_version, + error_analysis, + find_coalesced_group as find_coalesced_group_p2p_only, + find_coalesced_group_with_non_p2p, + get_version_detail, + just_print_entries, + match_coalesced_groups as match_coalesced_groups_p2p_only, + match_coalesced_groups_with_non_p2p, +) + + +__all__ = [ + "build_groups_memberships", + "build_collectives", + "transform_ft", + "build_db", +] + +# Set up logging +logger: FlightRecorderLogger = FlightRecorderLogger() + + +try: + from tabulate import tabulate +except ModuleNotFoundError: + logger.warning("tabulate is not installed. Proceeding without it.") + + # Define a no-op tabulate function + def tabulate(data: Any, headers: Any = None) -> Any: # type: ignore[misc] + return data + + +""" +Flat DB builder +""" + + +def build_groups_memberships( + pg_config: Any, +) -> tuple[ + list[Group], + dict[Any, Group], + list[Membership], + dict[str, set[Any]], + dict[tuple[str, int], str], +]: + """ + pg_config: { + global_rank: { + (pg_guid, desc, ranks) + } + } + + `pg_guid` is a system generated id, but depending on the mode of PG creation it could be a globally incrementing int + or a hash of the ranks. See `_process_group_name` in distributed_c10d.py. + `desc` is provided by the user (optionally) and should be 'meaningful' (e.g. TP/PP/DP group) + `ranks` is a list of the 'global ranks' that are members of the PG. + + (pg_guid, desc, ranks) tuples are appended lazily to the flight buffer when `getNCCLComm` is called on a PG and + the `enabled_` flag is true for that PG. + - the order of calling (init_process_group, new_group, etc) does not affect the order of the tuples in the list + + Returns: + `groups`: a groups table where each row is a Group namedtuple. + `_groups`: a dict that is indexed by pg_guid with Group namedtuple as value. + `memberships`: a membership table where each row is a Membership namedtuple. + `_memberships`: a dict that is indexed by pg_guid with set of ranks (int) as value. + `_pg_guids`: a dict that is indexed by (pg_uid, global_rank) with pg_guid as value. + """ + # flat lists for return + groups = [] + memberships = [] + + # dicts for faster cross-rank validation + _groups = {} + _memberships = {} + _pg_guids = {} + for global_rank in pg_config: + for pg_uid in pg_config[global_rank]: + desc = pg_config[global_rank][pg_uid]["desc"] + ranks = ast.literal_eval(pg_config[global_rank][pg_uid]["ranks"]) + # With the adoption of the split_group API, we can have multiple PGs with the same pg_guid (PG Name) + # So we need to add the hash of all its ranks within the PG as well. + # Also guid must be a string because `_process_group_name` returns a string. + pg_guid = pg_uid + str(hash(frozenset(ranks))) + _pg_guids[(pg_uid, global_rank)] = pg_guid + if isinstance(ranks, str): + # TODO Bug in FR data format? ranks is '[0, 1,...]' + ranks = eval(ranks) + + if pg_guid not in _groups: + groups.append(Group(id=pg_guid, desc=desc, size=len(ranks))) + for rank in ranks: + memberships.append(Membership(group_id=pg_guid, global_rank=rank)) + _groups[pg_guid] = groups[-1] + _memberships[pg_guid] = set(ranks) + else: + # validation across ranks + assert _groups[pg_guid].desc == desc, ( + f"mismatch in desc {_groups[pg_guid].desc} vs {desc} for group {pg_guid}" + ) + assert _memberships[pg_guid] == set(ranks), ( + f"mismatch in membership for group {pg_guid} {_memberships[pg_guid]} vs {set(ranks)}" + ) + return groups, _groups, memberships, _memberships, _pg_guids + + +def build_collectives( + all_entries: dict[int, list[dict[str, Any]]], + _groups: dict[str, Group], + _memberships: dict[str, set[Any]], + _pg_guids: dict[tuple[str, int], str], + version: str, + mismatch_cap: int = 10, +) -> tuple[list[Traceback], list[Collective], list[NCCLCall]]: + """ + groups, memberships are the non-flat dicts that are indexable + all_entries is a raw dict from the original dumps: + + all_entries: { + global_rank: [ + { + record_id: ordered id of the event in the trace buffer + pg_id: ProcessGroupNCCL::uid_ + *note: `pg_id` corresponds to nothing in groups table + process_group: (pg_name, desc) + *note: `pg_name`, `desc` corresponds to `pg_id`, `desc` in groups table + collective_seq_id: ordered id for collective operations and coalesced group operations + p2p_seq_id: ordered id for point-to-point operations + op_id: ordered id including individual ops inside coalescing group + profiling_name: descriptive name of the operation + 'time_created_ns', + 'input_sizes', + 'output_sizes', + 'state', + 'time_discovered_started_ns', + 'time_discovered_completed_ns', + 'retired', + 'frames', + } + ] + } + """ + tracebacks: list[Traceback] = [] + + collectives: list[Collective] = [] + nccl_calls: list[NCCLCall] = [] + + # once we find one mismatch, we stop pairing up collectives since the pairing is possibly incorrect + # instead, just record the remaining ops as NCCLCalls + mismatch = {_groups[g].id: 0 for g in _groups} + + # For best effort partial analysis. + dumps_ranks = {int(key) for key in all_entries} + """ + - it doesn't matter what order I put collectives/ncclops into their table. we can later on re-sort it by start time + - there could be multiple options for the "first" collective to pair up (rank 0,1 might do a bcast while rank 2,3 do a bcast) + - within a group, the first collective must be the same on all ranks in the group, then it can be marked as a + collective and removed + """ + while all_entries: + # we greedily match collectives, starting arbitrarily with the trace from the first rank + # later, if we exhaust the first rank, we continue with the next 'first rank' + rank_iter = iter(all_entries) + first_rank = next(rank_iter) + other_ranks = list(rank_iter) + + if len(all_entries[first_rank]) == 0: + all_entries.pop(first_rank) + continue + + # lets match the first collective! we need to know which ranks are involved, and ensure that this same + # collective is also the first one on those ranks within that group + entries = all_entries[first_rank] + current_entry = entries[0] + desc = current_entry["process_group"][1] + # For db build and logs printing, we want to use the original pg_name, not the hash one. + original_pg_name = current_entry["process_group"][0] + pg_name = _pg_guids[(original_pg_name, first_rank)] + expected_ranks = set(_memberships[pg_name]) + entry_state = EntryState(current_entry, expected_ranks) + match_record = MatchStateRecord( + expected_ranks=expected_ranks, + other_ranks=other_ranks, + entry_state=entry_state, + candidate_ranks={first_rank}, + candidate_idx={}, + found_ranks=set(), + found_idx={}, + errors=set(), + ) + + major_v, minor_v = get_version_detail(version) + find_coalesced_group = ( + find_coalesced_group_p2p_only + if major_v <= 2 and minor_v < 7 + else find_coalesced_group_with_non_p2p + ) + maybe_coalesced_group = find_coalesced_group( + pg_name, entries, _pg_guids, first_rank + ) + if len(maybe_coalesced_group) > 1: + num_coalesced_entries = len(maybe_coalesced_group) + # We need a copy of the original expected ranks to avoid modifying it. + candidate_ranks = copy.deepcopy(expected_ranks) + done_ranks = set() + all_coalesced_entries = {} + while candidate_ranks: + curr = candidate_ranks.pop() + done_ranks.add(curr) + grp = ( + find_coalesced_group(pg_name, all_entries[curr], _pg_guids, curr) # type: ignore[index] + if curr in all_entries # type: ignore[comparison-overlap] + else [] + ) + all_coalesced_entries[curr] = grp + for _, entry in grp: + op = Op(entry, _memberships, pg_name) + peer = None + if op.type == "send": + assert op._src_g == curr, ( + f"Send src error: {curr} expected but {op._src_g} is set" + ) + peer = op._dst_g + elif op.type == "recv": + assert op._dst_g == curr, ( + f"Recv dst error: {curr} expected but {op._dst_g} is set" + ) + peer = op._src_g + if peer and peer not in done_ranks: + candidate_ranks.add(peer) + + if major_v <= 2 and minor_v < 7: + match = match_coalesced_groups_p2p_only( + all_coalesced_entries, + group_size=_groups[pg_name].size, + groups=_groups, + memberships=_memberships, + _pg_guids=_pg_guids, + ) + else: + match = match_coalesced_groups_with_non_p2p( + copy.deepcopy( + all_coalesced_entries + ), # We want to keep a copy for cleanup. + pg_info=(pg_name, desc), + memberships=_memberships, + _pg_guids=_pg_guids, + mismatch=mismatch, + dumps_ranks=dumps_ranks, + version=version, + collectives=collectives, + match_record=match_record, + ) + + if match and mismatch[pg_name] == 0: + # We treat coalesced collectives as a single collective. + # TODO: we need to surface a merged collective info like input/output sizes to users. + collectives.append( + match_record.entry_state.to_collective(len(collectives)) + ) + else: + mismatch[pg_name] += 1 + for r in all_coalesced_entries: + idx_map = {r: i for i, _ in reversed(all_coalesced_entries[r])} # noqa: B035 + nccl_calls.extend( + reversed( + match_record.entry_state.to_nccl_call( + all_entries, + idx_map, + len(nccl_calls), + collectives[-1].id if match else None, + ) + ) + ) + # This extra cleanup is needed because we need to pop all collectives within a coalesced collective. + for i, k in idx_map.items(): + for _ in range(1, num_coalesced_entries): + all_entries[i].pop(k) + else: + # Iterate through all the ranks and check if there is a mismatch for the current entry. + check_current_entry_match( + all_entries, + _pg_guids, + (pg_name, desc), + current_entry, + _memberships, + mismatch, + match_record, + ) + + # Use heuristics to decide what type of errors and error messages we should print. + error_analysis( + all_entries, + match_record, + dumps_ranks, + first_rank, + current_entry, + mismatch, + get_version_detail(version), + pg_name, + ) + + # at this point there are 3 possibilities + # 1. we found a match on all the ranks that are members of the group + # -> we create a Collective and remove the individual entries from their original lists + if match_record.found_ranks == expected_ranks and mismatch[pg_name] == 0: + collectives.append( + match_record.entry_state.to_collective(len(collectives)) + ) + idx_map = { + r: match_record.found_idx[r] if r != first_rank else 0 + for r in match_record.found_ranks + } + nccl_calls.extend( + match_record.entry_state.to_nccl_call( + all_entries, idx_map, len(nccl_calls), collectives[-1].id + ) + ) + + # 2. we found a partial match but some ranks are missing + # 3. we found no match + # -> since its not a complete collective, no entry goes into collectives but we still record a nccl call + # TODO should there be a way to mark 'mismatches'? + else: + logger.debug("appending a non-matching collective") + idx_map = { + r: match_record.candidate_idx[r] if r != first_rank else 0 + for r in match_record.candidate_ranks + } + collectives.append( + match_record.entry_state.to_collective( + len(collectives), + errors=match_record.errors, + idx_map=idx_map, + all_entries=all_entries, + ) + ) + nccl_calls.extend( + match_record.entry_state.to_nccl_call( + all_entries, idx_map, len(nccl_calls), None + ) + ) + + if mismatch[pg_name] > mismatch_cap: + logger.error( + "Too many mismatches for process_group %s: %s aborting", pg_name, desc + ) + break + + return tracebacks, collectives, nccl_calls + + +def transform_ft( + details: dict[str, dict[str, Any]], group_world_size: int +) -> dict[str, dict[str, Any]]: + for dump_key, dump in details.items(): + rank = dump["rank"] + for key, pg_config in dump["pg_config"].items(): + if pg_config["desc"] == "default_pg": + ranks = eval(pg_config["ranks"]) + replica_id = rank // group_world_size + first_rank = replica_id * group_world_size + new_ranks = [r + first_rank for r in ranks] + details[dump_key]["pg_config"][key]["ranks"] = f"{new_ranks}" + + return details + + +def build_db( + details: dict[str, dict[str, Any]], args: argparse.Namespace, version: str +) -> Database: + if args.verbose: + os.environ["FR_TRACE_VERBOSE_OUTPUT"] = "1" + # temporary state used for building database + entries = {} + pg_config = {} + version_by_ranks = {} + for dump in details.values(): + rank = dump["rank"] + entries[rank] = dump["entries"] + version_by_ranks[rank] = dump["version"] + pg_config[rank] = dump["pg_config"] + + # Ensure version is consistent across all ranks. + check_version(version_by_ranks, version) + entries = align_trace_from_beginning(entries) + stack_id_trace_map: dict[str, int] = {} + if args.just_print_entries: + entries, stack_id_trace_map = add_stack_id_in_entries(entries) + + # flattened database + groups, _groups, memberships, _memberships, _pg_guids = build_groups_memberships( + pg_config + ) + logger.debug("built groups, memberships") + + if args.just_print_entries: + just_print_entries( + entries, _groups, _memberships, _pg_guids, args, stack_id_trace_map + ) + sys.exit(0) + + if not args.allow_incomplete_ranks: + check_no_missing_dump_files(entries, memberships) + + tracebacks, collectives, nccl_calls = build_collectives( + entries, _groups, _memberships, _pg_guids, version, args.mismatch_cap + ) + logger.debug("built collectives, nccl_calls") + if args.verbose: + logger.debug("Groups") + logger.debug(tabulate(groups, headers=Group._fields)) + logger.debug("Memberships") + logger.debug(tabulate(memberships, headers=Membership._fields)) + logger.debug("Collectives") + logger.debug(tabulate(collectives, headers=Collective._fields)) + logger.debug("NCCLCalls") + logger.debug(tabulate(nccl_calls, headers=NCCLCall._fields)) + db = Database( + tracebacks=tracebacks, + collectives=collectives, + ncclcalls=nccl_calls, + groups=groups, + memberships=memberships, + ) + return db diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/config_manager.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/config_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b12966588215ce01118f9aea9f8bb771390c3c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/config_manager.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +from collections.abc import Sequence + +from torch.distributed.flight_recorder.components.fr_logger import FlightRecorderLogger + + +__all__ = ["JobConfig"] + + +logger: FlightRecorderLogger = FlightRecorderLogger() + + +class JobConfig: + """ + A helper class to manage the script configuration. + """ + + def __init__(self: "JobConfig"): + self.parser = argparse.ArgumentParser( + description="PyTorch Flight recorder analyzing script." + ) + self.parser.add_argument( + "trace_dir", + nargs="?", + help="Directory containing one trace file per rank, named with _.", + ) + self.parser.add_argument( + "--selected-ranks", + default=None, + nargs="+", + type=int, + help="List of ranks we want to show traces for.", + ) + self.parser.add_argument( + "--allow-incomplete-ranks", + action="store_true", + help=( + "FR trace require all ranks to have dumps for analysis. " + "This flag allows best-effort partial analysis of results " + "and printing of collected data." + ), + ) + self.parser.add_argument( + "--pg-filters", + default=None, + nargs="+", + type=str, + help=( + "List of filter strings, it could be pg name or pg desc. " + "If specified, only show traces for the given pg." + ), + ) + self.parser.add_argument("-o", "--output", default=None) + self.parser.add_argument( + "-p", + "--prefix", + help=( + "Common filename prefix to strip such that rank can be extracted. " + "If not specified, will attempt to infer a common prefix." + ), + default=None, + ) + self.parser.add_argument("-j", "--just_print_entries", action="store_true") + self.parser.add_argument("-v", "--verbose", action="store_true") + self.parser.add_argument("--print_stack_trace", action="store_true") + self.parser.add_argument( + "--mismatch_cap", + type=int, + default=10, + help="Maximum number of mismatches we print (from earliest).", + ) + self.parser.add_argument( + "--transform-ft", + action="store_true", + help="Transform PG config to use global ranks to analyze traces produced by torchft", + ) + self.parser.add_argument( + "--group-world-size", + type=int, + default=None, + help="The number of ranks in 1 torchft replica group. Must be specified if --transform-ft is True", + ) + + def parse_args(self: "JobConfig", args: Sequence[str] | None) -> argparse.Namespace: + # pyrefly: ignore [bad-assignment] + args = self.parser.parse_args(args) + # pyrefly: ignore [missing-attribute] + if args.selected_ranks is not None: + # pyrefly: ignore [missing-attribute] + assert args.just_print_entries, ( + "Not support selecting ranks without printing entries" + ) + # pyrefly: ignore [missing-attribute] + if args.pg_filters is not None: + # pyrefly: ignore [missing-attribute] + assert args.just_print_entries, ( + "Not support selecting pg filters without printing entries" + ) + # pyrefly: ignore [missing-attribute] + if args.verbose: + logger.set_log_level(logging.DEBUG) + # pyrefly: ignore [bad-return] + return args diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/fr_logger.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/fr_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..e56634397bff9d6d1ec38eab43f1856f52e02829 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/fr_logger.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from collections.abc import Callable +from typing import Any + + +__all__ = ["FlightRecorderLogger"] + + +class FlightRecorderLogger: + _instance: Any | None = None + logger: logging.Logger + + def __init__(self) -> None: + self.logger: logging.Logger = logging.getLogger("Flight Recorder") + + def __new__(cls) -> Any: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance.logger = logging.getLogger("Flight Recorder") + cls._instance.logger.setLevel(logging.INFO) + formatter = logging.Formatter("%(message)s") + ch = logging.StreamHandler() + ch.setFormatter(formatter) + cls._instance.logger.addHandler(ch) + return cls._instance + + def set_log_level(self, level: int) -> None: + self.logger.setLevel(level) + + @property + def debug(self) -> Callable[..., None]: + return self.logger.debug + + @property + def info(self) -> Callable[..., None]: + return self.logger.info + + @property + def warning(self) -> Callable[..., None]: + return self.logger.warning + + @property + def error(self) -> Callable[..., None]: + return self.logger.error + + @property + def critical(self) -> Callable[..., None]: + return self.logger.critical diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/loader.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..ce361b103fe04488d0390df1b898d27016f2b47b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/loader.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import gc +import os +import pickle +import re +import time +from collections import defaultdict +from typing import Any + +from torch.distributed.flight_recorder.components.fr_logger import FlightRecorderLogger + + +__all__ = [ + "read_dump", + "read_dir", +] + + +logger: FlightRecorderLogger = FlightRecorderLogger() + + +def read_dump(prefix: str, filename: str) -> dict[str, str | int | list[Any]]: + basename = os.path.basename(filename) + + rank = int(basename[len(prefix) :]) + host_name = f"host_rank{rank}" + + with open(filename, "rb") as infile: + dump = pickle.load(infile) + + entries = dump["entries"] + version = dump["version"] + pg_config = dump["pg_config"] + + return { + "host_name": host_name, + "rank": rank, + "entries": entries, + "version": version, + "pg_config": pg_config, + } + + +exp = re.compile(r"([\w\-\_]*?)(\d+)$") + + +def _determine_prefix(files: list[str]) -> str: + """If the user doesn't specify a prefix, but does pass a dir full of similarly-prefixed files, we should be able to + infer the common prefix most of the time. But if we can't confidently infer, just fall back to requiring the user + to specify it + """ + possible_prefixes: defaultdict[str, set[int]] = defaultdict(set) + for f in files: + m = exp.search(f) + if m: + p, r = m.groups() + possible_prefixes[p].add(int(r)) + if len(possible_prefixes) == 1: + prefix = next(iter(possible_prefixes)) + logger.debug("Inferred common prefix %s", prefix) + return prefix + else: + raise ValueError( + "Unable to automatically determine the common prefix for the trace file names. " + "Please specify --prefix argument manually" + ) + + +def read_dir(args: argparse.Namespace) -> tuple[dict[str, dict[str, Any]], str]: + gc.disable() + prefix = args.prefix + details = {} + t0 = time.time() + version = "" + filecount = 0 + assert os.path.isdir(args.trace_dir), f"folder {args.trace_dir} does not exist" + for root, _, files in os.walk(args.trace_dir): + if prefix is None: + prefix = _determine_prefix(files) + for f in files: + if (offset := f.find(prefix)) == -1: + continue + details[f] = read_dump(f[:offset] + prefix, os.path.join(root, f)) + filecount += 1 + if not version: + version = str(details[f]["version"]) + tb = time.time() + assert len(details) > 0, ( + f"no files loaded from {args.trace_dir} with prefix {prefix}" + ) + logger.debug("loaded %s files in %ss", filecount, tb - t0) + return details, version diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/types.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/types.py new file mode 100644 index 0000000000000000000000000000000000000000..7fdfd9d8838b5e6d24c96501ba5556dd001b1a6a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/types.py @@ -0,0 +1,661 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import os +from enum import auto, Enum +from typing import ( # type: ignore[attr-defined] + _eval_type, + Any, + Generic, + NamedTuple, + TypeVar, +) + +from torch.distributed.flight_recorder.components.fr_logger import FlightRecorderLogger + + +__all__ = [ + "Ref", + "TypeInfo", + "MatchState", + "MatchInfo", + "Group", + "Membership", + "Traceback", + "Collective", + "NCCLCall", + "Database", + "EntryState", + "Op", + "MatchStateRecord", +] + + +T = TypeVar("T", bound=NamedTuple) + + +class Ref(Generic[T]): + pass + + +class TypeInfo(NamedTuple): + name: str + fields: list[tuple[str, type]] # type: ignore[type-arg] + + @classmethod + def from_type(cls, c: T) -> "TypeInfo": + if hasattr(c, "__name__"): + name = c.__name__ + else: + name = str(c) + return cls( + name, + [(f, _eval_type(c.__annotations__[f], globals(), {})) for f in c._fields], + ) + + +class MatchState(Enum): + """ + Enum representing the possible states of matching for collective operations. + + - FULLY_MATCHED: Indicates that all aspects of the collective operations match. + - COLLECTIVE_TYPE_MISMATCH: The types of the collective operations differ. + - SIZE_OR_SYNTAX_MISMATCH: There is a mismatch in input/output sizes or violation of collective syntax. + - COLLECTIVE_STATE_MISMATCH: + The states of the collective not same, such as one finished while another just started or scheduled. + - COLLECTIVE_DTYPE_MISMATCH: The data types of the collective input/output differ. + - UNDECIDED: + The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for alltoall_base. + """ + + FULLY_MATCHED = auto() + COLLECTIVE_TYPE_MISMATCH = auto() + SIZE_OR_SYNTAX_MISMATCH = auto() + COLLECTIVE_STATE_MISMATCH = auto() + COLLECTIVE_DTYPE_MISMATCH = auto() + UNDECIDED = auto() + + +class MatchInfo: + """ + Aside from the match state, we also store some dynamic info for the match such as the culprit rank + or collective state that caused the mismatch. + """ + + def __init__(self, state: MatchState, culprit: str | None = None) -> None: + self._state = state + self.culprit = culprit + + def __str__(self) -> str: + details = f", {self.culprit}" if getattr(self, "culprit", None) else "" + return f"Error type: {self._state.name}{details}" + + @property + def state(self) -> MatchState: + return self._state + + +""" +Schema for flat DB + +TODO schemas not yet implemented +# threads as recorded at termination of process +Threads + id: int + traceback_id: int + process_id: int + +Process: + id: int # Same as world groups RANK + pid: int + hostname: str + +NCCLOp: + # nccl op implementation details (sends/recv) + id: int + nccl_call_id: int + +""" + + +class Group(NamedTuple): + id: str + desc: str + size: int + + +class Membership(NamedTuple): + group_id: str + global_rank: int + + +class Traceback(NamedTuple): + id: int + frames: str + + +class Collective(NamedTuple): + id: int + group_id: str + pass_check: bool + collective_seq_id: int + p2p_seq_id: int + record_id: int + pg_desc: str + collective_name: str + input_sizes: list[list[int]] + output_sizes: list[list[int]] + expected_ranks: set[int] + collective_state: str + collective_frames: list[dict[str, str]] + input_numel: int | None = None + output_numel: int | None = None + missing_ranks: set[int] | None = None + mismatch_collectives: dict[int, "Collective"] | None = None + type_of_mismatch: MatchInfo | None = None + + +class NCCLCall(NamedTuple): + id: int + collective_id: Ref[Collective] + group_id: str + global_rank: int # technically Ref[Process] once we have it + traceback_id: Ref[Traceback] + collective_type: str + sizes: list[list[int]] + + +class Database(NamedTuple): + groups: list[Group] + memberships: list[Membership] + tracebacks: list[Traceback] + collectives: list[Collective] + ncclcalls: list[NCCLCall] + + +# TODO: We need to add a schema for the following +types = [ + TypeInfo.from_type(t) # type: ignore[type-var] + for t in [Database, NCCLCall, Collective, Traceback, Membership, Group] + if ( + isinstance(t, type) + and issubclass(t, tuple) + and hasattr(t, "_fields") + and t is not TypeInfo + ) +] + +""" +Stacktrace cache +TODO +""" + + +""" +Collective Matching logic + +NOTE: For now, these collectives need to be supported by NCCL, +https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/overview.html. +""" +COLLECTIVES = { + "broadcast", + "_broadcast_oop", + "reduce", + "_reduce_oop", + "all_gather", + "all_reduce", + "_all_gather_base", + "all_gather_into_tensor_coalesced", + "reduce_scatter", + "reduce_scatter_tensor_coalesced", + "_reduce_scatter_base", + "gather", + "scatter", + "all_to_all", + "all_reduce_barrier", + "allreduce_coalesced", + "ALLGATHER_coalesced", + "REDUCE_SCATTER_coalesced", +} + +P2P = { + "send", + "recv", +} + + +class EntryState: + """ + Util class to keep track of the state of an entry and standardize the way we + log the error info during analysis. + """ + + def __init__(self, entry: dict[str, Any], expected_ranks: set[int]) -> None: + self.pg_name = entry["process_group"][0] + self.desc = entry["process_group"][1] + self.pg_desc = ( + f"{self.pg_name}:{self.desc}" if self.desc != "undefined" else self.pg_name + ) + self.profiling_name = entry["profiling_name"] + self.collective_seq_id = entry["collective_seq_id"] + self.p2p_seq_id = entry["p2p_seq_id"] + self.record_id = entry["record_id"] + self.input_sizes = entry["input_sizes"] + self.output_sizes = entry["output_sizes"] + self.collective_state = entry["state"] + self.collective_frames = entry.get("frames", []) + self.expected_ranks = expected_ranks + self.missing_ranks: set[int] + self.input_numel: int + self.output_numel: int + self.errors: set[tuple[int, MatchInfo]] + + def log( + self, + logger: FlightRecorderLogger, + logger_msg: str, + frame_formatter: Any, + total_numel: tuple[int, int] | None = None, + errors: set[tuple[int, MatchInfo]] | None = None, + missing_ranks: set[int] | None = None, + ) -> None: + logger.info( + logger_msg, + self.collective_seq_id, + ) + logger.info("internal record id: %s", self.record_id) + logger.info("group info: %s", self.pg_desc) + logger.info("collective: %s", self.profiling_name) + if missing_ranks: + self.missing_ranks = missing_ranks + logger.info("missing ranks: %s", missing_ranks) + if total_numel: + self.input_numel = total_numel[0] + self.output_numel = total_numel[1] + logger.info("total input numel: %d", total_numel[0]) + logger.info("total output numel: %d", total_numel[1]) + logger.info("input sizes: %s", self.input_sizes) + logger.info("output sizes: %s", self.output_sizes) + logger.info("world size: %d", len(self.expected_ranks)) + logger.info("expected ranks: %s", str(self.expected_ranks)) + logger.info("collective state: %s", self.collective_state) + if errors: + self.errors = errors + error_msg = ", ".join( + f"Culprit rank {error[0]}; {str(error[1])}" for error in errors + ) + logger.info("error msg: %s", error_msg) + logger.info( + "collective stack trace: \n %s", frame_formatter(self.collective_frames) + ) + + def to_collective( + self, + id: int, + errors: set[tuple[int, MatchInfo]] | None = None, + idx_map: dict[int, int] | None = None, + all_entries: dict[int, list[dict[str, Any]]] | None = None, + ) -> Collective: + if not errors: + return Collective( + id=id, + group_id=self.pg_name, + record_id=self.record_id, + pg_desc=self.pg_desc, + pass_check=True, + collective_seq_id=self.collective_seq_id, + p2p_seq_id=self.p2p_seq_id, + collective_name=self.profiling_name, + input_sizes=self.input_sizes, + output_sizes=self.output_sizes, + expected_ranks=self.expected_ranks, + collective_state=self.collective_state, + collective_frames=self.collective_frames, + missing_ranks=getattr(self, "missing_ranks", None), + ) + else: + assert idx_map is not None, "idx_map is None" + assert all_entries is not None, "all_entries is None" + mismatch_collectives = {} + for rank, error in errors: + idx = idx_map[rank] + entry = all_entries[rank][idx] + desc = entry["process_group"][1] + pg_name = entry["process_group"][0] + mismatch_collectives[rank] = Collective( + id=id, + group_id=entry["process_group"][0], + record_id=entry["record_id"], + pg_desc=f"{pg_name}:{desc}" if desc != "undefined" else pg_name, + pass_check=False, + collective_seq_id=entry["collective_seq_id"], + p2p_seq_id=entry["p2p_seq_id"], + collective_name=entry["profiling_name"], + input_sizes=entry["input_sizes"], + output_sizes=entry["output_sizes"], + expected_ranks=self.expected_ranks, + collective_state=entry["state"], + collective_frames=entry.get("frames", []), + type_of_mismatch=error, + ) + return Collective( + id=id, + group_id=self.pg_name, + record_id=self.record_id, + pg_desc=self.pg_desc, + pass_check=False, + collective_seq_id=self.collective_seq_id, + p2p_seq_id=self.p2p_seq_id, + collective_name=self.profiling_name, + input_sizes=self.input_sizes, + output_sizes=self.output_sizes, + expected_ranks=self.expected_ranks, + collective_state=self.collective_state, + collective_frames=self.collective_frames, + input_numel=self.input_numel if hasattr(self, "input_numel") else None, + output_numel=self.output_numel + if hasattr(self, "output_numel") + else None, + missing_ranks=self.missing_ranks + if hasattr(self, "missing_ranks") + else None, + mismatch_collectives=mismatch_collectives, + ) + + def to_nccl_call( + self, + all_entries: dict[int, list[dict[str, Any]]], + idx_map: dict[int, int], + nccl_call_id: int, + collective_id: Any, + ) -> list[NCCLCall]: + result = [] + for i, k in idx_map.items(): + all_entries[i].pop(k) + result.append( + NCCLCall( + id=nccl_call_id, + collective_id=collective_id, + group_id=self.pg_name, # type: ignore[arg-type] + global_rank=i, + traceback_id=0, # type: ignore[arg-type] + collective_type=self.profiling_name, + sizes=self.input_sizes, + ) + ) + nccl_call_id += 1 + return result + + +class Op: + """Parses relevant info about operation out of 'event' dict + + examples of supported `profiling_name`s: + nccl:broadcast + nccl:send 1->2 + nccl:recv 3<-0 + """ + + def __init__( + self, event: dict[Any, Any], memberships: dict[str, set[Any]], pg_name: str + ): + self.profiling_name = event["profiling_name"] + comm_lib_backend, name = self.profiling_name.split(":") + assert comm_lib_backend in ["nccl", "xccl"], ( + f"name formatting error? {comm_lib_backend} != 'nccl' or 'xccl'" + ) + parts = name.split(" ") + type = parts[0] + meta = parts[1] if len(parts) == 2 else None + self.state = event["state"] + # Store the hashed pg_name for accessing memberships, and original pg info for display + self.pg_name = pg_name # This is the hashed version used for memberships lookup + self.original_pg_name, self.pg_desc = event["process_group"] + assert type in COLLECTIVES | P2P | {"coalesced"}, ( + f"{type} is not a supported operation" + ) + self.type = type + if type == "send": + assert isinstance(meta, str) + s, d = meta.split("->") + self._src, self._dst = int(s), int(d) + elif type == "recv": + assert isinstance(meta, str) + d, s = meta.split("<-") + self._dst, self._src = int(d), int(s) + else: + self._src, self._dst = -1, -1 + self._init_global_src_dst(memberships[pg_name]) + self.pg_size = len(memberships[pg_name]) + if type in P2P | COLLECTIVES: + self.input_sizes = event["input_sizes"] + self.output_sizes = event["output_sizes"] + else: + self.input_sizes, self.output_sizes = None, None + self.collective_seq_id = event["collective_seq_id"] + self.stack_id = event.get("stack_id", -1) + self.p2p_seq_id = event["p2p_seq_id"] + self.input_dtypes = event["input_dtypes"] + self.output_dtypes = event["output_dtypes"] + self.time_created_ns = event["time_created_ns"] + self.collective_frames = event.get("frames", []) + self.is_verbose = os.getenv("FR_TRACE_VERBOSE_OUTPUT", "0") == "1" + + def _init_global_src_dst(self, pg_ranks: set[Any]) -> None: + pg_ranks_sorted = sorted(pg_ranks) + self._src_g = pg_ranks_sorted[self._src] if self._src is not None else None + self._dst_g = pg_ranks_sorted[self._dst] if self._dst is not None else None + + @property + def src(self) -> int: + assert self.type in P2P, "can't get src of non-p2p op" + return self._src + + @property + def dst(self) -> int: + assert self.type in P2P, "can't get dst of non-p2p op" + return self._dst + + def __repr__(self) -> str: + p2p_info = "" + if self.type in P2P: + p2p_info = f"s={self._src_g} d={self._dst_g}" + if self.is_verbose: + verbose_info = ( + f"timestamp_created={self.time_created_ns}", + p2p_info, + f"input_sizes={self.input_sizes}", + f"output_sizes={self.output_sizes}", + f"input_dtypes={self.input_dtypes}", + f"output_dtypes={self.output_dtypes}", + "collective_seq_id | p2p_seq_id=" + f"{self.p2p_seq_id if self.type in P2P else self.collective_seq_id}", + f"pg_name={self.pg_name}", + f"pg_description={self.pg_desc}", + f"pg_size={self.pg_size}", + f"stack_id={self.stack_id}", + f"state={self.state}", + ) + return f"{self.type}(%s)" % ", ".join(s for s in verbose_info if s) + return f"{self.type}(%sinput_sizes={self.input_sizes}, state={self.state})" % ( + f"{p2p_info}, " if p2p_info else "" + ) + + def dtype_mismatch(self, other: "Op") -> bool: + if ( + ( + self.type not in ["scatter", "gather", "broadcast"] + and set(self.input_dtypes) != set(self.output_dtypes) + and self.input_sizes[0] + and self.output_sizes[0] + ) + or ( + self.type not in ["scatter", "broadcast"] + and set(self.input_dtypes) != set(other.input_dtypes) + and self.input_sizes[0] + and other.input_sizes[0] + ) + or ( + self.type not in ["gather"] + and set(self.output_dtypes) != set(other.output_dtypes) + and self.output_sizes[0] + and other.output_sizes[0] + ) + ): + return True + return False + + def match(self, other: "Op") -> MatchInfo: + # TODO: I think this can validly not match, + # e.g. if one PG was used for p2p ops between only some of the peers? + # if self.seq_id != other.seq_id: + # return False + + if self.type == "send": + # TODO: We need more states for p2p ops. + return ( + MatchInfo(MatchState.FULLY_MATCHED) + if ( + other.type == "recv" + and self.src == other.src + and self.dst == other.dst + and self.input_sizes == other.output_sizes + ) + else MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH) + ) + elif self.type == "recv": + return ( + MatchInfo(MatchState.FULLY_MATCHED) + if ( + other.type == "send" + and self.src == other.src + and self.dst == other.dst + and self.output_sizes == other.input_sizes + ) + else MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH) + ) + elif self.type in COLLECTIVES: + if self.type != other.type: + return MatchInfo( + MatchState.COLLECTIVE_TYPE_MISMATCH, + f"Expected collective type: '{self.type}' does not match found collective type: '{other.type}'", + ) + if ( + self.type not in ["all_to_all", "scatter"] + and self.input_sizes != other.input_sizes + ): + return MatchInfo( + MatchState.SIZE_OR_SYNTAX_MISMATCH, + f"Expected input sizes: '{self.input_sizes}' does not match found input sizes: " + f"'{other.input_sizes}'", + ) + if ( + self.type not in ["all_to_all", "gather"] + and self.output_sizes != other.output_sizes + ): + return MatchInfo( + MatchState.SIZE_OR_SYNTAX_MISMATCH, + f"Expected output sizes: '{self.output_sizes}' does not match found output sizes: " + f"'{other.output_sizes}'", + ) + if ( + self.type in ["all_reduce", "allreduce_coalesced"] + and self.input_sizes != other.output_sizes + ): + return MatchInfo( + MatchState.SIZE_OR_SYNTAX_MISMATCH, + f"Expected input sizes: '{self.input_sizes}' does not match found output sizes: '{other.output_sizes}'", + ) + if ( + self.type + in [ + "all_gather", + "all_gather_base", + "all_gather_into_tensor_coalesced", + ] + and math.prod(other.output_sizes[0]) + != math.prod(self.input_sizes[0]) * self.pg_size + ): + return MatchInfo( + MatchState.SIZE_OR_SYNTAX_MISMATCH, + f"Found input numel '{math.prod(other.input_sizes[0])} * pg size {self.pg_size}' " + f"does not match output numel '{math.prod(other.output_sizes[0])}'", + ) + if ( + self.type + in [ + "reduce_scatter", + "_reduce_scatter_base", + "reduce_scatter_tensor_coalesced", + ] + and math.prod(other.input_sizes[0]) + != math.prod(self.output_sizes[0]) * self.pg_size + ): + return MatchInfo( + MatchState.SIZE_OR_SYNTAX_MISMATCH, + f"Found input numel '{math.prod(other.input_sizes[0])}' does not match output numel " + f"'{math.prod(other.output_sizes[0])} * pg size {self.pg_size}'", + ) + if self.dtype_mismatch(other): + return MatchInfo( + MatchState.COLLECTIVE_DTYPE_MISMATCH, + f"Expected dtypes: '{set(self.input_dtypes)}' does not " + f"match found dtype: '{set(self.output_dtypes)}/" + f"{set(other.input_dtypes)}/{set(other.output_dtypes)}'", + ) + if self.state != other.state: + # MatchState() + return MatchInfo( + MatchState.COLLECTIVE_STATE_MISMATCH, + f"Expected state: '{self.state}' does not match found state: '{other.state}'", + ) + if self.type == "all_to_all": + return MatchInfo(MatchState.UNDECIDED) + elif self.type in [ + "coalesced", + "ALLGATHER_coalesced", + "REDUCE_SCATTER_coalesced", + ]: + return ( + MatchInfo(MatchState.FULLY_MATCHED) + if (other.type == self.type) + else MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH) + ) + return MatchInfo(MatchState.FULLY_MATCHED) + + +class MatchStateRecord: + def __init__( + self, + expected_ranks: set[int], + other_ranks: list[int], + entry_state: EntryState, + candidate_ranks: set[int], + candidate_idx: dict[int, int], + found_ranks: set[int], + found_idx: dict[int, int], + errors: set[tuple[int, MatchInfo]], + ) -> None: + self.expected_ranks = expected_ranks + self.other_ranks = other_ranks + self.entry_state = entry_state + self.candidate_ranks = candidate_ranks + self.candidate_idx = candidate_idx + self.found_ranks = found_ranks + self.found_idx = found_idx + self.errors = errors + self.has_undecided_case = False + + def reset_for_coalesced( + self, entry_state: EntryState, candidate_ranks: set[int] + ) -> None: + self.entry_state = entry_state + self.candidate_ranks = candidate_ranks + self.candidate_idx = {} + self.found_ranks = set() + self.found_idx = {} + self.errors = set() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6ab7919a2a24d81e7be692bc8bb9b0c326a99b28 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/components/utils.py @@ -0,0 +1,789 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import math +from typing import Any + +from torch.distributed.flight_recorder.components.fr_logger import FlightRecorderLogger +from torch.distributed.flight_recorder.components.types import ( + Collective, + EntryState, + Group, + MatchInfo, + MatchState, + MatchStateRecord, + Membership, + Op, + P2P, +) + + +__all__ = [ + "add_stack_id_in_entries", + "align_trace_from_beginning", + "check_current_entry_match", + "check_no_missing_dump_files", + "check_version", + "error_analysis", + "find_coalesced_group", + "find_coalesced_group_with_non_p2p", + "get_version_detail", + "just_print_entries", + "match_coalesced_groups_with_non_p2p", + "match_coalesced_groups", + "format_frame", + "format_frames", + "match_one_event", + "check_size_alltoall", +] + +logger: FlightRecorderLogger = FlightRecorderLogger() + + +try: + from tabulate import tabulate +except ModuleNotFoundError: + logger.debug("tabulate is not installed. Proceeding without it.") + + +def format_frame(frame: dict[str, str]) -> str: + name = frame["name"] + filename = frame["filename"] + line = frame["line"] + return f"{name} at {filename}:{line}" + + +def format_frames(frames: list[dict[str, str]]) -> str: + formatted_frames = [] + for frame in frames: + # pyrefly: ignore [bad-argument-type] + formatted_frames.append(format_frame(frame)) + return "\n".join(formatted_frames) + + +def match_one_event( + event_a: dict[Any, Any], + event_b: dict[Any, Any], + memberships: dict[str, set[Any]], + pg_name: str, +) -> MatchInfo: + op_a = Op(event_a, memberships, pg_name) + op_b = Op(event_b, memberships, pg_name) + return op_a.match(op_b) + + +def match_coalesced_groups( + all_rank_events: dict[Any, Any], + group_size: int, + groups: dict[str, Group], + memberships: dict[str, set[Any]], + _pg_guids: dict[tuple[str, int], str], +) -> bool: + """ + all_rank_events: { + rank: [ + (idx, event_dict) + ] + } + + Note: it is possible for event dicts in a coalesced group to be asymmetric. + e.g. the following events lists form a valid coalescing group + events0 [send:1] + events1 [recv:0, send:2] + events2 [recv:1] + + Rule 1: all ops should find a match + Rule 2: relative ordering of sends and recvs in one event list can be arbitrary + e.g. + events1 [recv:0, send:2] —> okay + events1 [send:2, recv:0] —> also okay + Rule 3: sends to the same dest or recvs from the src should be in a consistent order + e.g. + rank0 [send:1 (100B), send:1 (1000B)] + rank1 [recv:0 (1000B), recv:0 (100B)] —> not okay + """ + all_ops = { + rank: [ + Op(e, memberships, _pg_guids[(e["process_group"][0], rank)]) + for i, e in all_rank_events[rank] + ] + for rank in all_rank_events + } + + def visualize_ops( + match: bool, + _pg_guids: dict[tuple[str, int], str], + ) -> None: + all_ops = { + rank: [ + Op(e, memberships, _pg_guids[(e["process_group"][0], rank)]) + for i, e in all_rank_events[rank] + ] + for rank in all_rank_events + } + + i = 0 + row = [] + progress = True + table = [] + while progress: + progress = False + for r in all_ops: + if len(all_ops[r]) > i: + rank, event = all_rank_events[r][i] + # Check if the pg_guid exists for this rank and process group + pg_key = (event["process_group"][0], rank) + if pg_key in _pg_guids: + row.append( + Op( + event, + memberships, + _pg_guids[pg_key], + ) + ) + else: + # Skip this entry if pg_guid mapping doesn't exist + row.append(None) # type: ignore[arg-type] + progress = True + else: + row.append(None) # type: ignore[arg-type] + table.append(row) + row = [] + i += 1 + title = "Match" if match else "MISMATCH" + logger.info("%s \n", title) + logger.info("%s", tabulate(table)) # type: ignore[operator] + + # TODO can't verify seq_id bc there might have been valid seq deltas between ranks even within a pg. + for op_list in all_ops.values(): + if not op_list: + # print("TODO- not sure if its valid for only some ranks in a PG to participate in a coalesced op?") + return False + assert op_list[-1].type == "coalesced" + op_list.pop(-1) + + while all_ops: + first_rank = next(iter(all_ops)) + my_ops = all_ops[first_rank] + + if len(all_ops[first_rank]) == 0: + all_ops.pop(first_rank) + continue + + # lets match the first collective! we need to know which ranks are involved, and ensure that this same + # collective is also the first one on those ranks within that group + op = my_ops[0] + match_idx = -1 + if op.type in P2P: + dst_global_rank = sorted(memberships[op.pg_name])[op.dst] + peer_ops = all_ops[dst_global_rank] + for i, other in enumerate(peer_ops): + if op.match(other).state == MatchState.FULLY_MATCHED: + match_idx = i + break + elif op.dst == other.src: + # Rule 3 + break + else: + # Rule 1 + continue + else: + raise NotImplementedError("coalesced collective ops") + if match_idx >= 0: + my_ops.pop(0) + peer_ops.pop(match_idx) + else: + visualize_ops(False, _pg_guids) + return False + + visualize_ops(True, _pg_guids) + return True + + +# We enabled the creating FR entry for non-P2P slow path collective ops in v2.7. +def match_coalesced_groups_with_non_p2p( + all_rank_events: dict[Any, Any], + pg_info: tuple[str, str], + memberships: dict[str, set[Any]], + _pg_guids: dict[tuple[str, int], str], + mismatch: dict[str, int], + dumps_ranks: set[int], + version: str, + collectives: list[Collective], + match_record: MatchStateRecord, +) -> bool: + """ + all_rank_events: { + rank: [ + (idx, event_dict) + ] + } + + Note: it is possible for event dicts in a coalesced group to be asymmetric. + e.g. the following events lists form a valid coalescing group + events0 [send:1] + events1 [recv:0, send:2] + events2 [recv:1] + + Rule 1: all ops should find a match + Rule 2: relative ordering of sends and recvs in one event list can be arbitrary + e.g. + events1 [recv:0, send:2] —> okay + events1 [send:2, recv:0] —> also okay + Rule 3: sends to the same dest or recvs from the src should be in a consistent order + e.g. + rank0 [send:1 (100B), send:1 (1000B)] + rank1 [recv:0 (1000B), recv:0 (100B)] —> not okay + """ + all_ops = { + rank: [ + Op(e, memberships, _pg_guids[(e["process_group"][0], rank)]) + for _, e in all_rank_events[rank] + ] + for rank in all_rank_events + } + is_p2p = any(op.type in P2P for ops in all_ops.values() for op in ops) + pg_name = pg_info[0] + + def visualize_ops( + match: bool, + _pg_guids: dict[tuple[str, int], str], + ) -> None: + all_ops = { + rank: [ + Op(e, memberships, _pg_guids[(e["process_group"][0], rank)]) + for _, e in all_rank_events[rank] + ] + for rank in all_rank_events + } + + i = 0 + row = [] + progress = True + table = [] + while progress: + progress = False + for r in all_ops: + if len(all_ops[r]) > i: + rank, event = all_rank_events[r][i] + # Check if the pg_guid exists for this rank and process group + pg_key = (event["process_group"][0], rank) + if pg_key in _pg_guids: + row.append( + Op( + event, + memberships, + _pg_guids[pg_key], + ) + ) + else: + # Skip this entry if pg_guid mapping doesn't exist + row.append(None) # type: ignore[arg-type] + progress = True + else: + row.append(None) # type: ignore[arg-type] + table.append(row) + row = [] + i += 1 + title = "Match" if match else "MISMATCH" + logger.info("%s \n", title) + logger.info("%s", tabulate(table)) # type: ignore[operator] + + # TODO Need to verify no seq_id deltas for P2P ops. + for rank, op_list in all_ops.items(): + if not op_list: + logger.error("Rank %s has an empty op list.", rank) + continue + if op_list[-1].type == "coalesced" and is_p2p: + op_list.pop(-1) + + while all_ops: + first_rank = next(iter(all_ops)) + my_ops = all_ops[first_rank] + + if len(all_ops[first_rank]) == 0: + all_ops.pop(first_rank) + continue + + # lets match the first collective! we need to know which ranks are involved, and ensure that this same + # collective is also the first one on those ranks within that group + op = my_ops[0] + match_idx = -1 + if is_p2p: + dst_global_rank = sorted(memberships[op.pg_name])[op.dst] + peer_ops = all_ops[dst_global_rank] + for i, other in enumerate(peer_ops): + if op.match(other).state == MatchState.FULLY_MATCHED: + match_idx = i + break + elif op.dst == other.src: + # Rule 3 + break + else: + # Rule 1 + continue + if match_idx >= 0: + my_ops.pop(0) + peer_ops.pop(match_idx) + else: + visualize_ops(False, _pg_guids) + return False + else: + all_coalesced_entries = { + rank: [e for _, e in all_rank_events[rank]] for rank in all_rank_events + } + current_entry = all_coalesced_entries[first_rank][0] + my_ops.pop(0) + + match_record.reset_for_coalesced( + EntryState(current_entry, match_record.expected_ranks), + {first_rank}, + ) + + # Iterate through all the ranks and check if there is a mismatch for the current entry. + check_current_entry_match( + all_coalesced_entries, + _pg_guids, + pg_info, + current_entry, + memberships, + mismatch, + match_record, + ) + + # Use heuristics to decide what type of errors and error messages we should print. + error_analysis( + all_coalesced_entries, + match_record, + dumps_ranks, + first_rank, + current_entry, + mismatch, + get_version_detail(version), + pg_info[0], + ) + + # TODO: For now, we only check the correctness of individual collective within a coalesced one in + # this script. We need to merge (e.g, input/output sizes) together + # for downstream consumer. + + # at this point there are 3 possibilities + # 1. we found a match on all the ranks that are members of the group + # -> we create a Collective and remove the individual entries from their original lists + if ( + match_record.found_ranks == match_record.expected_ranks + and mismatch[pg_name] == 0 + ): + # Just pop out this collective. + idx_map = { + r: match_record.found_idx[r] if r != first_rank else 0 + for r in match_record.found_ranks + } + for i, k in idx_map.items(): + all_rank_events[i].pop(k) + for r in match_record.found_ranks: + if r != first_rank: + all_ops[r].pop(0) + + # 2. we found a partial match but some ranks are missing + # 3. we found no match + # -> since its not a complete collective, no entry goes into collectives but we still record a nccl call + else: + logger.debug("Non-matching collective inside coalesced group") + idx_map = { + r: match_record.candidate_idx[r] if r != first_rank else 0 + for r in match_record.candidate_ranks + } + collectives.append( + match_record.entry_state.to_collective( + len(collectives), + errors=match_record.errors, + idx_map=idx_map, + all_entries=all_coalesced_entries, + ) + ) + return False + + if is_p2p: + visualize_ops(True, _pg_guids) + return True + + +def check_size_alltoall(alltoall_cases: list[dict[str, Any]]) -> tuple[bool, int, int]: + input_numel = 0 + output_numel = 0 + for e in alltoall_cases: + input_numel += math.prod(e["input_sizes"][0]) + output_numel += math.prod(e["output_sizes"][0]) + return input_numel != output_numel, input_numel, output_numel + + +def check_current_entry_match( + all_entries: dict[int, list[dict[str, Any]]], + _pg_guids: dict[tuple[str, int], str], + pg_info: tuple[str, str], + current_entry: dict[str, Any], + _memberships: dict[str, set[Any]], + mismatch: dict[str, int], + match_record: MatchStateRecord, +) -> None: + pg_name, desc = pg_info[0], pg_info[1] + for o in match_record.expected_ranks.intersection(set(match_record.other_ranks)): + for i, e in enumerate(all_entries[o]): # type: ignore[index] + # step over ops from other PGs + # only check match state when seq_id matches + if ( + _pg_guids[(e["process_group"][0], o)] == pg_name + and e["process_group"][1] == desc + and e["collective_seq_id"] == match_record.entry_state.collective_seq_id + ): + match_info = match_one_event(current_entry, e, _memberships, pg_name) + if ( + match_info.state in [MatchState.FULLY_MATCHED, MatchState.UNDECIDED] + and mismatch[pg_name] == 0 + ): + match_record.found_ranks.add(o) + match_record.found_idx[o] = i + match_record.has_undecided_case = ( + match_info.state == MatchState.UNDECIDED + ) + else: + match_record.candidate_ranks.add(o) + match_record.candidate_idx[o] = i + if match_info.state not in [ + MatchState.FULLY_MATCHED, + MatchState.UNDECIDED, + ]: + # Here we assume the current rank is not the source of the error. + # But it's possible that the current rank is the culprit, then users will + # see lots of normal ranks reported as culprit. + # TODO: we need to figure out a better way to handle the case mentioned above. + match_record.errors.add((o, match_info)) + break + + +def error_analysis( + all_entries: dict[int, list[dict[str, Any]]], + match_record: MatchStateRecord, + dumps_ranks: set[int], + first_rank: int, + current_entry: dict[str, Any], + mismatch: dict[str, int], + version: tuple[int, int], + pg_name: str, +) -> None: + major_v, minor_v = version[0], version[1] + # case one: not every rank join the collective or in the flight recorder. + if ( + match_record.candidate_ranks | match_record.found_ranks + ) != match_record.expected_ranks and match_record.expected_ranks - ( + match_record.candidate_ranks | match_record.found_ranks + ) <= dumps_ranks: + mismatch[pg_name] += 1 + logger_msg = "Not all ranks joining collective, sequence number: %s" + missing_ranks = match_record.expected_ranks - ( + match_record.candidate_ranks | match_record.found_ranks + ) + match_record.entry_state.log( + logger, logger_msg, format_frames, missing_ranks=missing_ranks + ) + match_record.candidate_ranks.update(match_record.found_ranks) + match_record.candidate_idx.update(match_record.found_idx) + match_record.found_idx.clear() + match_record.found_ranks.clear() + # We didn't see any mismatch and all expected ranks are in the dump. + elif len( + match_record.candidate_ranks + ) == 1 and match_record.expected_ranks.issubset(dumps_ranks): + # case two: alltoall or alltoall_base case. + if match_record.has_undecided_case: + alltoall_cases = [current_entry] + [ + all_entries[o][match_record.found_idx[o]] + for o in match_record.found_ranks + ] + fail_check, total_input_numel, total_output_numel = check_size_alltoall( + alltoall_cases + ) + if major_v <= 2 and minor_v <= 3: + # We don't log the input/output sizes for alltoall before v2.4, + # so we don't consider the size mismatch as an error for now. + fail_check = False + if fail_check: + # When we see errors in all_to_all, it's hard to tell which rank is the source of the error. + mismatch[pg_name] += 1 + logger_msg = ( + "Input/output mismatch in the collective sequence number: %s" + ) + match_record.entry_state.log( + logger, + logger_msg, + format_frames, + total_numel=(total_input_numel, total_output_numel), + ) + match_record.candidate_ranks.update(match_record.found_ranks) + match_record.candidate_idx.update(match_record.found_idx) + match_record.found_idx.clear() + match_record.found_ranks.clear() + match_record.errors.add( + (first_rank, MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH)) + ) + else: + match_record.found_ranks.update(match_record.candidate_ranks) + match_record.found_idx.update(match_record.candidate_idx) + match_record.candidate_idx.clear() + match_record.candidate_ranks.clear() + # case three: all joined and everything matches on all ranks. + else: + match_record.found_ranks.update(match_record.candidate_ranks) + match_record.found_idx.update(match_record.candidate_idx) + match_record.candidate_idx.clear() + match_record.candidate_ranks.clear() + # case four: mismatch cases due to not same type, size mismatch or state mismatch. + elif len(match_record.errors) > 0: + mismatch[pg_name] += 1 + logger_msg = "Collective sequence number: %s has errors" + match_record.entry_state.log( + logger, logger_msg, format_frames, errors=match_record.errors + ) + match_record.candidate_ranks.update(match_record.found_ranks) + match_record.candidate_idx.update(match_record.found_idx) + match_record.found_idx.clear() + match_record.found_ranks.clear() + # partial analysis case when we cannot decide what's wrong with this collective entry. + else: + match_record.candidate_ranks.update(match_record.found_ranks) + match_record.candidate_idx.update(match_record.found_idx) + match_record.found_idx.clear() + match_record.found_ranks.clear() + # if any element in expected_ranks not in dumps_ranks. + if match_record.expected_ranks - dumps_ranks: + mismatch[pg_name] += 1 + logger.info( + "We cannot decide what's wrong with this collective entry " + "because we missed FR dumps from ranks (%s) so we don't have enough " + "information. If you want to debug further use -j to dump all raw trace", + str(match_record.expected_ranks - dumps_ranks), + ) + else: + logger.info( + "No errors found for this collective entry, There could be some " + "other reasons why we see collective timeout." + ) + + +def find_coalesced_group( + pg_name: str, + entries: list[dict[str, Any]], + _pg_guids: dict[tuple[str, int], str], + rank: int, +) -> list[tuple[int, dict[str, Any]]]: + """Given a list of entries, if the collective_seq_id of the first entry matches that of subsequent ones, + build an return a list of entries terminating in a 'coalesced' op entry all sharing a collective_seq_id + """ + found = [] + collective_seq_id = None + for i, e in enumerate(entries): + if _pg_guids[(e["process_group"][0], rank)] != pg_name: + continue + elif collective_seq_id is None: + collective_seq_id = ( + e["p2p_seq_id"] if e["is_p2p"] else e["collective_seq_id"] + ) + found.append((i, e)) + elif not e["is_p2p"] and e["collective_seq_id"] == collective_seq_id: + found.append((i, e)) + elif e["is_p2p"] and e["p2p_seq_id"] == collective_seq_id: + found.append((i, e)) + else: + break + + if len(found) > 1: + assert found[-1][1]["profiling_name"] == "nccl:coalesced" + return found + return [] + + +# We enabled the creating FR entry for non-P2P slow path collective ops in v2.7. +def find_coalesced_group_with_non_p2p( + pg_name: str, + entries: list[dict[str, Any]], + _pg_guids: dict[tuple[str, int], str], + rank: int, +) -> list[tuple[int, dict[str, Any]]]: + """Given a list of entries, if the collective_seq_id of the first entry matches that of subsequent ones, + build an return a list of entries terminating in a 'coalesced' op entry all sharing a collective_seq_id + """ + found = [] + collective_seq_id = None + for i, e in enumerate(entries): + if _pg_guids[(e["process_group"][0], rank)] != pg_name: + continue + elif collective_seq_id is None: + collective_seq_id = ( + e["p2p_seq_id"] if e["is_p2p"] else e["collective_seq_id"] + ) + found.append((i, e)) + elif not e["is_p2p"] and e["collective_seq_id"] == collective_seq_id: + found.append((i, e)) + elif e["is_p2p"] and e["p2p_seq_id"] == collective_seq_id: + found.append((i, e)) + else: + break + + if len(found) > 1: + name = found[-1][1]["profiling_name"] + if name.startswith("nccl:") and not name.endswith("_coalesced"): + logger.error("Rank %s does not have a coalesced end.", rank) + return found + return [] + + +def just_print_entries( + all_entries: dict[int, list[dict[str, Any]]], + _groups: dict[str, Group], + _memberships: dict[str, set[Any]], + _pg_guids: dict[tuple[str, int], str], + args: argparse.Namespace, + stack_id_trace_map: dict[str, int], +) -> None: + rows = [] + ranks = sorted(all_entries.keys()) + headers = [ + f"Rank {rank}" + for rank in ranks + if args.selected_ranks is None or rank in args.selected_ranks + ] + progress = True + while progress: + progress = False + row = [] + for rank in ranks: + if args.selected_ranks is not None and rank not in args.selected_ranks: + continue + if len(all_entries[rank]) == 0: + row.append("") + else: + entry = all_entries[rank].pop(0) + pg_name = _pg_guids[(entry["process_group"][0], rank)] + if ( + args.pg_filters is None + or entry["process_group"][1] in args.pg_filters + or entry["process_group"][0] in args.pg_filters + ): + row.append(str(Op(entry, _memberships, pg_name))) + else: + row.append("") + progress = True + if progress: + rows.append(row) + + logger.info(tabulate(rows, headers=headers)) + + if stack_id_trace_map and args.print_stack_trace: + headers = ["stack_id", "frame_stack"] + rows = [] + + for frame, stack_id in sorted( + stack_id_trace_map.items(), key=lambda item: item[1] + ): + rows.append([str(stack_id), frame]) + + logger.info(tabulate(rows, headers=headers)) + + +def check_no_missing_dump_files( + entries: dict[int, Any], memberships: list[Membership] +) -> None: + all_ranks = set() + for membership in memberships: + all_ranks.add(int(membership.global_rank)) + dumps_ranks = {int(key) for key in entries} + missing = all_ranks - dumps_ranks + assert len(missing) == 0, f"Missing dump files from ranks {missing}" + + +def check_version(version_by_ranks: dict[str, str], version: str) -> None: + for rank, v in version_by_ranks.items(): + assert v == version, ( + f"Rank {rank} has different version {v} from the given version {version}" + ) + + +def get_version_detail(version: str) -> tuple[int, int]: + # pyrefly: ignore [bad-assignment] + version = version.split(".") + assert len(version) == 2, f"Invalid version {version}" + major, minor = map(int, version) + return major, minor + + +def add_stack_id_in_entries( + entries: dict[int, list[dict[str, Any]]], +) -> tuple[dict[int, list[dict[str, Any]]], dict[str, int]]: + stack_id = 0 + stack_id_trace_map = {} + for rank in entries: + for dump in entries[rank]: + if dump.get("frames", []): + frames = str(dump["frames"]) + if frames not in stack_id_trace_map: + stack_id_trace_map[frames] = stack_id + dump["stack_id"] = stack_id + stack_id += 1 + else: + dump["stack_id"] = stack_id_trace_map[frames] + else: + dump["stack_id"] = -1 + + return entries, stack_id_trace_map + + +def align_trace_from_beginning( + entries: dict[int, list[dict[str, Any]]], +) -> dict[int, list[dict[str, Any]]]: + """ + Align the trace entries by record ID for entries. + This function takes a dictionary of rank names to lists of trace entries as input. + Each trace entry is a dictionary containing information about a collective operation, + including its unique identifier (`record_id` is monotonically increasing as we write into the ring buffer). + The function finds the largest starting point across all ranks by taking the maximum + `record_id` value of the first entry in each rank. Finally, it filters out any + entries with `record_id` values less than the maximum starting point. + The function returns the updated dictionary of sorted and filtered trace entries. + + Args: + entries (Dict[str, List[Dict[str, Any]]]): A dictionary of rank names to lists of trace entries. + + Returns: + entries (Dict[str, List[Dict[str, Any]]]): Entries sorted by record ID and filtered by the maximum starting point. + """ + + maximum_starting_record_id = 0 + for rank in entries: + # Although this is a ring buffer, we already sort the entries by `record_id` when dumping, we just + # need to find the largest starting point. For example, if the buffer has the following entries: + # Rank 0: [0, 1, 2, 3, 4, 5, 6] + # Rank 1: [1, 2, 3, 4, 5, 6, 7] + # Rank 2: [2, 3, 4, 5, 6, 7, 8] + # Rank 3: [0, 1, 2, 3, 4, 5, None] + # Then we should start from collective 2 not 0 because any collective before, + # we don't have complete records from all ranks so we need to ignore them. + # If we don't have any trace from some ranks, ignore them + # as well. + if len(entries[rank]) == 0: + continue + first_record_id = entries[rank][0]["record_id"] + maximum_starting_record_id = max(maximum_starting_record_id, first_record_id) + + for rank in entries: + entries[rank] = [ + entry + for entry in entries[rank] + if entry["record_id"] >= maximum_starting_record_id + ] + + return entries diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/fr_trace.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/fr_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..ab338d1503ae0ac4359728ba3a5983041e678f3d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/flight_recorder/fr_trace.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +"""Flight Recorder Trace Analyzer + +This script primarily merges data from individual flight recorder buffers from individual ranks in a +PyTorch Distributed program into a flattened database format that can be used for further analysis. + +However as part of the merging process, it is necessary to perform some analysis in order to match operators +on one rank with corresponding operators on other ranks and register them as one 'collective' entry. During this +process, a significant amount of useful information can already be extracted such as where the first mismatch occurs +in cases of desync (when not all ranks issue a compatible collective in a particular process group). + + +Not Yet Implemented +- TODO- tracebacks aren't implemented + +Known Issues +- Flight Recorder buffer sequence_id information is not sufficient to match collectives and coalesced collectives + unless we have the trace data from the beginning of the program. To enable confident analysis of trace buffers that + do not start from zero (and to simplify the script's matching logic) we need to add more information to the recorder. +- Currently, the script omits checking the 'status' of collectives. We can look for the first 'non completed' + collective easily enough and report that. + +Usage +python fr_trace.py [-o ] + +- Omitting the optional output file will still yield analysis information to stdout +- The output file is a pickle of the flat DB, which may change in format in the future. +- This script is versioned so that we can ensure our future changes to flight recorder are backwards compatible. +""" + +import pickle +from collections.abc import Sequence + +from torch.distributed.flight_recorder.components.builder import build_db, transform_ft +from torch.distributed.flight_recorder.components.config_manager import JobConfig +from torch.distributed.flight_recorder.components.loader import read_dir +from torch.distributed.flight_recorder.components.types import types + + +__all__ = ["main"] + + +def main(args: Sequence[str] | None = None) -> None: + config = JobConfig() + # pyrefly: ignore [bad-assignment] + args = config.parse_args(args) + # pyrefly: ignore [missing-attribute] + assert args.trace_dir, "Trace directory trace_dir is required" + # pyrefly: ignore [bad-argument-type] + details, version = read_dir(args) + # pyrefly: ignore [missing-attribute] + if args.transform_ft: + # pyrefly: ignore [missing-attribute] + assert args.group_world_size, "World size is required for transform_ft" + # pyrefly: ignore [bad-argument-type] + details = transform_ft(details, args.group_world_size) + # pyrefly: ignore [bad-argument-type] + db = build_db(details, args, version) + # pyrefly: ignore [missing-attribute] + if args.output: + # pyrefly: ignore [no-matching-overload] + with open(args.output, "wb") as f: + pickle.dump((types, db), f) + + +if __name__ == "__main__": + main() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1e4219250c39dc44dd0c1132e4e1b263de08f5c5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__init__.py @@ -0,0 +1,69 @@ +from ._flat_param import FlatParameter as FlatParameter +from ._fully_shard import ( + CPUOffloadPolicy, + FSDPModule, + fully_shard, + MixedPrecisionPolicy, + OffloadPolicy, + register_fsdp_forward_method, + share_comm_ctx, + UnshardHandle, +) +from .fully_sharded_data_parallel import ( + BackwardPrefetch, + CPUOffload, + FullOptimStateDictConfig, + FullStateDictConfig, + FullyShardedDataParallel, + LocalOptimStateDictConfig, + LocalStateDictConfig, + MixedPrecision, + OptimStateDictConfig, + OptimStateKeyType, + ShardedOptimStateDictConfig, + ShardedStateDictConfig, + ShardingStrategy, + StateDictConfig, + StateDictSettings, + StateDictType, +) + + +__all__ = [ + # FSDP1 + "BackwardPrefetch", + "CPUOffload", + "FullOptimStateDictConfig", + "FullStateDictConfig", + "FullyShardedDataParallel", + "LocalOptimStateDictConfig", + "LocalStateDictConfig", + "MixedPrecision", + "OptimStateDictConfig", + "OptimStateKeyType", + "ShardedOptimStateDictConfig", + "ShardedStateDictConfig", + "ShardingStrategy", + "StateDictConfig", + "StateDictSettings", + "StateDictType", + # FSDP2 + "CPUOffloadPolicy", + "FSDPModule", + "fully_shard", + "MixedPrecisionPolicy", + "OffloadPolicy", + "register_fsdp_forward_method", + "UnshardHandle", + "share_comm_ctx", +] + +# Set namespace for exposed private names +CPUOffloadPolicy.__module__ = "torch.distributed.fsdp" +FSDPModule.__module__ = "torch.distributed.fsdp" +fully_shard.__module__ = "torch.distributed.fsdp" +MixedPrecisionPolicy.__module__ = "torch.distributed.fsdp" +OffloadPolicy.__module__ = "torch.distributed.fsdp" +register_fsdp_forward_method.__module__ = "torch.distributed.fsdp" +UnshardHandle.__module__ = "torch.distributed.fsdp" +share_comm_ctx.__module__ = "torch.distributed.fsdp" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bac91ab2c4492c155449076660ed3941cdc1e9d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_common_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_common_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adec1ded99768b4fe571fdd64eda75624b229749 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_common_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_debug_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_debug_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2981e19044aaba4a9255008682fdd2c1f7b5114d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_debug_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_dynamo_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_dynamo_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28b7b65837838c493e96da35081b985f691957b8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_dynamo_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_exec_order_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_exec_order_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c9e8c6a1b1777720c546b9a7da3ff44a4473ecf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_exec_order_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_fsdp_extensions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_fsdp_extensions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44b38ae43ad2efb134284a6d5a588be682328ced Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_fsdp_extensions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_init_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_init_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6483ea3ae2c0f36164d49bc5cc66de02945d3f19 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_init_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_limiter_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_limiter_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4aaddb8dbb47fc04a53e22ff8014d0f7af0a7305 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_limiter_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_optim_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_optim_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba88a297d79b982f7690aecdbb7d89ac362a242e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_optim_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_runtime_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_runtime_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25145c230162edc868cb2fe2aca518d59c5c52cf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_runtime_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_shard_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_shard_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f04e75d2ef7cf9fa72e2124a592e6aed864af1b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_shard_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_state_dict_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_state_dict_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ada579b2520281bfa7a8bdbf80eabf8f0eaf820 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_state_dict_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_trace_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_trace_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ed9c9a3049a700142cc08665526affa4f76e7c7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_trace_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_traversal_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_traversal_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f39a49a7d8fa9d334e22df6c8fbc6300de2ed92 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_traversal_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_unshard_param_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_unshard_param_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db65ad48a1fe6198c16a73e78a822f404e2f03a5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_unshard_param_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_wrap_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_wrap_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e1552dc3518b8484261e412a34c0bd44d704118 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_wrap_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/api.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a47c6f1159cd161b9f32ffea4401ecce6377804f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/api.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/sharded_grad_scaler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/sharded_grad_scaler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdce49192d15d1881e200f4c5ed2c107323b2d3f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/sharded_grad_scaler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/wrap.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/wrap.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61585b092b255550bf0d57c33910617ba4ccb074 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/wrap.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_common_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_common_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..54d6c974caedf83a473148b7eb85da267f2be070 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_common_utils.py @@ -0,0 +1,550 @@ +# mypy: allow-untyped-defs +""" +This file includes private common utilities for FSDP. +""" + +import logging +import traceback +import warnings +import weakref +from collections.abc import Callable, Generator, Iterable +from enum import auto, Enum +from functools import partial +from itertools import chain +from typing import Any, cast, no_type_check, Optional, TYPE_CHECKING + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._flat_param as flat_param_file +import torch.nn as nn +from torch.distributed._composable_state import _get_module_state, _State +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + _CHECKPOINT_PREFIX, +) +from torch.distributed.utils import _apply_to_tensors +from torch.utils._mode_utils import no_dispatch + +from .api import ( + FullOptimStateDictConfig, + FullStateDictConfig, + OptimStateDictConfig, + ShardingStrategy, + StateDictConfig, + StateDictType, +) + + +if TYPE_CHECKING: + from torch.distributed.device_mesh import DeviceMesh + from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions + + from ._flat_param import FlatParamHandle + +FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module" +FSDP_PREFIX = FSDP_WRAPPED_MODULE + "." +FSDP_FLATTENED = "_fsdp_flattened" + +# Save a global mapping from module to its input tensor dtype to be populated +# during the forward pre-hook and consumed in the forward post-hook when +# overriding a module's mixed precision +# NOTE: We currently take the last input tensor's dtype in the case of multiple +# floating-point input tensors, which may be incorrect. However, since there is +# not a 1:1 correspondence between input and output tensors, we must use *some* +# heuristic like this to predict the desired output dtype. +_MODULE_TO_INP_DTYPE: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + + +class _FSDPDeviceHandle: + """ + This is a simple abstraction for FSDP computing devices, + which enables custom backends that implement CUDA-like + semantics to be integrated with FSDP. + """ + + def __init__(self, device: torch.device, backend: Any = None): + if backend is None: + try: + self.__backend = getattr(torch, device.type) + # pyrefly: ignore [read-only] + self.__device = device + except AttributeError as exc: + raise AttributeError( + f"Device '{device}' does not have a corresponding backend registered as 'torch.{device.type}'." + ) from exc + else: + self.__backend = backend + + @classmethod + def from_device(cls, device: torch.device) -> "_FSDPDeviceHandle": + """ + Return a device handle corresponding to the device, and through this handle, + operations with the same semantics as CUDA can be performed on the device. + Just return torch.cuda if the device is cuda to make attribute-access faster. + Custom backend must first register a module with the same name with {device.type} on torch. + """ + if device.type == "cuda": + return cast(_FSDPDeviceHandle, torch.cuda) + elif device.type == "mtia": + return cast(_FSDPDeviceHandle, torch.mtia) + return cls(device) + + def __getattr__(self, name: str, /) -> Any: + try: + return getattr(self.__backend, name) + except AttributeError as exc: + raise AttributeError( + f"Custom backend '{self.__device.type}' not implement 'torch.{self.__device.type}.{name}'" + ) from exc + + +class _UninitializedDeviceHandle(_FSDPDeviceHandle): + def __init__(self) -> None: + pass + + def __getattribute__(self, name: str, /) -> Any: + raise RuntimeError("Trying to use an uninitialized device handle.") + + +class _FSDPState(_State): + def __init__(self) -> None: + # TODO: Move all the attributes to this class to enable typing for + # FSDP/fully_shard. + self._ignored_modules: set[nn.Module] = set() + self._ignored_params: set[nn.Parameter] = set() + # Buffer names are cleaned (without wrapper prefixes) + self._ignored_buffer_names: set[str] = set() + self.process_group: Optional[dist.ProcessGroup] = None + self.rank: int = -1 + self.world_size: int = -1 + self._device_mesh: Optional[DeviceMesh] = None + self.sharding_strategy = ShardingStrategy.FULL_SHARD + self._use_orig_params: bool = False + self.training_state = TrainingState.IDLE + self._unshard_params_ctx: dict[nn.Module, Generator] = {} + self._state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT + self._state_dict_config: StateDictConfig = FullStateDictConfig() + self._optim_state_dict_config: OptimStateDictConfig = FullOptimStateDictConfig() + self._is_root: Optional[bool] = None + self._handle: Optional[flat_param_file.FlatParamHandle] = None + self._fully_sharded_module_to_handle: dict[ + nn.Module, Optional[flat_param_file.FlatParamHandle] + ] = {} + self.compute_device: Optional[torch.device] = None + self._gradient_predivide_factor: int = 0 + self._gradient_postdivide_factor: int = 0 + self._comm_hook: Optional[Callable] = None + self._comm_hook_state: Optional[Any] = None + self._unshard_event: Optional[torch.Event] = None + # Abstract device handle for fsdp compute device. For now, + # the compute device must implement cuda semantics used by fsdp + self._device_handle: _FSDPDeviceHandle = _UninitializedDeviceHandle() + # All following attributes should only be used for root states: + # Save these static lists to avoid the repeated tree traversals + self._all_fsdp_states: list[_FSDPState] = [] + self._all_handles: list[flat_param_file.FlatParamHandle] = [] + self._fsdp_extension: Optional[FSDPExtensions] = None + + +def _get_module_fsdp_state(module: nn.Module) -> Optional[_FSDPState]: + state = _get_module_state(module) + if state is None or not isinstance(state, _FSDPState): + return None + return state + + +def _get_module_fsdp_state_if_fully_sharded_module( + module: nn.Module, +) -> Optional[_FSDPState]: + state = _get_module_fsdp_state(module) + if state is None: + return None + if state == module: # FullyShardedDataParallel module case. + return state + if module in state._fully_sharded_module_to_handle: # fully_shard case. + return state + return None + + +class TrainingState(Enum): + """ + An enum that indicates the state of a ``FullyShardedDataParallel` instance. + """ + + IDLE = auto() + FORWARD_BACKWARD = auto() + SUMMON_FULL_PARAMS = auto() + + +class HandleTrainingState(Enum): + """ + An enum that indicates the state of a ``FlatParamHandle`. + """ + + IDLE = auto() + FORWARD = auto() + BACKWARD_PRE = auto() + BACKWARD_POST = auto() + SUMMON_FULL_PARAMS = auto() + + +def _is_composable(state: _FSDPState): + # TODO: This is a temporary hack for differentiate between code paths. + return not isinstance(state, nn.Module) + + +@no_type_check +def _module_handle(state: _FSDPState, module: nn.Module) -> Optional["FlatParamHandle"]: + """ + Returns the ``FlatParamHandle`` s corresponding to ``module``. This is + the handle that contains some parameter in ``module``. + """ + if _is_composable(state): + # A valid FSDP state may have no managed parameters and hence no + # handles, meaning no entry in `_fully_sharded_module_to_handles` + if state._handle is None: + return None + if module not in state._fully_sharded_module_to_handle: + raise AssertionError( + f"Expects a fully sharded module but got {module} on rank {state.rank}" + ) + return state._fully_sharded_module_to_handle[module] + else: + # NOTE: This assumes `module` is a `FullyShardedDataParallel` instance. + return module._handle + + +@no_type_check +def _has_fsdp_params(state: _FSDPState, module: nn.Module) -> bool: + """Returns if ``module`` has parameters managed by FSDP.""" + return _module_handle(state, module) is not None + + +def _get_sharding_strategy(handle): + """ + Returns the sharding strategy of the handle. + """ + return handle._sharding_strategy if handle else None + + +def clean_tensor_name(tensor_name: str) -> str: + """ + Cleans the parameter or buffer name by removing any module wrapper + prefixes. + """ + tensor_name = tensor_name.replace(FSDP_PREFIX, "") + # TODO: Explicitly replacing the checkpoint wrapper prefix is not ideal as + # it couples `CheckpointWrapper` and FSDP and also does not scale for more + # module wrappers. + tensor_name = tensor_name.replace(_CHECKPOINT_PREFIX, "") + return tensor_name + + +def _set_fsdp_flattened(tensor: torch.Tensor) -> None: + """ + Sets an attribute on ``tensor`` to mark it as flattened by FSDP. This is to + avoid re-flattening it during nested construction. + """ + setattr(tensor, FSDP_FLATTENED, True) + + +def _is_fsdp_flattened(tensor: torch.Tensor) -> bool: + """Returns if ``tensor`` has been marked as flattened by FSDP.""" + return getattr(tensor, FSDP_FLATTENED, False) + + +def _named_parameters_with_duplicates( + module: nn.Module, **kwargs: Any +) -> list[tuple[str, nn.Parameter]]: + """ + This API is required as some modules overwrite `named_parameters()` but do not support + `remove_duplicate`. + """ + if "remove_duplicate" in kwargs: + raise AssertionError( + "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument." + ) + kwargs["remove_duplicate"] = False + try: + ret = list(module.named_parameters(**kwargs)) + except AssertionError: + kwargs.pop("remove_duplicate") + ret = list(module.named_parameters(**kwargs)) + return ret + + +def _get_param_to_fqns( + model: torch.nn.Module, + dedup_shared_params: bool = True, +) -> dict[nn.Parameter, list[str]]: + """ + Constructs a mapping from parameter to a list of its \"canonical\" FQNs. Here, + we use canonical to mean the fully-qualified name assigned to the parameter + based on its position in the original nn.Module hierarchy before any wrapper + or parallelism has been applied to it. This is in contrast to FQNs that may be + generated after parallelisms or wrappers have been applied to the model. + + Each normal parameter maps to a singleton list containing its FQN, while each + ``FlatParameter`` maps to a list of its original parameter FQNs, which may + have length greater than one. All FQNs are prefixed starting from ``model``. + + In the case where FSDP was applied with ``use_orig_params=True``, there should be no + ``FlatParameter`` s registered to the model's modules and this mapping will only + contain mappings from ``nn.Parameter`` s to singleton FQN lists. + + It is only in the case where FSDP was applied with ``use_orig_params=False`` where + a ``FlatParameter`` will be registered in place of the original parameters and there + will be mappings from each ``FlatParameter`` to lists of FQNs corresponding to the + original parameters. + + Args: + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance). + dedup_shared_params (bool): For shared parameters, if ``True``, only + includes the FQNs corresponding to the first encounter of the + shared parameter in the module traversal; if ``False``, then + includes the FQNs across all encounters. (Default: ``True``) + """ + + def module_fn(module, prefix, tree_level, param_to_fqns): + for param_name, param in _named_parameters_with_duplicates( + module, recurse=False + ): + local_fqns = ( + param._fqns + if isinstance(param, flat_param_file.FlatParameter) + else [param_name] + ) # prefixed from `module` + global_fqns = [ + clean_tensor_name(prefix + name) for name in local_fqns + ] # prefixed from the top level `model` (i.e. including `prefix`) + is_shared_param = param in param_to_fqns + if not is_shared_param: + param_to_fqns[param] = global_fqns + else: + if isinstance(param, flat_param_file.FlatParameter): + # DMP overwrites `named_parameters` and skip (advance to + # the next child module) the wrapped_module (e.g., + # _dmp_wrapped_module and _fsdp_wrapped_module). When a user + # calls `named_child` to traverse the module recursively and + # calls `named_parameters` with `recurse=False`, parameters + # will be traversed more than once. + # This hack is specified designed for DMP + FSDP. We + # overwrite the flat_parameters traversal result to only obtain + # the last one, which happens to be the correct one. + # + # TODO: Remove this hack once DMP + FSDP is not supported. + warnings.warn( + "FlatParameter is being traversed more than once. " + "This case should only happen when using " + "DistributedModelParallel with FullyShardedDataParallel.", + stacklevel=2, + ) + param_to_fqns[param] = global_fqns + elif not dedup_shared_params: + param_to_fqns[param].extend(global_fqns) + + def return_fn(param_to_fqns): + return param_to_fqns + + param_to_unflat_param_names: dict[torch.nn.Parameter, list[str]] = {} + return _apply_to_modules( + model, + module_fn, + return_fn, + [key for key, _ in _named_parameters_with_duplicates(model)], + param_to_unflat_param_names, + ) + + +@no_type_check +def _log_post_backward_hook( + state: _FSDPState, handle: "FlatParamHandle", logger: logging.Logger +) -> None: + # Under TORCH_DISTRIBUTED_DEBUG=INFO, log the module names this hook fires for. + # Below logging of module names this post-bwd hook fires for can help debug certain + # cases where hooks don't fire, such as under certain activation checkpoint configs. + if state._use_orig_params and handle._debug_level == dist.DebugLevel.INFO: + param_fqns = _get_handle_fqns_from_root(state, handle) + logger.warning("FSDP firing post-backward hooks for parameters %s", param_fqns) + + +@no_type_check +def _get_handle_fqns_from_root( + state: _FSDPState, handle: "FlatParamHandle" +) -> Optional[list[str]]: + if handle is None: + return None + param_to_fqn = state._exec_order_data.param_to_fqn + handle_params = handle.flat_param._params # only populated for use_orig_params + param_fqns = [*chain.from_iterable(param_to_fqn[p] for p in handle_params)] + return param_fqns + + +def _apply_to_modules( + root_module: torch.nn.Module, + module_fn: Callable, + return_fn: Callable, + filter_fqns: Optional[list[str]] = None, + *args, + **kwargs, +): + """ + Performs a pre-order traversal of the modules in the hierarchy rooted at + ``root_module``, applying ``module_fn`` at each module and finally + returning a value using ``return_fn``. The traversal constructs the full + module prefix name (e.g. "module.submodule." just like in model state dict) + and makes that available to ``module_fn``. + + ``filter_fqns`` is used because some module may have its own prefix similar + to ``FullyShardedDataParallel`` and the ``named_parameters()`` is overwritten + to remove the prefix. + """ + + def f(module: torch.nn.Module, prefix: str, tree_level: int, *args, **kwargs): + # Call the module function before recursing over children (pre-order) + module_fn(module, prefix, tree_level, *args, **kwargs) + for submodule_name, submodule in module.named_children(): + if submodule is None: + continue + new_prefix = prefix + submodule_name + "." + new_tree_level = tree_level + 1 + if filter_fqns is not None: + for fqn in filter_fqns: + if fqn.startswith(new_prefix): + break + else: + # DMP's named_parameter() will mess up the traversal with + # ``named_children`` + `named_parameter(recurse=False)``. + # This hack is a must to make the traversal work. + # TODO: Remove this hack once DMP + FSDP is not supported. + # It turns out that recursive wrapping may trigger this as + # well. + if ( + submodule_name == "_fsdp_wrapped_module" + or submodule_name == "_dmp_wrapped_module" + ): + new_prefix = prefix + elif submodule_name == "module": + new_prefix = prefix + f(submodule, new_prefix, new_tree_level, *args, **kwargs) + + f(root_module, "", 0, *args, **kwargs) + return return_fn(*args, **kwargs) + + +@no_type_check +def _assert_in_training_states( + state: _FSDPState, + training_states: list[TrainingState], +) -> None: + """Asserts that FSDP is in the states ``_training_states``.""" + # Raise a `ValueError` instead of using `assert` to ensure that these + # logical assertions run even if `assert`s are disabled + if state.training_state not in training_states: + msg = ( + f"expected to be in states {training_states} but current state is " + f"{state.training_state}" + ) + # Print the error on rank 0 in case this is called in the backward pass + if state.rank == 0: + if isinstance(state, nn.Module): + print(f"Asserting FSDP instance is: {state}") + print(f"ERROR: {msg}") + traceback.print_stack() + raise ValueError(msg) + + +def _get_root_modules(modules: set[nn.Module]) -> set[nn.Module]: + """ + Returns: + Set[nn.Module]: The subset of ``modules`` that are root modules (i.e. + parent-less) with respect to the modules in the set itself. In other + words, these are the modules in ``modules`` that are not the child of + any other module in ``modules``. + """ + root_modules: set[nn.Module] = set() + module_to_submodules = {module: set(module.modules()) for module in modules} + for candidate_module in modules: + is_root_module = True + for module, submodules in module_to_submodules.items(): + is_child_module = ( + candidate_module is not module and candidate_module in submodules + ) + if is_child_module: + is_root_module = False + break + if is_root_module: + root_modules.add(candidate_module) + return root_modules + + +def _override_module_mixed_precision( + root: torch.nn.Module, + module_classes_to_override: Iterable[type[nn.Module]], + wrap_override_dict: dict[str, Any] = {"mixed_precision": None}, # noqa: B006 +) -> set[type[nn.Module]]: + module_classes_to_override = tuple(set(module_classes_to_override)) + # Return a set of the actually overridden module classes + overridden_module_classes: set[type[nn.Module]] = set() + for mod in root.modules(): + if isinstance(mod, module_classes_to_override): + overridden_module_classes.add(type(mod)) + mod._wrap_overrides = wrap_override_dict # type: ignore[assignment] + # TODO: We need to run this mixed precision ignored module in fp32, + # but ensure subsequent modules, that may possibly be running with + # mixed precision, still receive the appropriate precision inputs + # without user having to adjust mixed precision config too much. + # As a result, we attach pre and post forward hooks to up / down + # cast. We should revisit this design. + + def cast_fn( + dtype: torch.dtype, module: nn.Module, x: torch.Tensor + ) -> torch.Tensor: + if not torch.is_floating_point(x) or x.dtype == dtype: + return x + _MODULE_TO_INP_DTYPE[module] = x.dtype + return x.to(dtype) + + def forward_pre_hook(module, args): + return _apply_to_tensors(partial(cast_fn, torch.float32, module), args) + + def forward_post_hook(module, args, output): + # NOTE: If the forward did not have any floating-point tensors, + # then the dtype will not be set for this module, and we do not + # upcast the dtype. + if module in _MODULE_TO_INP_DTYPE: + old_dtype = _MODULE_TO_INP_DTYPE[module] + return _apply_to_tensors( + partial(cast_fn, old_dtype, module), output + ) + + # We intentionally append both of these hooks so that they run after + # all other hooks. + mod.register_forward_pre_hook(forward_pre_hook, prepend=False) + mod.register_forward_hook(forward_post_hook, prepend=False) + return overridden_module_classes + + +def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.Stream) -> None: + # FIXME record_stream doesn't work with non-cuda/mtia/xpu tensors + if tensor.device.type not in [ + "cuda", + "mtia", + "xpu", + torch._C._get_privateuse1_backend_name(), + ]: + return + + if torch.distributed._functional_collectives.is_torchdynamo_compiling(): + return + # from @ezyang: + # The no_dispatch was added in https://github.com/pytorch/pytorch/pull/88014 cc @fegin + # Looking over the PR, it looks like this is because we don't actually support Stream arguments + # in torch dispatch, so it just chokes. + # If Dynamo is able to answer "are there any torch dispatch modes" active (it should answer False), + # a better version of this would just be to check if there are any modes before disabling dispatch. + # TODO(voz): Extend a dynamo util to answer the above, unify the codepaths here. + tensor.record_stream(stream) + else: + with no_dispatch(): + tensor.record_stream(stream) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_debug_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_debug_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cf5a411f8c556ff1922775514cb2361a87bb492d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_debug_utils.py @@ -0,0 +1,159 @@ +# mypy: allow-untyped-defs +import logging +import time +from collections import defaultdict +from collections.abc import Iterator +from contextlib import contextmanager +from enum import Enum + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._flat_param as flat_param_file +from torch.distributed.fsdp._common_utils import ( + _apply_to_modules, + _get_module_fsdp_state, + clean_tensor_name, +) + + +logger = logging.getLogger(__name__) + + +class SimpleProfiler: + class Type(str, Enum): + ALL = "all" + ALLGATHER = "all_gather" + ALLGATHER_OBJ = "all_gather_object" + RESHARDING = "resharding" + H2D = "H2D" + D2H = "D2H" + + results: dict[str, float] = defaultdict(float) + profiling: set[str] = set() + + @classmethod + def reset(cls) -> None: + cls.results.clear() + cls.profiling.clear() + + @classmethod + @contextmanager + def profile(cls, profile_type: str) -> Iterator[None]: + if profile_type in cls.profiling: + raise AssertionError( + f"{profile_type} is already being profiled. " + "SimpleProfiler does not support profiling multiple instances at " + "the same time. " + ) + + cls.profiling.add(profile_type) + begin = time.monotonic() + try: + yield + finally: + end = time.monotonic() + cls.results[profile_type] += end - begin + cls.profiling.remove(profile_type) + + @classmethod + def dump_and_reset(cls, msg: str) -> None: + # This cannot be combined with DETAIL distributed log + # as the profiling will be very incorrect. + if dist.get_rank() == 0 and dist.get_debug_level() == dist.DebugLevel.INFO: + logger.info("%s %s", msg, cls.results) + cls.reset() + + +def _get_sharded_module_tree_with_module_name_to_fqns( + model: torch.nn.Module, +) -> tuple[str, dict[str, list[str]]]: + """ + It is used for composable fully_shard() code path, it returns + 1. sharded module tree info: each line represents a submodule name that contains the + submodule's FQN and its submodule class name, if the submodule is sharded by `fully_shard`, + the submodule name will add a postfix with ' FULLY SHARDED'. Each increased tree + level adds 4 spaces before the printed name. A printed sharded module tree info for a toy model + is like this: + [CompositeModel] FULLY SHARDED + l1[Linear] + u1[UnitModule] FULLY SHARDED + u1.l1[Linear] + u1.seq[Sequential] + u1.seq.0[ReLU] + u1.seq.1[Linear] + u1.seq.2[ReLU] + u1.l2[Linear] + u2[UnitModule] FULLY SHARDED + u2.l1[Linear] + u2.seq[Sequential] + u2.seq.0[ReLU] + u2.seq.1[Linear] + u2.seq.2[ReLU] + u2.l2[Linear] + l2[Linear] + 2. a dict mapping from the concated module FQN and class name to a list of its managed + original parameters' FQNs. An example of the dict for the above toy sharded model is like this: + {'[CompositeModel]': ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'], + 'u1[UnitModule]': ['u1.l1.weight', 'u1.l1.bias', 'u1.seq.1.weight', 'u1.seq.1.bias', 'u1.l2.weight', 'u1.l2.bias'], + 'u2[UnitModule]': ['u2.l1.weight', 'u2.l1.bias', 'u2.seq.1.weight', 'u2.seq.1.bias', 'u2.l2.weight', 'u2.l2.bias'] + } + All FQNs are prefixed starting from ``model``. + + Args: + model (torch.nn.Module): Root module (which may or may not be passed to + composable `fully_shard()`). + """ + + def module_fn( + module, prefix, tree_level, sharded_tree_info, sharded_module_name_to_fqns + ): + num_spaces = tree_level * 4 + trimed_prefix = ( + prefix[:-1] if (len(prefix) > 0 and prefix[-1] == ".") else prefix + ) + prefixed_module_name = trimed_prefix + "[" + module.__class__.__name__ + "]" + printed_prefixed_module_name = " " * num_spaces + prefixed_module_name + + state = _get_module_fsdp_state(module) + if state is None: + sharded_tree_info[0] += printed_prefixed_module_name + "\n" + return + + handle = state._fully_sharded_module_to_handle.get(module, None) + + if handle: + sharded_tree_info[0] += ( + printed_prefixed_module_name + " FULLY SHARDED" + "\n" + ) + else: + sharded_tree_info[0] += printed_prefixed_module_name + "\n" + + if handle: + param = handle.flat_param + if not isinstance(param, flat_param_file.FlatParameter): + raise AssertionError(f"Expected FlatParameter, got {type(param)}") + global_fqns = [ + clean_tensor_name(prefix + name) for name in param._fqns + ] # prefixed from the top level `model` (i.e. including `prefix`) + + if prefixed_module_name in sharded_module_name_to_fqns: + sharded_module_name_to_fqns[prefixed_module_name].extend(global_fqns) + else: + sharded_module_name_to_fqns[prefixed_module_name] = global_fqns + + def return_fn(sharded_tree_info, sharded_module_name_to_fqns): + return sharded_tree_info[0], sharded_module_name_to_fqns + + # Use List to mutate its value in place while running the recursive functions + sharded_tree_info: list[str] = [ + "", + ] + sharded_module_name_to_fqns: dict[str, list[str]] = {} + return _apply_to_modules( + model, + module_fn, + return_fn, + [key for key, _ in model.named_parameters()], + sharded_tree_info, + sharded_module_name_to_fqns, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_dynamo_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_dynamo_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..77bcd43b63be27da8e8b79f877ce7cb9d67c74b8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_dynamo_utils.py @@ -0,0 +1,43 @@ +import torch.nn as nn + + +def _annotate_modules_for_dynamo( + module: nn.Module, + ignored_modules: set[nn.Module], + use_orig_params: bool, +) -> None: + """ + Annotates the submodules in ``module`` 's tree, except those in + ``ignored_modules``, indicating that the submodules are FSDP-managed and + saving the ``use_orig_params`` setting passed to the FSDP constructor. + """ + for submodule in module.modules(): + if submodule not in ignored_modules: + """[note: Dynamo treats FSDP wrapped modules as UnspecializedNNModule] + + Dynamo doesn't get to see this instance (FullyShardedDataParallel) during tracing, since + it skips tracing all the torch.distributed.fsdp code. + - Why? Running the FSDP code eagerly avoids lots of issues trying to trace complex hooks, and also + gets us graph-breaks on FSDP module boundaries which we want anyway for comm ops. + - However, we _also_ want dynamo to treat the wrapped module inside FSDP 'unspecially' (*), + and we need a way to indicate to dynamo which modules are wrapped by FSDP. + + (*) UnspecializedNNModules in dynamo are traced-through without any assumptions, and with thorough + guards. NNModules otherwise are 'specialized', meaning there is less overhead due to assuming + their code is well-behaved. + + One particular issue with specialized NNModules for FSDP is that the + views created for orig_params are captured into the compiled graph on the first iteration, and while + they are always going to point to the correct flatparameter and give correct results, their order + of creation influences the order of backward execution, preventing overlap of comm and computation + during backward. We need to _use_ the new parameter views created on each forward iteration, in + order for backward to interleave hooks with compute per layer. UnspecializedNNModule lets us achieve + this by capturing the module code more 'functionally' and passing parameters in as inputs each time. + """ + submodule._is_fsdp_managed_module = True # type: ignore[assignment] + + # Dynamo only supports FSDP with use_orig_params=True. + # This is hacky, but I could not think of another way to add an assertion to dynamo + # for this, since Dynamo skips all the FSDP code frames and thus can't inspect the + # FSDP module directly + submodule._fsdp_use_orig_params = use_orig_params # type: ignore[assignment] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_exec_order_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_exec_order_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..db2ea7bfae0b92a6a103ac35655a6da627761e7e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_exec_order_utils.py @@ -0,0 +1,366 @@ +# mypy: allow-untyped-defs +import itertools +import warnings +from enum import auto, Enum +from typing import Optional, Union + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._traversal_utils as traversal_utils +import torch.nn as nn +from torch.distributed.fsdp._common_utils import _FSDPState, _get_param_to_fqns +from torch.distributed.fsdp._flat_param import FlatParamHandle + + +class _ExecOrderWarnStatus(Enum): + """Used internally for execution order validation.""" + + NONE = auto() # no deviation yet + WARNING = auto() # deviated this iteration; currently issuing warnings + WARNED = auto() # deviated in a previous iteration + + +class _ExecOrderData: + """ + This contains the data structures to track the execution order. We track + the pre-forward order on the *first* iteration for forward prefetching + (which thus assumes static graph) and the post-forward order on *every* + iteration for backward prefetching (which thus does not assume static + graph but may be provide an incorrect order). + """ + + def __init__( + self, + debug_level: dist.DebugLevel, + backward_prefetch_limit: int, + forward_prefetch_limit: int, + ) -> None: + # Tracks the (static) pre-forward order for execution order validation + # and forward prefetching + self.handles_pre_forward_order: list[FlatParamHandle] = [] + # Tracks the post-forward order for pre-backward prefetching + self.handles_post_forward_order: list[Optional[FlatParamHandle]] = [] + self._iter = 0 + + # Gives the max number of backward/forward prefetched all-gathers by a + # single module + self._backward_prefetch_limit = backward_prefetch_limit + self._forward_prefetch_limit = forward_prefetch_limit + + # Data structures for execution order validation + self._checking_order: bool = debug_level == dist.DebugLevel.DETAIL + self.process_group: Optional[dist.ProcessGroup] = None + self.world_size: Optional[int] = None + self.all_handles: list[FlatParamHandle] = [] + # Names are prefixed from the root module + self.param_to_fqn: dict[nn.Parameter, list[str]] = {} + # Current index in the pre-forward execution order + self.current_order_index = 0 + self.warn_status = _ExecOrderWarnStatus.NONE + + def init( + self, + state: _FSDPState, + root_module: nn.Module, + process_group: dist.ProcessGroup, + ) -> None: + """ + Initializes the data structures needed for checking the forward order. + This should be called after a root FSDP instance has been set during + lazy initialization. + """ + self.process_group = process_group + self.rank = process_group.rank() + self.world_size = process_group.size() + # Fix an order over the handles, which should be the same across ranks + for handle in traversal_utils._get_fsdp_handles(root_module): + index = len(self.all_handles) + self.all_handles.append(handle) + handle._handle_index = index + self.param_to_fqn = _get_param_to_fqns(root_module) + # TODO (awgu): We can broadcast the metadata of rank 0's `all_handles` + # to check that all ranks have the same handles in the same order. + # https://github.com/pytorch/pytorch/issues/79620 + + @property + def is_first_iter(self) -> bool: + return self._iter == 0 + + def get_handle_to_backward_prefetch( + self, + current_handle: FlatParamHandle, + ) -> Optional[FlatParamHandle]: + """ + Returns a :class:`list` of the handles keys of the handles to backward + prefetch given the current handles key. If there are no valid handles + keys to prefetch, then this returns an empty :class:`list`. + """ + current_index = current_handle._post_forward_index + if current_index is None: + return None + target_index = current_index - 1 + target_handle: Optional[FlatParamHandle] = None + for _ in range(self._backward_prefetch_limit): + if target_index < 0: + break + target_handle = self.handles_post_forward_order[target_index] + target_index -= 1 + return target_handle + + def get_handle_to_forward_prefetch( + self, + current_handle: FlatParamHandle, + ) -> Optional[FlatParamHandle]: + """ + Returns a :class:`list` of the handles keys of the handles to forward + prefetch given the current handles key. If there are no valid handles + keys to prefetch, then this returns an empty :class:`list`. + """ + current_index = current_handle._pre_forward_order_index + if current_index is None: + return None + target_index = current_index + 1 + target_handle: Optional[FlatParamHandle] = None + for _ in range(self._forward_prefetch_limit): + if target_index >= len(self.handles_pre_forward_order): + break + target_handle = self.handles_pre_forward_order[target_index] + target_index += 1 + return target_handle + + def record_post_forward(self, handle: Optional[FlatParamHandle]) -> None: + """ + Records ``handles`` in the post-forward order, where ``handles`` should + be a group of handles used in the same module's forward. If ``handles`` + is empty, then it is omitted. + + Unlike :meth:`record_pre_forward`, this records the order *every* + iteration with the expectation that the recorded order is reset in + :meth:`next_iter`. + """ + if not handle: + return + # Only record the first usage of a handles key + if handle._post_forward_index: + self.handles_post_forward_order.append(handle) + return + index = len(self.handles_post_forward_order) + handle._post_forward_index = index + self.handles_post_forward_order.append(handle) + + def record_pre_forward( + self, handle: Optional[FlatParamHandle], is_training: bool + ) -> None: + """ + Records ``handles`` in the pre-forward order, where ``handles`` should + be a group of handles used in the same module's forward. If ``handles`` + is empty, then it is omitted. + + On the first iteration, this checks the execution order across ranks. + See :meth:`_check_order` for details. + """ + if not handle: + return + self._check_order(handle, is_training) + # Fix the order after the first iteration and only record the first + # usage of a handles key + if not self.is_first_iter or handle._pre_forward_order_index is not None: + return + index = len(self.handles_pre_forward_order) + handle._pre_forward_order_index = index + self.handles_pre_forward_order.append(handle) + + def _check_order(self, handle: FlatParamHandle, is_training: bool) -> None: + """ + Checks the forward execution order as long as ``is_training`` is + ``True`` since checking in eval mode is not supported. This only checks + if the distributed debug level is DETAIL. + + - On the first iteration, this uses all-gathers to check that all ranks + are all-gathering the same handles and hence ``FlatParameter`` s, + raising an error if not. + - On subsequent iterations, this checks that each rank is locally + consistent with its own forward order from the first iteration, issuing + a warning if not. This issues a warning on the first deviating + iteration and stops warning thereafter. + """ + # Do not check order in eval mode since the post-backward callback does + # not run so it cannot be used to mark the end of an iteration + if not is_training or not self._checking_order: + return + if self.is_first_iter: + msg_prefix = "Forward order differs across ranks:" + optional_local_indices: tuple[Optional[int], ...] = ( + self._get_handle_indices(handle) + ) + device = handle.device # guaranteed to be non-CPU + num_valid_indices = sum( + (index is not None) for index in optional_local_indices + ) + tensor_kwargs: dict[str, Union[torch.dtype, torch.device]] = { + "dtype": torch.int32, + "device": device, + } + world_num_valid_indices = torch.zeros(self.world_size, **tensor_kwargs) # type: ignore[arg-type, call-overload] + local_num_valid_indices = torch.tensor([num_valid_indices], **tensor_kwargs) # type: ignore[arg-type, call-overload] + dist.all_gather_into_tensor( + world_num_valid_indices, + local_num_valid_indices, + group=self.process_group, + ) + # Copy entire tensor from D2H once to avoid per element D2H copies + world_num_valid_indices = world_num_valid_indices.cpu() + # Check that all ranks plan to all-gather the same number of + # parameters + # TODO (awgu): Since every module has at most one handle in the + # current implementation, this should never raise the error. + if self.world_size is None: + raise AssertionError("Expected world_size to not be None") + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + # TODO(voz): Don't graph break on this - dynamo hates the n1 != n2 + # tensor comparison control flow. + # https://github.com/pytorch/pytorch/issues/107055 + for (r1, n1), (r2, n2) in itertools.combinations( + ( + (rank, world_num_valid_indices[rank]) + for rank in range(self.world_size) + ), + 2, + ): + if n1 != n2: + raise RuntimeError( + f"{msg_prefix} rank {r1} is all-gathering {n1} parameters " + f"while rank {r2} is all-gathering {n2} parameters" + ) + world_indices = torch.zeros( # type: ignore[call-overload] + self.world_size * num_valid_indices, **tensor_kwargs + ) + local_indices = torch.tensor(optional_local_indices, **tensor_kwargs) # type: ignore[arg-type] + dist.all_gather_into_tensor( + world_indices, local_indices, group=self.process_group + ) + # Copy entire tensor from D2H once to avoid per element D2H copies + world_indices = world_indices.cpu() + # Check that all ranks plan to all-gather the same index parameters + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + # TODO(voz): Don't graph break on this - dynamo hates the i1 != i2 + # tensor comparison control flow. + # https://github.com/pytorch/pytorch/issues/107055 + for (r1, i1), (r2, i2) in itertools.combinations( + ( + ( + rank, + world_indices[ + rank * num_valid_indices : (rank + 1) + * num_valid_indices + ], + ) + for rank in range(self.world_size) + ), + 2, + ): + if i1 != i2: + r1_param_names = self._get_names_from_handle_indices(i1) + r2_param_names = self._get_names_from_handle_indices(i2) + raise RuntimeError( + f"{msg_prefix} rank {r1} is all-gathering parameters " + f"for {r1_param_names} while rank {r2} is all-gathering " + f"parameters for {r2_param_names}" + ) + else: + # Only issue warnings on the first deviating iteration and stop + # checking thereafter to avoid flooding the console + if self.warn_status == _ExecOrderWarnStatus.WARNED: + return + msg_prefix = None # non-`None` means we should warn + if self.current_order_index >= len(self.handles_pre_forward_order): + # This iteration sees extra all-gather(s) compared to the first + msg_prefix = ( + "Expected to not all-gather any more parameters in the " + "forward but trying to all-gather parameters for " + ) + else: + expected_handle = self.handles_pre_forward_order[ + self.current_order_index + ] + if expected_handle != handle: + expected_param_names = self._get_names_from_handles(expected_handle) + msg_prefix = ( + f"Expected to all-gather for {expected_param_names} " + "but trying to all-gather parameters for " + ) + if msg_prefix is not None: + param_names = self._get_names_from_handles(handle) + msg_suffix = ( + f"{param_names}" + if param_names + else "a newly-added parameter since construction time" + ) + warnings.warn( + "Forward order differs from that of the first iteration " + f"on rank {self.rank}. Collectives are unchecked and may " + f"give incorrect results or hang.\n{msg_prefix}{msg_suffix}", + stacklevel=2, + ) + self.warn_status = _ExecOrderWarnStatus.WARNING + self.current_order_index += 1 + + def _get_handle_indices( + self, + handle: FlatParamHandle, + ) -> tuple[Optional[int], ...]: + """ + Returns the handle indices (i.e. indices into ``self.all_handles``) + corresponding to the handles in ``handle``. An entry in the + returned tuple is ``None`` if the handle is invalid. + """ + indices: list[Optional[int]] = [] + if handle: + indices.append(handle._handle_index) + return tuple(indices) + + def _get_names_from_handle_indices( + self, + handle_indices: tuple[int, ...], + ) -> list[list[str]]: + """ + Returns a list of FQNs for each handle in ``handle_indices``. If a + handle index is invalid, then its FQNs are omitted from the returned + list. + """ + fqns: list[list[str]] = [] + for index in handle_indices: + if index is None or index < 0 or index >= len(self.all_handles): + continue + handle = self.all_handles[index] + flat_param = handle.flat_param + fqns.append(self.param_to_fqn[flat_param]) + return fqns + + def _get_names_from_handles( + self, + handle: FlatParamHandle, + ) -> list[list[str]]: + """ + Returns a list of FQNs for each handle in ``handles_key``. If a handle + is invalid, then its FQNs are omitted from the returned list. + """ + fqns: list[list[str]] = [] + if handle: + flat_param = handle.flat_param + if flat_param in self.param_to_fqn: + fqns.append(self.param_to_fqn[flat_param]) + return fqns + + def next_iter(self): + """ + Advances the internal data structures per iteration. This should be + called in the post-backward callback since that marks the true end of + an iteration. + """ + self._iter += 1 + self.handles_post_forward_order.clear() + if self._checking_order: + self.current_order_index = 0 + if self.warn_status == _ExecOrderWarnStatus.WARNING: + self.warn_status = _ExecOrderWarnStatus.WARNED diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py new file mode 100644 index 0000000000000000000000000000000000000000..85e4c23d509f8c8751ac60572a7a4a78da0fc9cf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py @@ -0,0 +1,2841 @@ +# mypy: allow-untyped-defs +import contextlib +import functools +import logging +import os +import warnings +from collections.abc import Callable, Generator, Iterator, Sequence +from enum import auto, Enum +from itertools import accumulate, chain +from typing import Any, cast, NamedTuple, no_type_check, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed.fsdp._common_utils import ( + _FSDPDeviceHandle, + _named_parameters_with_duplicates, + _no_dispatch_record_stream, + _set_fsdp_flattened, + HandleTrainingState, +) +from torch.distributed.utils import ( + _alloc_storage, + _data_ptr_allocated, + _free_storage, + _p_assert, +) +from torch.nn.parameter import _ParameterMeta # type: ignore[attr-defined] +from torch.testing._internal.distributed.fake_pg import FakeProcessGroup + +from ._fsdp_extensions import ( + _ext_post_unflatten_transform, + _ext_pre_flatten_transform, + FSDPExtensions, +) + + +__all__ = [ + "FlatParameter", + "FlatParamHandle", + "FlatParamShardMetadata", + "ParamInfo", + "SharedParamInfo", + "HandleShardingStrategy", +] + +logger = logging.getLogger(__name__) + + +""" +[Note: Fully Sharded Module] +We define the "fully sharded module" to be the original ``nn.Module`` that owns +a ``FlatParamHandle``. It is the *single* module logically responsible for the +*single* unshard/reshard pair for the handle's ``FlatParameter`` for a given +forward or backward pass. The fully sharded module should be passed to the +``FlatParamHandle`` constructor. + +For the wrapper code path: +- The ``FullyShardedDataParallel`` module wrapping the fully sharded module +runs the unshard/reshard on behalf of the fully sharded module by overriding +``nn.Module.forward``. +- The fully sharded module is exactly the module passed to the +``FullyShardedDataParallel`` constructor's ``module`` argument. + +For the non-wrapper code path: +- Hooks registered on the fully sharded module run the unshard/reshard. +- The fully sharded module may either be the direct argument to ``fully_shard`` +or a submodule chosen by the provided wrapping policy. +""" + +# Environment variable toggling whether to use unsafe `setattr()` for view +# setting in `_use_sharded_views()` and `_use_unsharded_views()` +# We should use 'safe' by default since it respects method overrides, but for +# special cases such as for high CPU overhead or for intentionally bypassing +# checks in the overrides, we may use 'unsafe'. +_FSDP_USE_UNSAFE_SETATTR = "FSDP_USE_UNSAFE_SETATTR" + +# Environment variable toggling whether to check for parameter/gradient +# writeback in case their storages change after FSDP initialization +# We should check by default since it prevents silent correctness errors, but +# since such changes are atypical, we may want to skip the check to save CPU +# overhead, especially since the check happens in the pre-forward and +# pre-backward each iteration. +_FSDP_SKIP_WRITEBACK_CHECK = "FSDP_SKIP_WRITEBACK_CHECK" + +# Env var toggling whether when model is in .eval() mode, should we run in fp32 +# or the reduced precision. +_FSDP_USE_FULL_PREC_IN_EVAL = "FSDP_USE_FULL_PREC_IN_EVAL" + +# Some value to set padding in tensors to for debuggability +_FLAT_PARAM_PADDING_VALUE = 42 + +# Environment variables for disabling the all-gather and reduce-scatter +# communication ops for ablation studies. Note that without these communication +# ops the training won't converge, and you probably need to disable correctness +# checks in your model. +_FSDP_USE_FAKE_ALL_GATHER = "FSDP_USE_FAKE_ALL_GATHER" +_FSDP_USE_FAKE_REDUCE = "FSDP_USE_FAKE_REDUCE" + + +# TODO: Define this for now to avoid circular imports. See if we can remove. +class HandleShardingStrategy(Enum): + FULL_SHARD = auto() + SHARD_GRAD_OP = auto() + NO_SHARD = auto() + HYBRID_SHARD = auto() + _HYBRID_SHARD_ZERO2 = auto() + + +RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = ( + HandleShardingStrategy.FULL_SHARD, + HandleShardingStrategy.HYBRID_SHARD, +) +NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = ( + HandleShardingStrategy.SHARD_GRAD_OP, + HandleShardingStrategy._HYBRID_SHARD_ZERO2, +) + + +class ParamInfo(NamedTuple): + """Information for an original parameter.""" + + param_name: str # unprefixed + module: nn.Module + module_name: str + + +class SharedParamInfo(NamedTuple): + """ + Additional information for a shared parameter. + + For each shared parameter, we designate one module and its parameter + variable to be the primary owner, determined as the first one encountered + in the parameter walk. These are prefixed with "prim". The primary module + and parameter do not have their own :class:`SharedParamInfo` instance. + """ + + param_name: str # unprefixed + module: nn.Module + module_name: str + prim_param_name: str # unprefixed + prim_module: nn.Module + prim_module_name: str + + +class _ShardParamInfo(NamedTuple): + """Shard-related information for an original parameter.""" + + in_shard: bool + # Use to index into the sharded flat parameter, e.g. + # `flat_param[offset_in_shard : offset_in_shard + numel_in_shard]` + offset_in_shard: Optional[int] + numel_in_shard: Optional[int] + # Use to get part of the parameter in the local shard from a flattened + # version of the unsharded parameter, e.g. either + # `param.flatten()[intra_param_start_idx : intra_param_end_idx + 1]` or + # `param.as_strided((param.numel(),), (1,))[intra_param_start_idx : intra_param_end_idx + 1]` + intra_param_start_idx: Optional[int] + intra_param_end_idx: Optional[int] # inclusive + + +class FlatParamShardMetadata(NamedTuple): + """ + This holds metadata specific to this rank's shard of the flat parameter. + + Attributes: + param_names (Tuple[str, ...]): Prefixed parameter names of this rank's + shard of the parameters; see :class:`FlatParameter`. + param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's + shard of the parameters; see :class:`FlatParameter`. + param_strides (Tuple[torch.Size, ...]): Parameter strides of this rank's + shard of the parameters; see :class:`FlatParameter`. + param_contiguities (Tuple[bool, ...]): Parameter `.contiguous` call results + of this rank's shard of the parameters; see :class:`FlatParameter`. + param_numels (Tuple[int, ...]): Parameter numels of this rank's shard + of the parameters; see :class:`FlatParameter`. + param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in + units of numels) giving this rank's part of each flattened + original parameter. + """ + + param_names: tuple[str, ...] + param_shapes: tuple[torch.Size, ...] + param_strides: tuple[tuple[int, ...], ...] + param_contiguities: tuple[bool, ...] + param_numels: tuple[int, ...] + param_offsets: tuple[tuple[int, int], ...] + + +class _FlatParameterMeta(_ParameterMeta): + # Make `isinstance(t, FlatParameter)` return True for custom tensor + # instances that have the _is_flat_param flag for BC + def __instancecheck__(self, instance): + # NB: do NOT test the super implementation + return isinstance(instance, torch.Tensor) and getattr( + instance, "_is_flat_param", False + ) + + +class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta): + """ + This is the flat parameter used by :class:`FullyShardedDataParallel`. + + It is comprised of one or more original parameters, which are flattened and + concatenated to construct the flat parameter. + + Under the current design, this parameter logically represents both the + unsharded and sharded flat parameter, and its data changes storages + dynamically. + - In the :class:`FullyShardedDataParallel` constructor, the parameter + is initialized as unsharded and then sharded in-place. + - At runtime, the parameter is lazily (re)-initialized. The sharded + parameter data is saved in ``self._local_shard``, and a new ``Tensor`` + ``self._full_param_padded`` is created, which is the all-gather + destination and owns the unsharded parameter storage thereafter. (See + :meth:`FlatParamHandle.init_flat_param_attributes`.) + - Throughout runtime, the parameter data changes storages as needed, + e.g. to the sharded flat parameter, low precision sharded flat + parameter, or the unsharded flat parameter. + + NOTE: Since ``use_orig_params=True`` supports intra-``FlatParameter`` + padding, we have two versions of the per-parameter numels, one that + includes the padding (``_numels_with_padding``) and one that does not + (``_numels``). The former may have length longer than the other data + structures, while the latter has the same length as the number of actual + original parameters like the other per-parameter data structures. + + NOTE: This is not a real class; instead, you will always get a Parameter + back out if you try to create one of these. This is similar to the trick + we implemented for Parameter to get it to work with subclasses; this + is primarily so that FlatParameter supports combination with FakeTensor. + + Attributes: + _unpadded_unsharded_size (torch.Size): Unsharded flat parameter's size + without right-hand-side padding for divisibility by the world size. + For ``use_orig_params=True``, this includes alignment padding. + _padded_unsharded_size (torch.Size): Unsharded flat parameter's size + with right-hand-side padding for divisibility by the world size. + For ``use_orig_params=True``, this includes alignment padding. This + is only set for sharded strategies since they require padding for + the all-gather. + _sharded_size (torch.Size): Sharded flat parameter's size with padding. + This is also set for ``NO_SHARD``, in which case it is the same as + the unsharded sizes. (We omit "padded" because there is no + analogous unpadded one.) + + _num_params (int): Number of original parameters flattened into this + flat parameter. This is the length of the per-parameter data + structures. + _param_infos (Tuple[ParamInfo, ...]): Each parameter's parameter info + entry; see :class:`ParamInfo` for details. + _shapes (Tuple[torch.Size, ...]): Each parameter's original shape. + _strides (Tuple[torch.Size, ...]): Each parameter's original stride. + _contiguities (Tuple[bool, ...]): Each parameter's ``contiguous()`` + call result. + _fqns (Tuple[str, ...]): Each parameter's fully-qualified name (FQN) + prefixed from the ``_fully_sharded_module``. The names are + guaranteed to be unique in the subtree rooted at that module. + _param_extensions (Tuple[Optional[Any], ...]): Each parameter's + extension (i.e. some per-parameter state) used to customize + pre-flatten and post-unflatten behavior or ``None``. This is + experimental, and users should not depend on its existence in the + future. + _numels_with_padding (Tuple[int, ...]): Each parameter's numel + including entries for the padding. This is used to construct views + into the flat parameter via ``torch.split()``. This may have length + longer than ``_num_params``. + _numels (Tuple[int, ...]): Each parameter's numel excluding entries for + padding. This has length equal to ``_num_params``. + _shard_param_infos (Tuple[_ShardParamInfo, ...]): Each parameter's + shard parameter info; see :class:`_ShardParamInfo` for details. + _shared_param_infos (Tuple[SharedParamInfo, ...]): Shared parameter + info entries; see :class:`SharedParamInfo` for details. + _modules (set[nn.Module]): Modules that contain some original parameter + that is flattened into the flat parameter. + + _shard_numel_padded (int): Numel padded for this rank's sharded flat + parameter. + _local_shard (Tensor): Sharded flat parameter with padding if using a + sharded strategy. If using ``NO_SHARD``, then this is the unpadded + unsharded flat parameter, and there is no notion of a sharded flat + parameter or padded unsharded flat parameter. + _full_param_padded (Tensor): Unsharded flat parameter with padding. + This is not defined for ``NO_SHARD``. When using mixed precision + for parameters, this has the low precision. + _full_prec_full_param_padded (Tensor): Full precision unsharded flat + parameter with padding. This is used for unsharding outside of + computation when using mixed precision for parameters. This is + never defined for ``NO_SHARD``. + _post_backward_hook_handle (RemovableHandle): + Flat parameter's post-backward hook handle. (Compile only) + _post_backward_hook_state (Tuple[AccumulateGrad, RemovableHandle]): + Flat parameter's :class:`AccumulateGrad` object and post-backward + hook handle. (Eager only) + _mp_shard (Tensor): Low precision sharded flat parameter with padding. + This is only defined when parameter mixed precision is enabled. For + ``NO_SHARD``, this is used for computation. + _cpu_grad (Tensor): Sharded gradient with padding stored on CPU. + This is only defined when offloading parameters is enabled. + _saved_grad_shard (Tensor): Sharded gradient with padding from previous + iterations for gradient accumulation without :meth:`no_sync`. + + _params (Optional[List[nn.Parameter]]): If ``use_orig_params=True``, + then each original parameter variable; otherwise, ``None``. This + does not include any padding tensors. + _shared_params (Optional[List[nn.Parameter]]): The original shared + parameter variables if ``use_orig_params=True`` and ``None`` + otherwise. + _tensors (Optional[List[Optional[Tensor]]]): This saves the ``Tensor`` + views created in the forward and tracked by autograd when + ``use_orig_params=True`` and is ``None`` otherwise. This is to + preserve those ``Tensor`` variables for the backward to ensure that + the ``FlatParameter`` 's ``AccumulateGrad`` object does not change + in which case the post-backward hook does not run. This is relevant + for cases like reentrant activation checkpointing. + _is_grad_none_mask (Optional[List[bool]]): If ``use_orig_params=True``, + a mask over the original parameters' gradients indicating if it is + logically ``None`` or not; otherwise, ``None``. This does not + include entries for padding. This mask is needed because only some + of the parameters may have ``None`` gradient, in which case the + flat gradient must be non-``None`` and must use zeros to + approximate those original ``None`` gradients. This mask informs + FSDP to set the original parameter gradients to ``None`` (instead + of zeros) as needed. + """ + + _unpadded_unsharded_size: torch.Size + _padded_unsharded_size: torch.Size + _sharded_size: torch.Size + _num_params: int + _param_infos: tuple[ParamInfo, ...] + _shapes: tuple[torch.Size, ...] + _strides: tuple[tuple[int, ...], ...] + _contiguities: tuple[bool, ...] + _fqns: tuple[str, ...] + _param_extensions: tuple[Optional[Any], ...] + _numels_with_padding: tuple[int, ...] + _numels: tuple[int, ...] + _shard_param_infos: tuple[_ShardParamInfo, ...] + _shared_param_infos: tuple[SharedParamInfo, ...] + _modules: set[nn.Module] + _shard_numel_padded: int + _local_shard: Tensor + _full_param_padded: Tensor + _full_prec_full_param_padded: Tensor + # Eager only + _post_backward_hook_state: tuple[Any, Any] + # Compile only + _post_backward_hook_handle: Any + _mp_shard: Tensor + _cpu_grad: Tensor + _saved_grad_shard: Tensor + _params: Optional[list[nn.Parameter]] + _shared_params: Optional[list[nn.Parameter]] + _tensors: Optional[list[Optional[Tensor]]] + _is_grad_none_mask: Optional[list[bool]] + + _is_padding_mask: list[bool] + + def __new__(cls, data=None, requires_grad=True): + if cls is not FlatParameter: + raise AssertionError("subclasses FlatParameter not supported") + r = nn.Parameter.__new__(nn.Parameter, data, requires_grad) # type: ignore[call-arg] + r._is_flat_param = True # type: ignore[attr-defined] + return r + + # NB: This is not a regular method, because FlatParameters are not actually + # instances of this class (see __new__ above). So you must indirectly + # call this directly through the classmethod. + @classmethod + def _init_metadata( + cls, + self, + param_infos: list[ParamInfo], + numels: list[int], + shapes: list[torch.Size], + strides: list[tuple[int, ...]], + contiguities: list[bool], + fqns: list[str], + shared_param_infos: list[SharedParamInfo], + param_extensions: list[Optional[Any]], + params: Optional[list[nn.Parameter]], + shared_params: Optional[list[nn.Parameter]], + is_padding_mask: list[bool], + ) -> None: + """ + Initialize attributes holding metadata about the original parameters comprising the flat parameter. + + We expose this method separate from the constructor to keep the + constructor only responsible for the flat parameter's tensor data. This + method should only be called once per model, while the constructor may + be called multiple times, e.g. when reloading from a checkpoint, in + which case only the tensor data needs to be passed to the constructor. + Since :meth:`load_state_dict` is implemented via :meth:`copy_`, the + metadata is correctly assumed to be unchanged. + + Args: + See the Attributes in the class docstring. + """ + if len(param_infos) != len(shapes): + raise AssertionError( + f"Expected param_infos length {len(param_infos)} to match shapes length {len(shapes)}" + ) + if len(param_infos) != len(strides): + raise AssertionError( + f"Expected param_infos length {len(param_infos)} to match strides length {len(strides)}" + ) + if len(param_infos) != len(contiguities): + raise AssertionError( + f"Expected param_infos length {len(param_infos)} to match contiguities length {len(contiguities)}" + ) + if len(param_infos) != len(fqns): + raise AssertionError( + f"Expected param_infos length {len(param_infos)} to match fqns length {len(fqns)}" + ) + if len(param_infos) != len(param_extensions): + raise AssertionError( + f"Expected param_infos length {len(param_infos)} to match param_extensions length {len(param_extensions)}" + ) + self._num_params = len(param_infos) + self._param_infos = param_infos + self._shapes = shapes + self._strides = strides + self._contiguities = contiguities + self._fqns = fqns + self._param_extensions = param_extensions + self._is_padding_mask = is_padding_mask + + numels_without_padding: list[int] = [] + for numel, is_padding in zip(numels, is_padding_mask): + if not is_padding: + numels_without_padding.append(numel) + self._numels = tuple(numels_without_padding) + self._numels_with_padding = tuple(numels) + if len(self._numels) != self._num_params: + raise AssertionError( + f"Expected _numels length {len(self._numels)} to equal _num_params {self._num_params}" + ) + + self._shared_param_infos = tuple(shared_param_infos) + self._modules = {pi.module for pi in self._param_infos}.union( + {spi.module for spi in self._shared_param_infos} + ) + if (params is None) != (shared_params is None): + raise AssertionError( + "Expected params and shared_params to both be None or both be not None" + ) + if params is not None: + if shared_params is None or len(shared_params) != len(shared_param_infos): + raise AssertionError( + f"Expected shared_params to be not None and have length {len(shared_param_infos)}, got {shared_params}" + ) + self._params = [] + for param, is_padding in zip(params, is_padding_mask): + if not is_padding: + self._params.append(param) + if shared_params is not None: + self._shared_params = shared_params + else: + self._shared_params = [] + # Mark the original parameters to avoid flattening them into + # another `FlatParameter` during recursive construction + for param in chain(self._params, self._shared_params): + _set_fsdp_flattened(param) + self._is_grad_none_mask = [False for _ in range(self._num_params)] + self._tensors = [None for _ in range(self._num_params)] + else: + self._params = None + self._shared_params = None + self._is_grad_none_mask = None + self._tensors = None + self._unpadded_unsharded_size = self.size() + _set_fsdp_flattened(self) + # Tracks whether the `FlatParameter`'s post-backward hook has been + # called to modify the behavior of the post-backward callback + self._post_backward_called = False + + +class FlatParamHandle: + """ + A handle that manages a flat parameter (:class:`FlatParameter`). + + This includes sharding and view management. + + Args: + params (Sequence[nn.Parameter]): The parameters to flatten into the + flat parameter. + fully_sharded_module (nn.Module): See [Note: Fully Sharded Module]. + device (torch.device): The compute and communication device, which + should be a non-CPU device. We refer to it as the compute device. + sharding_strategy (ShardingStrategy): Sharding strategy to apply to + this handle's ``FlatParameter``. + offload_params (bool): Whether to offload the handle's + ``FlatParameter`` to CPU. + mp_param_dtype (Optional[torch.dtype]): Parameter mixed precision + setting passed to the FSDP constructor. + mp_reduce_dtype (Optional[torch.dtype]): Gradient reduction mixed + precision setting passed to the FSDP constructor. + keep_low_precision_grads (bool): Whether to keep gradients in low + precision. + use_orig_params (bool): If ``True``, then FSDP preserves the original + parameter variables and returns them from ``named_parameters()`` + (e.g. to support different optimizer hyperparameters within one + :class:`FlatParameter`). If ``False``, then FSDP reconstructs the + parameters every iteration and returns the :class:`FlatParameter` s + from ``named_parameters()``. + """ + + ################## + # INITIALIZATION # + ################## + def __init__( + self, + params: Sequence[Union[nn.Parameter, Tensor]], + fully_sharded_module: nn.Module, + device: torch.device, + sharding_strategy: HandleShardingStrategy, + offload_params: bool, + mp_param_dtype: Optional[torch.dtype], + mp_reduce_dtype: Optional[torch.dtype], + keep_low_precision_grads: bool, + process_group: dist.ProcessGroup, + use_orig_params: bool, + *, + fsdp_extension: Optional[FSDPExtensions] = None, + ): + super().__init__() + params = list(params) + if len(params) == 0: + raise ValueError( + f"Cannot construct a {self.__class__.__name__} with an empty parameter list" + ) + self._init_setattr_fns() + self._skip_writeback_check = ( + os.environ.get(_FSDP_SKIP_WRITEBACK_CHECK, "") == "1" + ) + self._use_full_prec_in_eval = ( + os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1" + ) + self._use_fake_all_gather = os.environ.get(_FSDP_USE_FAKE_ALL_GATHER, "") == "1" + self._use_fake_reduce = os.environ.get(_FSDP_USE_FAKE_REDUCE, "") == "1" + if self._skip_writeback_check: + _warn_skip_writeback_check( + logger, + f"Since {_FSDP_SKIP_WRITEBACK_CHECK}=1, FSDP will not check " + "for parameter or gradient writeback. Changing parameter or " + "gradient storages may lead to silent correctness errors.", + ) + if self._use_fake_all_gather: + _warn_use_fake_all_gather( + logger, + f"Since {_FSDP_USE_FAKE_ALL_GATHER}=1, FSDP will not execute " + "all-gather ops. Your training will be incorrect, but " + "can reveal how much time spent on all-gather ops.", + ) + if self._use_fake_reduce: + _warn_use_fake_reduce( + logger, + f"Since {_FSDP_USE_FAKE_REDUCE}=1, FSDP will not execute " + "reduce-scatter ops. Your training will be incorrect, but " + "can reveal how much time spent on reduce-scatter ops.", + ) + # Only align addresses for `use_orig_params=True` (for now) + align_addresses = use_orig_params + self._init_get_unflat_views_fn(align_addresses) + # pyrefly: ignore [read-only] + self.device = device + self._device_handle = _FSDPDeviceHandle.from_device(self.device) + self.process_group = process_group + if self._use_fake_all_gather or self._use_fake_reduce: + self._fake_process_group = FakeProcessGroup._create_internal( + rank=process_group.rank(), world_size=process_group.size() + ) + self.rank = process_group.rank() + self.world_size = process_group.size() + self._sharding_strategy = sharding_strategy + self._offload_params = offload_params + self._use_orig_params = use_orig_params + self._keep_low_precision_grads = keep_low_precision_grads + self._training_state = HandleTrainingState.IDLE + self._debug_level = dist.get_debug_level() + self._fully_sharded_module = fully_sharded_module + # For strategies that do not free after forward, we skip using sharded + # views after forward since the unsharded data exists. We still switch + # `self.flat_param` to point to the sharded flat parameter since what + # it points to parameterizes behavior. We use the following attribute + # to track which tensor data the parameters are unsharded views into. + self._unsharded_flat_param_for_skipped_views: Optional[Tensor] = None + # The index in the state's `all_handles`, which must be the + # same across ranks for the execution order validation to work + self._handle_index: Optional[int] = None + # Index in handles_to_pre_forward_order + self._pre_forward_order_index: Optional[int] = None + # Index in `handles_post_forward_order` + self._post_forward_index: Optional[int] = None + # Used for guarding against mistargeted forward prefetches + self._needs_pre_forward_unshard = False + # Used for guarding against mistargeted backward prefetches + self._needs_pre_backward_unshard = False + # Was the handle prefetched? Set on successful _prefetch_handle and unshard + self._prefetched = False + # Optimistically assume a valid input `params` and set dtype attributes + # before `_init_flat_param()`, which performs the actual validation + self._orig_param_dtype = params[0].dtype + self._init_param_reduce_dtypes(mp_param_dtype, mp_reduce_dtype) + if self._fwd_bwd_param_dtype is None: + raise AssertionError("Expected _fwd_bwd_param_dtype to be not None") # mypy + self._aligned_numel = ( + _get_aligned_numel(unsharded_dtype=self._fwd_bwd_param_dtype) + if align_addresses + else 0 + ) + self._fsdp_extension = fsdp_extension + self._init_flat_param_and_metadata( + params, + fully_sharded_module, + self._aligned_numel, + use_orig_params, # type: ignore[arg-type] + ) + self._use_unsharded_views(as_params=False) + + def __repr__(self): + return f"FlatParamHandle(flat_param.fqns={self.flat_param._fqns})" + + def _init_setattr_fns(self): + use_unsafe_setattr = os.environ.get(_FSDP_USE_UNSAFE_SETATTR, "") == "1" + self._setattr_tensor: Callable[[nn.Module, str, Tensor], None] + self._setattr_param: Callable[[nn.Module, str, nn.Parameter], None] + if use_unsafe_setattr: + self._setattr_tensor = _unsafe_setattr_tensor + self._setattr_param = _unsafe_setattr_param + else: + self._setattr_tensor = _safe_setattr_tensor_or_param + self._setattr_param = _safe_setattr_tensor_or_param + + def _init_get_unflat_views_fn(self, align_addresses: bool): + self._get_unflat_views = ( + self._get_unflat_views_aligned + if align_addresses + else self._get_unflat_views_unaligned + ) + + def _init_flat_param_and_metadata( + self, + params: list[Union[Tensor, nn.Parameter]], + module: nn.Module, + aligned_numel: int, + use_orig_params: bool, + ) -> None: + """ + Initialize the ``FlatParameter`` and its metadata. + + NOTE: This should only be called once at construction time, after which + the ``FlatParameter`` metadata is assumed to be static. + + NOTE: The elements of ``params`` should only be ``Tensor`` s when + composing with ``DTensor`` -based tensor parallelism, in which case the + elements may be ``DTensor`` local shards. + """ + if len(params) == 0: + raise ValueError("Expects non-empty `params`") + if aligned_numel < 0: + raise ValueError( + f"Expects non-negative `aligned_numel` but got {aligned_numel}" + ) + ( + dtype, + flat_param_requires_grad, + device, + ) = self._validate_tensors_to_flatten(params) + params_set = set(params) + # For alignment padding, only `numels` gets strictly non-`None` + # elements, and all other lists get `None` elements for padding. + param_infos: list[ParamInfo] = [] + numels: list[int] = [] + shapes: list[torch.Size] = [] + strides: list[tuple[int, ...]] = [] + contiguities: list[bool] = [] + fqns: list[str] = [] + shared_param_infos: list[SharedParamInfo] = [] + shared_param_memo: dict[ + Union[Tensor, nn.Parameter], tuple[nn.Module, str, str] + ] = {} + params_to_flatten: list[Union[Tensor, nn.Parameter]] = [] + shared_params: list[Union[Tensor, nn.Parameter]] = [] + param_extensions: list[Any] = [] + is_padding_mask: list[bool] = [] + total_numel = total_numel_without_padding = 0 + for submodule_name, submodule in module.named_modules(remove_duplicate=False): + for param_name, param in _named_parameters_with_duplicates( + submodule, recurse=False + ): + if param not in params_set: + continue + if param in shared_param_memo: # shared reference + prim_module, prim_module_name, prim_param_name = shared_param_memo[ + param + ] + shared_params.append(param) + shared_param_infos.append( + SharedParamInfo( + param_name, + submodule, + submodule_name, + prim_param_name, + prim_module, + prim_module_name, + ) + ) + else: + if aligned_numel > 0: + numel_to_pad = aligned_numel - (total_numel % aligned_numel) + if numel_to_pad > 0 and numel_to_pad < aligned_numel: + padding_tensor = _construct_padding_tensor( + numel_to_pad, dtype, False, device + ) + params_to_flatten.append(padding_tensor) + is_padding_mask.append(True) + numels.append(numel_to_pad) + total_numel += numel_to_pad + transform_t, extension = _ext_pre_flatten_transform( + param, + self._fsdp_extension, + ) + param = cast(nn.Parameter, transform_t) + param_extensions.append(extension) + shared_param_memo[param] = (submodule, submodule_name, param_name) + params_to_flatten.append(param) + is_padding_mask.append(False) + param_infos.append(ParamInfo(param_name, submodule, submodule_name)) + numels.append(param.numel()) + shapes.append(param.shape) + strides.append(param.stride()) + contiguities.append(_is_truly_contiguous(param)) + fqn = ( + submodule_name + "." + param_name + if submodule_name + else param_name + ) + fqns.append(fqn) + total_numel += param.numel() + total_numel_without_padding += param.numel() + if len(params_to_flatten) == 0: + raise ValueError( + f"`params` were not found in `module`'s tree" + f"params: {params}\nmodule: {module}" + ) + if ( + self.rank == 0 + and aligned_numel > 0 + and total_numel != total_numel_without_padding + ): + logger.debug( + "FSDP FlatParameter address alignment created " + "%s numel of padding (%s vs. %s)", + total_numel - total_numel_without_padding, + total_numel, + total_numel_without_padding, + ) + if aligned_numel > 0: + # Pad to be divisible by world size to avoid a copy for the + # post-backward reduce-scatter + numel_to_pad = self.world_size - (total_numel % self.world_size) + if numel_to_pad > 0 and numel_to_pad < self.world_size: + if self.rank == 0: + logger.info( + "FSDP FlatParameter world size divisibility created " + "%s numel of padding", + numel_to_pad, + ) + padding_tensor = _construct_padding_tensor( + numel_to_pad, dtype, False, device + ) + params_to_flatten.append(padding_tensor) + is_padding_mask.append(True) + numels.append(numel_to_pad) + total_numel += numel_to_pad + # Pass `aligned_numel=0` since we already included padding tensors + self.flat_param: FlatParameter = self.flatten_tensors_into_flat_param( + params_to_flatten, + aligned_numel=0, + requires_grad=flat_param_requires_grad, + ) + FlatParameter._init_metadata( + self.flat_param, + param_infos, + numels, + shapes, + strides, + contiguities, + fqns, + shared_param_infos, + param_extensions, + _convert_to_params(params_to_flatten) if use_orig_params else None, + _convert_to_params(shared_params) if use_orig_params else None, + is_padding_mask, + ) + + def _validate_tensors_to_flatten( + self, tensors: list[Union[Tensor, nn.Parameter]] + ) -> tuple: + """Validate the tensors to flatten and returns any necessary metadata.""" + dtype: Optional[torch.dtype] = None + # Return as the logical OR over each tensor's value + flat_param_requires_grad: Optional[bool] = None + device: Optional[torch.device] = None + # For `use_orig_params=True`, permit non-uniform `requires_grad` + for tensor in tensors: + if isinstance(tensor, FlatParameter): + raise ValueError("Cannot flatten a `FlatParameter`") + if dtype is None and not tensor.is_floating_point(): + raise ValueError("Cannot flatten integer dtype tensors") + if dtype is not None and tensor.dtype != dtype: + raise ValueError( + f"Must flatten tensors with uniform dtype but got {dtype} " + f"and {tensor.dtype}" + ) + if ( + not self._use_orig_params + and flat_param_requires_grad is not None + and tensor.requires_grad != flat_param_requires_grad + ): + raise ValueError( + "Must flatten tensors with uniform `requires_grad` when " + "`use_orig_params=False`" + ) + if device is not None and tensor.device != device: + raise ValueError( + "Must flatten tensors on the same device but got both " + f"{device} and {tensor.device}" + ) + dtype = tensor.dtype + flat_param_requires_grad = flat_param_requires_grad or tensor.requires_grad + device = tensor.device + if flat_param_requires_grad is None: + raise AssertionError("Requires non-empty `tensors` list") + return dtype, flat_param_requires_grad, device + + def flatten_tensors( + self, + tensors: list[Tensor], + aligned_numel: int, + ) -> Tensor: + """ + Flatten ``tensors`` into a single flat tensor. + + The flattening optionally includes + padding if ``aligned_numel`` is greater than 0, where ``aligned_numel`` + gives the numel required to have address alignment. + + NOTE: The padding alignment algorithm must be kept in sync with + :meth:`_init_flat_param_metadata`. We separate the two methods because + the initialization happens once, whereas this method may be called + multiple times throughout training (e.g. for checkpointing). + """ + if len(tensors) == 0: + raise ValueError("Expects non-empty `tensors`") + if aligned_numel < 0: + raise ValueError( + f"Expects non-negative `aligned_numel` but got {aligned_numel}" + ) + dtype, _, device = self._validate_tensors_to_flatten(tensors) + flat_tensors: list[Tensor] = [] + if aligned_numel > 0: + total_numel = 0 + for tensor in tensors: + numel_to_pad = aligned_numel - (total_numel % aligned_numel) + if numel_to_pad > 0 and numel_to_pad < aligned_numel: + padding_tensor = _construct_padding_tensor( + numel_to_pad, dtype, False, device + ) + flat_tensors.append(padding_tensor) + total_numel += numel_to_pad + flat_tensors.append( + torch.flatten(_detach_if_needed(tensor)) + if _is_truly_contiguous(tensor) + else _detach_if_needed(tensor).as_strided((tensor.numel(),), (1,)) + ) + total_numel += tensor.numel() + numel_to_pad = self.world_size - (total_numel % self.world_size) + if numel_to_pad > 0 and numel_to_pad < self.world_size: + padding_tensor = _construct_padding_tensor( + numel_to_pad, dtype, False, device + ) + flat_tensors.append(padding_tensor) + total_numel += numel_to_pad + else: + flat_tensors = [ + torch.flatten(_detach_if_needed(tensor)) + if _is_truly_contiguous(tensor) + else _detach_if_needed(tensor).as_strided((tensor.numel(),), (1,)) + for tensor in tensors + ] + return torch.cat(flat_tensors, dim=0) + + def flatten_tensors_into_flat_param( + self, + tensors: list[Tensor], + aligned_numel: int, + requires_grad: bool, + ) -> FlatParameter: + flat_param_data = self.flatten_tensors(tensors, aligned_numel) + return FlatParameter(flat_param_data, requires_grad=requires_grad) + + def _init_param_reduce_dtypes( + self, + mp_param_dtype: Optional[torch.dtype], + mp_reduce_dtype: Optional[torch.dtype], + ) -> None: + """ + Initialize param and reduce dtypes. + + Precondition: ``self.flat_param`` is set. This ensures that this + handle's parameters have a single dtype. + + Postcondition: This sets ``self._fwd_bwd_param_dtype`` and + ``self._reduce_dtype``. If ``mp_param_dtype`` or ``mp_reduce_dtype`` + is ``None``, then we assume the original parameter dtype. One special + case is if ``mp_param_dtype`` is not ``None`` and ``mp_reduce_dtype`` + is ``None``, in which case we assume the gradient reduction dtype + matches the forward/backward parameter dtype. + """ + # Save whether these dtypes were specified so that we permit the + # parameter dtype to change up until the lazy initialization + self._low_prec_param_dtype_specified = mp_param_dtype is not None + self._low_prec_reduce_dtype_specified = mp_reduce_dtype is not None + if ( + self._low_prec_param_dtype_specified + and not self._low_prec_reduce_dtype_specified + ): + # Special case: infer gradient reduction mixed precision + self._fwd_bwd_param_dtype = mp_param_dtype + self._reduce_dtype = self._fwd_bwd_param_dtype + else: + self._fwd_bwd_param_dtype = mp_param_dtype or self._orig_param_dtype + self._reduce_dtype = mp_reduce_dtype or self._orig_param_dtype + if self._fwd_bwd_param_dtype is None: + raise AssertionError("Expected _fwd_bwd_param_dtype to be not None") + if self._reduce_dtype is None: + raise AssertionError("Expected _reduce_dtype to be not None") + + ################################### + # SHARD INITIALIZATION & METADATA # + ################################### + @torch.no_grad() + def shard(self): + """ + Shard the handle's ``FlatParameter``. + + This allocates new memory for + the sharded flat parameter and frees the unsharded flat parameter's + storage. + + Postcondition: ``self.flat_param`` is the sharded flat parameter. Shard + metadata attributes are set for all sharding strategies. + """ + flat_param = self.flat_param + if not self.uses_sharded_strategy: + self._init_shard_metadata(0, 0, flat_param.numel() - 1) + else: + _p_assert( + flat_param.storage_offset() == 0, + "The `FlatParameter` is not the sole occupant of its storage", + ) + sharded_flat_param, numel_padded = FlatParamHandle._get_shard( + flat_param, self.rank, self.world_size + ) + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + allocated = flat_param._typed_storage()._size() > 0 + if allocated: + flat_param._typed_storage()._resize_(0) + flat_param.set_(sharded_flat_param) # type: ignore[call-overload] + start_idx = sharded_flat_param.numel() * self.rank + end_idx = sharded_flat_param.numel() * (self.rank + 1) - 1 # inclusive + self._init_shard_metadata(numel_padded, start_idx, end_idx) + if self._use_orig_params: + self._use_sharded_views() + + def _init_shard_metadata( + self, + numel_padded: int, + unsharded_start_idx: int, + unsharded_end_idx: int, + ) -> None: + """ + Initialize shard-related metadata for this rank's shard of the flat parameter. + + This includes ``_sharded_size``, ``_shard_param_infos``, and ``_shard_numel_padded``. + + Args: + numel_padded (int): Numel padded for this rank's sharded flat + parameter. + unsharded_start_idx (int): Start index in the unsharded flat + parameter assigned to this rank. + unsharded_end_idx (int): End index (inclusive) in the unsharded + flat parameter assigned to this rank. + + Precondition: ``self.flat_param`` 's data is the sharded flat + parameter. + """ + flat_param = self.flat_param + flat_param._sharded_size = flat_param.size() # type: ignore[attr-defined] + sharded_flat_param_numel = flat_param.numel() # includes `numel_padded` + _p_assert( + unsharded_start_idx >= 0 and unsharded_start_idx <= unsharded_end_idx, + f"unsharded_start_idx: {unsharded_start_idx} unsharded_end_idx: {unsharded_end_idx}", + ) + _p_assert( + numel_padded <= sharded_flat_param_numel, + f"numel_padded: {numel_padded} " + f"sharded_flat_param_numel: {sharded_flat_param_numel}", + ) + shard_param_infos = self._get_shard_metadata( + unsharded_start_idx, unsharded_end_idx + ) + if len(shard_param_infos) != flat_param._num_params: + raise AssertionError( + f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}" + ) + flat_param._shard_param_infos = shard_param_infos # type: ignore[attr-defined] + flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined] + + def _get_shard_metadata( + self, + unsharded_start_idx: int, + unsharded_end_idx: int, + ) -> tuple[_ShardParamInfo, ...]: + """ + Compute the shard metadata based on ``unsharded_start_idx`` and ``unsharded_end_idx`` (inclusive). + + ``unsharded_start_idx`` and ``unsharded_end_idx`` give the interval of the + unsharded flat parameter specifying the shard. + """ + flat_param_offsets = self._get_flat_param_offsets() + if len(flat_param_offsets) != len(self.flat_param._numels_with_padding): + raise AssertionError( + f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}" + ) + shard_param_infos: list[_ShardParamInfo] = [] + sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1 + # `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices + # into the unsharded flat parameter (inclusive) of the given parameter + for ( + (unsharded_param_start_idx, unsharded_param_end_idx), + is_padding, + ) in zip(flat_param_offsets, self.flat_param._is_padding_mask): + if is_padding: + continue + in_sharded_flat_param = ( + unsharded_start_idx <= unsharded_param_end_idx + and unsharded_end_idx >= unsharded_param_start_idx + ) + if not in_sharded_flat_param: + shard_param_info = _ShardParamInfo(False, None, None, None, None) + else: + if unsharded_start_idx <= unsharded_param_start_idx: + # This branch can only happen once since the rank's + # unsharded start index can only intersect one parameter + intra_param_start_idx = 0 + offset_in_shard = unsharded_param_start_idx - unsharded_start_idx + else: + intra_param_start_idx = ( + unsharded_start_idx - unsharded_param_start_idx + ) + offset_in_shard = 0 + if not ( + offset_in_shard >= 0 and offset_in_shard < sharded_flat_param_numel + ): + raise AssertionError( + f"Invalid `offset_in_shard` of {offset_in_shard} for " + f"sharded flat parameter with {sharded_flat_param_numel} numel" + ) + intra_param_end_idx = ( + min(unsharded_param_end_idx, unsharded_end_idx) + - unsharded_param_start_idx + ) + numel_in_shard = intra_param_end_idx - intra_param_start_idx + 1 + shard_param_info = _ShardParamInfo( + True, + offset_in_shard, + numel_in_shard, + intra_param_start_idx, + intra_param_end_idx, + ) + shard_param_infos.append(shard_param_info) + return tuple(shard_param_infos) + + @staticmethod + def _get_unpadded_shard( + tensor: Tensor, + rank: int, + world_size: int, + ) -> tuple[Tensor, int]: + """ + Return the unpadded shard of ``tensor`` for the given ``rank`` and ``world_size``. + + The returned value is a tuple of the shard of ``tensor`` without any + padding and the numel to pad for that shard. + + If ``tensor`` is already flattened or may be viewed in the flattened + shape (which is true in the expected usage), then this method does not + allocate any new tensor memory. + """ + chunks = ( + torch.flatten(tensor).chunk(world_size) + if _is_truly_contiguous(tensor) + else tensor.as_strided((tensor.numel(),), (1,)).chunk(world_size) + ) + if len(chunks) < (rank + 1): + # This rank gets an empty chunk fully padded with zeros since there + # are not enough chunks across ranks + chunk = chunks[0].new_empty(0) + else: + chunk = chunks[rank] + numel_to_pad = chunks[0].numel() - chunk.numel() + if numel_to_pad < 0: + raise AssertionError( + "Chunk's size should be at most the first chunk's size" + ) + return chunk, numel_to_pad + + @staticmethod + def _get_shard( + tensor: Tensor, + rank: int, + world_size: int, + ) -> tuple[Tensor, int]: + """ + Return the shard of ``tensor`` with padding for the given ``rank`` and ``world_size`` and the numel padded for that shard. + + This method allocates new memory (via :meth:`clone`) since the + unsharded ``tensor`` may be deallocated after this method returns. + """ + chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard( + tensor, rank, world_size + ) + shard = chunk.clone() + if numel_to_pad > 0: + shard = F.pad(shard, [0, numel_to_pad]) + return shard, numel_to_pad + + @staticmethod + def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size: + """ + Return the shape of ``tensor`` after sharding including padding. + + This requires ``tensor`` to have 1D shape and ensures that the returned + shape is 1D. + """ + if len(tensor.shape) != 1: + raise AssertionError(f"Expected 1D tensor shape, got {tensor.shape}") + unpadded_sharded_tensor, numel_to_pad = FlatParamHandle._get_unpadded_shard( + tensor, rank, world_size + ) + unpadded_sharded_size = unpadded_sharded_tensor.size() + if len(unpadded_sharded_size) != 1: + raise AssertionError( + f"Expected 1D unpadded_sharded_size, got {unpadded_sharded_size}" + ) + return torch.Size([unpadded_sharded_size[0] + numel_to_pad]) + + def _get_flat_param_offsets(self) -> list[tuple[int, int]]: + """ + Return [start, end] offsets of each original parameter's flattened data in the unsharded flat parameter (without padding). + + NOTE: The returned list includes elements for alignment padding. + """ + cumulative_sum = list(accumulate(self.flat_param._numels_with_padding)) + starts = [0] + cumulative_sum[:-1] + ends = [end - 1 for end in cumulative_sum] # inclusive + param_offsets = list(zip(starts, ends)) + return param_offsets + + @no_type_check + def shard_metadata( + self, + ) -> FlatParamShardMetadata: + """ + Return the shard-related metadata specific to this rank's shard of the flat parameter. + + NOTE: The returned tuple does not include elements for alignment + padding but does account for the padding. + """ + fqns_list = [] + shapes_list = [] + strides_list = [] + contiguities_list = [] + numels_list = [] + shard_param_offsets = [] + for fqn, shape, stride, contiguous, numel, shard_param_info in zip( + self.flat_param._fqns, + self.flat_param._shapes, + self.flat_param._strides, + self.flat_param._contiguities, + self.flat_param._numels, + self.flat_param._shard_param_infos, + ): + if not shard_param_info.in_shard: + continue + fqns_list.append(fqn) + shapes_list.append(shape) + strides_list.append(stride) + contiguities_list.append(contiguous) + numels_list.append(numel) + shard_param_offsets.append( + ( + shard_param_info.intra_param_start_idx, + shard_param_info.intra_param_end_idx, + ) + ) + return FlatParamShardMetadata( + tuple(fqns_list), + tuple(shapes_list), + tuple(strides_list), + tuple(contiguities_list), + tuple(numels_list), + tuple(shard_param_offsets), + ) + + @no_type_check + @torch.no_grad() + def init_flat_param_attributes(self) -> None: + """ + This initializes some attributes on the handle's ``FlatParameter``. + This should be called during lazy initialization since it requires the + parameter to be on the compute device if not offloading to CPU and we + want to give users the chance to move the parameter appropriately after + the FSDP constructor. + + For each tensor attribute on the ``FlatParameter``, see the unshard and + reshard methods in this class for the allocation and free pattern. + """ + flat_param = self.flat_param + if flat_param.dtype != self._orig_param_dtype: + # Entering this branch means that the user changed the parameter + # dtype after FSDP initialization, in which case we may need to + # refresh some saved dtype attributes (dtypes specified as a part + # of mixed precision take precedence). + if not self._low_prec_param_dtype_specified: + self._fwd_bwd_param_dtype = flat_param.dtype + # For `reduce_dtype`, require `param_dtype` was not specified since + # then we infer the `reduce_dtype` from the specified `param_dtype` + if ( + not self._low_prec_reduce_dtype_specified + and not self._low_prec_param_dtype_specified + ): + self._reduce_dtype = flat_param.dtype + self._orig_param_dtype = flat_param.dtype + cpu_device = torch.device("cpu") + if self._offload_params: + _p_assert( + flat_param.device == cpu_device, + f"Expects the `FlatParameter` to be on CPU when parameter CPU " + f"offloading is enabled, not {flat_param.device}", + ) + else: + self._check_on_compute_device(self.flat_param) + flat_param._local_shard = flat_param.data + if self._offload_params: + # Pin the memory for faster H2D transfer + flat_param._local_shard = flat_param._local_shard.pin_memory() + # Pre-allocate the sharded gradient on CPU to enable non-blocking + # D2H transfer during the backward pass + flat_param._cpu_grad = torch.zeros_like( + flat_param._local_shard, device=cpu_device + ).pin_memory() + if self._uses_param_mixed_precision: + # For parameter mixed precision, we maintain a low precision + # sharded tensor on the compute device to be all-gathered (for + # sharded strategies) or directly used (for `NO_SHARD`) for + # computation. + flat_param._mp_shard = torch.empty_like( + flat_param._local_shard, + device=self.device, + dtype=self._fwd_bwd_param_dtype, + ) + _free_storage(flat_param._mp_shard) + if self.uses_sharded_strategy: + # We maintain a padded unsharded tensor that serves as the + # all-gather destination and owns the original parameter storages. + unsharded_param_dtype = ( + self._fwd_bwd_param_dtype + if self._uses_param_mixed_precision + else flat_param.dtype + ) # use low precision if parameter mixed precision is enabled + padded_unsharded_numel = flat_param.numel() * self.world_size + flat_param._full_param_padded = torch.empty( + padded_unsharded_numel, + device=self.device, + dtype=unsharded_param_dtype, + ) + flat_param._padded_unsharded_size = flat_param._full_param_padded.size() + _free_storage(flat_param._full_param_padded) + + if self._uses_param_mixed_precision: + # For parameter mixed precision, we maintain a full precision + # padded unsharded tensor for when we force full precision. + flat_param._full_prec_full_param_padded = torch.empty( + padded_unsharded_numel, + device=self.device, + dtype=flat_param.dtype, # full precision + ) + _free_storage(flat_param._full_prec_full_param_padded) + + ################### + # UNSHARD/RESHARD # + ################### + def pre_unshard(self) -> bool: + """ + Return ``False`` if this is a no-op and ``True`` otherwise. + + Postcondition: ``self.flat_param`` 's data is on the device for + communication and is what should be all-gathered. This means that it + matches the dtype of the expected unsharded parameter. + """ + if ( + self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS + and self._skipped_use_sharded_views + ): + # Since this path imposes special semantics for the unsharded flat + # parameter (e.g. forcing full precision), use sharded views to + # reuse the existing logic for that special handling + self._use_sharded_views() + ret = False + if self._use_orig_params and not self._skip_writeback_check: + ret = self._writeback_orig_params() + if ( + self.uses_sharded_strategy + and not self._offload_params + and not self.needs_unshard() + ): + pass # no-op + elif self._uses_param_mixed_precision and not self._force_full_precision: + self._use_low_precision_shard() + ret = True + elif self._offload_params and self.flat_param.device != self.device: + # NOTE: This creates a new tensor distinct from any attributes. + self.flat_param_to(self.device, non_blocking=True) + ret = True + self._check_on_compute_device(self.flat_param) + return ret + + def _use_low_precision_shard(self): + """Allocate on the compute device and switch to using the low precision sharded flat parameter.""" + self._check_low_precision_shard() + flat_param = self.flat_param + _alloc_storage( + flat_param._mp_shard, + flat_param._local_shard.size(), # type: ignore[attr-defined] + ) + # `copy_()` implicitly casts to the low precision + flat_param._mp_shard.copy_( # type: ignore[attr-defined] + flat_param._local_shard.to( # type: ignore[attr-defined] + self.device, non_blocking=True + ) + ) + # Invariant: `_mp_shard` is always on the compute device. + flat_param.data = flat_param._mp_shard # type: ignore[attr-defined] + + def unshard(self): + """ + Run the unshard logic. + + This includes all-gathering the flat parameter + and switching to using the unsharded flat parameter. If the handle does + not need unsharding, then this only switches to using the unsharded + flat parameter. For ``NO_SHARD``, this is a no-op. + + If FSDP is in :meth:`summon_full_params` and the handle uses parameter + mixed precision, then the parameter is forced to full precision. + """ + if not self.needs_unshard(): + # Even when not needing an unshard, we should switch to using + # the unsharded flat parameter + unsharded_flat_param = ( + self._get_padded_unsharded_flat_param() + if self.uses_sharded_strategy + else self.flat_param + ) + self._use_unsharded_flat_param(unsharded_flat_param) + return + unsharded_flat_param = self._alloc_padded_unsharded_flat_param() + padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param) + self._use_unsharded_flat_param(padded_unsharded_flat_param) + + def needs_unshard(self) -> bool: + """Return if the handle's flat parameter needs to be unsharded.""" + if not self.uses_sharded_strategy: + return False + unsharded_flat_param = self._get_padded_unsharded_flat_param() + already_unsharded = _same_storage_size( + unsharded_flat_param, unsharded_flat_param.numel() + ) + return not already_unsharded + + def _alloc_padded_unsharded_flat_param(self): + """ + Allocate the *padded* unsharded flat parameter. + + The unpadded unsharded + flat parameter is always a view into the padded one. This padded + parameter is saved to a different attribute on the ``FlatParameter`` + depending on if we force full precision. + """ + self._check_sharded_strategy() + flat_param = self.flat_param + unsharded_flat_param = self._get_padded_unsharded_flat_param() + self._check_storage_freed(unsharded_flat_param) + _alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size) # type: ignore[attr-defined] + return unsharded_flat_param + + def _get_padded_unsharded_flat_param(self) -> torch.Tensor: + """ + Return a reference to the padded unsharded flat parameter depending on the calling context. + + This should only be called if using a sharded strategy. + """ + self._check_sharded_strategy() + flat_param = self.flat_param + if self._force_full_precision and self._uses_param_mixed_precision: + # When parameter mixed precision is enabled, we use a different + # tensor as the all-gather destination to preserve the invariant + # that `_full_param_padded` is in the low precision + unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined] + _p_assert( + unsharded_flat_param.dtype != self._fwd_bwd_param_dtype, + f"Expects full precision but got {self._fwd_bwd_param_dtype}", + ) + # For no-reshard-after-forward strategies, `_full_param_padded` may + # still be allocated from a previous forward. As we are forcing + # full precision here, the full-precision unsharded copy may be + # modified, invalidating the existing low-precision unsharded copy, + # so we should free it here to ensure a new all-gather for the next + # forward/backward computation to persist the modifications. + if flat_param._full_param_padded.untyped_storage().size() > 0: + _free_storage(flat_param._full_param_padded) + else: + unsharded_flat_param = flat_param._full_param_padded # type: ignore[attr-defined] + return unsharded_flat_param + + def _all_gather_flat_param( + self, + padded_unsharded_flat_param: Tensor, + ) -> Tensor: + """ + All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``. + + Then switch to use the all-gathered tensor. + """ + _p_assert( + hasattr(self, "process_group") and hasattr(self, "world_size"), + "Expects a process group and world size to have been set via `shard()`", + ) + sharded_flat_param = self.flat_param.data + expected_numel = sharded_flat_param.numel() * self.world_size + _p_assert( + padded_unsharded_flat_param.numel() == expected_numel, + f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}", + ) + + pg = ( + self._fake_process_group + if self._use_fake_all_gather + else self.process_group + ) + + # HACK this should be handled by C10D + if sharded_flat_param.is_cpu: # type: ignore[attr-defined] + tensor_list = list( + torch.chunk( + padded_unsharded_flat_param, + dist.get_world_size(pg), # type: ignore[arg-type] + ) + ) + dist.all_gather(tensor_list, sharded_flat_param, group=pg) + else: + dist.all_gather_into_tensor( + padded_unsharded_flat_param, + sharded_flat_param, + pg, + ) + + if self._offload_params: + # In case of offloading, `flat_param.data` (i.e. sharded param) is + # created on the pre-unshard stream. We need to hand it over to the + # unshard stream for all-gather + _no_dispatch_record_stream( + sharded_flat_param, + self._device_handle.current_stream(), # unshard_stream + ) + return padded_unsharded_flat_param + + def _use_unsharded_flat_param( + self, + padded_unsharded_flat_param: torch.Tensor, + ) -> None: + """ + Switch to use the *unpadded* unsharded flat parameter. + + This is a view into the *padded* unsharded flat parameter. + """ + unsharded_size = self.flat_param._unpadded_unsharded_size + flat_param_part = padded_unsharded_flat_param[: unsharded_size.numel()] + # slicing [:] is not visible to autograd because of .data + self.flat_param.data = flat_param_part + in_forward = self._training_state == HandleTrainingState.FORWARD + in_pre_backward = self._training_state == HandleTrainingState.BACKWARD_PRE + if self._use_orig_params: + if self._skipped_use_sharded_views and in_pre_backward: + # This call corresponds to the complementary pre-backward + # `_use_unsharded_views()` to the skipped pre-forward + # `_use_sharded_views()`, so we should skip this one too. + return + # We use `Tensor` views in the forward so that they are tracked by + # autograd. We use them in the pre-backward as well to support + # reentrant activation checkpointing, which needs the views to be + # tracked by autograd in the backward pass's recomputed forward. + self._use_unsharded_views( + as_params=(not in_forward and not in_pre_backward) + ) + elif in_forward: + self._use_unsharded_views(as_params=False) + + def post_unshard(self): + """ + Run the post-unshard logic. + + This includes freeing the low precision shard if needed. + """ + if self._uses_param_mixed_precision and self.uses_sharded_strategy: + self._free_low_precision_sharded_param() + self._check_on_compute_device(self.flat_param) + + def _free_low_precision_sharded_param(self): + """Frees the low precision sharded flat parameter.""" + self._check_low_precision_shard() + # `_mp_shard` is allocated in the pre-unshard stream, consumed in the + # unshard stream for sharded strategies, and consumed in both the + # unshard and default streams for `NO_SHARD`. For sharded strategies, + # the current stream here is the unshard stream, and for `NO_SHARD`, + # it is the default stream. For `NO_SHARD`, only recording for the + # default stream suffices since the default stream waits for the + # unshard stream. + _no_dispatch_record_stream( + self.flat_param._mp_shard, + self._device_handle.current_stream(), # type: ignore[attr-defined] + ) + _free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined] + + @torch.no_grad() + def unshard_grad(self): + """ + Unshard the handle's ``FlatParameter``'s gradient. + + If all ranks have + ``None`` gradient, then all original parameters will as well. This + method performs an all-reduce and an all-gather. The additional + all-reduce is tolerable since this method is not meant to be used on + the computation critical path. + + Postcondition: ``_saved_grad_shard`` is defined and contains the value + to set ``flat_param.grad`` after gradients are resharded. + """ + if not self.uses_sharded_strategy: + self._use_unsharded_grad_views() + return + flat_param = self.flat_param + self._check_unsharded(flat_param) + + # Check if all ranks have a `None` gradient + num_grad_none = torch.zeros(1, dtype=torch.int32, device=self.device) + num_grad_none[0] = flat_param.grad is None + dist.all_reduce(num_grad_none, group=self.process_group) + if num_grad_none[0] == self.world_size: + flat_param._saved_grad_shard = None # type: ignore[assignment] + self._use_unsharded_grad_views() + return + + if flat_param.grad is None: + # In the case that only some ranks have `None` gradient, we use + # zeros to approximate as a best effort attempt + if self._debug_level == dist.DebugLevel.INFO: + warnings.warn( + f"[Rank {self.rank}] Only some but not all ranks have a " + "`None` `FlatParameter` gradient, so FSDP is using zeros to " + "approximate those ranks' sharded gradients being `None`", + stacklevel=2, + ) + flat_param._saved_grad_shard = None # type: ignore[assignment] + sharded_grad = torch.zeros(flat_param._sharded_size, device=self.device) # type: ignore[attr-defined] + else: + self._check_sharded(flat_param.grad) + flat_param._saved_grad_shard = flat_param.grad # type: ignore[attr-defined] + sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined] + padded_unsharded_grad = torch.empty( + flat_param._padded_unsharded_size, # type: ignore[attr-defined] + device=self.device, + dtype=sharded_grad.dtype, + ) + dist.all_gather_into_tensor( + padded_unsharded_grad, sharded_grad, self.process_group + ) + unsharded_size = self.flat_param._unpadded_unsharded_size + flat_param.grad = padded_unsharded_grad[: unsharded_size.numel()].view( + unsharded_size + ) + self._use_unsharded_grad_views() + + def reshard_grad(self): + if self._use_orig_params: + self._use_sharded_grad_views() + if not self.uses_sharded_strategy: + return + self.flat_param.grad = self.flat_param._saved_grad_shard # type: ignore[attr-defined] + delattr(self.flat_param, "_saved_grad_shard") + + def prepare_gradient_for_backward(self): + """ + Prepare the gradient for the backward computation. + + This is done by saving and clearing any existing sharded gradient + in ``.grad`` to enable computing a new unsharded gradient. + """ + _p_assert( + self._training_state + in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE), + "Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)", + ) + flat_param = self.flat_param + if flat_param.grad is not None and ( + flat_param.grad.size() != flat_param._unpadded_unsharded_size + or flat_param.grad.device != flat_param.device # grad on CPU + ): + self._check_on_compute_device(self.flat_param) + grad_offloaded = flat_param.grad.device != self.device + _p_assert( + not grad_offloaded or self._offload_params, + f"Expects the sharded gradient to be on {self.device} " + f"but got {flat_param.grad.device}", + ) + prev_iter_synced_gradients = ( + flat_param.grad.size() == flat_param._local_shard.size() # type: ignore[attr-defined] + ) + if prev_iter_synced_gradients: + # TODO (awgu): Gradient accumulation outside `no_sync()` + # does not work with CPU offloading. The issue should be + # that, in the post-backward hook, we cannot do an addition + # between a CPU tensor (the existing sharded gradient) and + # a GPU tensor (the new sharded gradient). + if not grad_offloaded: + flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined] + sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined] + else: + _p_assert( + hasattr(flat_param, "_cpu_grad"), + "`_cpu_grad` should be defined if the gradient is on CPU", + ) + sharded_grad = flat_param._cpu_grad # type: ignore[attr-defined] + # If user specified to keep the gradient in low precision, then + # the gradient may still be of the low precision dtype if the + # user did not set the gradient to `None` after the previous + # backward, in which case FSDP should cast back to the full + # precision dtype so that FSDP can accumulate in that dtype in + # the post-backward hook and assign to `.grad` in that dtype in + # the post-backward callback. + local_shard_dtype = flat_param._local_shard.dtype # type: ignore[attr-defined] + if ( + self._keep_low_precision_grads + and sharded_grad.dtype != local_shard_dtype + ): + sharded_grad.data = sharded_grad.to(local_shard_dtype) + else: + padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined] + _p_assert( + flat_param.grad.size() == padded_unsharded_size, + "Expects `.grad` to be the unsharded gradient in " + f"`no_sync()` with size {padded_unsharded_size} " + f"but got size {flat_param.grad.size()}", + ) + flat_param.grad = None + + def prepare_gradient_for_optim(self): + """Prepare the gradient for optimizer computation by moving the sharded gradient to the ``.grad`` attribute.""" + + def cast_grad_to_param_dtype_if_needed(flat_param): + # TODO (rohan-varma): test for full precision with keep_low_precision_grads + if not self._force_full_precision and self._keep_low_precision_grads: + _p_assert(flat_param.grad is not None, "Unexpected None grad!") + if flat_param.grad.dtype != self._fwd_bwd_param_dtype: + flat_param.grad.data = flat_param.grad.to(self._fwd_bwd_param_dtype) + if self._use_orig_params: + self._use_sharded_grad_views() + + flat_param = self.flat_param + # TODO (awgu): We should replace these conditional checks to encode + # the logical intention more directly. + if hasattr(flat_param, "_cpu_grad"): + # NOTE: This branch includes `NO_SHARD`. + self._check_sharded(flat_param) + self._check_on_cpu(flat_param) + flat_param.grad = flat_param._cpu_grad # type: ignore[attr-defined] + cast_grad_to_param_dtype_if_needed(flat_param) + elif hasattr(flat_param, "_saved_grad_shard"): + self._check_sharded(flat_param) + self._check_on_compute_device(flat_param) + if flat_param._saved_grad_shard is not None: + self._check_on_compute_device(flat_param._saved_grad_shard) # type: ignore[attr-defined] + # If no sharded gradient was computed this iteration, then there is + # no need to forward `_saved_grad_shard` to `grad` + if flat_param._post_backward_called: # type: ignore[attr-defined] + flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined] + if flat_param.grad is not None: + cast_grad_to_param_dtype_if_needed(flat_param) + else: + _p_assert( + not self.uses_sharded_strategy or not flat_param._post_backward_called, # type: ignore[attr-defined] + "All sharded parameters that received a gradient in the " + "post-backward should use `_saved_grad_shard`", + ) + # Delete `_saved_grad_shard` since its existence indicates a previous + # gradient to accumulate with in the post-backward hook + if hasattr(flat_param, "_saved_grad_shard"): + delattr(flat_param, "_saved_grad_shard") + + @contextlib.contextmanager + def to_cpu(self): + """ + Move the unpadded unsharded flat parameter to CPU while in the context and moves it back to the previous device upon exit. + + For now, this assumes the ``FlatParameter`` is the unpadded unsharded flat parameter + since (1) there is no reason to include the padding in the copy and (2) + there is no use case for the sharded flat parameter. + + Precondition: ``self.flat_param`` 's data is the unpadded unsharded + flat parameter on the compute device, and the handle uses a sharded + strategy. + Postcondition: Same as the precondition. + """ + self._check_sharded_strategy() + _p_assert( + self.flat_param.size() == self.flat_param._unpadded_unsharded_size, + f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}", + ) + self._check_on_compute_device(self.flat_param) + # Check that the unpadded unsharded flat parameter is a view into the + # padded unsharded flat parameter as expected + # NOTE: This check is not strictly needed for correctness but is a + # useful sanity check since the tensor should only be used internally. + _p_assert( + _same_storage(self.flat_param, self._get_padded_unsharded_flat_param()), + "Expects the unpadded parameter to be a view into the padded parameter", + ) + self.flat_param_to(torch.device("cpu")) + self._free_unsharded_flat_param() + try: + yield + finally: + _p_assert( + self.flat_param.size() == self.flat_param._unpadded_unsharded_size, + f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}", + ) + padded_unsharded_flat_param = self._alloc_padded_unsharded_flat_param() + # Copy from CPU to the compute device + padded_unsharded_flat_param[: self.flat_param.numel()].copy_( + self.flat_param + ) + self._use_unsharded_flat_param(padded_unsharded_flat_param) + + def reshard(self, free_unsharded_flat_param: bool): + """ + Run the reshard logic. + + This includes freeing the unsharded flat + parameter if ``free_unsharded_flat_param`` and switching to using the + sharded flat parameter. Note that this also implicitly offloads + the sharded flat parameter (if CPU offload is enabled) by pointing + it to the ``_local_shard`` attribute which resides on CPU. + """ + # Switch to the sharded `FlatParameter` before freeing to prevent + # "use-after-free"-type bugs with external profiling tools, where for + # `use_orig_params=True`, the `param` does not point to valid memory + # when setting `param.data = ...` in `_use_sharded_views()`. + self._use_sharded_flat_param() + if free_unsharded_flat_param: + self._free_unsharded_flat_param() + + def post_reshard(self): + """ + Run the post-reshard logic. + + This includes freeing any memory that + can now be freed given that the ``FlatParameter`` points to the full + precision sharded flat parameter. + + Precondition: ``self.flat_param`` 's data points to the full precision + sharded flat parameter. + """ + # For `NO_SHARD`, `_mp_shard` is not freed in the post-unshard since it + # is also the low precision *unsharded* flat parameter. Hence, we delay + # the free until the reshard. + if ( + self._uses_param_mixed_precision + and not self.uses_sharded_strategy + and not self._force_full_precision # did not use the low precision shard + ): + self._free_low_precision_sharded_param() + + def _free_unsharded_flat_param(self): + """ + Free the padded unsharded flat parameter. We allow this + function to be called even when storage is not allocated + + The tensor to free depends + on the calling context since the unshard may have forced full + precision, in which case a different tensor is used. + """ + self._check_sharded_strategy() + unsharded_flat_param = self._get_padded_unsharded_flat_param() + self._check_on_compute_device(unsharded_flat_param) + # Do not free the memory until all ops in the current stream finish + _no_dispatch_record_stream( + unsharded_flat_param, self._device_handle.current_stream() + ) + _free_storage(unsharded_flat_param) + + def _use_sharded_flat_param(self) -> None: + """Switches to using the sharded flat parameter.""" + flat_param = self.flat_param + if self._use_orig_params: + in_forward = self._training_state == HandleTrainingState.FORWARD + skip_use_sharded_views = ( + torch.is_grad_enabled() + and in_forward + and self._sharding_strategy + in NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES + ) + # Only incur the extra `.data` call if needed + if skip_use_sharded_views: + unsharded_flat_param = flat_param.data + if self._offload_params: + device = flat_param._local_shard.device # type: ignore[attr-defined] + _p_assert( + device == torch.device("cpu"), + f"Expects the local shard to be on CPU but got {device}", + ) + flat_param.data = flat_param._local_shard # type: ignore[attr-defined] + if self._use_orig_params: + if skip_use_sharded_views: # type: ignore[possibly-undefined] + self._unsharded_flat_param_for_skipped_views = unsharded_flat_param # type: ignore[possibly-undefined] + else: + self._use_sharded_views() + # For the post-forward reshard, we may try to use sharded gradient + # views (or unsharded gradient views if a gradient was accumulated + # in `no_sync()`), but for the post-backward reshard, we delay the + # call to after the reduce-scatter. + if ( + in_forward # type: ignore[possibly-undefined] + # Skip using gradient views if skipped using sharded views + # since exposing unsharded parameters with sharded gradients + # may be confusing to the user + and not self._skipped_use_sharded_views + ): + # TODO: Change `_unpadded_unsharded_size` if we change the + # gradient to be computed directly with padding. + accumulated_grad_in_no_sync = ( + flat_param.grad is not None + and self.uses_sharded_strategy + and flat_param.grad.shape == flat_param._unpadded_unsharded_size + ) + if accumulated_grad_in_no_sync: + self._use_unsharded_grad_views() + else: + self._use_sharded_grad_views() + + ######### + # VIEWS # + ######### + @no_type_check + def _get_unflat_views_unaligned( + self, + tensor: Optional[torch.Tensor] = None, + ) -> Iterator[Tensor]: + """ + Return unflattened ``Tensor`` views into ``tensor``. + + If `tensor`` is ``None``, ``flat_param`` is used. The unflattening is based + on ``flat_param`` 's metadata. + + Examples for ``tensor`` include ``flat_param.grad`` or unsharded + tensor optimizer state. + """ + flat_param = self.flat_param + if tensor is None: + tensor = flat_param + views = ( + _ext_post_unflatten_transform( + subtensor.view(shape) + if contiguous + else subtensor.as_strided(shape, stride), + param_extension, + self._fsdp_extension, + ) + for (subtensor, shape, stride, contiguous, param_extension) in zip( + torch.split(tensor, flat_param._numels, dim=0), + flat_param._shapes, + flat_param._strides, + flat_param._contiguities, + flat_param._param_extensions, + ) + ) + return views + + @no_type_check + def _get_unflat_views_aligned( + self, + tensor: Optional[Tensor] = None, + ) -> list[Tensor]: + """ + Return unflattened ``Tensor`` views into ``tensor`` with handling for padding. + + This method has the same contract as :meth:`_get_unflat_views_unaligned` + except it checks for ``None`` placeholders representing padding for + alignment, which may incur slightly more CPU overhead. + """ + flat_param = self.flat_param + if tensor is None: + tensor = flat_param + splits: list[Tensor] = torch.split( + tensor, flat_param._numels_with_padding, dim=0 + ) + idx = 0 + views: list[Tensor] = [] + for split, is_padding in zip(splits, flat_param._is_padding_mask): + if is_padding: + continue + views.append( + _ext_post_unflatten_transform( + split.view(flat_param._shapes[idx]) + if flat_param._contiguities[idx] + else split.as_strided( + flat_param._shapes[idx], flat_param._strides[idx] + ), + flat_param._param_extensions[idx], + self._fsdp_extension, + ) + ) + idx += 1 + return views + + @no_type_check + @torch.enable_grad() + def _use_unsharded_views(self, as_params: bool) -> None: + """ + Unflatten the unsharded flat parameter by setting the original parameter variables to be views into it. + + Args: + as_params (bool): If ``True``, then registers the original + parameters as ``nn.Parameter`` s; if ``False``, then registers + the original parameters only as ``Tensor`` s. ``False`` should + be used during forward/backward computation and when hiding the + original parameters from :meth:`nn.Module.named_parameters`. + + Note: + when prefetching for next forward, current forward may be + annotated with `@torch.no_grad()` + `@torch.enable_grad()` ensures non-empty `view.grad_fn` + otherwise `_post_backward_hook` will not get called + """ + flat_param = self.flat_param + self._check_unsharded(flat_param) + views = self._get_unflat_views() + from torch.distributed.tensor import DTensor + + for i, (view, (param_name, module, _)) in enumerate( + zip(views, flat_param._param_infos) + ): + if self._use_orig_params and as_params: + if type(view) is DTensor: + # A `DTensor` `view` is not compatible with assigning + # `param.data = view`, so we cannot preserve the parameter + # variable. + self._setattr_param( + module, + param_name, + nn.Parameter(view, requires_grad=flat_param.requires_grad), + ) + continue + param = self.flat_param._params[i] + self._setattr_param(module, param_name, param) + param.data = view + elif as_params: + self._setattr_param( + module, + param_name, + nn.Parameter(view, requires_grad=flat_param.requires_grad), + ) + else: # `as_params=False` + param_var: Tensor = view + if self._use_orig_params: + if self._training_state == HandleTrainingState.FORWARD: + # Save the `Tensor` for the pre-backward + self.flat_param._tensors[i] = view # save for pre-backward + elif self._training_state == HandleTrainingState.BACKWARD_PRE: + # Use the saved `Tensor` variable from the forward to + # preserve the autograd graph so that the post-backward + # hook fires (e.g. for reentrant AC) + tensor = self.flat_param._tensors[i] + tensor.data = view + param_var = tensor + self._setattr_tensor(module, param_name, param_var) + if ( + self._use_orig_params + and self._training_state == HandleTrainingState.FORWARD + ): + module._parameters[param_name] = param_var + for i, ( + param_name, + module, + _, + prim_param_name, + prim_module, + _, + ) in enumerate(self.flat_param._shared_param_infos): + prim_param: Union[Tensor, nn.Parameter] = getattr( + prim_module, prim_param_name + ) + _p_assert( + not as_params or isinstance(prim_param, nn.Parameter), + f"as_params={as_params} type(prim_param)={type(prim_param)}", + ) + if self._use_orig_params and as_params: + shared_param = self.flat_param._shared_params[i] + self._setattr_param(module, param_name, shared_param) + shared_param.data = prim_param + elif as_params: + self._setattr_param(module, param_name, prim_param) + else: + self._setattr_tensor(module, param_name, prim_param) + if ( + self._use_orig_params + and self._training_state == HandleTrainingState.FORWARD + ): + module._parameters[param_name] = prim_param + + @no_type_check + def _use_unsharded_grad_views(self) -> None: + """ + Unflatten the unsharded flat parameter's gradient. + + The original parameter variables' gradients are set to be views into + the unsharded flat parameter's gradient. + """ + # Expects the gradient to be in `flat_param.grad` + if self.flat_param.grad is None: + for param in chain(self.flat_param._params, self.flat_param._shared_params): + param.grad = None + return + self._check_unsharded(self.flat_param.grad) + views = self._get_unflat_views(self.flat_param.grad) + for i, (view, (param_name, module, _)) in enumerate( + zip(views, self.flat_param._param_infos) + ): + _p_assert( + hasattr(module, param_name), + f"{self.flat_param._fqns[i]} is missing", + ) + param = getattr(module, param_name) + if ( + param.shape != view.shape + or param.dtype != view.dtype + or param.device != view.device + ): + # NOTE: This is a hack using `.data` to side step the check + # that parameter/gradient sizes/dtypes/devices match. From + # calling `reshard()`, `param` has the sharded size, has the + # full precision dtype, and if CPU offloading is enabled, is on + # CPU. Thus, one or more of the following cases can hold when + # in `no_sync()`, where `view` is the original parameter's + # gradient: + # 1. `view` can have the unsharded size. + # 2. `view` can have the parameter low precision dtype. + # 3. `view` can be on GPU. + if param.grad is None: + param.grad = torch.empty_like(param) + param.grad.data = view + else: + param.grad = view + for ( + param_name, + module, + module_name, + prim_param_name, + prim_module, + _, + ) in self.flat_param._shared_param_infos: + _p_assert( + hasattr(module, param_name), + f"{module_name + '.' + param_name if module_name else param_name} is missing", + ) + param = getattr(module, param_name) + prim_param = getattr(prim_module, prim_param_name) + if ( + param.shape != prim_param.grad.shape + or param.dtype != prim_param.grad.dtype + or param.device != prim_param.grad.device + ): + # NOTE: This is the same hack to use `.data` to side step the + # size check. + if param.grad is None: + param.grad = torch.empty_like(param) + param.grad.data = prim_param.grad + else: + param.grad = prim_param.grad + + @contextlib.contextmanager + def unflatten_as_params(self) -> Generator: + """ + Unflatten the original parameters. + + The function assumes that the flat parameter is unsharded. When in the context, + unflattens the original parameters as ``nn.Parameter`` views into the + flat parameter, and after the context, restores the original parameters + as ``Tensor`` views into the flat parameter. + """ + self._use_unsharded_views(as_params=True) + try: + yield + finally: + self._use_unsharded_views(as_params=False) + + @no_type_check + @torch.no_grad() + def _use_sharded_views(self) -> None: + """ + Set the original parameter variables' data to be flattened views into the sharded flat parameter. + + The views are kept as flattened to simplify the case where a parameter + is sharded across ranks. Parameters whose data is not present in the + sharded flat parameter have their data set to a size-0 empty tensor. We + do not delete them to ensure to preserve expected behaviors like model + printability. Parameters whose data is present must preserve their + variables to be passable to an optimizer. + """ + self._unsharded_flat_param_for_skipped_views = None + if not self.uses_sharded_strategy: + # For `NO_SHARD`, use the *unflattened* unsharded views since we + # have the unsharded parameter + self._use_unsharded_views(as_params=True) + return + flat_param = self.flat_param + self._check_sharded(flat_param) + # Construct once and reuse for all parameters not in the local shard + size_0_empty_tensor = torch.empty( + 0, + dtype=self.flat_param.dtype, # in case `flat_param` changed dtype + device=self.flat_param.device, + requires_grad=False, + ) + for param, shard_param_info, (param_name, module, _) in zip( + flat_param._params, flat_param._shard_param_infos, flat_param._param_infos + ): + self._setattr_param(module, param_name, param) + if not shard_param_info.in_shard: + # Allow the original data to be freed via garbage collection + param.data = size_0_empty_tensor + else: + offset = shard_param_info.offset_in_shard + numel_in_shard = shard_param_info.numel_in_shard + param.data = flat_param[offset : offset + numel_in_shard] + if self.flat_param._shared_params is None: + raise AssertionError("Expected _shared_params to be not None") + for param, (param_name, module, _, prim_param_name, prim_module, _) in zip( + self.flat_param._shared_params, self.flat_param._shared_param_infos + ): + self._setattr_param(module, param_name, param) + prim_param = getattr(prim_module, prim_param_name) + param.data = prim_param # could be both empty and non-empty + if self._training_state == HandleTrainingState.BACKWARD_POST: + # Clear the saved `Tensor`s since they are unneeded now + for i in range(len(self.flat_param._tensors)): + self.flat_param._tensors[i] = None + + @no_type_check + @torch.no_grad() + def _use_sharded_grad_views(self) -> None: + """ + Set the original parameter variables' gradients to be flattened views into the sharded flat parameter's gradient. + + This is a no-op if there is no gradient. + + Parameters whose data is not present in the sharded flat parameter and + parameters with ``requires_grad=False`` have their gradients set to + ``None``. Since the gradient variables do not need to be preserved, + this method does not manipulate existing ``Tensor`` data directly and + creates new ``Tensor`` variables instead. + """ + flat_param = self.flat_param + self._check_sharded(flat_param) + grad = self.sharded_grad + if grad is None: + for param in chain(flat_param._params, flat_param._shared_params): + param.grad = None + return + self._check_sharded(grad) + for param, shard_param_info, is_grad_none in zip( + flat_param._params, + flat_param._shard_param_infos, + flat_param._is_grad_none_mask, + ): + if not shard_param_info.in_shard: + param.grad = None + else: + numel_in_shard = shard_param_info.numel_in_shard + if param.requires_grad and not is_grad_none: + offset = shard_param_info.offset_in_shard + if self._keep_low_precision_grads or param.dtype != grad.dtype: + # NOTE: This is a hack using `.data` to side step the + # check that parameter/gradient dtypes match. Here, + # `param` has full precision; `grad` has low precision. + if param.grad is None: + # `.grad` must have the same shape as `param` + param.grad = torch.empty_like(param) + param.grad.data = grad[ + offset : offset + numel_in_shard + ].reshape(param.shape) + else: + param.grad = grad[offset : offset + numel_in_shard].reshape( + param.shape + ) + else: + param.grad = None + if flat_param._shared_params is None: + raise AssertionError("Expected _shared_params to be not None") + for param, (_, _, _, prim_param_name, prim_module, _) in zip( + flat_param._shared_params, flat_param._shared_param_infos + ): + in_sharded_flat_param = hasattr(prim_module, prim_param_name) + if in_sharded_flat_param and param.requires_grad: + prim_param = getattr(prim_module, prim_param_name) + param.grad = prim_param.grad # share the same reference + else: + param.grad = None + + @no_type_check + @torch.no_grad() + def _writeback_orig_params(self) -> bool: + """ + Write back any parameters that changed storage to the handle's ``FlatParameter``. + + Iterates over the original parameters and writes back any parameters + that changed storages (due to a non-inplace operator) to the handle's + ``FlatParameter``. This method preserves the ``FlatParameter` 's + device even if an original parameter's device changes. + + Raises: + RuntimeError: If an original parameter or gradient changes storages + but no longer has the expected flattened shape. + Returns: ``True`` if some writeback happened, and ``False`` otherwise. + """ + if ( + self.uses_sharded_strategy + and not self.is_sharded(self.flat_param) + and not self._skipped_use_sharded_views + ): + # For `NO_SHARD`, we may still need to writeback + return False + flat_param = self.flat_param + wroteback = False + if self._skipped_use_sharded_views and self.uses_sharded_strategy: + # NOTE: We must use the unsharded flat parameter from which the + # unsharded views were computed, not the one from the current + # calling context (`_get_padded_unsharded_flat_param()`) since that + # may be different (e.g. the model changed from train to eval). + flat_param_tensor = self._unsharded_flat_param_for_skipped_views + _p_assert( + _data_ptr_allocated(flat_param_tensor), + "If skipped using sharded views, the unsharded flat parameter " + "should be allocated", + ) + else: + flat_param_tensor = flat_param + # NOTE: Since this method is called in the pre-unshard, which is only + # called during computation in the pre-forward or pre-backward, the + # sharded gradient should be guaranteed to be in `.grad`, not in + # `._saved_grad_shard`. + flat_param_grad = ( + flat_param.grad + if self.uses_sharded_strategy or not self._offload_params + else flat_param._cpu_grad + ) + for i, ( + param, + (in_shard, offset_in_shard, numel_in_shard, _, _), + (param_name, module, _), + ) in enumerate( + zip( + flat_param._params, + flat_param._shard_param_infos, + flat_param._param_infos, + ) + ): + if not in_shard: + continue + if not hasattr(module, param_name): + # Do not writeback if original parameters are deregistered + # (e.g. during model checkpointing) + continue + + # Check for parameter writeback + if self._skipped_use_sharded_views: + param = flat_param._tensors[i] + _p_assert( + param is not None, + f"Expects to have saved tensor for {flat_param._fqns[i]}", + ) + param_changed = getattr(module, param_name) is not param + needs_param_writeback = ( + param_changed # changed parameter variable itself + or not _same_storage(param, flat_param_tensor) + ) + if self._skipped_use_sharded_views and ( + param_changed or needs_param_writeback + ): + raise AssertionError( + "FSDP does not support changing the parameters between " + f"forward and backward for {self._sharding_strategy}" + ) + if param_changed: + # NOTE: The gradient is not preserved after a parameter change. + param = getattr(module, param_name) + flat_param._params[i] = param + if needs_param_writeback: + expected_shape = torch.Size([numel_in_shard]) + src = param if self.uses_sharded_strategy else param.view(-1) + self._writeback_tensor( + src, flat_param, i, expected_shape, offset_in_shard, True + ) + wroteback = True + + # Check for gradient writeback + if self._skipped_use_sharded_views: + # Skip the writeback check because we do not expose gradients + # when we skipped using sharded views + continue + if param.grad is None and flat_param.grad is not None: + expected_shape = torch.Size([numel_in_shard]) + self._writeback_tensor( + None, flat_param.grad, i, expected_shape, offset_in_shard, False + ) + elif param.grad is not None: + # For `NO_SHARD` + CPU offloading, `_cpu_grad` is always in + # memory and owns the gradient storage, so it will never + # require gradient writeback. + if not self.uses_sharded_strategy and self._offload_params: + # Explicitly continue to handle the case of `no_sync()`, + # where `param.grad` is a view into the GPU gradient + # referenced by `flat_param.grad`, while `flat_param_grad` + # is `flat_param._cpu_grad`, which is on CPU + continue + + needs_grad_writeback = flat_param_grad is None or not _same_storage( + param.grad, flat_param_grad + ) + if needs_grad_writeback: + if flat_param_grad is None: + flat_param_grad = torch.zeros_like(flat_param) + expected_shape = torch.Size([numel_in_shard]) + src = ( + param.grad + if self.uses_sharded_strategy + else param.grad.view(-1) + ) + self._writeback_tensor( + src, + flat_param_grad, + i, + expected_shape, + offset_in_shard, + False, + ) + flat_param.grad = flat_param_grad + flat_param_grad = flat_param.grad + + # TODO: If we want to handle shared parameters, we need to re-generate + # the shared parameter data structures in case sharedness changed. + for ( + param_name, + module, + _, + prim_param_name, + prim_module, + _, + ) in flat_param._shared_param_infos: + if getattr(module, param_name) is not getattr(prim_module, prim_param_name): + raise NotImplementedError( + "Changing shared parameters is not supported yet" + ) + return wroteback + + def _writeback_tensor( + self, + src_tensor: Optional[Tensor], + dst_tensor: Tensor, + tensor_index: int, + expected_shape: torch.Size, + offset: int, + is_param: bool, # else gradient + ) -> None: + """ + Write back ``src_tensor`` to ``dst_tensor`` at offset ``offset``, where ``src_tensor`` should have shape ``expected_shape``. + + ``is_param`` indicates if the tensor is the parameter (if ``True``) or gradient (if + ``False``). If ``src_tensor`` is ``None``, then the effect is zeroing + instead of copying. ``tensor_index`` gives the index of ``src_tensor`` + in the metadata structures. + + Raises: + RuntimeError: If the ``src_tensor`` does not have the expected + shape. + """ + _p_assert( + len(expected_shape) == 1, + f"Expects a 1D expected shape but got {expected_shape}", + ) + if self._debug_level == dist.DebugLevel.INFO: + rank = self.rank if hasattr(self, "rank") else dist.get_rank() + src_shape = src_tensor.shape if src_tensor is not None else None + src_device = src_tensor.device if src_tensor is not None else None + warnings.warn( + f"[Rank {rank}] {'Parameter' if is_param else 'Gradient'} needs " + f"writeback in {self._training_state}\n" + f"expected shape={expected_shape} shape={src_shape} " + f"expected device={dst_tensor.device} device={src_device}", + stacklevel=2, + ) + if src_tensor is not None and src_tensor.shape != expected_shape: + # NOTE: Gradient shape mismatch is not possible in practice since + # the gradient shape is enforced to match that of the parameter and + # we already check for parameter shape mismatch. + raise RuntimeError( + f"Cannot writeback when the {'parameter' if is_param else 'gradient'} " + f"shape changes\nExpects {expected_shape} but got {src_tensor.shape}" + ) + if src_tensor is not None: + dst_tensor[offset : offset + expected_shape.numel()].copy_(src_tensor) + else: + dst_tensor[offset : offset + expected_shape.numel()].zero_() + if self.flat_param._is_grad_none_mask is None: + raise AssertionError("Expected _is_grad_none_mask to be not None") + self.flat_param._is_grad_none_mask[tensor_index] = True + + def _reset_flat_param_grad_info_if_needed(self): + """ + Reset ``flat_param.grad`` if needed. + + When ``use_orig_params=True``: + (1) sets the underlying ``flat_param.grad`` to ``None`` if *all* of the + original parameters' ``.grad`` are ``None``, and + (2) sets ``flat_param.requires_grad=False`` if *none* of the original + parameters require gradient. + For (1), this is targeting ``optim.zero_grad(set_to_none=True)``, in + which case we want to free the gradients as soon after the + ``zero_grad()`` call as possible. + """ + if not self._use_orig_params: + return + flat_param = self.flat_param + if flat_param._params is None: + raise AssertionError("Expected _params to be not None") # mypy + all_grad_none = True + requires_grad = False + for param in flat_param._params: + all_grad_none &= param.grad is None + requires_grad |= param.requires_grad + if all_grad_none: + flat_param.grad = None + # As long as one parameter requires gradient, then the flat parameter + # must require gradient + flat_param.requires_grad = requires_grad + + def _deregister_orig_params(self): + for param_info in self.flat_param._param_infos: + param_name, module, _ = param_info + if hasattr(module, param_name): + delattr(module, param_name) + for param_name, module, _, _, _, _ in self.flat_param._shared_param_infos: + if hasattr(module, param_name): + delattr(module, param_name) + + ########### + # HELPERS # + ########### + def flat_param_to(self, *args, **kwargs): + """Wrap an in-place call to ``.to()`` for ``self.flat_param``.""" + # pyrefly: ignore [not-iterable] + self.flat_param.data = self.flat_param.to(*args, **kwargs) + if self._use_orig_params: + # Refresh the views because their storage may have changed + if self.is_sharded(self.flat_param): + self._use_sharded_views() + else: + self._use_unsharded_views(as_params=True) + + def _get_modules(self) -> set[nn.Module]: + """Return a :class:`set` of the modules whose parameters are included in this handle's flat parameter.""" + return {pi.module for pi in self.flat_param._param_infos}.union( + {spi.module for spi in self.flat_param._shared_param_infos} + ) + + def is_sharded(self, tensor: Tensor) -> bool: + """ + Return whether ``tensor`` is *currently* sharded. + + For ``NO_SHARD``, we choose to have this always return ``False`` for clarity. + """ + if ( + not hasattr(self.flat_param, "_sharded_size") + or not self.uses_sharded_strategy + ): + # `_sharded_size` is defined iff `handle.shard()` has been called + return False + sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined] + return tensor.size() == sharded_size + + def param_module_names(self) -> Iterator[tuple[str, str]]: + shared_param_infos = [ + ParamInfo(param_name, module, module_name) + for ( + param_name, + module, + module_name, + _, + _, + _, + ) in self.flat_param._shared_param_infos + ] + for param_info in chain(self.flat_param._param_infos, shared_param_infos): + param_name, _, module_name = param_info # type: ignore[misc] + yield (param_name, module_name) + + def shared_param_module_names(self) -> Iterator[tuple[str, str]]: + for param_name, _, module_name in [ + ParamInfo(param_name, module, module_name) + for ( + param_name, + module, + module_name, + _, + _, + _, + ) in self.flat_param._shared_param_infos + ]: + yield (param_name, module_name) + + @property + def _fqns_in_shard(self) -> list[str]: + """Return the FQNs of the parameters present in this rank's shard.""" + fqns_in_shard: list[str] = [] + for fqn, shard_param_info in zip( + self.flat_param._fqns, + self.flat_param._shard_param_infos, # type: ignore[attr-defined] + ): + if shard_param_info.in_shard: + fqns_in_shard.append(fqn) + return fqns_in_shard + + @property + def sharded_grad(self) -> Optional[Tensor]: + """Return the handle's sharded gradient.""" + flat_param = self.flat_param + # Priority for non-`None`: `_cpu_grad` > `_saved_grad_shard` > `grad` + # - CPU offloading: `_cpu_grad` + # - No CPU offloading + sharded strategies: `_saved_grad_shard` + # - No CPU offloading + `NO_SHARD`: `grad` + grad: Optional[Tensor] + if hasattr(flat_param, "_cpu_grad"): + grad = flat_param._cpu_grad # type: ignore[attr-defined] + elif hasattr(flat_param, "_saved_grad_shard"): + # In the post-backward hook, the sharded gradient is still in + # `_saved_grad_shard`. + grad = flat_param._saved_grad_shard # type: ignore[attr-defined] + else: + # If in IDLE or in FORWARD states, then there may be an + # (accumulated) gradient. If accessed in IDLE, then this should + # be due to re-registering the original parameters (e.g. in state + # dict load). + _p_assert( + flat_param.grad is None + or not self.uses_sharded_strategy + or self._training_state + in (HandleTrainingState.FORWARD, HandleTrainingState.IDLE), + "Sharded strategies should use `_cpu_grad` or `_saved_grad_shard` " + "unless in IDLE or FORWARD", + ) + grad = flat_param.grad + return grad + + def _reset_is_grad_none(self) -> None: + """ + Reset ``_is_grad_none_mask`` as needed. + + This method should only be + called in the post-backward after gradient computation, in which case + if a parameter requires gradient, then it will surely receive a + gradient and we may reset its mask entry to ``False``. + """ + if not self._use_orig_params: + return + _p_assert( + self._training_state == HandleTrainingState.BACKWARD_POST, + "Expects to only be called in the post-backward after gradient computation", + ) + flat_param = self.flat_param + if flat_param._params is None: + raise AssertionError("Expected _params to be not None") # mypy + for i, param in enumerate(flat_param._params): # type: ignore[arg-type] + # As long as the parameter requires gradient, it should receive a + # meaningful gradient (even if the gradient happens to be zeros) + if param.requires_grad: + if flat_param._is_grad_none_mask is None: + raise AssertionError( + "Expected _is_grad_none_mask to be not None" + ) # mypy + flat_param._is_grad_none_mask[i] = False + + ####################### + # CHECKS & INVARIANTS # + ####################### + def _check_sharded_strategy(self): + _p_assert(self.uses_sharded_strategy, "Expects sharded strategy") + + def _check_on_compute_device(self, tensor: Tensor): + _p_assert( + tensor.device == self.device, + f"Expects tensor to be on the compute device {self.device}, was on {tensor.device}", + ) + + def _check_on_cpu(self, tensor: Tensor): + _p_assert( + tensor.device == torch.device("cpu"), + f"Expects tensor to be on CPU but got {tensor.device}", + ) + + @staticmethod + def _check_storage_freed(tensor: Tensor): + # Compile does not resize during trace + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + _p_assert( + _same_storage_size(tensor, 0), + "Expects storage to be freed but got storage with size > 0", + ) + + @staticmethod + def _check_storage_allocated(tensor: Tensor): + _p_assert(_storage_size_allocated(tensor), "Expects storage to be allocated") + + def _check_low_precision_shard(self): + _p_assert( + self._uses_param_mixed_precision, + "Not using low precision for parameters", + ) + _p_assert( + getattr(self.flat_param, "_mp_shard", None) is not None, + "Expects `_mp_shard` to exist", + ) + device = self.flat_param._mp_shard.device # type: ignore[attr-defined] + _p_assert( + device == self.device, + f"Expects the low precision shard to be on {self.device} but got {device}", + ) + + def _check_unsharded(self, tensor: Tensor): + msg_prefix = "Expects tensor to be unsharded " + _p_assert(tensor is not None, msg_prefix + "but got `None`") + unsharded_size = self.flat_param._unpadded_unsharded_size + _p_assert( + tensor.size() == unsharded_size, + msg_prefix + f"with size {unsharded_size} but got {tensor.size()}", + ) + + def _check_sharded(self, tensor: Tensor): + msg_prefix = "Expects tensor to be sharded " + _p_assert(tensor is not None, msg_prefix + "but got `None`") + sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined] + _p_assert( + tensor.size() == sharded_size, + msg_prefix + f"with size {sharded_size} but got {tensor.size()}", + ) + + ############## + # PROPERTIES # + ############## + @property + def uses_sharded_strategy(self) -> bool: + return self._sharding_strategy != HandleShardingStrategy.NO_SHARD + + @property + def _uses_param_mixed_precision(self) -> bool: + return self._fwd_bwd_param_dtype != self._orig_param_dtype + + @property + def _uses_reduce_mixed_precision(self) -> bool: + return self._reduce_dtype != self._orig_param_dtype + + @property + def _force_full_precision(self) -> bool: + return ( + self._uses_param_mixed_precision or self._uses_reduce_mixed_precision + ) and ( + self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS + or + # Also disable mixed precision in model eval mode, if configured + (not self._fully_sharded_module.training and self._use_full_prec_in_eval) + ) + + @property + def _skipped_use_sharded_views(self) -> bool: + """ + This property is used for sharding strategies that do not free after forward with ``use_orig_params=True``. + + This returns if this handle is + currently in a state where it has skipped using sharded views, in which + case it can restore view invariants via ``_use_sharded_views()``. + """ + return self._unsharded_flat_param_for_skipped_views is not None + + +# NOTE: These are hacks to bypass `nn.Module.__setattr__` checks. +def _unsafe_setattr_param( + module: nn.Module, param_name: str, param: nn.Parameter +) -> None: + module._parameters[param_name] = param + # This bypasses any overrides in case `module` is an instance of an + # `nn.Module` subclass + super(nn.Module, module).__setattr__(param_name, param) + + +def _unsafe_setattr_tensor(module: nn.Module, param_name: str, tensor: Tensor) -> None: + module._parameters.pop(param_name, None) + # This bypasses any overrides in case `module` is an instance of an + # `nn.Module` subclass + super(nn.Module, module).__setattr__(param_name, tensor) + + +def _safe_setattr_tensor_or_param( + module: nn.Module, param_name: str, tensor_or_param: Union[Tensor, nn.Parameter] +): + # Call `delattr()` and `setattr()` to go through `nn.Module` checks + if hasattr(module, param_name): + delattr(module, param_name) + setattr(module, param_name, tensor_or_param) + + +def _convert_to_params( + tensors: list[Union[torch.Tensor, nn.Parameter]], +) -> list[nn.Parameter]: + return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors] + + +def _is_truly_contiguous(x: Tensor) -> bool: + # Special case: Pytorch thinks that 1x1 channels_last convolution weights are + # both contiguous and channels_last contiguous at the same time. + # CuDNN does not agree though and refuses to select faster kernels. + # It is the reason of having the extra check here. + return x.stride(-1) == 1 and x.is_contiguous() + + +def _detach_if_needed(param_or_tensor: Union[nn.Parameter, Tensor]) -> Tensor: + return ( + param_or_tensor.detach() + if isinstance(param_or_tensor, nn.Parameter) + else param_or_tensor + ) + + +def _get_aligned_numel(unsharded_dtype: torch.dtype): + # NOTE: This alignment constraint comes from TorchInductor. + ALIGNMENT = 16 # bytes + unsharded_dtype_size = _get_dtype_size(unsharded_dtype) + aligned_numel = ALIGNMENT // unsharded_dtype_size + return aligned_numel + + +@functools.lru_cache(8) +def _get_dtype_size(dtype): + return torch.empty((), dtype=dtype).element_size() + + +def _construct_padding_tensor( + padding_numel: int, dtype: torch.dtype, requires_grad: bool, device: torch.device +): + # NOTE: Set the padding value as a magic number for debuggability. The + # value itself should never be used in any user-facing computation. + return ( + torch.ones( + (padding_numel,), dtype=dtype, requires_grad=requires_grad, device=device + ) + * _FLAT_PARAM_PADDING_VALUE + ) + + +# Use `lru_cache(1)` to only log the warning once (assuming the fixed warning +# message is passed in) +@functools.lru_cache(1) +def _warn_skip_writeback_check(log: logging.Logger, warning: str): + logger.warning(warning) + + +# Use `lru_cache(1)` to only log the warning once +@functools.lru_cache(1) +def _warn_use_fake_all_gather(log: logging.Logger, warning: str): + logger.warning(warning) + + +# Use `lru_cache(1)` to only log the warning once +@functools.lru_cache(1) +def _warn_use_fake_reduce(log: logging.Logger, warning: str): + logger.warning(warning) + + +def _same_storage(a, b): + # Params are DTensors in backward + # with SHARD_GRAD_OP + TP + from torch.distributed.tensor import DTensor + + if isinstance(a, DTensor): + a = a._local_tensor + if isinstance(b, DTensor): + b = b._local_tensor + return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr() + + +def _same_storage_size(a: torch.Tensor, b: int): + return a.untyped_storage().size() // a.element_size() == b + + +def _storage_size_allocated(tensor: Tensor): + storage_size: int = tensor.untyped_storage().size() + return storage_size > 0 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fsdp_extensions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fsdp_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..699274ba50f9a57f26120bd15f5c49b4679f0e9e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fsdp_extensions.py @@ -0,0 +1,180 @@ +from abc import ABC, abstractmethod +from typing import Any, Optional + +import torch +import torch.distributed as dist +from torch.distributed._shard.sharded_tensor.api import ShardedTensor +from torch.distributed._shard.sharded_tensor.shard import Shard +from torch.distributed.fsdp._shard_utils import ( + _all_gather_dtensor, + _create_chunk_dtensor, + _create_chunk_sharded_tensor, +) +from torch.distributed.tensor import DeviceMesh, DTensor + + +class FSDPExtensions(ABC): + """ + This enables some customizable hooks to enable composability with tensor + parallelism. To activate these hooks, use :func:`_set_fsdp_extensions` to + set a custom :class:`FSDPExtensions` that implements the hooks. + """ + + @abstractmethod + def pre_flatten_transform( + self, + tensor: torch.Tensor, + ) -> tuple[torch.Tensor, Optional[Any]]: + """E.g. converting ``DistributedTensor`` to local tensor.""" + ... + + @abstractmethod + def post_unflatten_transform( + self, + tensor: torch.Tensor, + param_extension: Any, + ) -> torch.Tensor: + """E.g. converting local tensor to ``DistributedTensor``.""" + ... + + @abstractmethod + def chunk_tensor( + self, + tensor: torch.Tensor, + rank: int, + world_size: int, + num_devices_per_node: int, + pg: dist.ProcessGroup, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """Shards a tensor to chunks and returns the local chunk.""" + ... + + @abstractmethod + def chunk_dtensor( + self, + tensor: torch.Tensor, + rank: int, + device_mesh: DeviceMesh, + ) -> torch.Tensor: + """Shards a tensor/DTensor to DTensor and returns the local DTensor.""" + ... + + @abstractmethod + def pre_load_state_dict_transform( + self, + tensor: torch.Tensor, + ) -> tuple[torch.Tensor, list[Shard]]: + """ + This is to be called before loading a *sharded* model state dict and + should return the tensor and list of shards from which to load data. + """ + ... + + @abstractmethod + def all_gather_dtensor( + self, + tensor: DTensor, + parent_mesh: Optional[DeviceMesh], + ) -> torch.Tensor: + """ + This is to be called before loading a *sharded* DTensor state dict. + This gathers tensor in FSDP dimension and returns local tensor of + TP DTensor. + """ + ... + + +_extensions: Optional[FSDPExtensions] = None + + +def _set_fsdp_extensions(flattener: FSDPExtensions) -> None: + global _extensions + _extensions = flattener + + +def _ext_pre_flatten_transform( + tensor: torch.Tensor, + fsdp_extension: Optional[FSDPExtensions] = None, +) -> tuple[torch.Tensor, Optional[Any]]: + if fsdp_extension is not None: + new_tensor, param_extension = fsdp_extension.pre_flatten_transform(tensor) + if param_extension is not None: + return new_tensor, param_extension + return tensor, None + + +def _ext_post_unflatten_transform( + tensor: torch.Tensor, + param_extension: Any, + fsdp_extension: Optional[FSDPExtensions] = None, +) -> torch.Tensor: + if fsdp_extension is not None and param_extension is not None: + return fsdp_extension.post_unflatten_transform(tensor, param_extension) + return tensor + + +def _ext_chunk_tensor( + tensor: torch.Tensor, + rank: int, + world_size: int, + num_devices_per_node: int, + pg: dist.ProcessGroup, + fsdp_extension: Optional[FSDPExtensions] = None, +) -> torch.Tensor: + chunk_tensor_fn = ( + fsdp_extension.chunk_tensor + if fsdp_extension is not None + else _create_chunk_sharded_tensor + ) + return chunk_tensor_fn( + tensor, + rank, + world_size, + num_devices_per_node, + pg, + ) + + +def _ext_chunk_dtensor( + tensor: torch.Tensor, + rank: int, + device_mesh: DeviceMesh, + fsdp_extension: Optional[FSDPExtensions] = None, +) -> torch.Tensor: + chunk_dtensor_fn = ( + fsdp_extension.chunk_dtensor + if fsdp_extension is not None + else _create_chunk_dtensor + ) + return chunk_dtensor_fn( + tensor, + rank, + device_mesh, + ) + + +def _ext_pre_load_state_dict_transform( + tensor: torch.Tensor, + fsdp_extension: Optional[FSDPExtensions] = None, +) -> tuple[torch.Tensor, list[Shard]]: + if fsdp_extension is not None: + return fsdp_extension.pre_load_state_dict_transform(tensor) + + if type(tensor) is not ShardedTensor: + raise AssertionError(f"Expected ShardedTensor, got {type(tensor)}") + shards = tensor.local_shards() + return (tensor, shards) + + +def _ext_all_gather_dtensor( + tensor: DTensor, + parent_mesh: Optional[DeviceMesh], + fsdp_extension: Optional[FSDPExtensions] = None, +) -> torch.Tensor: + all_gather_dtensor_fn = ( + fsdp_extension.all_gather_dtensor + if fsdp_extension is not None + else _all_gather_dtensor + ) + return all_gather_dtensor_fn(tensor, parent_mesh) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4d0b341a3f82b35fc903ccffd5208d8fdade399 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__init__.py @@ -0,0 +1,20 @@ +from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy +from ._fully_shard import ( + FSDPModule, + fully_shard, + register_fsdp_forward_method, + share_comm_ctx, + UnshardHandle, +) + + +__all__ = [ + "CPUOffloadPolicy", + "FSDPModule", + "fully_shard", + "MixedPrecisionPolicy", + "OffloadPolicy", + "register_fsdp_forward_method", + "UnshardHandle", + "share_comm_ctx", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f874e668ee5c52d3e160dc286f27ac0aea5bad2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_api.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2969dbef38bca7150fe7f633d7900e8148d21663 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_api.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_collectives.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_collectives.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f27af434b9e8bbe8d4d9ab964f211c2a32e243b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_collectives.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_common.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d683f1b2e96e0e736504bcaadeb37ea4cdc82a4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_common.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_init.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_init.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d634b0f64d0ed6197602ba216ed777d9f16a89af Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_init.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e189eeeb3bc69693322fd697c8d8c581fde292c6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param_group.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param_group.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9eafdb59b9e8c83ec604110457765962e9182913 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param_group.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_state.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_state.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..389b29a29864818a883b093d68e5693d7a4879b0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_state.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fully_shard.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fully_shard.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bad0e48876bf3622bc620ac7a55a4fa59c671129 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fully_shard.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_api.py new file mode 100644 index 0000000000000000000000000000000000000000..38650323f5e99727f04964ca59fb268ca8e7b65c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_api.py @@ -0,0 +1,155 @@ +# mypy: allow-untyped-defs +from abc import ABC, abstractmethod +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.distributed as dist + + +_ReduceOp = Union[dist.ReduceOp, dist.ReduceOp.RedOpType] + + +@dataclass(frozen=True) +class MixedPrecisionPolicy: + """ + This configures FSDP's mixed precision. Unlike autocast, this applies mixed + precision at the module level, not op level, which means low-precision + activations are saved for backward and high-to-low-precision casts are + incurred only at module boundaries. + + FSDP works well with module-level mixed precision since it keeps the + high-precision sharded parameters in memory anyway. In other words, FSDP + does not require any extra memory to keep a high-precision copy of the + parameters for the optimizer step. + + Attributes: + param_dtype (Optional[torch.dtype]): This specifies the dtype for + the unsharded parameter and hence the dtype for forward/backward + computation and the parameter all-gather. If this is ``None``, then + the unsharded parameter uses the original dtype. The optimizer step + uses the sharded parameter in the original dtype. (Default: + ``None``) + reduce_dtype (Optional[torch.dtype]): This specifies the dtype for + gradient reduction (i.e. reduce-scatter or all-reduce). If this is + ``None`` but ``param_dtype`` is not ``None``, then the reduction + uses the compute dtype. This can be used to run gradient reduction + in full precision while using low precision for compute. If also + gradient reduction is disabled via :meth:`set_requires_gradient_sync`, + then FSDP will accumulate gradients using ``reduce_dtype``. + (Default: ``None``) + output_dtype (Optional[torch.dtype]): This specifies the dtype for + casting floating-point forward outputs. This can be used to + help implement cases where different modules have different mixed + precision policies. (Default: ``None``) + cast_forward_inputs (bool): This specifies whether FSDP should cast the + forward's floating-point input tensors to ``param_dtype`` or not. + """ + + param_dtype: Optional[torch.dtype] = None + reduce_dtype: Optional[torch.dtype] = None + output_dtype: Optional[torch.dtype] = None + cast_forward_inputs: bool = True + + +class Comm(ABC): + """ + Interface for communication primitives. + A primitive primarily needs to handle 3 tasks, namely: + + 1. How to allocate memory for communication + Depending on the goal, an implementation can choose to: + a. associate each call to a temporary buffer + (best for flexibility and simplicity) + b. reuse an persistent buffer for efficiency reasons + + 2. Where to allocate memory + (e.g. NCCL mem pool or regular cuda caching allocator) + + 3. What to do/call upon the comm is called + (see `AllGather` interface as an example) + """ + + @abstractmethod + def allocate( + self, + size: Sequence[Union[int, torch.SymInt]], + *, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + """ + This handles the "how to allocate memory" part. + + A default implementation could be simply: + + .. code-block:: python + with self.mem_pool: + torch.empty(...) + + Args: + size (Sequence[Union[int, torch.SymInt]]): size of the tensor buffer + dtype (torch.dtype): dtype of the tensor buffer + device (torch.device): which device to allocate the tensor onto + """ + ... + + +class AllGather(Comm): + """ + Interface for all_gather comm primitive + """ + + @abstractmethod + def __call__( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + group: dist.ProcessGroup, + async_op: bool = False, + ) -> Optional[dist.Work]: ... + + +class ReduceScatter(Comm): + """ + Interface for reduce_scatter comm primitive + """ + + @abstractmethod + def __call__( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + group: dist.ProcessGroup, + op: _ReduceOp, + async_op: bool = False, + ) -> Optional[dist.Work]: ... + + +@dataclass +class OffloadPolicy: + """ + This base class represents the policy of no offloading and is only used as + the default value for the ``offload_policy`` arg. + """ + + +@dataclass +class CPUOffloadPolicy(OffloadPolicy): + """ + This offload policy offloads parameters, gradients, and optimizer states to + CPU. Sharded parameters are copied host-to-device before all-gather. The + all-gathered parameters are freed according to ``reshard_after_forward``. + Sharded gradients are copied device-to-host in backward, and the optimizer + step runs on CPU with CPU optimizer states. + + Attributes: + pin_memory (bool): Whether to pin sharded parameter and gradient + memory. Pinning memory allows both more efficient H2D/D2H copies + and for the copies to overlap with compute. However, the pinned + memory cannot be used by other processes. Set this to ``False`` if + you have insufficient CPU memory. (Default: ``True``) + """ + + pin_memory: bool = True diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py new file mode 100644 index 0000000000000000000000000000000000000000..2bd7d24cd7d3f2fc24b634d72197f8e51c4839e6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py @@ -0,0 +1,762 @@ +import math +from collections.abc import Callable, Sequence +from itertools import chain +from typing import Any, cast, NamedTuple, Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import _get_device_handle +from torch.distributed.distributed_c10d import ReduceOp +from torch.distributed.fsdp._fully_shard._fsdp_api import AllGather, ReduceScatter +from torch.distributed.tensor import DTensor + +from ._fsdp_api import _ReduceOp +from ._fsdp_common import ( + _get_dim0_padded_size, + _raise_assert_with_print, + _to_dtype_if_needed, + compiled_autograd_enabled, +) +from ._fsdp_param import FSDPParam, ShardedState + + +class AllGatherResult(NamedTuple): + all_gather_output: torch.Tensor + all_gather_event: Optional[torch.Event] + all_gather_work: Optional[dist.distributed_c10d.Work] + # For each parameter, the all-gather input dtype for each input + param_all_gather_input_dtypes: list[list[torch.dtype]] + # For each parameter, the all-gather input numel for each input + param_all_gather_input_numels: list[list[int]] + # 1D flattened version of `param_all_gather_input_numels` saved to avoid + # CPU overhead from recomputing + all_gather_input_split_sizes: list[int] + + +lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901 + +lib.define( + """ + all_gather_copy_in( + Tensor[] all_gather_inputs, + Tensor all_gather_output, + SymInt[] inp_split_sizes, + SymInt all_gather_input_numel, + SymInt rank + ) -> (Tensor, Tensor) + """ +) + + +class DefaultAllocMixin: + def allocate( + self, + size: Sequence[Union[int, torch.SymInt]], + *, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + return torch.empty(*size, dtype=dtype, device=device) + + +class ProcessGroupAllocMixin: + def __init__(self, group: dist.ProcessGroup, *args: Any, **kwargs: Any): + self._group = group + super().__init__(*args, **kwargs) + + def allocate( + self, + size: Sequence[Union[int, torch.SymInt]], + *, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + backend = self._group._get_backend(device) + if backend.supports_tensor_alloc(device): + size_1d = math.prod(int(s) for s in size) + return backend.allocate_tensor(size_1d, dtype=dtype, device=device) + return torch.empty(*size, dtype=dtype, device=device) + + +class DefaultAllGather(DefaultAllocMixin, AllGather): + def __call__( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + group: dist.ProcessGroup, + async_op: bool = False, + ) -> Optional[dist.Work]: + return dist.all_gather_into_tensor( + output_tensor, + input_tensor, + group=group, + async_op=async_op, + ) + + +class ProcessGroupAllocAllGather(ProcessGroupAllocMixin, AllGather): + def __init__(self, group: dist.ProcessGroup) -> None: + super().__init__(group) + + def __call__( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + group: dist.ProcessGroup, + async_op: bool = False, + ) -> Optional[dist.Work]: + return dist.all_gather_into_tensor( + output_tensor, + input_tensor, + group=group, + async_op=async_op, + ) + + +class DefaultReduceScatter(DefaultAllocMixin, ReduceScatter): + def __call__( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + group: dist.ProcessGroup, + op: _ReduceOp, + async_op: bool = False, + ) -> dist.Work: + return dist.reduce_scatter_tensor( + output=output_tensor, + input=input_tensor, + group=group, + op=op, + async_op=async_op, + ) + + +class ProcessGroupAllocReduceScatter(ProcessGroupAllocMixin, ReduceScatter): + def __init__(self, group: dist.ProcessGroup) -> None: + super().__init__(group) + + def __call__( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + group: dist.ProcessGroup, + op: _ReduceOp, + async_op: bool = False, + ) -> dist.Work: + return dist.reduce_scatter_tensor( + output=output_tensor, + input=input_tensor, + group=group, + op=op, + async_op=async_op, + ) + + +@torch.library.impl(lib, "all_gather_copy_in", "Meta") +def all_gather_copy_in_meta( + all_gather_inputs: list[torch.Tensor], + all_gather_output: torch.Tensor, + inp_split_sizes: list[int], + all_gather_input_numel: int, + rank: int, +) -> tuple[torch.Tensor, torch.Tensor]: + all_gather_input = all_gather_output.narrow( + 0, all_gather_input_numel * rank, all_gather_input_numel + ) + return all_gather_input, all_gather_output + + +@torch.library.impl(lib, "all_gather_copy_in", "CUDA") +@torch.library.impl(lib, "all_gather_copy_in", "XPU") +@torch.library.impl(lib, "all_gather_copy_in", "HPU") +@torch.library.impl(lib, "all_gather_copy_in", "CPU") +@torch.library.impl(lib, "all_gather_copy_in", "MTIA") +@torch.library.impl(lib, "all_gather_copy_in", "PrivateUse1") +def all_gather_copy_in_cuda( + all_gather_inputs: list[torch.Tensor], + all_gather_output: torch.Tensor, + inp_split_sizes: list[int], + all_gather_input_numel: int, + rank: int, +) -> tuple[torch.Tensor, torch.Tensor]: + all_gather_input = all_gather_output.narrow( + 0, all_gather_input_numel * rank, all_gather_input_numel + ) + foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) + with torch.no_grad(): + torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs) + return all_gather_input, all_gather_output + + +lib.define( + "split_with_sizes_copy(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()" +) + + +@torch.library.impl(lib, "split_with_sizes_copy", "Meta") +@torch.library.impl(lib, "split_with_sizes_copy", "CUDA") +@torch.library.impl(lib, "split_with_sizes_copy", "XPU") +@torch.library.impl(lib, "split_with_sizes_copy", "HPU") +@torch.library.impl(lib, "split_with_sizes_copy", "CPU") +@torch.library.impl(lib, "split_with_sizes_copy", "MTIA") +@torch.library.impl(lib, "split_with_sizes_copy", "PrivateUse1") +def split_with_sizes_copy( + all_gather_output: torch.Tensor, + all_gather_input_split_sizes: list[int], + dim: int, + out: list[torch.Tensor], +) -> None: + torch.split_with_sizes_copy( + all_gather_output, all_gather_input_split_sizes, dim=dim, out=out + ) + + +lib.define( + "chunk_cat(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> ()" +) + + +@torch.library.impl(lib, "chunk_cat", "Meta") +@torch.library.impl(lib, "chunk_cat", "CUDA") +@torch.library.impl(lib, "chunk_cat", "XPU") +@torch.library.impl(lib, "chunk_cat", "HPU") +@torch.library.impl(lib, "chunk_cat", "CPU") +@torch.library.impl(lib, "chunk_cat", "MTIA") +@torch.library.impl(lib, "chunk_cat", "PrivateUse1") +def chunk_cat( + tensors: list[torch.Tensor], + dim: int, + num_chunks: int, + out: torch.Tensor, +) -> None: + torch._chunk_cat(tensors, dim, num_chunks, out=out) + + +@torch.no_grad() +def foreach_all_gather( + fsdp_params: list[FSDPParam], + group: dist.ProcessGroup, + async_op: bool, + all_gather_copy_in_stream: torch.Stream, + all_gather_stream: torch.Stream, + device: torch.device, + all_gather_comm: AllGather, +) -> Optional[AllGatherResult]: + world_size, rank = group.size(), group.rank() + device_handle = _get_device_handle(device.type) + with device_handle.stream(all_gather_copy_in_stream): + param_all_gather_inputs = _get_param_all_gather_inputs(fsdp_params) + ( + param_all_gather_input_dtypes, + param_all_gather_input_numels, + dtype, + ) = _get_all_gather_input_metadatas(param_all_gather_inputs) + if dtype == torch.uint8: + all_gather_inputs = [ + t.view(torch.uint8) for ts in param_all_gather_inputs for t in ts + ] + else: + all_gather_inputs = [*chain.from_iterable(param_all_gather_inputs)] + inp_split_sizes = [t.numel() for t in all_gather_inputs] + all_gather_input_numel = sum(inp_split_sizes) + all_gather_output = all_gather_comm.allocate( + (all_gather_input_numel * world_size,), dtype=dtype, device=device + ) + all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in( + all_gather_inputs, + all_gather_output, + inp_split_sizes, + all_gather_input_numel, + rank, + ) + del param_all_gather_inputs + all_gather_stream.wait_stream(all_gather_copy_in_stream) + with device_handle.stream(all_gather_stream): + all_gather_work = all_gather_comm( + output_tensor=all_gather_output, + input_tensor=all_gather_input, + group=group, + async_op=async_op, + ) + all_gather_event = all_gather_stream.record_event() + return AllGatherResult( + all_gather_output, + all_gather_event, + all_gather_work, + param_all_gather_input_dtypes, + param_all_gather_input_numels, + inp_split_sizes, + ) + + +@torch.no_grad() +def _get_param_all_gather_inputs( + fsdp_params: list[FSDPParam], +) -> list[list[torch.Tensor]]: + if compiled_autograd_enabled(): + return [fsdp_param.all_gather_inputs for fsdp_param in fsdp_params] + + # Intentionally try to run a fast-path that bypasses abstractions for the + # common FSDP case of bf16/fp32 mixed precision in order to use foreach + # copy for lower CPU overhead and more efficient copying in eager + def use_foreach_copy(fsdp_param: FSDPParam) -> bool: + return ( + fsdp_param.param_dtype is not None + and not fsdp_param.offload_to_cpu + and not hasattr(fsdp_param._sharded_local_tensor, "fsdp_pre_all_gather") + ) + + param_all_gather_inputs: list[list[torch.Tensor]] = [[] for _ in fsdp_params] + foreach_copy_indices: list[int] = [] + foreach_copy_inputs: list[torch.Tensor] = [] + foreach_copy_input_numels: list[int] = [] + + # 1st pass: for foreach-copy parameters, get inputs and metadata for the + # foreach copy, and for the others, actually get their all-gather inputs + for i, fsdp_param in enumerate(fsdp_params): + if use_foreach_copy(fsdp_param): + foreach_copy_indices.append(i) + all_gather_input = ( + fsdp_param._sharded_param_data + if fsdp_param.sharded_state == ShardedState.SHARDED + else cast(torch.Tensor, fsdp_param._sharded_post_forward_param_data) + ) + foreach_copy_inputs.append(all_gather_input) + foreach_copy_input_numels.append(all_gather_input.numel()) + else: + param_all_gather_inputs[i] = fsdp_param.all_gather_inputs + + # 2nd pass: use foreach copy to compute the remaining all-gather inputs + if foreach_copy_inputs: + fsdp_param_0 = fsdp_params[foreach_copy_indices[0]] + param_dtype, device = fsdp_param_0.param_dtype, fsdp_param_0.device + flat_foreach_copy_input = torch.empty( + (sum(foreach_copy_input_numels),), device=device, dtype=param_dtype + ) + splits = torch.split(flat_foreach_copy_input, foreach_copy_input_numels) + torch._foreach_copy_(splits, foreach_copy_inputs) + for i, split in zip(foreach_copy_indices, splits): + param_all_gather_inputs[i] = [split] + + return param_all_gather_inputs + + +@torch.no_grad() +def foreach_all_gather_copy_out( + all_gather_result: AllGatherResult, + fsdp_params: list[FSDPParam], + group: dist.ProcessGroup, +) -> None: + ( + all_gather_output, + all_gather_event, + all_gather_work, + param_all_gather_input_dtypes, + param_all_gather_input_numels, + all_gather_input_split_sizes, + ) = all_gather_result + _dtype, device = all_gather_output.dtype, all_gather_output.device + device_handle = _get_device_handle(device.type) + if all_gather_event is not None: # sync op + device_handle.current_stream().wait_event(all_gather_event) + if isinstance(all_gather_work, dist.distributed_c10d.Work): # async op + all_gather_work.wait() + world_size, device = group.size(), all_gather_output.device + + split_with_sizes_out: list[torch.Tensor] = [] + shard_i_copy_infos: list[tuple[FSDPParam, list[torch.Tensor]]] = [] + for all_gather_input_numels, all_gather_input_dtypes, fsdp_param in zip( + param_all_gather_input_numels, param_all_gather_input_dtypes, fsdp_params + ): + # NOTE: Under compile, make sure we always recreate all_gather_outputs + # per AllGather. See [Note: Invariants for torch.compile Traceable FSDP2]. + force_recreate = compiled_autograd_enabled() + fsdp_param.init_all_gather_outputs( + all_gather_input_numels, + all_gather_input_dtypes, + world_size, + device, + force_recreate=force_recreate, + ) + if not force_recreate: + fsdp_param.alloc_all_gather_outputs() + param_all_gather_outputs = fsdp_param.all_gather_outputs + if fsdp_param.fsdp_placement.dim != 0: + # Copy to a temporary and then chunk-cat into the final all-gather + # output tensors + param_all_gather_outputs = [ + torch.empty_like(t) for t in param_all_gather_outputs + ] + shard_i_copy_infos.append((fsdp_param, param_all_gather_outputs)) + split_with_sizes_out.extend(param_all_gather_outputs) + + all_gather_output = all_gather_output.view(world_size, -1) + if all_gather_output.dtype == torch.uint8: + out = [t.view(world_size, -1).view(torch.uint8) for t in split_with_sizes_out] + else: + out = [t.view(world_size, -1) for t in split_with_sizes_out] + + # only avoid VC bump if we are not in inference mode + if torch._dynamo.is_compiling(): + # For torch.compile, we turn off inference_mode for fake tensor + # propagation, and therefore graph break on is_inference. For `compile`, + # we don't care about VCs, so just skip the optimization. + non_inference_outs = [] + else: + non_inference_outs = [o for o in out if not o.is_inference()] + + if len(non_inference_outs) > 0: + with torch.autograd._unsafe_preserve_version_counter(tuple(non_inference_outs)): + torch.ops.fsdp.split_with_sizes_copy( + all_gather_output, all_gather_input_split_sizes, dim=1, out=out + ) + else: + torch.ops.fsdp.split_with_sizes_copy( + all_gather_output, all_gather_input_split_sizes, dim=1, out=out + ) + + for fsdp_param, param_all_gather_outputs in shard_i_copy_infos: + # Chunk-cat from the temporary to the final all-gather output tensors + shard_dim = fsdp_param.fsdp_placement.dim + + with torch.autograd._unsafe_preserve_version_counter( + tuple(fsdp_param.all_gather_outputs) + ): + for param_all_gather_output, target_all_gather_output in zip( + param_all_gather_outputs, fsdp_param.all_gather_outputs + ): + padded_sharded_size = ( + fsdp_param.padded_sharded_param_size + if fsdp_param.sharded_state == ShardedState.SHARDED + else cast( + torch.Tensor, fsdp_param._sharded_post_forward_param_data + ).size() + ) + pre_param_size = list(padded_sharded_size) + pre_param_size[0] *= world_size + chunks = torch.chunk( + param_all_gather_output.view(pre_param_size), world_size, dim=0 + ) + post_param_size = list(padded_sharded_size) + post_param_size[shard_dim] *= world_size + cat_out = target_all_gather_output.view(post_param_size) + torch.cat(chunks, dim=shard_dim, out=cat_out) + + +@torch.no_grad() +def foreach_reduce( + fsdp_params: list[FSDPParam], + unsharded_grads: list[torch.Tensor], + reduce_scatter_group: dist.ProcessGroup, + reduce_scatter_stream: torch.Stream, + reduce_scatter_comm: ReduceScatter, + orig_dtype: Optional[torch.dtype], + reduce_dtype: Optional[torch.dtype], + device: torch.device, + gradient_divide_factor: Optional[float], + all_reduce_group: Optional[dist.ProcessGroup], # not `None` iff HSDP + all_reduce_stream: torch.Stream, + all_reduce_grads: bool, + partial_reduce_output: Optional[torch.Tensor], # only used for HSDP + all_reduce_hook: Optional[Callable[[torch.Tensor], None]], + force_sum_reduction_for_comms: bool = False, +) -> tuple[ + torch.Tensor, + torch.Event, + torch.Event, + Optional[torch.Tensor], + Optional[torch.Event], + Optional[torch.Tensor], +]: + """ + ``unsharded_grads`` owns the references to the gradients computed by + autograd, so clearing the list frees the gradients. + """ + + grad_dtypes = {grad.dtype for grad in unsharded_grads} + if len(grad_dtypes) != 1: + # Check this at runtime since it could be a real runtime error if e.g. + # fp8 weights do not produce the correct higher precision gradients + _raise_assert_with_print( + f"FSDP reduce-scatter expects uniform gradient dtype but got {grad_dtypes}" + ) + grad_dtype = unsharded_grads[0].dtype + reduce_dtype = reduce_dtype or grad_dtype + (predivide_factor, postdivide_factor, reduce_scatter_op, all_reduce_op) = ( + _get_gradient_divide_factors( + reduce_scatter_group, + all_reduce_group, + reduce_dtype, + device.type, + gradient_divide_factor, + force_sum_reduction_for_comms, + ) + ) + + if reduce_scatter_group is None: + world_size = 1 + else: + world_size = reduce_scatter_group.size() + device_handle = _get_device_handle(device.type) + current_stream = device_handle.current_stream() + + if world_size > 1: + for i, (fsdp_param, unsharded_grad) in enumerate( + zip(fsdp_params, unsharded_grads) + ): + if (shard_dim := fsdp_param.fsdp_placement.dim) == 0: + continue + if unsharded_grad.size(shard_dim) % world_size != 0: + raise AssertionError( + f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}" + ) + chunks = torch.chunk(unsharded_grad, world_size, dim=shard_dim) + unsharded_grads[i] = torch.cat(chunks, dim=0) + + padded_unsharded_sizes = tuple( + _get_dim0_padded_size(grad.size(), world_size) for grad in unsharded_grads + ) + reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes) + reduce_scatter_output_numel = reduce_scatter_input_numel // world_size + reduce_scatter_input = reduce_scatter_comm.allocate( + (reduce_scatter_input_numel,), + dtype=reduce_dtype, + device=device, + ) + + foreach_reduce_scatter_copy_in(unsharded_grads, reduce_scatter_input, world_size) + + # Only after the copy-in finishes can we free the gradients + unsharded_grads.clear() + reduce_scatter_stream.wait_stream(current_stream) + all_reduce_input = None + all_reduce_event = None + + with device_handle.stream(reduce_scatter_stream): + reduce_output = reduce_scatter_comm.allocate( + (reduce_scatter_output_numel,), + dtype=reduce_dtype, + device=device, + ) + _div_if_needed(reduce_scatter_input, predivide_factor) + if world_size > 1: + reduce_scatter_comm( + output_tensor=reduce_output, + input_tensor=reduce_scatter_input, + group=reduce_scatter_group, + op=reduce_scatter_op, + ) + else: + # For single GPU, just copy the input to output (no actual reduce-scatter needed), and + # account for a possible gradient_divide_factor. + if gradient_divide_factor is not None: + reduce_output.copy_(reduce_scatter_input / gradient_divide_factor) + else: + reduce_output.copy_(reduce_scatter_input) + reduce_scatter_event = reduce_scatter_stream.record_event() + post_reduce_stream = reduce_scatter_stream + if all_reduce_group is not None: # HSDP or DDP/replicate + # Accumulations must run in the reduce-scatter stream + if not all_reduce_grads: + if partial_reduce_output is not None: + partial_reduce_output += reduce_output + else: + partial_reduce_output = reduce_output + return ( + reduce_scatter_input, + reduce_scatter_event, + post_reduce_stream.record_event(), + all_reduce_input, + all_reduce_event, + partial_reduce_output, + ) + if partial_reduce_output is not None: + reduce_output += partial_reduce_output + post_reduce_stream = all_reduce_stream + if world_size >= 1: + all_reduce_stream.wait_stream(reduce_scatter_stream) + else: + all_reduce_stream.wait_stream(current_stream) + with device_handle.stream(all_reduce_stream): + dist.all_reduce( + reduce_output, + group=all_reduce_group, + op=all_reduce_op, + ) + all_reduce_input = reduce_output + all_reduce_event = all_reduce_stream.record_event() + # -- END: ops in reduce_scatter stream + + if all_reduce_hook is not None: + # Execute user-specified all reduce hook. + # If native HSDP is used, this is executed after the HSDP all reduce. + # If 1-d FSDP is used, this is executed post reduce-scatter. + post_reduce_stream = all_reduce_stream + all_reduce_stream.wait_stream(reduce_scatter_stream) + with device_handle.stream(all_reduce_stream): + all_reduce_hook(reduce_output) + # -- END: ops post reduce_scatter + + with device_handle.stream(post_reduce_stream): + _div_if_needed(reduce_output, postdivide_factor) + reduce_output = _to_dtype_if_needed(reduce_output, orig_dtype) + # View out and accumulate sharded gradients + flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1] + for padded_unsharded_size, fsdp_param in zip( + padded_unsharded_sizes, fsdp_params + ): + # Assume even sharding for Shard(i), i > 0; otherwise would require + # copy-out for contiguous strides + new_sharded_grad = torch.as_strided( + reduce_output, + size=fsdp_param.sharded_size, + stride=fsdp_param.contiguous_sharded_stride, + storage_offset=flat_grad_offset, + ) + to_accumulate_grad = fsdp_param.sharded_param.grad is not None + if fsdp_param.offload_to_cpu: + # Only overlap the D2H copy (copying to pinned memory) if not + # accumulating gradients since the CPU add kernel depends on + # the copy result and we cannot run the add as a callback + non_blocking = fsdp_param.pin_memory and not to_accumulate_grad + # Since the GPU sharded gradient is allocated in the RS stream, + # we can free it here by not keeping a ref without waiting for + # the D2H copy since future RS-stream ops run after the copy + new_sharded_grad = new_sharded_grad.to( + torch.device("cpu"), non_blocking=non_blocking + ) + if non_blocking: + # Record an event on which to block the CPU thread to + # ensure that the D2H copy finishes before the optimizer + fsdp_param.grad_offload_event = post_reduce_stream.record_event() + if to_accumulate_grad: + if not isinstance(fsdp_param.sharded_param.grad, DTensor): + raise AssertionError( + f"Expected fsdp_param.sharded_param.grad to be DTensor, got {type(fsdp_param.sharded_param.grad)}" + ) + fsdp_param.sharded_param.grad._local_tensor += new_sharded_grad + else: + new_sharded_dtensor_grad = fsdp_param.to_sharded_dtensor( + new_sharded_grad + ) + fsdp_param.sharded_param.grad = new_sharded_dtensor_grad + if not compiled_autograd_enabled(): + for hook in ( + getattr(fsdp_param.sharded_param, "_post_accumulate_grad_hooks", {}) + or {} + ).values(): + hook(fsdp_param.sharded_param) + padded_sharded_numel = padded_unsharded_size.numel() // world_size + flat_grad_offset += padded_sharded_numel + post_reduce_event = post_reduce_stream.record_event() + # The RS output is allocated in the RS stream and used in the default + # stream (for optimizer). To ensure its memory is not reused for later + # RSs, we do not need extra synchronization since the sharded parameters + # hold refs through the end of backward. + return ( + reduce_scatter_input, + reduce_scatter_event, + post_reduce_event, + all_reduce_input, + all_reduce_event, + None, + ) + + +def foreach_reduce_scatter_copy_in( + unsharded_grads: list[torch.Tensor], + reduce_scatter_input: torch.Tensor, + world_size: int, +) -> None: + reduce_scatter_input = reduce_scatter_input.view(world_size, -1) + torch.ops.fsdp.chunk_cat( + unsharded_grads, dim=0, num_chunks=world_size, out=reduce_scatter_input + ) + + +def _get_all_gather_input_metadatas( + param_all_gather_inputs: list[list[torch.Tensor]], +) -> tuple[list[list[torch.dtype]], list[list[int]], torch.dtype]: + param_all_gather_input_dtypes: list[list[torch.dtype]] = [] + param_all_gather_input_numels: list[list[int]] = [] + all_gather_dtype = param_all_gather_inputs[0][0].dtype + for all_gather_inputs in param_all_gather_inputs: + input_dtypes: list[torch.dtype] = [] + input_numels: list[int] = [] + for all_gather_input in all_gather_inputs: + if all_gather_input.dtype != all_gather_dtype: + all_gather_dtype = torch.uint8 + input_dtypes.append(all_gather_input.dtype) + input_numels.append(all_gather_input.numel()) + param_all_gather_input_dtypes.append(input_dtypes) + param_all_gather_input_numels.append(input_numels) + return ( + param_all_gather_input_dtypes, + param_all_gather_input_numels, + all_gather_dtype, + ) + + +def _get_gradient_divide_factors( + reduce_scatter_group: Optional[dist.ProcessGroup], + all_reduce_group: Optional[dist.ProcessGroup], + reduce_dtype: torch.dtype, + device_type: str = "", + factor: Optional[float] = None, + force_sum_reduction_for_comms: bool = False, +) -> tuple[ + Optional[float], + Optional[float], + Union[dist.ReduceOp, dist.ReduceOp.RedOpType], + Union[dist.ReduceOp, dist.ReduceOp.RedOpType], +]: + # MTIA appears to only support SUM reduction, hence we force it implicitly + if device_type == "mtia": + force_sum_reduction_for_comms = True + + # For fp32/bf16, we do not need to worry about overflow/underflow, so we + # use NCCL's built-in division to avoid separate div kernels + overflow_risk = reduce_dtype not in (torch.float32, torch.bfloat16) + if reduce_scatter_group is not None: + data_parallel_size = reduce_scatter_group.size() + else: + data_parallel_size = 1 + + if all_reduce_group is not None: + data_parallel_size *= all_reduce_group.size() + + if not overflow_risk and not force_sum_reduction_for_comms: + if factor is None: + # Warning: NCCL ReduceOp.AVG may produce incorrect results with + # world size 1. + if data_parallel_size == 1: + return None, None, ReduceOp.SUM, ReduceOp.SUM + return None, None, ReduceOp.AVG, ReduceOp.AVG + if reduce_scatter_group is not None and factor == reduce_scatter_group.size(): + reduce_scatter_op = ReduceOp.AVG + else: + reduce_scatter_op = torch.distributed._make_nccl_premul_sum(1 / factor) + return None, None, reduce_scatter_op, ReduceOp.SUM + + if factor is None: + factor = float(data_parallel_size) + pre_factor: Optional[float] + if overflow_risk: + # Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid + # overflow/underflow. For N data parallel workers, each worker computes + # g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid + # overflow/underflow, we divide by ~sqrt(N) before/after the reduction. + pre_factor = 1 + while factor % pre_factor == 0 and factor / pre_factor > pre_factor: + pre_factor *= 2 + post_factor = factor / pre_factor + else: + # Prefer post-multiplying as it operates on less data and is thus faster + pre_factor, post_factor = None, factor + + return pre_factor, post_factor, ReduceOp.SUM, ReduceOp.SUM + + +def _div_if_needed(tensor: torch.Tensor, div_factor: Optional[float]) -> None: + if div_factor is not None and div_factor != 1: + tensor.div_(div_factor) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_common.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_common.py new file mode 100644 index 0000000000000000000000000000000000000000..85addad83b3b08cbed358f3eb31b2bf4f2a2c9e8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_common.py @@ -0,0 +1,181 @@ +# mypy: allow-untyped-defs +import math +import traceback +from dataclasses import dataclass +from enum import auto, Enum +from typing import Any, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._composable.contract import _get_registry +from torch.distributed.tensor import DeviceMesh, DTensor +from torch.distributed.tensor._dtensor_spec import DTensorSpec + + +_compiled_autograd_enabled: bool = False + + +def detect_compiled_autograd(): + if torch.compiler.is_compiling(): + raise AssertionError( + "`detect_compiled_autograd()` is designed to be called in eager mode" + ) + global _compiled_autograd_enabled + import torch._dynamo.compiled_autograd as ca + + _compiled_autograd_enabled = ( + ca.compiled_autograd_enabled + or ca.compiled_autograd_enabled_force_eager + or ca.in_compiled_autograd_region + ) + + +def compiled_autograd_enabled(): + global _compiled_autograd_enabled + return _compiled_autograd_enabled + + +@dataclass +class DataParallelMeshInfo: + mesh: DeviceMesh + shard_mesh_dim: Optional[int] = None + replicate_mesh_dim: Optional[int] = None + + def __post_init__(self): + if self.shard_mesh_dim is None and self.replicate_mesh_dim is None: + raise AssertionError( + "At least one of shard_mesh_dim and replicate_mesh_dim must not be None" + ) + + +@dataclass +class FSDPMeshInfo(DataParallelMeshInfo): + def __post_init__(self): + super().__post_init__() + if self.shard_mesh_dim is None: + raise AssertionError("Expects non-None shard_mesh_dim") + self.shard_mesh_size: int = self.mesh.size(self.shard_mesh_dim) + self.shard_process_group = self.mesh.get_group(self.shard_mesh_dim) + self.shard_mesh_rank: int = self.shard_process_group.rank() + + +@dataclass +class DDPMeshInfo(DataParallelMeshInfo): + def __post_init__(self): + super().__post_init__() + if self.replicate_mesh_dim is None: + raise AssertionError("Expects non-None replicate_mesh_dim") + self.replicate_mesh_size: int = self.mesh.size(self.replicate_mesh_dim) + self.replicate_process_group = self.mesh.get_group(self.replicate_mesh_dim) + self.replicate_mesh_rank: int = self.replicate_process_group.rank() + + +@dataclass +class HSDPMeshInfo(FSDPMeshInfo, DDPMeshInfo): + def __post_init__(self): # pylint:disable=useless-parent-delegation + # Calls `FSDPMeshInfo` -> `DDPMeshInfo` -> `DataParallelMeshInfo` + super().__post_init__() + + +class TrainingState(Enum): + """Describes the training state of one FSDP state / parameter group.""" + + # Transition to forward starting pre-forward until post-forward + FORWARD = auto() + # Transition to pre-backward when unsharding in backward + PRE_BACKWARD = auto() + # Transition to post-backward when resharding and reducing gradients + POST_BACKWARD = auto() + # Idle before/after forward or before pre-backward/after post-backward + IDLE = auto() + + +def _raise_assert_with_print(*args: Any, **kwargs: Any): + print(f"[Rank {dist.get_rank()}] ", end="") + print(*args, **kwargs) + traceback.print_stack() + raise AssertionError(*args, **kwargs) + + +def _is_composable_with_fsdp(module: nn.Module) -> bool: + registry = _get_registry(module) + if registry is None: + return True + # Registry keys by function name + return "replicate" not in registry + + +def _get_dim0_padded_size(tensor_size: torch.Size, dim0_factor: int) -> torch.Size: + padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor + return torch.Size([padded_dim0]) + tensor_size[1:] + + +def _chunk_with_empty( + tensor: torch.Tensor, num_chunks: int, dim: int +) -> list[torch.Tensor]: + chunks = list(torch.chunk(tensor, num_chunks, dim=dim)) + while len(chunks) < num_chunks: + chunks.append(chunks[0].new_empty(0)) + return chunks + + +def _get_dim_chunked_size( + chunk: torch.Tensor, unchunked_size: torch.Size, dim: int +) -> torch.Size: + if chunk.numel() > 0: + return chunk.size() + # For 0 numel, we need to preserve nonzero-sized dims for DTensor APIs + return unchunked_size[:dim] + torch.Size([0]) + unchunked_size[dim + 1 :] + + +def _from_local_no_grad( + local_tensor: torch.Tensor, + sharding_spec: DTensorSpec, +) -> DTensor: + """ + This method is similar to ``DTensor.from_local()`` except that in eager mode + it avoids some CPU overhead by avoiding default args and not being differentiable. + """ + + if not compiled_autograd_enabled(): + # pyrefly: ignore [bad-argument-type] + return DTensor( + # Use the local tensor directly instead of constructing a new tensor + # variable, e.g. with `view_as()`, since this is not differentiable + # pyrefly: ignore [bad-argument-count] + local_tensor, + sharding_spec, + # pyrefly: ignore [unexpected-keyword] + requires_grad=local_tensor.requires_grad, + ) + else: + return DTensor.from_local( + local_tensor, + sharding_spec.mesh, + sharding_spec.placements, + shape=sharding_spec.shape, + stride=sharding_spec.stride, + ) + + +def _to_dtype_if_needed( + tensor: torch.Tensor, dtype: Optional[torch.dtype] +) -> torch.Tensor: + if dtype is not None and tensor.dtype != dtype: + return tensor.to(dtype) + return tensor + + +def _cast_fp_tensor(dtype: torch.dtype, x: torch.Tensor) -> torch.Tensor: + if ( + not isinstance(x, torch.Tensor) + or not torch.is_floating_point(x) + or x.dtype == dtype + ): + return x + return x.to(dtype) + + +def is_bw() -> bool: + return torch._C._current_graph_task_id() != -1 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_init.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_init.py new file mode 100644 index 0000000000000000000000000000000000000000..01d196795c3d8f9270138f757b3e7f3de9e10f11 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_init.py @@ -0,0 +1,243 @@ +import itertools +import logging +from typing import Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch._logging import warning_once +from torch.distributed.device_mesh import _get_device_handle +from torch.distributed.tensor import DeviceMesh, DTensor, init_device_mesh +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo +from ._fsdp_state import _get_module_fsdp_state + + +logger = logging.getLogger("torch.distributed.fsdp.fully_shard") + + +def _get_post_forward_mesh_info( + reshard_after_forward: Union[bool, int], mesh_info: FSDPMeshInfo +) -> Optional[FSDPMeshInfo]: + shard_mesh_size = mesh_info.shard_mesh_size + if not isinstance(reshard_after_forward, (bool, int)): + raise ValueError( + "reshard_after_forward should be a bool or an int representing the " + f"group size to reshard to, not {reshard_after_forward}" + ) + # NOTE: `isinstance(False, int)` returns `True`. + if not isinstance(reshard_after_forward, bool) and isinstance( + reshard_after_forward, int + ): + if ( + reshard_after_forward < 1 + or reshard_after_forward > shard_mesh_size + or shard_mesh_size % reshard_after_forward != 0 + ): + raise ValueError( + "If passing reshard_after_forward as an int, it should be a " + f"factor of {shard_mesh_size}, not {reshard_after_forward}" + ) + elif reshard_after_forward == 1: + msg = ( + "reshard_after_forward=1 (int) means resharding parameters to world size 1, " + "instead of reshard_after_forward=True (bool)" + ) + warning_once(logger, msg, stacklevel=2) + reshard_after_forward = False + elif reshard_after_forward == shard_mesh_size: + reshard_after_forward = True + post_forward_mesh_info = None + if reshard_after_forward is True: + post_forward_mesh_info = mesh_info + elif reshard_after_forward is not False: # int case + # For HSDP, we can flatten the two replicate dims into the 0th dim + post_forward_mesh_tensor = mesh_info.mesh.mesh.view(-1, reshard_after_forward) + post_forward_mesh = DeviceMesh( + mesh_info.mesh.device_type, post_forward_mesh_tensor + ) + post_forward_mesh_info = HSDPMeshInfo( + post_forward_mesh, shard_mesh_dim=1, replicate_mesh_dim=0 + ) + return post_forward_mesh_info + + +def _init_default_fully_shard_mesh() -> DeviceMesh: + """Default to global CUDA mesh if possible else global CPU mesh.""" + if not dist.distributed_c10d.is_initialized(): + dist.distributed_c10d.init_process_group() + default_pg = dist.distributed_c10d._get_default_group() + device = torch._C._get_accelerator() + mesh = init_device_mesh(device.type, mesh_shape=(default_pg.size(),)) + return mesh + + +def _get_device_from_mesh(mesh: DeviceMesh) -> torch.device: + if mesh.device_type == "cpu": + return torch.device("cpu") + device_handle = _get_device_handle(mesh.device_type) + return torch.device(mesh.device_type, device_handle.current_device()) + + +def _ignore_module( + module: nn.Module, + ignored_params: set[nn.Parameter], + ignore_decision: dict[nn.Module, bool], +) -> bool: + """ + Decide if it is safe to ignore a module for applying fully_shard. + """ + if module in ignore_decision: + return ignore_decision[module] + + if len(list(module.buffers(recurse=False))) > 0: + # Cannot ignore a module with any buffer + ignore_decision[module] = False + return False + + for _, param in module.named_parameters(recurse=False): + if param not in ignored_params: + # at least one param is not ignored. So this module shouldn't be. + ignore_decision[module] = False + return False + + # Need to consider descendants of module + for child in list(module.children()): + ignore_child = _ignore_module(child, ignored_params, ignore_decision) + if not ignore_child: + # Cannot ignore module if one of its children is not ignored + ignore_decision[module] = False + return False + + # Safe to ignore module + ignore_decision[module] = True + return True + + +def _adjust_managed_modules( + modules: list[nn.Module], ignored_params: set[nn.Parameter] +) -> list[nn.Module]: + """ + Adjust the given list of managed modules by removing those with all parameters ignored. + """ + ignore_decision: dict[nn.Module, bool] = {} + new_modules = [] + for module in modules: + ignored = _ignore_module(module, ignored_params, ignore_decision) + if not ignored: + new_modules.append(module) + return new_modules + + +def _get_managed_modules( + root_modules: tuple[nn.Module, ...], + ignored_params: Optional[set[nn.Parameter]] = None, +) -> list[nn.Module]: + modules: list[nn.Module] = [] + root_modules_set = set(root_modules) + # Track visisted modules to avoid visiting shared modules multiple times + visited_modules: set[nn.Module] = set() + + def dfs(module: nn.Module) -> None: + """ + Runs a DFS to collect managed modules, not recursing into modules with + a non-composable API or ``fully_shard`` already applied. + """ + if not _is_composable_with_fsdp(module): + return + elif ( + module not in root_modules_set + and _get_module_fsdp_state(module) is not None + ): + return # nested `fully_shard` module + visited_modules.add(module) + for submodule in module.children(): + if submodule not in visited_modules: + dfs(submodule) + modules.append(module) + + for root_module in root_modules: + dfs(root_module) + + if ignored_params is None: + return modules + + adjusted_modules = _adjust_managed_modules(modules, ignored_params) + return adjusted_modules + + +def _verify_managed_param(name: str, param: nn.Parameter) -> None: + """ + Verify if the parameter is accepted by fully_shard. The only restriction now + is that the parameter cannot be a scalar tensor (param.numel == 0) since we + need at least one dim to shard. + """ + if len(param.shape) == 0: + raise ValueError( + "fully_shard doesn't support scalar parameters. " + f"Change {name} to a 1D tensor with numel equal to 1." + ) + + +def _get_managed_states( + modules: list[nn.Module], ignored_params: Optional[set[nn.Parameter]] = None +) -> tuple[list[nn.Parameter], list[torch.Tensor]]: + params: list[nn.Parameter] = [] + buffers: list[torch.Tensor] = [] + # Track visited parameters/buffers to avoid visiting shared parameters and + # buffers multiple times + visited_params: set[nn.Parameter] = set() + visited_buffers: set[torch.Tensor] = set() + if ignored_params is None: + ignored_params = set() + + for module in modules: + for name, param in module.named_parameters(recurse=False): + if param in ignored_params: + # do not include an ignored parameters + continue + if param not in visited_params: + _verify_managed_param(name, param) + params.append(param) + visited_params.add(param) + for buffer in module.buffers(recurse=False): + if buffer not in visited_buffers: + buffers.append(buffer) + visited_buffers.add(buffer) + return params, buffers + + +def _move_states_to_device( + params: list[nn.Parameter], + buffers: list[torch.Tensor], + device: torch.device, +) -> None: + """ + We have FSDP move states to device for simpler and faster initialization + since FSDP almost always uses CUDA for training. We move parameters/buffers + rather than modules since modules to support ignoring parameters/buffers in + the future. + """ + # Follow the logic in `nn.Module._apply` + # pyrefly: ignore [bad-argument-type] + for tensor in itertools.chain(params, buffers): + if tensor.device == device or tensor.device.type == "meta": + # Keep meta-device tensors on meta device for deferred init + continue + if isinstance(tensor, DTensor): + if (dtensor_mesh_type := tensor.device_mesh.device_type) != device.type: + raise ValueError( + "Requires DTensor to have mesh of the same type as the FSDP mesh " + f"but got {dtensor_mesh_type} for DTensor and {device.type} for FSDP" + ) + raise AssertionError( + f"Expects DTensor to be moved to {dtensor_mesh_type} but got {tensor.device}" + ) + tensor_ = tensor + if is_traceable_wrapper_subclass(tensor_): + with torch.no_grad(): # avoid autograd increasing C++ refcount by 1 + tensor_on_device = nn.Parameter(tensor.to(device)) + torch.utils.swap_tensors(tensor, tensor_on_device) + else: + tensor.data = tensor.to(device) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py new file mode 100644 index 0000000000000000000000000000000000000000..476fbd94928947bc95cf13eab10b85d76e554164 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py @@ -0,0 +1,966 @@ +# mypy: allow-untyped-defs +import inspect +import itertools +from collections.abc import Callable, Sequence +from dataclasses import dataclass, field +from enum import auto, Enum +from typing import Any, cast, Optional + +import torch +import torch.nn as nn +from torch._prims_common import make_contiguous_strides_for +from torch.distributed._functional_collectives import AsyncCollectiveTensor +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp._fully_shard._fsdp_common import DDPMeshInfo +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor.placement_types import _StridedShard, Placement + +from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy +from ._fsdp_common import ( + _chunk_with_empty, + _from_local_no_grad, + _get_dim_chunked_size, + _raise_assert_with_print, + _to_dtype_if_needed, + compiled_autograd_enabled, + FSDPMeshInfo, + HSDPMeshInfo, +) + + +""" +[Note: FSDP tensors] +FSDP considers the following tensors: +- Original parameter: parameter passed to :class:`FSDPParam`, i.e. the one + on the module when applying FSDP +- Sharded parameter: sharding the original parameter on dim-0 (or a + user-specified dim) as a DTensor over the main mesh +- All-gather inputs: the ``torch.Tensor`` or ``Tensor`` s passed to all-gather, + derived from the sharded parameter +- All-gather output: the ``torch.Tensor`` or ``Tensor`` s resulting from + all-gathering the all-gather inputs +- Unsharded parameter: parameter used for forward/backward computation, derived + from the all-gather output; autograd leaf + +We define these tensors to describe the general framework that can accommodate +extensions, where: +- all-gather-inputs = pre-all-gather-transform(sharded-parameter) +- unsharded-parameter = post-all-gather-transform(all-gather-outputs) + +For the default ``torch.Tensor`` case, there is only one all-gather input, and +it shares the same underlying tensor data as the sharded parameter, meaning +that they can be thought of as the same tensors. The same applies for the +all-gather output and unsharded parameter. For non-``torch.Tensor`` extensions, +these equivalences may no longer hold due to the pre/post-all-gather +transforms, and some may have multiple all-gather inputs/outputs (e.g. +quantized data and scales). + +[Note: FSDP and autograd] +FSDP dynamically frees and allocates the unsharded parameter. Since autograd +can pack a reference to it or a view to save for backward, we use storage +resizing to implement the freeing/allocation since that preserves the aliasing. +This implies that we construct the unsharded parameter object once and write to +it in-place thereafter. For the default ``torch.Tensor` original parameter +case, the all-gather output and unsharded parameter share the same +data, so we use storage resizing on the all-gather output. +""" + +lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901 + +lib.define("copy_(Tensor(a!) tensor, Tensor data) -> ()") + + +@torch.library.impl(lib, "copy_", "Meta") +@torch.library.impl(lib, "copy_", "CUDA") +@torch.library.impl(lib, "copy_", "XPU") +@torch.library.impl(lib, "copy_", "HPU") +@torch.library.impl(lib, "copy_", "CPU") +@torch.library.impl(lib, "copy_", "MTIA") +def copy_(tensor, data): + tensor.copy_(data) + + +""" +[Note: Avoiding functionalization for fsdp.copy_ and inductor.resize_storage_bytes_] + +Currently we don't functionalize `fsdp.copy_` op or `inductor.resize_storage_bytes_` op +(i.e. they show up as a mutation op in the middle of the AOT joint graph). + +Reason: +Traceable FSDP2 compiled autograd BWD graph have the following traits: +(1) Two inputs of the graph were aliased to each other (one from hook closed-over tensors, one from FWD saved tensors). +(2) One of them is mutated (copy_ and resize_ to handle the all-gathered param). +(3) They are both subclasses. +The combination of these traits is not supported by AOTAutograd (it's difficult to reason about subclass aliasing). +So this doesn't work at all for Traceable FSDP2. + +The compromise we use is to avoid functionalization for the FSDP2 copy_ and resize_ ops. +This avoids the problem above, because from AOTAutograd point-of-view there are no mutations +that functionalization needs to handle. (Although we need to be careful not to DCE those mutable ops.) + +We can avoid this functionalization because: +(1) The nn.Parameter is never used before its .copy_() is called in eager code (i.e. no alias of it is created), +so it's safe to call .copy_() in the middle of the graph to update its content and start using the nn.Parameter downstream. +(2) We always re-allocate the buffer for nn.Parameter to store the AllGather output and to be used in downstream user ops. +So calling resize-to-0 in the middle of the graph to free nn.Parameter memory after use should always be okay +(since we always allocate anew next time we need it, we strictly don't need to keep the old tensor storage around anymore). + +Q: Wouldn't the extra resize_ and copy_ ops hurt both memory usage and performance? +A: Yes it would. As an optimization, we have an Inductor post-grad FX pass to remove those resize_ and copy_ ops +for unsharded params that have this pattern: resize_(full) -> copy_ -> resize_(0). + +TODO: +Now that we are maintaining the invariant of "no aliased + mutated graph inputs" in both the forward and backward, +it is now more feasible to functionalize all of the mutable FSDP ops. Some of the pros and cons are: + +Cons (of functionalizing those ops): +(1) By not functionalizing them as we are today, we are making it more likely that they will run at the "correct" time +in the generated code. If we start to functionalize them, we will need to make sure that Inductor reinplaces them +in a way where it properly moves the mutations back to exactly where they should have run, or we risk suffering worse +peak memory than eager. (We probably already need to do something similar in Inductor's reinplacing for copy_: +https://github.com/pytorch/pytorch/issues/135305#issuecomment-2334888089) + +Pros (of functionalizing): +(1) Better safety, we don't need to worry about the graph passes in inductor/partitioning handling input mutations +mid-graph quite as much (to be fair we've already done some amount of auditing, but we might have to do some more). +(2) Better perf: each mutation midway through the graph prevents Inductor from pattern matching across it. +But maybe there are few enough mutations induced by FSDP for this to matter. +""" + + +@torch.library.impl(lib, "copy_", "Functionalize") +def copy__functionalize(tensor, data): + torch._sync(tensor) + torch._sync(data) + tensor_inner = torch._from_functional_tensor(tensor) + data_inner = torch._from_functional_tensor(data) + with torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ): + torch.ops.fsdp.copy_.default(tensor_inner, data_inner) + + +torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default) + + +class ShardedState(Enum): + """ + - ``SHARDED``: The sharded parameter is registered to the module. It is the + only contributor to parameter memory. + - ``SHARDED_POST_FORWARD``: The unsharded parameter is resharded to a + smaller world size. Since this data should not be used for computation, + we do not register it to the module. Users should reshard the module + before any in-place modifications. Both it and the sharded parameter + contribute to parameter memory. + - ``UNSHARDED``: The unsharded parameter is registered to the module. Both + it and the sharded parameter contribute to parameter memory. + """ + + SHARDED = auto() + SHARDED_POST_FORWARD = auto() + UNSHARDED = auto() + + +@dataclass +class ParamModuleInfo: + """ + For a parameter, this stores the module and the parameter name to be able + to do a parameter swap via ``setattr(module, param_name, ...)`` or to get + the parameter via ``getattr(module, param_name)``. We additionally save + shared modules and shared parameter names to update them accordingly. + """ + + # Parameter names are unprefixed, e.g. "weight", not "lin.weight" + module: nn.Module + param_name: str + shared_modules: list[nn.Module] = field(default_factory=list) + shared_param_names: list[str] = field(default_factory=list) + + +@dataclass +class ExtensionsData: + # User-defined metadata passed from pre to post-all-gather + all_gather_metadata: Optional[Any] = None + # Save the all-gather input sizes to unflatten the all-gather outputs to ND + all_gather_input_sizes: Sequence[torch.Size] = () # ND + + def clear(self): + self.all_gather_metadata = None + self.all_gather_input_sizes = () + + +class FSDPParam: + """ + This class manages a parameter with FSDP or FSDP variants applied, + implementing dim-0 per-parameter sharding. + """ + + orig_dtype: torch.dtype + param_dtype: Optional[torch.dtype] + reduce_dtype: Optional[torch.dtype] + _orig_size: torch.Size # ND + sharded_size: torch.Size # ND + contiguous_sharded_stride: tuple[int, ...] + padded_sharded_param_size: torch.Size # ND + sharded_post_forward_size: torch.Size # ND + contiguous_sharded_post_forward_stride: tuple[int, ...] + _sharded_param_data: torch.Tensor # 1D + sharded_param: nn.Parameter # ND + _sharded_post_forward_param_data: Optional[torch.Tensor] # 1D + _sharded_post_forward_param: Optional[nn.Parameter] # ND + _unsharded_param: nn.Parameter # ND + unsharded_accumulated_grad: Optional[torch.Tensor] # ND + _sharding_spec: DTensorSpec + # DTensor attributes (only defined for DTensor `param`): + _tp_spec: DTensorSpec + all_gather_outputs: list[torch.Tensor] # 1D + # All-gather extension attributes + _extensions_data: ExtensionsData + _unsharded_inner_tensors: list[torch.Tensor] + + def __init__( + self, + param: nn.Parameter, + module_info: ParamModuleInfo, + mesh_info: FSDPMeshInfo, + post_forward_mesh_info: Optional[FSDPMeshInfo], + device: torch.device, + shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]], + mp_policy: MixedPrecisionPolicy, + offload_policy: OffloadPolicy, + ): + self._module_info: ParamModuleInfo = module_info + self.mesh_info = mesh_info + self.post_forward_mesh_info = post_forward_mesh_info + # pyrefly: ignore [read-only] + self.device = device + self.mp_policy = mp_policy + self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy) + self.pin_memory = ( + self.offload_to_cpu and cast(CPUOffloadPolicy, offload_policy).pin_memory + ) + self.grad_offload_event: Optional[torch.Event] = None + self._init_sharded_param(param, device, shard_placement_fn) + if self.post_forward_mesh_info: + self._init_sharded_post_forward_param_metadata(param) + self._init_extensions() + self.all_gather_outputs: list[torch.Tensor] = [] + self.unsharded_accumulated_grad = None + self._param_fqn: Optional[str] = None # prefixed from root module + # TODO: Remove this padding logic once DTensor pads the local tensor: + # https://github.com/pytorch/pytorch/issues/113045 + self._post_load_hook_handle = ( + module_info.module.register_load_state_dict_post_hook( + lambda *args, **kwargs: self.reset_sharded_param() + ) + ) + + @torch.no_grad() + def _init_sharded_param( + self, + param: nn.Parameter, + device: torch.device, + shard_placement_fn: Optional[Callable], + ): + if param.device != device and param.device.type != "meta": + raise AssertionError( + f"Expects the parameter to already be moved to device {device} but got {param.device}" + ) + if not param.is_contiguous(): + raise NotImplementedError( + f"FSDP does not support non-contiguous parameters yet: {param.shape=} {param.stride()=}" + ) + fsdp_placement = shard_placement_fn(param) if shard_placement_fn else None + if fsdp_placement is None: + fsdp_placement = Shard(0) + elif fsdp_placement.dim < 0: + fsdp_placement = Shard(fsdp_placement.dim + param.ndim) + if not isinstance(fsdp_placement, Shard): + raise AssertionError( + f"Expected Shard, got {type(fsdp_placement)}: {fsdp_placement}" + ) + self.fsdp_placement = fsdp_placement + shard_dim = fsdp_placement.dim + # TODO: Replace the sharded DTensor parameter construction logic with + # `distribute_tensor` after https://github.com/pytorch/pytorch/issues/116101 + # TODO: Simplify the following sharded parameter padding logic after + # https://github.com/pytorch/pytorch/issues/113045 + self.is_dtensor = isinstance(param, DTensor) + if self.is_dtensor: + self._tp_spec = cast(DTensor, param)._spec + dp_mesh, tp_mesh = (self.mesh_info.mesh, self._tp_spec.mesh) + if dp_mesh is None or tp_mesh is None: + raise AssertionError( + "FSDP requires the DP and model parallel TP/EP mesh to be not None but got: \n" + f"DP's mesh: {dp_mesh}\nTP/EP's mesh: {tp_mesh}" + ) + self._spmd_mesh = DeviceMesh._concatenate([dp_mesh, tp_mesh]) + if len(self._tp_spec.placements) > 2: + raise NotImplementedError( + f"FSDP only supports 1D TP/EP or 2D EP+TP, not {self._tp_spec.placements}" + ) + split_factor = self._tp_spec.num_shards_map[shard_dim] + if not (2 <= self._spmd_mesh.ndim <= 4): + raise AssertionError( + "_spmd_mesh.ndim can only be 2 (FSDP+TP/EP), 3 (FSDP+EP+TP, HSDP+TP/EP), " + f"or 4 (HSDP+EP+TP) but got {self._spmd_mesh.ndim}." + ) + self._spmd_placements: tuple[Placement, ...] + if isinstance(self.mesh_info, FSDPMeshInfo): # FSDP or HSDP + dp_shard_tp_placement = ( + ( + _StridedShard(shard_dim, split_factor=split_factor) + if split_factor > 1 + else fsdp_placement + ), + *self._tp_spec.placements, + ) + else: # DDP + dp_shard_tp_placement = ( + (Replicate()), + *self._tp_spec.placements, + ) + if isinstance(self.mesh_info, HSDPMeshInfo): # HSDP + if self.mesh_info.replicate_mesh_dim != 0: + raise AssertionError( + f"Expected replicate_mesh_dim to be 0, got {self.mesh_info.replicate_mesh_dim}" + ) + self._spmd_placements = (Replicate(),) + dp_shard_tp_placement + else: # FSDP or DDP + self._spmd_placements = dp_shard_tp_placement + + self._sharding_spec = DTensorSpec( + self._spmd_mesh, + self._spmd_placements, + tensor_meta=self._tp_spec.tensor_meta, + ) + param_data = cast(DTensor, param)._local_tensor + else: + self._spmd_mesh = self.mesh_info.mesh + if isinstance(self.mesh_info, HSDPMeshInfo): # HSDP + self._spmd_placements = (Replicate(), fsdp_placement) + elif isinstance(self.mesh_info, FSDPMeshInfo): # FSDP + self._spmd_placements = (fsdp_placement,) + elif isinstance(self.mesh_info, DDPMeshInfo): # DDP + self._spmd_placements = (Replicate(),) + self._sharding_spec = DTensorSpec( + self._spmd_mesh, + self._spmd_placements, + tensor_meta=TensorMeta(param.size(), param.stride(), param.dtype), + ) + param_data = param + if not param_data.is_contiguous(): + raise AssertionError( + f"Expected contiguous tensor, got {param_data.shape=} {param_data.stride()=}" + ) + shard_dim = fsdp_placement.dim + if shard_dim >= param_data.ndim: + raise AssertionError( + f"Shard dim {shard_dim} is invalid for {param_data.ndim}D tensor: {param.shape}" + ) + self._orig_size = param_data.size() + self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size) + if isinstance(self.mesh_info, FSDPMeshInfo): # FSDP or HSDP + shard_rank = self.mesh_info.shard_mesh_rank + shard_world_size = self.mesh_info.shard_mesh_size + else: # DDP + shard_rank = 0 + shard_world_size = 1 + + if shard_dim > 0 and param_data.size(shard_dim) % shard_world_size != 0: + # If sharding on nonzero dim, require even sharding for now because + # the uneven sharding (1) requires extra copies before/after FSDP + # collectives and (2) introduces extra complexity to handle padding + # and unpadding + raise NotImplementedError( + f"FSDP does not support uneven sharding on dim {shard_dim}: " + f"{param_data.size()} (world size: {shard_world_size})" + ) + chunks = _chunk_with_empty(param_data, shard_world_size, dim=shard_dim) + sharded_param = chunks[shard_rank] + self.sharded_size = _get_dim_chunked_size( + sharded_param, param_data.size(), dim=shard_dim + ) + self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size) + padded_sharded_size = chunks[0].size() # 0th always padded + self.padded_sharded_param_size = padded_sharded_size + # Pre-pad the sharded parameter to avoid padding before all-gather + padded_sharded_param = param_data.new_zeros(padded_sharded_size) + if sharded_param.numel() > 0: + padded_sharded_param.narrow( + dim=shard_dim, start=0, length=sharded_param.size(shard_dim) + ).copy_(sharded_param) + if self.offload_to_cpu and not padded_sharded_param.is_meta: + padded_sharded_param = padded_sharded_param.cpu() + if self.pin_memory: + padded_sharded_param = padded_sharded_param.pin_memory() + self._sharded_param_data = padded_sharded_param.view(-1) + length = sharded_param.size(shard_dim) if sharded_param.numel() > 0 else 0 + sharded_param = padded_sharded_param.narrow( + dim=shard_dim, start=0, length=length + ) + if not sharded_param.is_contiguous(): + raise AssertionError( + f"Expected contiguous tensor with {self.fsdp_placement=}" + ) + self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param)) + self.sharded_param.requires_grad_(param.requires_grad) + # Let `param_data` be freed normally when its ref count reaches 0 when + # the `fully_shard` call returns to allow provided parameters to alias + self._setattr_on_modules(self.sharded_param) + self.sharded_state = ShardedState.SHARDED + + def _init_sharded_post_forward_param_metadata(self, param: torch.Tensor) -> None: + mesh_info = self.post_forward_mesh_info + if mesh_info is None: + raise AssertionError("Expected post_forward_mesh_info to not be None") + param_data = param._local_tensor if isinstance(param, DTensor) else param + if isinstance(mesh_info, FSDPMeshInfo): + chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0) + self.sharded_post_forward_size = _get_dim_chunked_size( + chunks[mesh_info.shard_mesh_rank], + param_data.size(), + dim=self.fsdp_placement.dim, + ) + else: # DDP + chunks = _chunk_with_empty(param_data, 1, dim=0) + self.sharded_post_forward_size = _get_dim_chunked_size( + chunks[0], + param_data.size(), + dim=self.fsdp_placement.dim, + ) + self.contiguous_sharded_post_forward_stride = make_contiguous_strides_for( + self.sharded_post_forward_size + ) + + def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy): + param_dtype, reduce_dtype = (mp_policy.param_dtype, mp_policy.reduce_dtype) + self.orig_dtype = self.sharded_param.dtype + # Clamp `reduce_dtype` to `None` if no casting is required: since + # gradients are computed in `param_dtype`, if `reduce_dtype` matches, + # then we do not need extra casting + if reduce_dtype == param_dtype: + reduce_dtype = None + # Clamp `param_dtype` to `None` if no casting is required + if param_dtype == self.orig_dtype: + param_dtype = None + self.param_dtype = param_dtype + self.reduce_dtype = reduce_dtype + # None indicates that the mixed precision is not enabled + + def _init_extensions(self) -> None: + inner_tensor = self._sharded_local_tensor + has_fsdp_pre_all_gather = hasattr(inner_tensor, "fsdp_pre_all_gather") + has_fsdp_post_all_gather = hasattr(inner_tensor, "fsdp_post_all_gather") + if has_fsdp_pre_all_gather != has_fsdp_post_all_gather: + raise AssertionError( + "Both fsdp_pre_all_gather and fsdp_post_all_gather should be defined " + f"if using all-gather extensions: {inner_tensor}" + ) + if has_fsdp_pre_all_gather: + self._extensions_data = ExtensionsData() + self._unsharded_inner_tensors: list[torch.Tensor] = [] + + def init_all_gather_outputs( + self, + all_gather_input_numels: list[int], + all_gather_input_dtypes: list[torch.dtype], + world_size: int, + device: torch.device, + force_recreate: bool = False, + ): + if not force_recreate and len(self.all_gather_outputs) > 0: + return # already initialized + self.all_gather_outputs = [ + torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device) + for numel, dtype in zip(all_gather_input_numels, all_gather_input_dtypes) + ] + + def init_unsharded_param(self): + """ + [Note: Invariants for torch.compile Traceable FSDP2] + 1. Under compile, we always re-populate the content of `self._unsharded_param` + per AllGather using the slow path. + 2. Under compile, we always recreate `self.all_gather_outputs` per AllGather. + This is to ensure the buffer creation is internal to the graph and + avoid `self.all_gather_outputs` being captured as a graph input. + 3. Under compile, at the end of `free_unsharded_param()`, we always clean up + `self.all_gather_outputs` and `self._unsharded_inner_tensors`, + to avoid them being captured as graph output. + + With these invariants, only these tensors will be inputs to the graph: + - Sharded parameters + - Placeholders for the `self._unsharded_param` nn.Parameter + """ + if not compiled_autograd_enabled() and hasattr( + self, "_unsharded_param" + ): # after the 1st all-gather + inner_tensor = self._sharded_local_tensor + if not hasattr(inner_tensor, "fsdp_post_all_gather"): + return # already initialized + for tensor in self._unsharded_inner_tensors: + alloc_storage(tensor) + all_gather_outputs = self._unflatten_all_gather_outputs() + inner_tensor.fsdp_post_all_gather( + all_gather_outputs, + self._extensions_data.all_gather_metadata, + self.param_dtype or self.orig_dtype, + out=self._unsharded_param, + ) + self._extensions_data.clear() + return + inner_tensor = self._sharded_local_tensor + if not compiled_autograd_enabled() and hasattr( + inner_tensor, "fsdp_post_all_gather" + ): + all_gather_outputs = self._unflatten_all_gather_outputs() + ( + unsharded_tensor, + self._unsharded_inner_tensors, + ) = inner_tensor.fsdp_post_all_gather( + all_gather_outputs, + self._extensions_data.all_gather_metadata, + self.param_dtype or self.orig_dtype, + ) + self._extensions_data.clear() + else: + # For the default path (no post-all-gather), the all-gather output + # gives the unsharded parameter data directly + if len(self.all_gather_outputs) != 1: + raise AssertionError( + f"Expected 1 all_gather_output, got {len(self.all_gather_outputs)}" + ) + unsharded_tensor = self.all_gather_outputs[0] + unsharded_param = torch.as_strided( + unsharded_tensor, + self._orig_size, + self._contiguous_orig_stride, + storage_offset=0, + ) + if self.is_dtensor: + unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec) + if hasattr(self, "_unsharded_param"): + if not compiled_autograd_enabled(): + raise AssertionError("Expected compiled_autograd to be enabled") + with ( + torch.no_grad(), + torch.autograd._unsafe_preserve_version_counter(self._unsharded_param), + ): + # NOTE: Under compile, if an unsharded param goes through + # resize_(full) -> copy_ -> resize_(0) pattern, we will remove those + # resize_ and copy_ ops in a compiler graph pass + # `remove_fsdp2_unsharded_param_graph_input_usage` to recover performance. + self._unsharded_param.untyped_storage().resize_( + self._unsharded_param.numel() * self._unsharded_param.itemsize + ) + torch.ops.fsdp.copy_(self._unsharded_param, unsharded_param) + else: + self._unsharded_param = nn.Parameter( + unsharded_param, requires_grad=self.sharded_param.requires_grad + ) + + def _unflatten_all_gather_outputs(self) -> tuple[torch.Tensor, ...]: + return tuple( + t.view(-1, *s[1:]) + for t, s in zip( + self.all_gather_outputs, self._extensions_data.all_gather_input_sizes + ) + ) + + def to_sharded(self) -> None: + self._setattr_on_modules(self.sharded_param) + self.free_unsharded_param() + self.sharded_state = ShardedState.SHARDED + + def to_sharded_post_forward(self) -> None: + if self.is_dtensor: + raise NotImplementedError( + "Resharding to smaller mesh with TP is not supported yet" + ) + self._assert_in_states(ShardedState.UNSHARDED) + if self.post_forward_mesh_info is None: + raise AssertionError("Expected post_forward_mesh_info to not be None") + if len(self.all_gather_outputs) != 1: + raise AssertionError( + f"Expected 1 all_gather_output, got {len(self.all_gather_outputs)}" + ) + shard_world_size = self.post_forward_mesh_info.shard_mesh_size + if (numel := self.all_gather_outputs[0].numel()) % shard_world_size != 0: + _raise_assert_with_print( + f"All-gather output size ({numel}) must be divisible by the shard " + f"world size ({shard_world_size})" + ) + shard_rank = self.post_forward_mesh_info.shard_mesh_rank + # pyrefly: ignore [unbound-name] + sharded_numel = numel // shard_world_size + self._sharded_post_forward_param_data = ( + self.all_gather_outputs[0].narrow( + 0, sharded_numel * shard_rank, sharded_numel + ) + ).clone() # clone to be able to free all-gather output + sharded_post_forward_tensor = torch.as_strided( + self._sharded_post_forward_param_data, + size=self.sharded_post_forward_size, + stride=self.contiguous_sharded_post_forward_stride, + storage_offset=0, + ) + self._sharded_post_forward_param = nn.Parameter( + self.to_sharded_post_forward_dtensor(sharded_post_forward_tensor) + ) + self._setattr_on_modules(self._sharded_post_forward_param) + self.free_unsharded_param() + self.sharded_state = ShardedState.SHARDED_POST_FORWARD + + def to_unsharded(self) -> None: + # Assume that the data has been allocated and all-gathered + set_requires_grad_if_needed(self.sharded_param, self._unsharded_param) + self._setattr_on_modules(self._unsharded_param) + if self.sharded_state == ShardedState.SHARDED_POST_FORWARD: + # The data is allocated in the default stream via the post-forward + # reshard and must be kept alive for the next all-gather copy-in. + # Since we call this method after the copy-out, the data's lifetime + # is ensured without further synchronization. + self._sharded_post_forward_param = None + self._sharded_post_forward_param_data = None # free + self.sharded_state = ShardedState.UNSHARDED + + def _setattr_on_modules(self, param: nn.Parameter) -> None: + unsafe_setattr_param( + self._module_info.module, self._module_info.param_name, param + ) + for shared_module, shared_param_name in zip( + self._module_info.shared_modules, self._module_info.shared_param_names + ): + unsafe_setattr_param(shared_module, shared_param_name, param) + + def to_sharded_dtensor(self, tensor: torch.Tensor) -> DTensor: + """ + Converts a local tensor representing either the sharded parameter or + sharded gradient to DTensor. + """ + if tensor.shape != self.sharded_size: + _raise_assert_with_print( + f"Expects size {self.sharded_size} but got {tensor.shape}" + ) + return _from_local_no_grad( + tensor, + self._sharding_spec, + ) + + def to_sharded_post_forward_dtensor(self, tensor: torch.Tensor) -> DTensor: + if tensor.shape != self.sharded_post_forward_size: + _raise_assert_with_print( + f"Expects size {self.sharded_post_forward_size} but got {tensor.shape}" + ) + if not isinstance(self.post_forward_mesh_info, HSDPMeshInfo): + raise AssertionError( + f"Expected HSDPMeshInfo, got {type(self.post_forward_mesh_info)}" + ) + # TODO: Prefer this DTensor to be read-only and generalize the + # placement once we support TP. + post_forward_sharding_spec = DTensorSpec( + self.post_forward_mesh_info.mesh, + (Replicate(), Shard(0)), + tensor_meta=self._sharding_spec.tensor_meta, + ) + return _from_local_no_grad(tensor, post_forward_sharding_spec) + + def to_accumulated_grad_if_needed(self) -> None: + # Access `_unsharded_param` to bypass the sharded state check since we + # prefer to reshard before upcasting the gradient to save memory + if ( + self.reduce_dtype is None + or self._unsharded_param.grad is None + or self._unsharded_param.grad.dtype == self.reduce_dtype + ): + return + unsharded_grad = self._unsharded_param.grad + self._unsharded_param.grad = None + self.unsharded_accumulated_grad = unsharded_grad.to(self.reduce_dtype) + + def accumulate_unsharded_grad_if_needed(self) -> None: + if ( + self.unsharded_accumulated_grad is not None + and self.unsharded_param.grad is not None + ): + self.unsharded_accumulated_grad += self.unsharded_param.grad + self.unsharded_param.grad = None + + def alloc_all_gather_outputs(self) -> None: + for tensor in self.all_gather_outputs: + alloc_storage(tensor) + + def free_unsharded_param(self) -> None: + if compiled_autograd_enabled(): + """ + Assumptions under compile: + - `self._unsharded_param` is NOT an alias of `self.all_gather_outputs`. + Instead, we resize `self._unsharded_param` storage size to full and then + explicitly *copy* the data from `self.all_gather_outputs` to `self._unsharded_param` + in `init_unsharded_param()`. (For full-graph FSDP2 case, we will then remove + the resize_ and copy_ ops in a compiler graph pass to recover performance.) + - `self.all_gather_outputs` and `self._unsharded_inner_tensors` are NOT + graph inputs. They are created within the graph and is guaranteed to be freed + by the end of the graph. They don't leak outside of the graph. + """ + self._unsharded_param.untyped_storage().resize_(0) + self.all_gather_outputs = [] + self._unsharded_inner_tensors = [] + else: + for tensor in itertools.chain( + self.all_gather_outputs, self._unsharded_inner_tensors + ): + free_storage(tensor) + + @property + def all_gather_inputs(self) -> list[torch.Tensor]: # 1D + self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD) + if self.sharded_state == ShardedState.SHARDED: + if not compiled_autograd_enabled() and hasattr( + self._sharded_local_tensor, "fsdp_pre_all_gather" + ): + sharded_local_tensor = self._sharded_local_tensor + if self.offload_to_cpu: + sharded_local_tensor = sharded_local_tensor.to( + self.device, non_blocking=True + ) + pre_all_gather_signature = inspect.signature( + # pyrefly: ignore [missing-attribute] + sharded_local_tensor.fsdp_pre_all_gather + ) + num_fn_params = len(pre_all_gather_signature.parameters) + # Old signature only passes mesh; keep for BC for now + if num_fn_params not in (1, 5): + raise AssertionError( + f"Invalid fsdp_pre_all_gather: {pre_all_gather_signature}\n" + "Expects fsdp_pre_all_gather(self, mesh: DeviceMesh, " + "outer_size: torch.Size, outer_stride: tuple[int, ...], " + "module: nn.Module, mp_policy: MixedPrecisionPolicy)" + ) + if num_fn_params == 1: + ( + all_gather_inputs, + self._extensions_data.all_gather_metadata, + # pyrefly: ignore [missing-attribute] + ) = sharded_local_tensor.fsdp_pre_all_gather( + self.shard_mesh_from_root + ) + else: + ( + all_gather_inputs, + self._extensions_data.all_gather_metadata, + # pyrefly: ignore [missing-attribute] + ) = sharded_local_tensor.fsdp_pre_all_gather( + self.shard_mesh_from_root, + self._orig_size, + self._contiguous_orig_stride, + self._module_info.module, + self.mp_policy, + ) + if ( + sharded_local_tensor.size() != self.padded_sharded_param_size + and any( + all_gather_input.size() != self.padded_sharded_param_size + for all_gather_input in all_gather_inputs + ) + ): + # NOTE: Since this error can only be raised on the + # ranks that have padding, this can manifest as a NCCL + # watchdog timeout, as the other ranks will not error. + raise AssertionError( + "When a parameter is unevenly sharded by FSDP " + f"(orig size={self._orig_size}, FSDP world size={self.mesh_info.mesh.size()}), " + "fsdp_pre_all_gather must return all-gather inputs with the padded sharded size " + f"{self.padded_sharded_param_size} but got {[t.size() for t in all_gather_inputs]}" + ) + self._extensions_data.all_gather_input_sizes = [ + t.size() for t in all_gather_inputs + ] + return [t.view(-1) for t in all_gather_inputs] + sharded_param_data = self._sharded_param_data + if self.offload_to_cpu: + sharded_param_data = sharded_param_data.to( + self.device, non_blocking=True + ) + return [_to_dtype_if_needed(sharded_param_data, self.param_dtype)] + elif self.sharded_state == ShardedState.SHARDED_POST_FORWARD: + if not compiled_autograd_enabled() and hasattr( + self._sharded_local_tensor, "fsdp_pre_all_gather" + ): + raise NotImplementedError + all_gather_input = _to_dtype_if_needed( + cast(torch.Tensor, self._sharded_post_forward_param_data), + self.param_dtype, + ) + return [all_gather_input] + return [torch.empty(0)] # mypy + + @property + def unsharded_param(self) -> nn.Parameter: # ND + return self._unsharded_param + + @property + def unsharded_grad_data(self) -> torch.Tensor: + grad = self.unsharded_param.grad + if grad is None: + raise AssertionError("Expects unsharded_param.grad to not be None") + return self._get_grad_inner_tensor(grad) + + @property + def unsharded_accumulated_grad_data(self) -> torch.Tensor: + grad = self.unsharded_accumulated_grad + if grad is None: + raise AssertionError("Expects unsharded_accumulated_grad to not be None") + return self._get_grad_inner_tensor(grad) + + def _get_grad_inner_tensor(self, grad: torch.Tensor) -> torch.Tensor: + if self.is_dtensor: + if isinstance(grad, AsyncCollectiveTensor): + grad = grad.wait() + if not isinstance(grad, DTensor): + raise AssertionError(f"Expected DTensor, got {type(grad)}") + placements = self._tp_spec.placements + if placements != grad.placements: + if len(self._tp_spec.placements) != len(grad.placements): + raise AssertionError( + f"Expected same placement length: {self._tp_spec=} {grad.placements=}" + ) + grad = grad.redistribute(placements=placements) + grad = grad._local_tensor + return grad + + @property + def _sharded_local_tensor(self) -> torch.Tensor: + return cast(DTensor, self.sharded_param)._local_tensor + + @property + def shard_mesh(self): + mesh = self.mesh_info.mesh + if mesh.ndim == 1: + return mesh + elif mesh.ndim == 2: + if mesh.mesh_dim_names is None: + raise AssertionError("Expected mesh_dim_names to not be None") + return mesh[mesh.mesh_dim_names[-1]] + raise ValueError(f"Invalid mesh: {mesh}") + + @property + def shard_mesh_from_root(self): + mesh = self.mesh_info.mesh + + if mesh.ndim == 1: + return mesh + else: + if mesh.mesh_dim_names is None: + raise AssertionError("Expected mesh_dim_names to not be None") + shard_dim_name = mesh.mesh_dim_names[-1] + return mesh[shard_dim_name] + + def _assert_in_states(self, *states: ShardedState) -> None: + if self.sharded_state not in states: + _raise_assert_with_print( + f"Expects to be in one of {states}, not {self.sharded_state}" + ) + + def reset_sharded_param(self): + # For ops like `nn.Module._apply` or `load_state_dict(assign=True)` + # that change the sharded parameter tensor, we may need to re-pad the + # sharded local tensor and re-save the reference. + module_info = self._module_info + new_param = getattr(module_info.module, module_info.param_name) + if new_param is not self.sharded_param: + if torch.__future__.get_swap_module_params_on_conversion(): + raise AssertionError( + f"Expects swap_tensors to preserve object but got {new_param} " + f"instead of {self.sharded_param}" + ) + self.sharded_param = new_param + # pyrefly: ignore [missing-attribute] + local_tensor = new_param._local_tensor + if local_tensor.is_meta: + return + updated_local_tensor = False + # local_tensor can be padded twice + # 1st time in fully_shard(model) + # 2nd time in model(input) lazy_init + # 2nd time should be no-op if parameters remain unchanged + # 2nd time shouldn't be no-op if people call model.load_state_dict(...) before lazy_init + # this makes it possible for trainer to call `sd = model.state_dict()` before the training loop + # and use `sd` without calling .state_dict() per iteration + same_local_tensor = False + # TODO: need to support tensor subclass + if type(self._sharded_param_data) is torch.Tensor: + same_local_tensor = ( + # when sharding param with shape (1, ...) over 2 ranks + # local_tensor on rank 1 can be size 0, data_ptr() can be 0 + self._sharded_param_data.untyped_storage().data_ptr() > 0 + and self._sharded_param_data.untyped_storage().data_ptr() + == local_tensor.untyped_storage().data_ptr() + ) + padded_sharded_size = self.padded_sharded_param_size + shard_dim = self.fsdp_placement.dim + length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0 + if local_tensor.size() != padded_sharded_size and not same_local_tensor: + if shard_dim != 0: + raise AssertionError( + f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}" + ) + padded_local_tensor = local_tensor.new_zeros(padded_sharded_size) + padded_local_tensor.narrow(dim=shard_dim, start=0, length=length).copy_( + local_tensor + ) + local_tensor = padded_local_tensor + updated_local_tensor = True + if self.pin_memory and not local_tensor.is_pinned(): + local_tensor = local_tensor.cpu().pin_memory() + updated_local_tensor = True + if not same_local_tensor: + self._sharded_param_data = local_tensor.view(-1) + if not isinstance(self.sharded_param, DTensor): + raise AssertionError(f"Expected DTensor, got {type(self.sharded_param)}") + if updated_local_tensor: + # Only change the local tensor object if needed + self.sharded_param._local_tensor = local_tensor.narrow( + dim=shard_dim, start=0, length=length + ) + if not self.sharded_param._local_tensor.is_contiguous(): + raise AssertionError( + "Expected sharded_param._local_tensor to be contiguous" + ) + self._sharding_spec = self.sharded_param._spec + + def __repr__(self): + return f"FSDPParam(fqn={self._param_fqn}, orig_size={self._orig_size})" + + +def alloc_storage(tensor: torch.Tensor) -> None: + size = tensor.numel() * tensor.itemsize + if (storage := tensor.untyped_storage()).size() != size: + storage.resize_(size) + + +def free_storage(tensor: torch.Tensor) -> None: + if (storage := tensor.untyped_storage()).size() != 0: + storage.resize_(0) + + +# NOTE: These bypass `nn.Module.__setattr__` checks, which incur non-trivial +# CPU overhead, if the module did not override it. For FSDP, we know we do not +# need those checks when transitioning between sharded/unsharded parameters. +def unsafe_setattr_param( + module: nn.Module, param_name: str, param: nn.Parameter +) -> None: + if getattr(module.__setattr__, "__func__", None) is nn.Module.__setattr__: + module._parameters[param_name] = param + else: # slow path + setattr(module, param_name, param) + + +def set_requires_grad_if_needed( + src_tensor: torch.Tensor, dst_tensor: torch.Tensor +) -> None: + # Only call `requires_grad_` if needed to avoid the Python <> C++ context + # switch overhead + if src_tensor.requires_grad != dst_tensor.requires_grad: + dst_tensor.requires_grad_(src_tensor.requires_grad) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py new file mode 100644 index 0000000000000000000000000000000000000000..b70a5f06f4ae9b982b0f8e3a486573f79176c30b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py @@ -0,0 +1,901 @@ +# mypy: allow-untyped-defs +import contextlib +import logging +from collections.abc import Callable +from typing import Any, cast, NamedTuple, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed.device_mesh import _get_device_handle +from torch.distributed.fsdp._common_utils import _named_parameters_with_duplicates +from torch.distributed.tensor import Shard +from torch.profiler import record_function +from torch.utils._pytree import tree_flatten, tree_unflatten +from torch.utils.hooks import RemovableHandle + +from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy +from ._fsdp_collectives import ( + AllGather, + AllGatherResult, + DefaultAllGather, + DefaultReduceScatter, + foreach_all_gather, + foreach_all_gather_copy_out, + foreach_reduce, + ProcessGroupAllocAllGather, + ProcessGroupAllocReduceScatter, + ReduceScatter, +) +from ._fsdp_common import ( + compiled_autograd_enabled, + DDPMeshInfo, + FSDPMeshInfo, + HSDPMeshInfo, + is_bw, + TrainingState, +) +from ._fsdp_param import alloc_storage, FSDPParam, ParamModuleInfo, ShardedState + + +logger = logging.getLogger("torch.distributed.fsdp.fully_shard") + +_ModuleToHandleDict = dict[nn.Module, RemovableHandle] # for state dict + + +""" +[Note: Overlapping all-gather copy-in and all-gather] +For implicit forward prefetching, we want to overlap the next copy-in with the +current all-gather. We do so using a separate copy-in stream. However, since +we have the all-gather input as a view into the output, we must make sure to +copy into different memory from the current all-gather's output. Thus, we keep +a reference to the current all-gather's output and have the next FSDP parameter +group free it after its copy-in. Finally, we have the last FSDP state flush the +reference to avoid holding onto memory after forward. +""" + + +class FSDPCommContext: + """This has the communication state shared across FSDP states/parameter groups.""" + + def lazy_init(self, device: torch.device): + self.device_handle = _get_device_handle(device.type) + # Setting the all-gather/reduce-scatter streams to be higher priority + # can help avoid some issues where their copies in/out are delayed and + # block computation (this is different from high-pri NCCL streams) + high_priority = -1 + # All-gather state and copy-in stream allow overlapping the next + # copy-in with the current all-gather in forward; copy-in overlaps with + # reduce-scatter in backward without the separate copy-in stream + self.all_gather_copy_in_stream = self.device_handle.Stream( + priority=high_priority + ) + # All-gather stream allows overlapping next all-gather with current + # forward compute + self.all_gather_stream = self.device_handle.Stream(priority=high_priority) + # Reduce-scatter stream gives separate execution "thread" for post- + # backward logic like pre/post-gradient division and reduce-scatter + self.reduce_scatter_stream = self.device_handle.Stream(priority=high_priority) + # Run the HSDP all-reduces concurrently with all-gather/reduce-scatter + # since collectives use different network resources and can overlap + # in the typical intra-node sharding / inter-node replication case + self.all_reduce_stream = self.device_handle.Stream() + # All-gather/reduce-scatter states keep references to collective + # tensors produced in one stream and used in another and accompanying + # CUDA events for synchronization + self.all_gather_state: Optional[AllGatherState] = None + self.reduce_scatter_state: Optional[ReduceScatterState] = None + # Post-forward order for explicit backward prefetching + self.post_forward_order: list[FSDPParamGroup] = [] # will cause ref cycles + + def get_all_gather_streams( + self, async_op: bool, training_state: TrainingState + ) -> tuple[torch.Stream, torch.Stream]: + if not async_op and training_state in ( + TrainingState.FORWARD, + TrainingState.PRE_BACKWARD, + ): + # Use separate streams for implicit prefetching + return self.all_gather_copy_in_stream, self.all_gather_stream + current_stream = self.device_handle.current_stream() + return current_stream, current_stream + + +# See [Note: Overlapping all-gather copy-in and all-gather] +class AllGatherState(NamedTuple): + all_gather_result: AllGatherResult + event: Optional[torch.Event] # all-gather copy-out + + +class ReduceScatterState(NamedTuple): + reduce_scatter_input: torch.Tensor + event: Optional[torch.Event] # reduce-scatter event + + +class AllReduceState(NamedTuple): + all_reduce_input: torch.Tensor + event: Optional[torch.Event] # all-reduce event + + +class FSDPParamGroup: + """This class represents a parameter group to communicate together.""" + + _orig_dtype: Optional[torch.dtype] + _reduce_dtype: Optional[torch.dtype] + + def __init__( + self, + params: list[nn.Parameter], + modules: tuple[nn.Module, ...], + mesh_info: FSDPMeshInfo, + post_forward_mesh_info: Optional[FSDPMeshInfo], + device: torch.device, + shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]], + mp_policy: MixedPrecisionPolicy, + offload_policy: OffloadPolicy, + ): + self.modules = modules # permit ref cycle because 1:1 lifetime + param_module_infos = _get_param_module_infos(params, modules) + + self.fsdp_params = [ + FSDPParam( + param, + module_info, + mesh_info, + post_forward_mesh_info, + device, + shard_placement_fn, + mp_policy, + offload_policy, + ) + for param, module_info in zip(params, param_module_infos) + ] + self.mesh_info = mesh_info + self.post_forward_mesh_info = post_forward_mesh_info + # pyrefly: ignore [read-only] + self.device = device + self.device_handle = _get_device_handle(device.type) + self.mp_policy = mp_policy + self.offload_policy = offload_policy + self._training_state = TrainingState.IDLE + # Group's sharded state always matches its parameters' sharded states + self._sharded_state = ShardedState.SHARDED + self._module_fqn: Optional[str] = None # prefixed from root module + # Only consider resetting sharded parameters once in lazy init since it + # can incur nontrivial overhead to reset them + self._reset_sharded_params: bool = False + + # - Hook state + self._module_to_pre_save_state_dict_hook_handle: _ModuleToHandleDict = {} + self._module_to_pre_load_state_dict_hook_handle: _ModuleToHandleDict = {} + self._all_reduce_hook: Optional[Callable[[torch.Tensor], None]] = None + self._all_gather_comm: AllGather = DefaultAllGather() + self._all_gather_output = torch.empty(0, device=self.device) + self._reduce_scatter_comm: ReduceScatter = DefaultReduceScatter() + # Optional stream to run the user-defined all-reduce hook in + # Saved here and not in the comm. context because we allow the user to + # specify it, possibly at construction time before lazy init + self._all_reduce_hook_stream: Optional[torch.cuda.Stream] = None + + # - Communication and communication/computation overlap + self.comm_ctx = FSDPCommContext() + # Group's indices in the shared post-forward order + self._post_forward_indices: list[int] = [] + # Whether to reduce gradients at all (whether for FSDP or HSDP) + self.reduce_grads: bool = True + # Whether to all-reduce gradients for HSDP; only used if + # `self.reduce_grads` is true, in which case setting this to false + # means reduce-scatter but no all-reduce + self.all_reduce_grads: bool = True + # Whether to reshard parameters after backward (only useful for + # gradient accumulation) + self.reshard_after_backward: bool = True + # Optional custom factor for the gradient reduction op (e.g. to divide + # by a factor other than the world size) + self.gradient_divide_factor: Optional[float] = None + # Whether reduce-scatter and all-reduce should be issued using only + # summations, potentially with separate pre-/post-scaling. + self.force_sum_reduction_for_comms: bool = False + # `async_op` arg used for pre-forward/pre-backward unshard; can be + # overridden to only do explicit prefetching and avoid inter-stream + # fragmentation from using separate unshard streams + self.unshard_async_op: bool = False + # Whether to unshard in backward: can be overridden by the user if the + # parameters in this group are not needed for backward (e.g. embedding) + self.unshard_in_backward: bool = True + + # - CUDA events for stream synchronization + # Holds the all-gather output buffer, sync objects, and metadata + self._all_gather_result: Optional[AllGatherResult] = None + # Holds the reduce-scatter/all-reduce view-out CUDA event that marks the end of + # the group's post-backward (e.g. reduce-scatter, all-reduce and div), which + # should be waited on at the end of backward + self._post_reduce_event: Optional[torch.Event] = None + # Holds the reshard-after-forward CUDA event when resharding to a + # different world size, which should be waited on in the next unshard + self._reshard_after_forward_event: Optional[torch.Event] = None + + # Only for HSDP, if accumulating gradients without all-reduce, save the + # partial reduce output (only reduce-scattered but not all-reduced) + self._partial_reduce_output: Optional[torch.Tensor] = None + # Holds the all-reduce input and all-reduce event to keep it alive + # until the end of backward (critical when doing bf16 reduction with + # fp32 parameters since the all-reduce input is allocated in the RS + # stream and will have no refs to it after being upcast to fp32) + self._all_reduce_state: Optional[AllReduceState] = None + + # Initialization # + def _init_mp_dtypes(self) -> None: + for fsdp_param in self.fsdp_params: + fsdp_param.init_dtype_attrs(self.mp_policy) + trainable_params: list[FSDPParam] = [ + p for p in self.fsdp_params if p.sharded_param.requires_grad + ] + orig_dtypes = {p.orig_dtype for p in trainable_params} + reduce_dtypes = {p.reduce_dtype for p in trainable_params} + if len(trainable_params) > 0 and len(orig_dtypes) != 1: + # Models may have no grad params + raise AssertionError( + f"FSDP expects uniform original parameter dtype but got {orig_dtypes}" + ) + self._orig_dtype = next(iter(orig_dtypes)) if trainable_params else None + if len(trainable_params) > 0 and len(reduce_dtypes) != 1: + # This can be relaxed if we issue one reduce-scatter per reduce + # dtype (but we would need a way for users to specify multiple + # reduce dtypes) + raise AssertionError( + f"FSDP expects uniform reduce dtype but got {reduce_dtypes}" + ) + self._reduce_dtype = next(iter(reduce_dtypes)) if trainable_params else None + + def lazy_init(self): + # Lazy init should be idempotent + # Users may change or register parameters after construction time. + # For example, DoRA (https://arxiv.org/abs/2402.09353) initializes linear magnitudes based on + # other parameters (e.g. loaded from the state dict). + if not hasattr(self.comm_ctx, "device_handle"): + self.comm_ctx.device_handle = _get_device_handle(self.device.type) + if self.is_sharded and not self._reset_sharded_params: + for fsdp_param in self.fsdp_params: + fsdp_param.reset_sharded_param() + fsdp_param._init_extensions() # allow monkey patch after init + self._reset_sharded_params = True + self._validate_no_meta_params() + self._validate_cpu_offload_params() + # Initialize mixed precision attributes lazily in case the user changes + # the parameter dtypes after construction time but before forward + self._init_mp_dtypes() + self._register_state_dict_hooks() + + def set_allocate_memory_from_process_group(self, enable: bool) -> None: + """ + Whether to (try to) use the ProcessGroup's allocate_tensor method for + the staging buffers for collective comms. + """ + if not isinstance( + self._all_gather_comm, (DefaultAllGather | ProcessGroupAllocAllGather) + ): + raise AssertionError( + "cannot call set_allocate_memory_from_process_group() " + f"when all gather comm is custom: {self._all_gather_comm.__class__.__name__}" + ) + self._all_gather_comm = ( + ProcessGroupAllocAllGather(self._all_gather_process_group) + if enable + else DefaultAllGather() + ) + + if not isinstance( + self._reduce_scatter_comm, + (DefaultReduceScatter | ProcessGroupAllocReduceScatter), + ): + raise AssertionError( + "cannot call set_allocate_memory_from_process_group() " + f"when reduce scatter comm is custom: {self._reduce_scatter_comm.__class__.__name__}" + ) + self._reduce_scatter_comm = ( + ProcessGroupAllocReduceScatter(self._reduce_scatter_process_group) + if enable + else DefaultReduceScatter() + ) + + # Runtime # + def unshard(self, async_op: bool = False): + if self._all_gather_result is not None: # already called, pending wait + return + if self.is_unsharded: + return # no-op + if ( + not self.unshard_in_backward + and self._training_state == TrainingState.PRE_BACKWARD + ): + return + if self._reshard_after_forward_event is not None: + # Resharded parameter data is allocated in the default stream and + # used in the all-gather streams + self._wait_all_gather_streams_on_event(self._reshard_after_forward_event) + self._reshard_after_forward_event = None + + if isinstance(self.mesh_info, FSDPMeshInfo): + world_size = self._all_gather_process_group.size() + else: + world_size = 1 + if world_size == 1: + # can't skip due to early return in wait_for_unshard if + # no self._all_gather_result + self._all_gather_result = AllGatherResult( + all_gather_output=self._all_gather_output, + all_gather_event=self.device_handle.Event().record(), + all_gather_work=None, + param_all_gather_input_dtypes=[], + param_all_gather_input_numels=[], + all_gather_input_split_sizes=[], + ) + + return + + with record_function(self._with_fqn("FSDP::all_gather")): + self._all_gather_result = foreach_all_gather( + self.fsdp_params, + self._all_gather_process_group, + async_op, + *self.comm_ctx.get_all_gather_streams(async_op, self._training_state), + self.device, + self._all_gather_comm, + ) + + def wait_for_unshard(self): + """ + 1. In forward with implicit prefetching, to overlap the current copy-out + with the next all-gather, we save a reference to the current all-gather + result to free after the next copy-out. + 2. Otherwise (explicit prefetching or in backward), we free the + all-gather result immediately after the current copy-out since we can + already overlap the current copy-out with the previous reduce-scatter. + """ + if not self._all_gather_result: + return # no preceding unshard + async_op = self._all_gather_result.all_gather_work is not None + if self._training_state == TrainingState.FORWARD: # implicit prefetch + if prev_all_gather_state := self.comm_ctx.all_gather_state: + self._wait_all_gather_streams_on_event(prev_all_gather_state.event) + self.comm_ctx.all_gather_state = None # free the all-gather result + if isinstance(self.mesh_info, FSDPMeshInfo): + world_size = self._all_gather_process_group.size() + else: + world_size = 1 + if world_size == 1: + # directly initialize unsharded parameters from sharded parameters + + for fsdp_param in self.fsdp_params: + # Use all_gather_inputs which already handles conversion to param_dtype + # This is consistent with the world_size > 1 path + all_gather_input = fsdp_param.all_gather_inputs[0] + + # Make sure the all_gather_outputs has proper storage size before using it + # First ensure we have at least one tensor in all_gather_outputs + fsdp_param.init_all_gather_outputs( + [all_gather_input.numel()], + [all_gather_input.dtype], + world_size, + self.device, + force_recreate=False, + ) + + tensor = fsdp_param.all_gather_outputs[0] + alloc_storage(tensor) + + # find alternative way to check if tensor.is_inference + with torch.autograd._unsafe_preserve_version_counter(tensor): + tensor.copy_(all_gather_input) + + else: + with record_function(self._with_fqn("FSDP::all_gather_copy_out")): + foreach_all_gather_copy_out( + self._all_gather_result, + self.fsdp_params, + self._all_gather_process_group, + ) + + for fsdp_param in self.fsdp_params: + fsdp_param.init_unsharded_param() + + self._to_unsharded() + all_gather_copy_out_event = self.device_handle.Event() + all_gather_copy_out_event.record() + + if ( + not async_op + and self._training_state == TrainingState.FORWARD + and world_size > 1 + ): + # Defer free to allow for overlap of this copy-out with next + # all-gather collective + self.comm_ctx.all_gather_state = AllGatherState( + self._all_gather_result, all_gather_copy_out_event + ) + else: + self._wait_all_gather_streams_on_event(all_gather_copy_out_event) + + self._all_gather_result = None # free unless saved in `all_gather_state` + + def _wait_all_gather_streams_on_event(self, event: Optional[torch.Event]): + # Calling `unshard` before lazy init means streams are not initialized + if hasattr(self.comm_ctx, "all_gather_copy_in_stream") and event is not None: + self.comm_ctx.all_gather_copy_in_stream.wait_event(event) + if hasattr(self.comm_ctx, "all_gather_stream") and event is not None: + self.comm_ctx.all_gather_stream.wait_event(event) + + def reshard(self): + if self._training_state == TrainingState.FORWARD: + if not self._reshard_after_forward: + return + if self._use_post_forward_mesh: + self._to_sharded_post_forward() + self._reshard_after_forward_event = self.device_handle.Event() + if self._reshard_after_forward_event is not None: + self._reshard_after_forward_event.record() + return + self._to_sharded() + + def pre_forward( + self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[tuple[Any, ...], dict[str, Any]]: + if not compiled_autograd_enabled(): + logger.debug("%s", self._with_fqn("FSDP::pre_forward")) + with record_function(self._with_fqn("FSDP::pre_forward")): + self._training_state = TrainingState.FORWARD + self.unshard(self.unshard_async_op) + self.wait_for_unshard() + args, kwargs = self._register_post_backward_hook(args, kwargs) + return args, kwargs + + def post_forward(self, module: nn.Module, input: Any, output: Any): + if not compiled_autograd_enabled(): + logger.debug("%s", self._with_fqn("FSDP::post_forward")) + with record_function(self._with_fqn("FSDP::post_forward")): + if not compiled_autograd_enabled(): + # for AC(fully_shard(model)), AC runs fsdp's _pre_forward + # it shouldn't change post_forward_order + if not is_bw(): + self.reshard() + self._record_post_forward() + else: + self.reshard() + self._record_post_forward() + self._training_state = TrainingState.IDLE + return output + + def _record_post_forward(self) -> None: + # Since a group has one pre-backward unshard for each forward call + # before the backward, we record each usage (with multiplicity) + post_forward_index = len(self.comm_ctx.post_forward_order) + self.comm_ctx.post_forward_order.append(self) + self._post_forward_indices.append(post_forward_index) + + def pre_backward(self, default_prefetch: bool, *unused: Any): + if ( + compiled_autograd_enabled() + and self._training_state == TrainingState.PRE_BACKWARD + ): + # Traceable FSDP2 cannot trigger the param group's `post_backward` immediately after param usage; + # instead it relies on this to trigger the previously unexecuted `post_backward`. + self.post_backward() + if self._training_state == TrainingState.PRE_BACKWARD: + return + if not compiled_autograd_enabled(): + logger.debug("%s", self._with_fqn("FSDP::pre_backward")) + with record_function(self._with_fqn("FSDP::pre_backward")): + self._training_state = TrainingState.PRE_BACKWARD + self.unshard(self.unshard_async_op) # no-op if prefetched + self.wait_for_unshard() + if default_prefetch and not compiled_autograd_enabled(): + self._backward_prefetch() + + def post_backward(self, *unused: Any): + # This method should be idempotent and safe to call even when this + # FSDP parameter group was not used in backward (should be a no-op) + if not compiled_autograd_enabled(): + logger.debug("%s", self._with_fqn("FSDP::post_backward")) + self._training_state = TrainingState.POST_BACKWARD + with record_function(self._with_fqn("FSDP::post_backward_accumulate")): + for fsdp_param in self.fsdp_params: + fsdp_param.accumulate_unsharded_grad_if_needed() + with record_function(self._with_fqn("FSDP::post_backward_reshard")): + if not self.reduce_grads: + if self.reshard_after_backward: + self.reshard() + for fsdp_param in self.fsdp_params: + fsdp_param.to_accumulated_grad_if_needed() + return + # Save the autograd-computed gradients before resharding to only + # access the unsharded parameters when their data is present + fsdp_params_with_grad: list[FSDPParam] = [] + unsharded_grads: list[torch.Tensor] = [] + for fsdp_param in self.fsdp_params: + if not hasattr(fsdp_param, "_unsharded_param"): + continue + # May have an accumulated gradient of the reduce dtype if the + # previous backward did not reduce-scatter + if fsdp_param.unsharded_accumulated_grad is not None: + fsdp_params_with_grad.append(fsdp_param) + unsharded_grads.append(fsdp_param.unsharded_accumulated_grad_data) + fsdp_param.unsharded_accumulated_grad = None + elif fsdp_param.unsharded_param.grad is not None: + fsdp_params_with_grad.append(fsdp_param) + unsharded_grads.append(fsdp_param.unsharded_grad_data) + fsdp_param.unsharded_param.grad = None + if self.reshard_after_backward: + self.reshard() + if len(fsdp_params_with_grad) == 0: + return + with record_function(self._with_fqn("FSDP::post_backward_reduce")): + if ( + self.comm_ctx.reduce_scatter_state is not None + and self.comm_ctx.reduce_scatter_state.event is not None + ): + self.device_handle.current_stream().wait_event( + self.comm_ctx.reduce_scatter_state.event + ) + self.comm_ctx.reduce_scatter_state = None + all_reduce_pg = ( + self._all_reduce_process_group + if isinstance(self.mesh_info, DDPMeshInfo) + else None + ) + all_reduce_stream: torch.cuda.Stream + if all_reduce_pg is None and self._all_reduce_hook_stream is not None: + # this means the native HSDP is not enabled, + # but user may want to have a custom HSDP setup + if self._all_reduce_hook is None: + raise AssertionError( + "all reduce hook stream is specified but hook itself is missing." + ) + all_reduce_stream = self._all_reduce_hook_stream + else: + all_reduce_stream = self.comm_ctx.all_reduce_stream + + self._wait_for_post_backward() + ( + reduce_scatter_input, + reduce_scatter_event, + self._post_reduce_event, + all_reduce_input, + all_reduce_event, + self._partial_reduce_output, + ) = foreach_reduce( + fsdp_params_with_grad, + unsharded_grads, + ( + self._reduce_scatter_process_group + if isinstance(self.mesh_info, FSDPMeshInfo) + else None # pyre-fixme[6] + ), + self.comm_ctx.reduce_scatter_stream, + self._reduce_scatter_comm, + self._orig_dtype, + self._reduce_dtype, + self.device, + self.gradient_divide_factor, + ( + self._all_reduce_process_group + if isinstance(self.mesh_info, DDPMeshInfo) + else None + ), + all_reduce_stream, + self.all_reduce_grads, + self._partial_reduce_output, + self._all_reduce_hook, + self.force_sum_reduction_for_comms, + ) + self.comm_ctx.reduce_scatter_state = ReduceScatterState( + reduce_scatter_input, reduce_scatter_event + ) + if all_reduce_input is not None: + if self.device.type != "cpu": + if all_reduce_event is None: + raise AssertionError( + "Expected all_reduce_event to be set for non-CPU device" + ) + self._all_reduce_state = AllReduceState( + all_reduce_input, all_reduce_event + ) + + def finalize_backward(self): + self._wait_for_post_backward() + for fsdp_param in self.fsdp_params: + if fsdp_param.grad_offload_event is not None: + fsdp_param.grad_offload_event.synchronize() + fsdp_param.grad_offload_event = None + if self._all_gather_result is not None: + # If there was a mistargeted unshard without a corresponding wait, + # then we wait here and clear the unshard + if (event := self._all_gather_result.all_gather_event) is not None: + torch.accelerator.current_stream().wait_event(event) + work = self._all_gather_result.all_gather_work + if isinstance(work, dist.distributed_c10d.Work): + work.wait() + self._all_gather_result = None + self._post_forward_indices.clear() + + def _wait_for_post_backward(self): + if self._post_reduce_event is not None: + self.device_handle.current_stream().wait_event(self._post_reduce_event) + self._post_reduce_event = None + if ( + self._all_reduce_state is not None + and self._all_reduce_state.event is not None + ): + self.device_handle.current_stream().wait_event(self._all_reduce_state.event) + self._all_reduce_state = None + + def _backward_prefetch(self) -> None: + if self._training_state == TrainingState.PRE_BACKWARD: + if not self._post_forward_indices: + # Can be cleared if running multiple `backward`s + return + curr_index = self._post_forward_indices.pop() + if (target_index := curr_index - 1) < 0: + return + # Prefetch naively using the reverse post-forward order, which may + # have mistargeted prefetches if not all modules used in forward + # are used in this backward + # pyrefly: ignore [unbound-name] + target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index] + self._prefetch_unshard(target_fsdp_param_group, "backward") + + @staticmethod + def _prefetch_unshard( + target_fsdp_param_group: "FSDPParamGroup", pass_type: str + ) -> None: + if pass_type == "backward": + training_state = TrainingState.PRE_BACKWARD + elif pass_type == "forward": + training_state = TrainingState.FORWARD + else: + raise ValueError(f"Unknown pass type: {pass_type}") + target_fqn = target_fsdp_param_group._module_fqn + with ( + record_function(f"FSDP::{pass_type}_prefetch for {target_fqn}"), + target_fsdp_param_group.use_training_state(training_state), + ): + async_op = target_fsdp_param_group.unshard_async_op + target_fsdp_param_group.unshard(async_op) + + # Utilities # + def _to_sharded(self): + if not self.is_sharded: + for fsdp_param in self.fsdp_params: + fsdp_param.to_sharded() + self._sharded_state = ShardedState.SHARDED + + def _to_sharded_post_forward(self): + if not self.is_sharded_post_forward: + for fsdp_param in self.fsdp_params: + fsdp_param.to_sharded_post_forward() + self._sharded_state = ShardedState.SHARDED_POST_FORWARD + + def _to_unsharded(self): + if not self.is_unsharded: + for fsdp_param in self.fsdp_params: + fsdp_param.to_unsharded() + self._sharded_state = ShardedState.UNSHARDED + + @property + def is_sharded(self) -> bool: + return self._sharded_state == ShardedState.SHARDED + + @property + def is_sharded_post_forward(self) -> bool: + return self._sharded_state == ShardedState.SHARDED_POST_FORWARD + + @property + def is_unsharded(self) -> bool: + return self._sharded_state == ShardedState.UNSHARDED + + @contextlib.contextmanager + def use_training_state(self, training_state: TrainingState): + old_training_state = self._training_state + self._training_state = training_state + try: + yield + finally: + self._training_state = old_training_state + + # Hook Registration # + def _register_post_backward_hook( + self, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[tuple[Any, ...], dict[str, Any]]: + # Traceable FSDP2 relies on `root_post_backward_callback` to call each + # `FSDPParamGroup.post_backward` + if (not torch._dynamo.config.skip_fsdp_hooks) or compiled_autograd_enabled(): + return args, kwargs + if not torch.is_grad_enabled(): + return args, kwargs + args_list, args_spec = tree_flatten(args) + kwargs_list, kwargs_spec = tree_flatten(kwargs) + args_kwargs_list = list(args_list) + list(kwargs_list) + inp_tensor_indices: list[int] = [] + inp_tensors: list[torch.Tensor] = [] + for i, obj in enumerate(args_kwargs_list): + if torch.is_tensor(obj) and obj.requires_grad: + inp_tensor_indices.append(i) + inp_tensors.append(obj) + if len(inp_tensors) == 0: + return args, kwargs # no tensors that require gradients + inp_tensors = RegisterPostBackwardFunction.apply(self, *inp_tensors) + for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors): + args_kwargs_list[inp_tensor_idx] = inp_tensor + args_list = args_kwargs_list[: len(args_list)] + kwargs_list = args_kwargs_list[len(args_list) :] + args = tree_unflatten(args_list, args_spec) + kwargs = tree_unflatten(kwargs_list, kwargs_spec) + return args, kwargs + + def _register_state_dict_hooks(self) -> None: + num_pre_save_hooks = len(self._module_to_pre_save_state_dict_hook_handle) + num_pre_load_hooks = len(self._module_to_pre_load_state_dict_hook_handle) + if num_pre_save_hooks != num_pre_load_hooks: + raise AssertionError( + f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}" + ) + if num_pre_save_hooks > 0: + return # already registered + modules_with_fsdp_params: set[nn.Module] = { + fsdp_param._module_info.module for fsdp_param in self.fsdp_params + } + + def to_sharded_hook(*args: Any, **kwargs: Any) -> None: + self._to_sharded() + + for module in modules_with_fsdp_params: + self._module_to_pre_save_state_dict_hook_handle[module] = ( + module.register_state_dict_pre_hook(to_sharded_hook) + ) + self._module_to_pre_load_state_dict_hook_handle[module] = ( + module._register_load_state_dict_pre_hook(to_sharded_hook) + ) + + # Properties # + @property + def _reshard_after_forward(self) -> bool: + return self.post_forward_mesh_info is not None + + @property + def _use_post_forward_mesh(self) -> bool: + return ( + self._reshard_after_forward + and self.mesh_info != self.post_forward_mesh_info + ) + + @property + def _is_hsdp(self) -> bool: + return isinstance(self.mesh_info, HSDPMeshInfo) + + @property + def _all_gather_process_group(self) -> dist.ProcessGroup: + mesh_info = ( + cast(FSDPMeshInfo, self.post_forward_mesh_info) + if self.is_sharded_post_forward + else self.mesh_info + ) + if not isinstance(mesh_info, FSDPMeshInfo): + raise AssertionError( + f"Expected mesh_info to be FSDPMeshInfo, got {type(mesh_info)}" + ) + return mesh_info.shard_process_group + + @property + def _reduce_scatter_process_group(self) -> dist.ProcessGroup: + if not isinstance(self.mesh_info, FSDPMeshInfo): + raise AssertionError( + f"Expected mesh_info to be FSDPMeshInfo, got {type(self.mesh_info)}" + ) + return self.mesh_info.shard_process_group + + @property + def _all_reduce_process_group(self) -> dist.ProcessGroup: + if not isinstance(self.mesh_info, DDPMeshInfo): + raise AssertionError( + f"Expected mesh_info to be DDPMeshInfo or HSDPMeshInfo, got {type(self.mesh_info)}" + ) + return self.mesh_info.replicate_process_group + + def _with_fqn(self, label: str) -> str: + if self._module_fqn: + return f"{label} ({self._module_fqn})" + return label + + def __repr__(self): + return f"FSDPParamGroup(fqn={self._module_fqn})" + + def _validate_no_meta_params(self): + param_names_on_meta = [ + fsdp_param._param_fqn + for fsdp_param in self.fsdp_params + if fsdp_param.sharded_param.device.type == "meta" + ] + if param_names_on_meta: + raise RuntimeError( + "FSDP parameters should be materialized from meta device before training, " + f"but the following were still on meta device: {param_names_on_meta}\n" + "For example, call module.to_empty(device) to materialize to device and " + "call module.reset_parameters() on each module to initialize values." + ) + + def _validate_cpu_offload_params(self): + if not isinstance(self.offload_policy, CPUOffloadPolicy): + return + fsdp_params_not_on_cpu = [ + fsdp_param + for fsdp_param in self.fsdp_params + if fsdp_param.sharded_param.device.type != "cpu" + ] + if fsdp_params_not_on_cpu: + raise RuntimeError( + "FSDP parameters should be materialized on CPU when enabling CPU offloading. " + 'For example, load a CPU state dict or call module.to_empty(device="cpu"). ' + "Found following parameters on non-CPU device: " + f"{[(fsdp_param._param_fqn, fsdp_param.sharded_param.device) for fsdp_param in fsdp_params_not_on_cpu]}\n" + ) + + +def _get_param_module_infos( + params: list[nn.Parameter], modules: tuple[nn.Module, ...] +) -> list[ParamModuleInfo]: + """ + Shared parameter: lin1.weight = lin2.weight + Shared module: mlp.lin1 = mlp.lin2 + We do not remove duplicates when traversing both modules and parameters to + find shared modules' parameters and shared parameters within a module. + """ + params_set = set(params) + param_to_module_info: dict[nn.Parameter, ParamModuleInfo] = {} + for module in modules: + for _, submodule in module.named_modules(remove_duplicate=False): + for param_name, param in _named_parameters_with_duplicates( + submodule, recurse=False + ): + if param in params_set: + if param not in param_to_module_info: + param_to_module_info[param] = ParamModuleInfo( + submodule, param_name + ) + else: + param_to_module_info[param].shared_modules.append(submodule) + param_to_module_info[param].shared_param_names.append( + param_name + ) + if len(param_to_module_info) != len(params): + raise AssertionError(f"Some parameters are not in the module tree of {modules}") + return [param_to_module_info[param] for param in params] + + +class RegisterPostBackwardFunction(torch.autograd.Function): + @staticmethod + def _assert_not_tracing_fsdp(): + if compiled_autograd_enabled(): + # TODO: Find a way to print the offending FSDP2 module. + msg = """\ +When Traceable FSDP2 is enabled, we should not be calling into `RegisterPostBackwardFunction`. +Instead, we rely on the param group's next `pre_backward` hook to trigger its previously unexecuted +`post_backward`, and we rely on FSDPState's `root_post_backward_callback` to trigger the resharding +of any leftover unsharded param groups. +If you are here, it means the forward part of this FSDP2 instance is not compiled, and you must also +compile the forward part if you want to use Traceable FSDP2.""" + torch._dynamo.comptime.comptime.print(msg) + raise RuntimeError(msg) + + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, param_group: FSDPParamGroup, *inputs: torch.Tensor): + # All tensors in `inputs` should require gradient + RegisterPostBackwardFunction._assert_not_tracing_fsdp() + ctx.param_group = param_group + return inputs + + @staticmethod + def backward(ctx, *grads: torch.Tensor): + RegisterPostBackwardFunction._assert_not_tracing_fsdp() + ctx.param_group.post_backward() + return (None,) + grads diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py new file mode 100644 index 0000000000000000000000000000000000000000..d68dfbf2ddcb0faaf1888fc912ba09bc599e2c5c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py @@ -0,0 +1,408 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +import logging +from collections.abc import Callable, Sequence +from typing import Any, Optional, TYPE_CHECKING + +import torch +import torch.nn as nn +from torch._logging import warning_once +from torch.autograd import Variable +from torch.autograd.graph import _MultiHandle +from torch.distributed._composable_state import ( + _get_module_state, + _insert_module_state, + _State, +) +from torch.distributed.device_mesh import _get_device_handle +from torch.distributed.utils import _apply_to_tensors, _to_kwargs +from torch.utils._pytree import tree_flatten + +from ._fsdp_api import MixedPrecisionPolicy +from ._fsdp_common import ( + _cast_fp_tensor, + compiled_autograd_enabled, + detect_compiled_autograd, + TrainingState, +) +from ._fsdp_param_group import FSDPCommContext, FSDPParamGroup + + +if TYPE_CHECKING: + from ._fsdp_param import FSDPParam + + +logger = logging.getLogger("torch.distributed.fsdp.fully_shard") + + +class FSDPStateContext: + """This has state shared across FSDP states.""" + + def __init__(self) -> None: + # All FSDP states in the root state's module tree + self.all_states: list[FSDPState] = [] + # Iteration's forward root runs the once-per-forward logic; this root + # may not be the overall root set by lazy initialization in cases where + # only a submodule runs forward (e.g. encoder-only for eval) + self.iter_forward_root: Optional[FSDPState] = None + # Final callback should only be queued once per backward + self.post_backward_final_callback_queued: bool = False + # Whether to finalize backward in this backward's final callback + self.is_last_backward: bool = True + # Optional user-provided event recorded after optimizer for the + # all-gather streams to wait on in the root pre-forward + self.post_optim_event: Optional[torch.Event] = None + + +def disable_if_config_true(func): + @functools.wraps(func) + def fsdp_hook_wrapper(*args, **kwargs): + if torch._dynamo.config.skip_fsdp_hooks: + return torch._dynamo.disable( + func, + recursive=True, + reason="skipping FSDP hooks since torch._dynamo.config.skip_fsdp_hooks is set", + )(*args, **kwargs) + else: + return func(*args, **kwargs) + + return fsdp_hook_wrapper + + +class FSDPState(_State): + def __init__(self) -> None: + super().__init__() + self._fsdp_param_group: Optional[FSDPParamGroup] = None + self._is_root: Optional[bool] = None # root set during lazy init + self._state_ctx = FSDPStateContext() + self._comm_ctx = FSDPCommContext() + self._training_state: TrainingState = TrainingState.IDLE + self._states_to_forward_prefetch: list[FSDPState] = [] + self._states_to_backward_prefetch: list[FSDPState] = [] + self._modules_to_run_forward: set[nn.Module] = set() + # ``False`` when user set reshard_after_forward + # through ``fully_shard`` or ``set_reshard_after_forward`` + self._auto_reshard_after_forward: Optional[bool] = True + + # Define a separate init since `__init__` is called in the contract + def init( + self, + modules: tuple[nn.Module, ...], + device: torch.device, + mp_policy: MixedPrecisionPolicy, + auto_reshard_after_forward: bool, + ) -> None: + for module in modules: + _insert_module_state(module, self) + self._modules = modules + # pyrefly: ignore [read-only] + self._device = device + self._device_handle = _get_device_handle(device.type) + self._mp_policy = mp_policy + self._auto_reshard_after_forward = auto_reshard_after_forward + if len(modules) == 1: + self._pre_forward_hook_handle = modules[0].register_forward_pre_hook( + self._pre_forward, prepend=True, with_kwargs=True + ) + self._post_forward_hook_handle = modules[0].register_forward_hook( + self._post_forward, prepend=False + ) + else: + hook_handle = _register_group_forward_hooks( + modules, + self._pre_forward, + self._post_forward, + self._modules_to_run_forward, + ) + self._pre_forward_hook_handle = hook_handle + self._post_forward_hook_handle = hook_handle + + def _root_pre_forward( + self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[tuple[Any, ...], dict[str, Any]]: + self._lazy_init() + if self._state_ctx.iter_forward_root is not None: + return args, kwargs + if not compiled_autograd_enabled(): + logger.debug("FSDP::root_pre_forward") + self._state_ctx.iter_forward_root = self + with torch.profiler.record_function("FSDP::root_pre_forward"): + # Wait for optimizer before implicitly prefetched all-gathers + if (event := self._state_ctx.post_optim_event) is not None: + self._comm_ctx.all_gather_copy_in_stream.wait_event(event) + self._comm_ctx.all_gather_stream.wait_event(event) + self._state_ctx.post_optim_event = None + else: + current_stream = self._device_handle.current_stream() + self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream) + self._comm_ctx.all_gather_stream.wait_stream(current_stream) + if self._device.type in [ + "cuda", + "hpu", + "xpu", + "mtia", + torch._C._get_privateuse1_backend_name(), + ]: + with torch.profiler.record_function("FSDP::inputs_to_device"): + args_tuple, kwargs_tuple = _to_kwargs( + args, kwargs, self._device, False + ) # same as DDP + args, kwargs = args_tuple[0], kwargs_tuple[0] + return args, kwargs + + def _lazy_init(self) -> None: + """ + Lazy initialization represents when all modules' parallelisms have + finalized (e.g. FSDP has been applied to all desired modules). This + means that we can determine which state is the root, and we do so by + the 1st state to run forward. + """ + if self._is_root is not None: + return # no-op: already initialized + self._is_root = True + if len(self._modules) > 1: + raise RuntimeError( + f"FSDP requires a single root module but got {self._modules}" + ) + detect_compiled_autograd() + root_module = self._modules[0] + visited_states: set[FSDPState] = set() + for module_name, module in root_module.named_modules(): + if (state := _get_module_fsdp_state(module)) is None: + continue + if module is not root_module: + if state not in visited_states and state._is_root is not None: + raise RuntimeError( + "FSDP state has already been lazily initialized for " + f"{module_name}\nFSDP requires running forward through " + "the root module first" + ) + state._is_root = False + self._state_ctx.all_states.append(state) + visited_states.add(state) + if self._fsdp_param_group and self._auto_reshard_after_forward: + # For the root, do not reshard after forward since for training, + # the parameters would be freed and all-gathered immediately + self._fsdp_param_group.post_forward_mesh_info = None + self._init_fqns() + self._init_shared_state() + # Run parameter group lazy inits after initializing FQNs for improved + # error messages + for state in self._state_ctx.all_states: + if state._fsdp_param_group: + state._fsdp_param_group.lazy_init() + + def _init_shared_state(self) -> None: + self._comm_ctx.lazy_init(self._device) + for state in self._state_ctx.all_states: + state._state_ctx = self._state_ctx + state._comm_ctx = self._comm_ctx + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.comm_ctx = self._comm_ctx + + def _init_fqns(self) -> None: + """Sets module and parameter FQN attributes for debugging.""" + if not self._is_root: + raise AssertionError("Expected _is_root to be True") + root_module = self._modules[0] + param_to_fsdp_param: dict[nn.Parameter, FSDPParam] = {} + module_to_fsdp_param_group: dict[nn.Module, FSDPParamGroup] = {} + for state in self._state_ctx.all_states: + if fsdp_param_group := state._fsdp_param_group: + for fsdp_param in fsdp_param_group.fsdp_params: + param_to_fsdp_param[fsdp_param.sharded_param] = fsdp_param + for module in fsdp_param_group.modules: + module_to_fsdp_param_group[module] = fsdp_param_group + for param_name, param in root_module.named_parameters(): + if param in param_to_fsdp_param: + param_to_fsdp_param[param]._param_fqn = param_name + for module_name, module in root_module.named_modules(): + if module in module_to_fsdp_param_group: + module_fqn = module_to_fsdp_param_group[module]._module_fqn + if module_fqn is None: + module_to_fsdp_param_group[module]._module_fqn = module_name + else: + if not isinstance(module_fqn, str): + raise AssertionError( + f"Expected module_fqn to be str, got {type(module_fqn)}: {module_fqn}" + ) + module_fqn += f", {module_name}" + module_to_fsdp_param_group[module]._module_fqn = module_fqn + + @disable_if_config_true + def _pre_forward( + self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[tuple[Any, ...], dict[str, Any]]: + # When composing with module-hook-based activation checkpointing, the + # pre-backward hook is responsible for the unshard + if self._training_state == TrainingState.PRE_BACKWARD: + return args, kwargs + self._training_state = TrainingState.FORWARD + args, kwargs = self._root_pre_forward(module, args, kwargs) + if self._mp_policy.cast_forward_inputs and self._mp_policy.param_dtype: + with torch.profiler.record_function("FSDP::cast_forward_inputs"): + cast_fn = functools.partial( + _cast_fp_tensor, self._mp_policy.param_dtype + ) + args, kwargs = ( + _apply_to_tensors(cast_fn, args), + _apply_to_tensors(cast_fn, kwargs), + ) + if self._fsdp_param_group: + args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs) + for fsdp_state in self._states_to_forward_prefetch: + if (target_param_group := fsdp_state._fsdp_param_group) is not None: + FSDPParamGroup._prefetch_unshard(target_param_group, "forward") + return args, kwargs + + @disable_if_config_true + def _post_forward(self, module: nn.Module, input: Any, output: Any) -> Any: + # When composing with module-hook-based activation checkpointing, the + # post-backward hook is responsible for the reshard + if self._training_state == TrainingState.PRE_BACKWARD: + return output + if self._fsdp_param_group: + output = self._fsdp_param_group.post_forward(module, input, output) + output = self._register_pre_backward_hook(output) + self._training_state = TrainingState.IDLE + if self._state_ctx.iter_forward_root is self: + if all_gather_state := self._comm_ctx.all_gather_state: + # Free the last all-gather result if needed; refer to + # [Note: Overlapping all-gather copy-in and all-gather] + self._comm_ctx.all_gather_copy_in_stream.wait_event( + all_gather_state.event + ) + self._comm_ctx.all_gather_stream.wait_event(all_gather_state.event) + self._comm_ctx.all_gather_state = None # free the all-gather result + self._state_ctx.iter_forward_root = None + if self._mp_policy.output_dtype is not None: + with torch.profiler.record_function("FSDP::cast_forward_outputs"): + output = _apply_to_tensors( + functools.partial(_cast_fp_tensor, self._mp_policy.output_dtype), + output, + ) + return output + + def _pre_backward(self, grad: torch.Tensor) -> torch.Tensor: + self._training_state = TrainingState.PRE_BACKWARD + self._register_root_post_backward_final_callback() + if self._fsdp_param_group: + default_prefetch = len(self._states_to_backward_prefetch) == 0 + self._fsdp_param_group.pre_backward(default_prefetch) + for fsdp_state in self._states_to_backward_prefetch: + if (target_param_group := fsdp_state._fsdp_param_group) is not None: + FSDPParamGroup._prefetch_unshard(target_param_group, "backward") + return grad + + def _root_post_backward_final_callback(self) -> None: + if not compiled_autograd_enabled(): + logger.debug("FSDP::root_post_backward") + with torch.profiler.record_function("FSDP::root_post_backward_callback"): + for state in self._state_ctx.all_states: + fsdp_param_group = state._fsdp_param_group + if ( + fsdp_param_group + and fsdp_param_group._training_state != TrainingState.POST_BACKWARD + ): + # Run post-backward in case forward inputs did not require + # gradient so the autograd backward did not run + fsdp_param_group.post_backward() + state._training_state = TrainingState.IDLE + if fsdp_param_group: + fsdp_param_group._training_state = TrainingState.IDLE + if self._state_ctx.is_last_backward: + state._finalize_backward() + if self._state_ctx.is_last_backward: + self._comm_ctx.post_forward_order.clear() + if self._comm_ctx.reduce_scatter_state is not None: + self._device_handle.current_stream().wait_event( + self._comm_ctx.reduce_scatter_state.event + ) + self._comm_ctx.reduce_scatter_state = None + self._state_ctx.post_backward_final_callback_queued = False + + def _finalize_backward(self) -> None: + if self._modules_to_run_forward: + msg = ( + f"{len(self._modules_to_run_forward)} of the {len(self._modules)} " + f"modules passed to fully_shard did not run forward before backward, " + "which is error-prone since FSDP post-forward/pre-backward logic " + "will not run for these modules. We recommend passing only modules " + "that run forward together. Modules that did not run forward: " + f"{list(self._modules_to_run_forward)}" + ) + warning_once(logger, msg, stacklevel=2) + # Clear since we want the next forward to run + self._modules_to_run_forward.clear() + if self._fsdp_param_group: + self._fsdp_param_group.finalize_backward() + + def _register_pre_backward_hook(self, output: Any) -> Any: + if not torch.is_grad_enabled(): + return output + flat_outputs, _ = tree_flatten(output) + for t in flat_outputs: + if torch.is_tensor(t) and t.requires_grad: + t.register_hook(self._pre_backward) + return output + + def _register_root_post_backward_final_callback(self): + if self._state_ctx.post_backward_final_callback_queued: + return + self._state_ctx.post_backward_final_callback_queued = True + Variable._execution_engine.queue_callback( + self._root_post_backward_final_callback + ) + + +def _get_module_fsdp_state(module: nn.Module) -> Optional[FSDPState]: + state = _get_module_state(module) + if isinstance(state, FSDPState): + return state + return None + + +def _register_group_forward_hooks( + modules: Sequence[nn.Module], + pre_hook: Callable, + post_hook: Callable, + modules_to_run: set[nn.Module], +): + """ + Registers group forward pre and post-hooks. The pre-hook runs upon the + first module pre-forward, and the post-hook runs upon the last. If at least + one module does not run forward, then the post-hook does not run. + """ + modules_set = set(modules) + + @disable_if_config_true + @functools.wraps(pre_hook) + def wrapped_pre_hook(*args: Any, **kwargs: Any): + if len(modules_to_run) == 0: # first to run + modules_to_run.update(modules_set) + return pre_hook(*args, **kwargs) + + @disable_if_config_true + def get_wrapped_post_hook(module: nn.Module): + @functools.wraps(post_hook) + def wrapped_post_hook(*args: Any, **kwargs: Any): + modules_to_run.discard(module) + if len(modules_to_run) == 0: + return post_hook(*args, **kwargs) + + return wrapped_post_hook + + pre_handles = [ + module.register_forward_pre_hook( + wrapped_pre_hook, prepend=True, with_kwargs=True + ) + for module in modules + ] + post_handles = [ + module.register_forward_hook( + get_wrapped_post_hook(module), prepend=False, always_call=True + ) + for module in modules + ] + return _MultiHandle(tuple(pre_handles + post_handles)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py new file mode 100644 index 0000000000000000000000000000000000000000..998a33746f961fbf65f43b2c4245a6f12a9d3893 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py @@ -0,0 +1,746 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs + +from __future__ import annotations + +import functools +from contextlib import contextmanager +from typing import Any, cast, NoReturn, Optional, overload, TYPE_CHECKING, Union +from typing_extensions import deprecated + +import torch +import torch.nn as nn +from torch.distributed._composable import contract +from torch.distributed.utils import _get_root_modules + +from ._fsdp_api import AllGather, MixedPrecisionPolicy, OffloadPolicy, ReduceScatter +from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo +from ._fsdp_init import ( + _get_device_from_mesh, + _get_managed_modules, + _get_managed_states, + _get_post_forward_mesh_info, + _init_default_fully_shard_mesh, + _move_states_to_device, +) +from ._fsdp_param_group import FSDPParamGroup +from ._fsdp_state import _get_module_fsdp_state, FSDPState + + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Iterator + + from torch.distributed.tensor import DeviceMesh, Shard + +__all__ = [ + "fully_shard", + "FSDPModule", + "UnshardHandle", + "register_fsdp_forward_method", + "get_cls_to_fsdp_cls", + "disable_fsdp_module_new_init", + "share_comm_ctx", +] + + +cls_to_fsdp_cls: dict[type, type] = {} + + +def get_cls_to_fsdp_cls() -> dict[type, type]: + return cls_to_fsdp_cls + + +@overload +# pyrefly: ignore [inconsistent-overload] +def fully_shard( + module: nn.Module, + *, + mesh: Optional[DeviceMesh] = ..., + reshard_after_forward: Union[bool, int] = ..., + shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = ..., + mp_policy: MixedPrecisionPolicy = ..., + offload_policy: OffloadPolicy = ..., + ignored_params: Optional[set[nn.Parameter]] = ..., +) -> FSDPModule: ... + + +@overload +# pyrefly: ignore [inconsistent-overload] +def fully_shard( + module: list[nn.Module], + *, + mesh: Optional[DeviceMesh] = ..., + reshard_after_forward: Union[bool, int] = ..., + shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = ..., + mp_policy: MixedPrecisionPolicy = ..., + offload_policy: OffloadPolicy = ..., + ignored_params: Optional[set[nn.Parameter]] = ..., +) -> list[FSDPModule]: ... + + +# The decorator adds a state object to `module` that can be accessed via +# `fully_shard.state(module)`. The state object and module are 1:1. +# [1] Python runtime decorator does not play well with static type checking +# so suppressing some type checks to support type overloads +# such that caller can still get correct return types based on input type +@contract(state_cls=FSDPState) # type: ignore[misc] # see [1] +def fully_shard( + module, + *, + mesh: Optional[DeviceMesh] = None, + reshard_after_forward: Optional[Union[bool, int]] = None, + shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = None, + mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(), + offload_policy: OffloadPolicy = OffloadPolicy(), + ignored_params: Optional[set[nn.Parameter]] = None, +): + """ + Apply fully sharded data parallelism (FSDP) to ``module``, where FSDP + shards module parameters, gradients, and optimizer states across data + parallel workers to save memory at the cost of communication. + + At initialization, FSDP shards the module's parameters across the data + parallel workers given by ``mesh``. Before forward, FSDP all-gathers the + sharded parameters across the data-parallel workers to get the unsharded + parameters for forward computation. If ``reshard_after_forward`` is + ``True``, then FSDP frees the unsharded parameters after forward and + re-all-gathers them in backward before gradient computation. After gradient + computation, FSDP frees the unsharded parameters and reduce-scatters the + unsharded gradients across data-parallel workers. + + This implementation represents the sharded parameters as :class:`DTensor` s + sharded on dim-0, while the unsharded parameters will be like the original + parameters on ``module`` (e.g. :class:`torch.Tensor` if originally + :class:`torch.Tensor`). A module + `forward pre-hook `_ + on ``module`` all-gathers the parameters, and a module + `forward hook `_ + on ``module`` frees them (if needed). Similar backward hooks all-gather + parameters and later free parameters and reduce-scatter gradients. + + Since grouping multiple tensors together for one collective is critical for + communication efficiency, this implementation makes this grouping first + class. Calling :meth:`fully_shard` on ``module`` constructs one group that + includes the parameters in ``module.parameters()`` except those already + assigned to a group from an earlier call on a submodule. This means that + :meth:`fully_shard` should be called bottom-up on your model. Each group's + parameters are all-gathered in one collective, and its gradients are + reduce-scattered in one collective. Partitioning the model into multiple + groups ("layer by layer") allows for peak memory savings and communication/computation + overlap. Users generally should *not* call :meth:`fully_shard` only on the + topmost root module. + + Args: + module (Union[nn.Module, List[nn.Module]): The module or modules to + shard with FSDP and group together for communication. + mesh (Optional[DeviceMesh]): This data parallel mesh defines the + sharding and device. If 1D, then parameters are fully sharded + across the 1D mesh (FSDP) with ``(Shard(0),)`` placement. If 2D, + then parameters are sharded across the 1st dim and replicated + across the 0th dim (HSDP) with ``(Replicate(), Shard(0))`` + placement. The mesh's device type gives the device type used for + communication; if a CUDA or CUDA-like device type, then we use the + current device. + reshard_after_forward (Optional[Union[bool, int]]): This controls the parameter + behavior after forward and can trade off memory and communication: + + - If ``True``, then this reshards parameters after forward and + re-all-gathers in backward. + - If ``False``, then this keeps the unsharded parameters in memory + after forward and avoids the all-gather in backward. For best performance, + we usually set ``False`` for the root module, because the root module + is typically required immediately when the backward pass begins. + - If ``None``, it is set to ``True`` for non-root modules and ``False`` + for root modules. + - If an ``int``, then this represents the world size to reshard to + after forward. It should be a non-trivial divisor of the ``mesh`` + shard dim size (i.e. excluding 1 and the dim size itself). A + choice may be the intra-node size (e.g. ``torch.cuda.device_count()``). + This allows the all-gather in backward to be over a smaller world + size at the cost of higher memory usage than setting to ``True``. + - After forward, the parameters registered to the module depend on + to this: The registered parameters are the sharded parameters if + ``True``; unsharded parameters if ``False``; and the parameters + resharded to the smaller mesh otherwise. To modify the parameters + between forward and backward, the registered parameters must be + the sharded parameters. For ``False`` or an ``int``, this can be + done by manually resharding via :meth:`reshard`. + shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]): + This callable can be used to override the sharding placement for a + parameter to shard a parameter on a dimension other than dim-0. If + this callable returns a :class:`Shard` placement (not ``None``), + then FSDP will shard according to that placement (e.g. ``Shard(1)``). + If sharding on a nonzero dim, we currently require even sharding, + i.e. the tensor dim size on that dim must be divisible by the FSDP + shard mesh size. + mp_policy (MixedPrecisionPolicy): This controls the mixed precision + policy, which offers parameter/reduction mixed precision for this + module. See :class:`MixedPrecisionPolicy` for details. + offload_policy (OffloadPolicy): This controls the offloading policy, + which offers parameter/gradient/optimizer state offloading. See + :class:`OffloadPolicy` and its subclasses for details. + ignored_params: Optional(Set[nn.Parameter]): The set of parameters to be + ignored by FSDP. They will not be sharded, nor moved to the device + during init, nor have their gradients reduced in backward. + + Returns: + FSDPModule: The module with FSDP applied (in-place). + """ + torch._C._log_api_usage_once("torch.distributed.fsdp.fully_shard") + if isinstance(module, (nn.ModuleList, nn.ModuleDict)): + raise ValueError( + f"fully_shard does not support containers that do not implement forward: {module}" + ) + mesh = mesh or _init_default_fully_shard_mesh() + if mesh.ndim not in (1, 2): + raise ValueError(f"fully_shard expects a 1D or 2D DeviceMesh but got {mesh}") + elif mesh.ndim == 1: + mesh_info = FSDPMeshInfo(mesh, shard_mesh_dim=0) + else: + if mesh.mesh_dim_names is None: + raise AssertionError( + "Please init the 2D mesh for HSDP with mesh_dim_names specified" + ) + mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0) + device = _get_device_from_mesh(mesh) + auto_reshard_after_forward = reshard_after_forward is None + # If the user does not provide ``reshard_after_forward``, we set it to True. + # During lazy_init, we identify which module is the root and override its value to False + post_forward_mesh_info = _get_post_forward_mesh_info( + reshard_after_forward if not auto_reshard_after_forward else True, # type: ignore[arg-type] + mesh_info, + ) + + arg_module = module + modules = ( + (module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module)) + ) + state = fully_shard.state(modules[0]) # type: ignore[attr-defined] # see [1] + state.init(modules, device, mp_policy, auto_reshard_after_forward) + + managed_modules = _get_managed_modules(modules, ignored_params) + params, buffers = _get_managed_states(managed_modules, ignored_params) + + _move_states_to_device(params, buffers, device) + if params: + state._fsdp_param_group = FSDPParamGroup( + params, + modules, + mesh_info, + post_forward_mesh_info, + device, + shard_placement_fn, + mp_policy, + offload_policy, + ) + + # For Dynamo + for managed_module in managed_modules: + managed_module._is_fsdp_managed_module = True # type: ignore[assignment] + managed_module._fsdp_use_orig_params = True # type: ignore[assignment] + + # Place FSDP leftmost for highest priority in the method resolution order + for module in modules: + cls = module.__class__ + new_cls = cls_to_fsdp_cls.get(cls) + if not new_cls: + dct = {"__deepcopy__": _unimplemented_deepcopy} + new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct) + cls_to_fsdp_cls[cls] = new_cls + module.__class__ = new_cls + return arg_module + + +def _unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn: + raise AssertionError( + "FSDP does not support deepcopy. Please use state dict for serialization." + ) + + +_enable_fsdp_module_new_init: bool = True + + +@contextmanager +def disable_fsdp_module_new_init() -> Iterator[None]: + global _enable_fsdp_module_new_init + prev, _enable_fsdp_module_new_init = _enable_fsdp_module_new_init, False + try: + yield + finally: + _enable_fsdp_module_new_init = prev + + +class FSDPModule: + def __new__(cls, *args, **kwargs): + """ + Override ``__new__`` to remove the FSDP class and directly construct + the original class for cases like indexing into a container module. + """ + # Use index 2 since 0 is the dynamically constructed `FSDP<...>` class + # and index 1 is the `FSDPModule` class itself + orig_cls = cls.__mro__[2] + self = orig_cls.__new__(orig_cls, *args, **kwargs) + if _enable_fsdp_module_new_init: + self.__init__(*args, **kwargs) + return self + + def reshard(self) -> None: + """ + Reshards the module's parameters, freeing the unsharded parameters if + they are allocated and registering the sharded parameters to the + module. This method is *not* recursive. + """ + state = self._get_fsdp_state() + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.reshard() + + def unshard(self, async_op: bool = False) -> Optional[UnshardHandle]: + """ + Unshards the module's parameters by allocating memory and all-gathering + the parameters. This method is *not* recursive. The unshard follows the + :class:`MixedPrecisionPolicy`, so it will all-gather following + ``param_dtype`` if set. + + Args: + async_op (bool): If ``True``, then returns a :class:`UnshardHandle` + that has a :meth:`wait` method to wait on the unshard op. If + ``False``, then returns ``None`` and waits on the handle inside + this function. + + .. note:: If ``async_op=True``, then FSDP will wait on the pending + unshard in the module's pre-forward for the user. The user only + needs to call :meth:`wait` explicitly if the wait should happen + before pre-forward. + """ + state = self._get_fsdp_state() + fsdp_param_group = state._fsdp_param_group + if fsdp_param_group is not None: + fsdp_param_group.lazy_init() + fsdp_param_group.unshard(async_op=async_op) + handle = _UnshardHandleImpl(fsdp_param_group) + if async_op: + return handle + handle.wait() + return None + + def set_is_last_backward(self, is_last_backward: bool) -> None: + """ + Sets whether the next backward is the last one. On the last backward, + FSDP waits on pending gradient reduction and clears internal data + data structures for backward prefetching. This can be useful for + microbatching. + """ + state = self._get_fsdp_state() + state._state_ctx.is_last_backward = is_last_backward + + def set_requires_gradient_sync( + self, requires_gradient_sync: bool, *, recurse: bool = True + ) -> None: + """ + Sets if the module should sync gradients. This can be used to implement + gradient accumulation *without communication*. For HSDP, this controls + both reduce-scatter and all-reduce together. This is the equivalence of + `no_sync` in FSDP1. + + Args: + requires_gradient_sync (bool): Whether to reduce gradients for the + module's parameters. + recurse (bool): Whether to set for all FSDP submodules or just the + passed-in module. + """ + self_module = cast(nn.Module, self) + modules = list(self_module.modules()) if recurse else [self_module] + for module in modules: + if isinstance(module, FSDPModule): + state = module._get_fsdp_state() + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.reduce_grads = requires_gradient_sync + fsdp_param_group.all_reduce_grads = requires_gradient_sync + + def set_requires_all_reduce( + self, requires_all_reduce: bool, *, recurse: bool = True + ) -> None: + """ + Sets if the module should all-reduce gradients. This can be used to + implement gradient accumulation with only reduce-scatter but not + all-reduce for HSDP. + """ + self_module = cast(nn.Module, self) + modules = list(self_module.modules()) if recurse else [self_module] + for module in modules: + if isinstance(module, FSDPModule): + state = module._get_fsdp_state() + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.all_reduce_grads = requires_all_reduce + + def set_reshard_after_forward( + self, reshard_after_forward: bool, recurse: bool = True + ) -> None: + """ + Sets if the module should reshard parameters after forward. This can be + used to change the ``reshard_after_forward`` FSDP arg at runtime. For + example, this can be used to set the FSDP root module's value to + ``True`` (since it is otherwise specially set to ``False``), or it can + set an FSDP module's value to ``False`` for running evals and set back + to ``True`` for training. + + Args: + reshard_after_forward (bool): Whether to reshard parameters after + forward. + recurse (bool): Whether to set for all FSDP submodules or just the + passed-in module. + """ + if not isinstance(reshard_after_forward, bool): + raise ValueError( + f"reshard_after_forward should be a bool, got {type(reshard_after_forward)}" + ) + self_module = cast(nn.Module, self) + modules = list(self_module.modules()) if recurse else [self_module] + for module in modules: + if isinstance(module, FSDPModule): + state = module._get_fsdp_state() + state._auto_reshard_after_forward = False + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.post_forward_mesh_info = ( + _get_post_forward_mesh_info( + reshard_after_forward, fsdp_param_group.mesh_info + ) + ) + + def set_reshard_after_backward( + self, reshard_after_backward: bool, *, recurse: bool = True + ) -> None: + """ + Sets if the module should reshard parameters after backward. This can + be used during gradient accumulation to trade off higher memory for + reduced communication since the unsharded parameters do not need to be + re-all-gathered before the next forward. + + Args: + reshard_after_backward (bool): Whether to reshard parameters after + backward. + recurse (bool): Whether to set for all FSDP submodules or just the + passed-in module. + """ + self_module = cast(nn.Module, self) + modules = list(self_module.modules()) if recurse else [self_module] + for module in modules: + if isinstance(module, FSDPModule): + state = module._get_fsdp_state() + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.reshard_after_backward = reshard_after_backward + + def set_modules_to_forward_prefetch(self, modules: list[FSDPModule]) -> None: + """ + Sets the FSDP modules for which this FSDP module should explicitly + prefetch all-gathers in forward. The prefetching runs after this + module's all-gather copy-out. + + Passing a singleton list containing the next FSDP module gives the same + all-gather overlap behavior as the default overlap behavior, except the + prefetched all-gather is issued earlier from the CPU. Passing a list + with at least length two is required for more aggressive overlap and + will use more reserved memory. + + Args: + modules (List[FSDPModule]): FSDP modules to prefetch. + """ + _assert_all_fsdp_modules(modules) + self._get_fsdp_state()._states_to_forward_prefetch = [ + module._get_fsdp_state() for module in modules + ] + + def set_modules_to_backward_prefetch(self, modules: list[FSDPModule]) -> None: + """ + Sets the FSDP modules for which this FSDP module should explicitly + prefetch all-gathers in backward. This overrides the default backward + pretching implementation that prefetches the next FSDP module based on + the reverse post-forward order. + + Passing a singleton list containing the previous FSDP module gives the + same all-gather overlap behavior as the default overlap behavior. + Passing a list with at least length two is required for more aggressive + overlap and will use more reserved memory. + + Args: + modules (List[FSDPModule]): FSDP modules to prefetch. + """ + _assert_all_fsdp_modules(modules) + self._get_fsdp_state()._states_to_backward_prefetch = [ + module._get_fsdp_state() for module in modules + ] + + def set_custom_all_gather(self, comm: AllGather) -> None: + """ + Overrides the default ``all_gather`` communication behavior, + to have better control over the communication and memory usage. + See `Comm` and `ReduceScatter` for details. + + Args: + comm (AllGather): Custom all-gather communication. + """ + state = self._get_fsdp_state() + if (fsdp_param_group := state._fsdp_param_group) is not None: + fsdp_param_group._all_gather_comm = comm + + def set_custom_reduce_scatter(self, comm: ReduceScatter) -> None: + """ + Overrides the default ``reduce_scatter`` communication behavior, + to have better control over the communication and memory usage. + See `Comm` and `ReduceScatter` for details. + + Args: + comm (ReduceScatter): Custom reduce_scatter communication. + """ + state = self._get_fsdp_state() + if (fsdp_param_group := state._fsdp_param_group) is not None: + fsdp_param_group._reduce_scatter_comm = comm + + def set_all_reduce_hook( + self, + hook: Callable[[torch.Tensor], None], + *, + stream: Optional[torch.cuda.Stream] = None, + ): + """ + Args: + hook (Callable[[torch.Tensor], None]): User-defined all-reduce hook + with expected signature ``hook(reduce_output: torch.Tensor) -> None`` + where ``reduce_output`` is the reduce-scatter output if only + using FSDP or the all-reduce output if using native HSDP. + stream (Optional[torch.cuda.Stream]): Stream to run the all-reduce + hook in. This should only be set if not using native HSDP. If + using native HSDP, the hook will run in the internally defined + all-reduce stream used by the native HSDP all-reduce. + """ + state = self._get_fsdp_state() + if (fsdp_param_group := state._fsdp_param_group) is not None: + fsdp_param_group._all_reduce_hook = hook + if stream is not None: + if fsdp_param_group._is_hsdp: + raise ValueError("stream cannot be set when using native HSDP") + fsdp_param_group._all_reduce_hook_stream = stream + + def set_post_optim_event(self, event: torch.Event) -> None: + """ + Sets a post-optimizer-step event for the root FSDP module to wait the + all-gather streams on. + + By default, the root FSDP module waits the all-gather streams on the + current stream to ensure that the optimizer step has finished before + all-gathering. However, this may introduce false dependencies if + there is unrelated computation after the optimizer step. This API + allows the user to provide their own event to wait on. After the root + waits on the event, the event is discarded, so this API should be + called with a new event each iteration. + + Args: + event (torch.Event): Event recorded after the optimizer step + to wait all-gather streams on. + """ + self._get_fsdp_state()._state_ctx.post_optim_event = event + + @deprecated("Use `set_gradient_divide_factor` instead") + def set_reduce_scatter_divide_factor(self, factor: float) -> None: + """Use :py:meth:`set_gradient_divide_factor` instead""" + self.set_gradient_divide_factor(factor) + + def set_gradient_divide_factor(self, factor: float) -> None: + """ + Sets a custom divide factor for the gradient reduction. This might use + a custom reduce op using NCCL's PreMulSum, which allows multiplying by + the factor before reduction. + + Args: + factor (float): Custom divide factor. + """ + state = self._get_fsdp_state() + if (fsdp_param_group := state._fsdp_param_group) is not None: + fsdp_param_group.gradient_divide_factor = factor + + def set_force_sum_reduction_for_comms(self, enable: bool) -> None: + """ + Sets whether to require the low-level collective communication + primitives to exclusively use "sum"-type reductions, even if it comes + at the cost of separate additional pre- or post-scaling operations. + This is needed for example because NCCL currently supports zero-copy + transfers only for this kind of collectives. + + NB: for MTIA devices, this is always implicitly enabled. + + NB: if `set_all_reduce_hook` is used under FSDP setup, the caller needs + to ensure the custom all-reduce across FSDP units follow this strategy + as well, as FSDP can no longer automatically handle that. + + Args: + enable (bool): Whether to only ever use ReduceOp.SUM for comms. + """ + state = self._get_fsdp_state() + if (fsdp_param_group := state._fsdp_param_group) is not None: + fsdp_param_group.force_sum_reduction_for_comms = enable + + def set_unshard_in_backward(self, unshard_in_backward: bool) -> None: + """ + Sets whether the FSDP module's parameters need to be unsharded in + backward. This can be used in expert cases when the user knows that all + parameters in this FSDP module's parameter group are not needed for + backward computation (e.g. embedding). + """ + state = self._get_fsdp_state() + if (fsdp_param_group := state._fsdp_param_group) is not None: + fsdp_param_group.unshard_in_backward = unshard_in_backward + + def set_allocate_memory_from_process_group_for_comm(self, enable: bool) -> None: + """ + Sets whether the temporary staging buffers used to send and receive data + over collective communications should be allocated using the custom + optimized allocator provided by the ProcessGroup itself (if any). This + might allow the ProcessGroup to be more efficient. For example, when + using NCCL, this enables it to leverage zero-copy transfers over SHARP + (for NVLink and/or InfiniBand). + + This cannot be used together with :meth:`set_custom_all_gather` or + :meth:`set_custom_reduce_scatter` as those APIs allow for + finer-grained control over each communication, and this method cannot + determine their staging buffer allocation strategy. + + Args: + enable (bool): Whether to turn on ProcessGroup allocation. + """ + state = self._get_fsdp_state() + if (fsdp_param_group := state._fsdp_param_group) is not None: + fsdp_param_group.set_allocate_memory_from_process_group(enable) + + def _set_unshard_async_op(self, async_op: bool): + """ + Sets whether to use ``async_op=True`` or ``False`` for the pre-forward + and pre-backward unshard op. This defaults to ``False`` but can be set + to ``True`` with this method. + + Setting this to ``True`` allows the all-gather allocations to happen in + the default stream, avoiding inter-stream memory fragmentation. + However, you must use explicit prefetching (e.g. via :meth:`unshard`) + in forward to still get overlap, and the pre-all-gather ops like dtype + casting and copy-in will not overlap with compute. + """ + self_module = cast(nn.Module, self) + for module in self_module.modules(): + if isinstance(module, FSDPModule): + state = module._get_fsdp_state() + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.unshard_async_op = async_op + + def _get_fsdp_state(self) -> FSDPState: + if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None: + raise AssertionError(f"No FSDP state found on {self}") + return state + + def _apply(self, *args: Any, **kwargs: Any) -> Any: + # Reshard to ensure that sharded parameters are registered + self.reshard() + ret = super()._apply(*args, **kwargs) # type: ignore[misc] + state = self._get_fsdp_state() + if not (fsdp_param_group := state._fsdp_param_group): + return ret + # TODO: Remove this padding logic once DTensor pads the local tensor: + # https://github.com/pytorch/pytorch/issues/113045 + with torch.no_grad(): + for fsdp_param in fsdp_param_group.fsdp_params: + fsdp_param.reset_sharded_param() + return ret + + +class UnshardHandle: + """ + A handle to wait on a :meth:`FSDPModule.unshard` op. + """ + + def wait(self) -> None: + """ + Waits on the unshard op. This ensures that the current stream can use + the unsharded parameters, which are now registered to the module. + """ + return + + +class _UnshardHandleImpl(UnshardHandle): + def __init__(self, fsdp_param_group: Optional[FSDPParamGroup]): + self._fsdp_param_group = fsdp_param_group + + def wait(self): + if self._fsdp_param_group is not None: + self._fsdp_param_group.wait_for_unshard() + # Avoid keeping a reference + self._fsdp_param_group = None + + +def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None: + """ + Registers a method on ``module`` to be considered a forward method for + FSDP. + + FSDP all-gathers parameters pre-forward and optionally frees parameters + post-forward (depending on ``reshard_after_forward``). FSDP only knows to + do this for :meth:`nn.Module.forward` by default. This function patches a + user-specified method to run the pre/post-forward hooks before/after the + method, respectively. If ``module`` is not an :class:`FSDPModule`, then + this is a no-op. + + Args: + module (nn.Module): Module to register the forward method on. + method_name (str): Name of the forward method. + """ + if not isinstance(module, FSDPModule): + # Make no-op to allow including both when using/not using FSDP + return + if not hasattr(module, method_name): + raise ValueError(f"{type(module)} does not have a method {method_name}") + orig_method = getattr(module, method_name) + + @functools.wraps(orig_method) + def wrapped_method(self, *args, **kwargs): + fsdp_state = self._get_fsdp_state() + args, kwargs = fsdp_state._pre_forward(self, args, kwargs) + out = orig_method(*args, **kwargs) + return fsdp_state._post_forward(self, args, out) + + # Use `__get__` to make `wrapped_method` an instance method + setattr( + module, + method_name, + wrapped_method.__get__(module, type(module)), # type:ignore[attr-defined] + ) + + +def share_comm_ctx(modules: list[FSDPModule]) -> None: + """ + Share cuda streams for multiple FSDPModules + + Example usage: + from torch.distributed.fsdp import share_comm_ctx + share_comm_ctx([fsdp_model_1, fsdp_model_2, ...]) + + For Pipeline Parallelism (PP), each model chunk is a FSDP root. We want + to share cuda streams for all-gather, reduce-scatter, and all-reduce. + This avoids allocating inter-stream memory framgmentation + + Args: + modules (List[FSDPModule]): modules to share cuda streams + """ + if len(modules) == 0: + return + for module in modules: + if not isinstance(module, FSDPModule): + raise ValueError(f"Expects list of FSDPModules but got {module}") + fsdp_states = [module._get_fsdp_state() for module in modules] + comm_ctx = fsdp_states[0]._comm_ctx + for fsdp_state in fsdp_states[1:]: + fsdp_state._comm_ctx = comm_ctx + if fsdp_param_group := fsdp_state._fsdp_param_group: + fsdp_param_group.comm_ctx = comm_ctx + + +def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None: + for module in modules: + if not isinstance(module, FSDPModule): + raise ValueError(f"Expects FSDPModule but got {type(module)}: {module}") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..36bdc23e741c0bbee64d4c79e8b1b5e0c553263c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py @@ -0,0 +1,1206 @@ +# mypy: allow-untyped-defs +import collections +import itertools +import os +import warnings +from collections.abc import Callable, Generator, Iterable, Iterator +from typing import Any, no_type_check, Optional, TYPE_CHECKING, Union + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._exec_order_utils as exec_order_utils +import torch.distributed.fsdp._traversal_utils as traversal_utils +import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file +import torch.nn as nn +from torch.distributed.algorithms._comm_hooks import default_hooks +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.distributed_c10d import _get_default_group +from torch.distributed.fsdp._common_utils import ( + _FSDPDeviceHandle, + _FSDPState, + _get_module_fsdp_state, + _is_fsdp_flattened, + _named_parameters_with_duplicates, + clean_tensor_name, + TrainingState, +) +from torch.distributed.fsdp._flat_param import ( + _FSDP_USE_FULL_PREC_IN_EVAL, + FlatParameter, + FlatParamHandle, + HandleShardingStrategy, +) +from torch.distributed.fsdp._limiter_utils import _FreeEventQueue +from torch.distributed.fsdp.api import ( + BackwardPrefetch, + CPUOffload, + FullOptimStateDictConfig, + FullStateDictConfig, + MixedPrecision, + ShardingStrategy, + StateDictConfig, + StateDictType, +) +from torch.distributed.fsdp.wrap import _Policy +from torch.distributed.tensor.parallel.fsdp import DTensorExtensions +from torch.distributed.utils import _sync_params_and_buffers +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + +if TYPE_CHECKING: + from torch.utils.hooks import RemovableHandle + +_TORCHDISTX_AVAIL = True +try: + from torchdistx import deferred_init, fake # type: ignore[import] +except ImportError: + _TORCHDISTX_AVAIL = False + +PARAM_BROADCAST_BUCKET_SIZE = 250 * 1024 * 1024 +FSDP_SYNCED = "_fsdp_synced" +# Specification of process groups for hybrid sharding strategies. +HybridShardProcessGroupType = tuple[dist.ProcessGroup, dist.ProcessGroup] +# Overall specification of process group. +ProcessGroupType = Optional[Union[dist.ProcessGroup, HybridShardProcessGroupType]] + + +# TODO (awgu): Refactor this later +SHARDING_STRATEGY_MAP = { + ShardingStrategy.NO_SHARD: HandleShardingStrategy.NO_SHARD, + ShardingStrategy.FULL_SHARD: HandleShardingStrategy.FULL_SHARD, + ShardingStrategy.SHARD_GRAD_OP: HandleShardingStrategy.SHARD_GRAD_OP, + ShardingStrategy.HYBRID_SHARD: HandleShardingStrategy.HYBRID_SHARD, + ShardingStrategy._HYBRID_SHARD_ZERO2: HandleShardingStrategy._HYBRID_SHARD_ZERO2, +} +HYBRID_SHARDING_STRATEGIES = [ + ShardingStrategy.HYBRID_SHARD, + ShardingStrategy._HYBRID_SHARD_ZERO2, +] +NO_RESHARD_AFTER_FORWARD_STRATEGIES = ( + ShardingStrategy.SHARD_GRAD_OP, + ShardingStrategy._HYBRID_SHARD_ZERO2, +) + + +# NOTE: Since non-self attributes cannot be type annotated, several attributes +# on `state` are defined first as local variables before being assigned. + + +@no_type_check +def _init_process_group_state( + state: _FSDPState, + process_group: ProcessGroupType, + sharding_strategy: ShardingStrategy, + policy: Optional[_Policy], + device_mesh: Optional[DeviceMesh] = None, +) -> _FSDPState: + if process_group is not None and device_mesh is not None: + raise ValueError( + "Cannot pass both process_group and device_mesh at the " + "same time. Please just pass only one of them." + ) + is_hybrid_strategy = sharding_strategy in HYBRID_SHARDING_STRATEGIES + if is_hybrid_strategy: + if process_group is None and policy is None and device_mesh is None: + # Raise an error here, since this is manual wrapping with no process group + # passed in, there is no way to ensure all wrapped FSDP instances use the same + # process groups. + raise ValueError( + f"Manual wrapping with {sharding_strategy} " + "requires explicit specification of process group or device_mesh." + ) + else: + state = _init_process_group_state_for_hybrid_shard( + state, process_group, device_mesh + ) + else: + if device_mesh: + state._device_mesh = device_mesh + state.process_group = device_mesh.get_group(mesh_dim=0) + else: + state.process_group = ( + process_group if process_group is not None else _get_default_group() + ) + + state.rank = state.process_group.rank() + state.world_size = state.process_group.size() + data_parallel_world_size = state.world_size + if is_hybrid_strategy: + data_parallel_world_size *= state._inter_node_pg.size() + state._gradient_predivide_factor = ( + default_hooks.DefaultState._get_gradient_predivide_factor( + data_parallel_world_size + ) + ) + state._gradient_postdivide_factor = ( + data_parallel_world_size / state._gradient_predivide_factor + ) + return state + + +@no_type_check +def _init_process_group_state_for_hybrid_shard( + state: _FSDPState, + process_group: ProcessGroupType, + device_mesh: DeviceMesh, +) -> _FSDPState: + if device_mesh: + if _is_valid_hybrid_shard_device_mesh(device_mesh): + state._device_mesh = device_mesh + # We currently only allow _inter_node_pg to be the outermost dimension, and the + # process_group(intra_node) to be the innermost dimension. + state._inter_node_pg = device_mesh.get_group(mesh_dim=0) + state.process_group = device_mesh.get_group(mesh_dim=1) + else: + raise ValueError( + f"Expected device_mesh to have ndim=2 but got {device_mesh.ndim}" + ) + elif process_group is None: + default_group = _get_default_group() + intra_node_group, inter_node_group = _init_intra_and_inter_node_groups( + default_group, state._device_handle.device_count() + ) + # we shard across intra-node + state.process_group = intra_node_group + # save _inter_node_pg to allreduce across. + state._inter_node_pg = inter_node_group + else: + # Check type and assign state.process_group and state._inter_node_pg. + if _is_valid_hybrid_shard_pg_type(process_group): + # Assuming that user passed in as intra node group and inter node group + # as documented. + state.process_group, state._inter_node_pg = process_group + else: + raise ValueError( + "Expected process_group to be passed in as either None or " + f"Tuple[dist.ProcessGroup, dist.ProcessGroup] but got {type(process_group)}" + ) + # Create state for allreduce + state._inter_node_state = _get_default_comm_hook_state( + process_group=state._inter_node_pg, + ) + return state + + +@no_type_check +def _is_valid_hybrid_shard_pg_type(process_group: Any) -> bool: + return ( + isinstance(process_group, tuple) + and len(process_group) == 2 + and all(isinstance(pg, dist.ProcessGroup) for pg in process_group) + ) + + +@no_type_check +def _is_valid_hybrid_shard_device_mesh(device_mesh: DeviceMesh) -> bool: + return isinstance(device_mesh, DeviceMesh) and device_mesh.ndim == 2 + + +@no_type_check +def _init_intra_node_process_group(num_devices_per_node: int) -> dist.ProcessGroup: + """ + Return a process group across the current node. + + For example, given each row is a distinct node: + 0 1 2 3 4 5 6 7 + 8 9 10 11 12 13 14 15 + This API would return an intra-node subgroup across + [0, 1, ..., 7] or [8, 9, ..., 15] depending on the process's rank. + For example, rank 3 would get [0, 1, ..., 7]. + """ + intra_node_subgroup, _ = dist.new_subgroups(num_devices_per_node) + return intra_node_subgroup + + +@no_type_check +def _init_inter_node_process_group( + global_process_group: dist.ProcessGroup, + num_devices_per_node: int, +) -> dist.ProcessGroup: + """ + Return an inter-node process group where each contained rank has the same local rank. + + For example, given each row is a distinct node: + 0 1 2 3 4 5 6 7 + 8 9 10 11 12 13 14 15 + This API would return inter-node process group [0, 8], [1, 9], [2, 10], and so forth + depending on the process's rank. For example, rank 1 would get [1, 9], rank 5 + would get [5, 13]. + """ + # the inter-node pg that is returned + inter_node_pg = None + sharding_backend = dist.get_backend(global_process_group) + world_size = dist.get_world_size(global_process_group) + # Assuming fully homogeneous setup + num_nodes = world_size // num_devices_per_node + my_local_rank = dist.get_rank(global_process_group) % num_devices_per_node + for local_rank in range(num_devices_per_node): + ranks_for_inter_group = [ + local_rank + (i * num_devices_per_node) for i in range(num_nodes) + ] + # every rank always needs to call dist.new_group + grp = dist.new_group(ranks=ranks_for_inter_group, backend=sharding_backend) + if local_rank == my_local_rank: + inter_node_pg = grp + + if inter_node_pg is None: + raise AssertionError( + f"{my_local_rank} expected to assign inter-node pg, but did not" + ) + return inter_node_pg + + +def _init_intra_and_inter_node_groups( + global_process_group: dist.ProcessGroup, + num_devices_per_node: int, +) -> tuple[dist.ProcessGroup, dist.ProcessGroup]: + """ + Initialize intra and inter-node process groups and return the ones corresponding to this process's rank. + + This function can be used to initialize process groups for ``HYBRID_SHARD`` or + ``_HYBRID_SHARD_ZERO2`` in FSDP. + This function assumes each node has an equal number of CUDA-enabled devices. + Returns: + Tuple[dist.ProcessGroup, dist.ProcessGroup]: Intra and inter-node process group. + """ + return ( + _init_intra_node_process_group(num_devices_per_node), + _init_inter_node_process_group(global_process_group, num_devices_per_node), + ) + + +@no_type_check +def _init_ignored_module_states( + state: _FSDPState, + module: nn.Module, + ignored_modules: Optional[Iterable[torch.nn.Module]], + ignored_states: Union[ + Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]] + ] = None, +) -> _FSDPState: + if ignored_modules is not None and ignored_states is not None: + raise ValueError( + "Cannot pass both ignored_modules and ignored_states at the " + "same time. Please just pass ignored_states." + ) + ignored_parameters = None + passed_as_ignored_states = ignored_states is not None + if passed_as_ignored_states: + ignored_states_list = list(ignored_states) + _check_ignored_states(ignored_states_list, True) + else: + ignored_states_list = [] + _check_ignored_states( + list(ignored_modules) if ignored_modules is not None else [], False + ) + if len(ignored_states_list) > 0: + if isinstance(ignored_states_list[0], nn.Parameter): + ignored_parameters = ignored_states_list + else: + ignored_modules = ignored_states_list + state._ignored_modules = _get_ignored_modules(module, ignored_modules) + state._ignored_params = _get_ignored_params( + module, + state._ignored_modules, + ignored_parameters, + ) + state._ignored_buffer_names = _get_ignored_buffer_names( + module, + state._ignored_modules, + ) + # TODO: FSDP's contract for buffers is not well-defined. They are + # implicitly ignored for most functionality since they are not sharded; + # however, FSDP still imposes some semantics on buffers (e.g. buffer mixed + # precision). We should formalize this contract and decide if we need to + # compute and store `_ignored_buffers`. + return state + + +def _check_ignored_states( + ignored_states: list[Any], passed_as_ignored_states: bool +) -> None: + """ + Check that the ignored states are uniformly parameters or uniformly modules. + + We may remove this check in the future if we permit mixing. + """ + if len(ignored_states) == 0: + return + if passed_as_ignored_states: + all_params = all(isinstance(state, nn.Parameter) for state in ignored_states) + all_modules = all(isinstance(state, nn.Module) for state in ignored_states) + if not all_params and not all_modules: + # Sort for consistent ordering for unit test regex matching + sorted_types = sorted({type(state) for state in ignored_states}, key=repr) + raise ValueError( + "ignored_states expects all nn.Parameter or all nn.Module list " + f"elements but got types {sorted_types}" + ) + else: + if not all(isinstance(state, nn.Module) for state in ignored_states): + sorted_types = sorted({type(state) for state in ignored_states}, key=repr) + raise ValueError( + "ignored_modules expects nn.Module list elements but got " + f"types {sorted_types}" + ) + + +@no_type_check +def _init_device_handle( + state: _FSDPState, + module: nn.Module, + ignored_params: set[nn.Parameter], + device_id: Optional[Union[int, torch.device]], +) -> _FSDPState: + """ + Determine device handle used for initializing FSDP. + + If a device is specified by ``device_id``, + then returns device handle corresponds to that device type. Otherwise, If the + module is already on a non-CPU device, then the device type is that non-CPU device type. + If the module is on CPU or meta, then the device type is the current accelerator device. + See the :ref:`Accelerators` for details. + + + This method will be called once ignored parameters was determined, as the device handle maybe needed + for other initialization. + """ + determined_device = None + if device_id is not None: + determined_device = ( + device_id + if isinstance(device_id, torch.device) + else torch.device(device_id) + ) + if determined_device is None: + for param in _get_orig_params(module, ignored_params): + if param.device.type in {"cpu", "meta"}: + continue + if determined_device is None: + determined_device = param.device + else: + if param.device.type != determined_device.type: + raise RuntimeError( + f"FSDP does not support modules with different device types " + f"but got params on {determined_device.type} and {param.device.type}" + ) + determined_device = determined_device or torch._C._get_accelerator() + if determined_device.type == "cpu": + raise RuntimeError( + "FSDP needs a non-CPU accelerator device, but no accelerator device is detected." + ) + + state._device_handle = _FSDPDeviceHandle.from_device(determined_device) + return state + + +@no_type_check +def _init_buffer_state( + state: _FSDPState, + module: nn.Module, +) -> _FSDPState: + state._buffer_names = _get_buffer_names(module) + # Save a mapping from clean fully-qualified buffer name (starting from + # `module`) to its original dtype for restoring that dtype during model + # checkpointing when buffer mixed precision is enabled. The names should + # be clean since the casting happens in a `summon_full_params()` context. + _buffer_name_to_orig_dtype: dict[str, torch.dtype] = {} + for buffer_name, buffer in module.named_buffers(): + buffer_name = clean_tensor_name(buffer_name) + _buffer_name_to_orig_dtype[buffer_name] = buffer.dtype + state._buffer_name_to_orig_dtype = _buffer_name_to_orig_dtype + return state + + +@no_type_check +def _init_core_state( + state: _FSDPState, + sharding_strategy: Optional[ShardingStrategy], + mixed_precision: Optional[MixedPrecision], + cpu_offload: Optional[CPUOffload], + limit_all_gathers: bool, + use_orig_params: bool, + backward_prefetch_limit: int, + forward_prefetch_limit: int, +) -> _FSDPState: + # We clamp the strategy to `NO_SHARD` for world size of 1 since they are + # currently functionally equivalent. This may change if/when we integrate + # FSDP with MoE. + if state.world_size == 1: + if sharding_strategy != ShardingStrategy.NO_SHARD: + warnings.warn( + "FSDP is switching to use `NO_SHARD` instead of " + f"{sharding_strategy or ShardingStrategy.FULL_SHARD} since " + "the world size is 1.", + stacklevel=2, + ) + sharding_strategy = ShardingStrategy.NO_SHARD + elif sharding_strategy == ShardingStrategy.NO_SHARD: + warnings.warn( + "The `NO_SHARD` sharding strategy is deprecated. If having issues, " + "please use `DistributedDataParallel` instead.", + FutureWarning, + # Level 1 is here, level 2 is from `FullyShardedDataParallel`, and + # level 3 is from the true caller + stacklevel=3, + ) + state.sharding_strategy = sharding_strategy or ShardingStrategy.FULL_SHARD + state.mixed_precision = mixed_precision or MixedPrecision() + if mixed_precision is not None: + torch._C._log_api_usage_once( + f"torch.distributed.fsdp.mixed_precision.{str(state.mixed_precision)}" + ) + state._use_full_prec_in_eval = ( + os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1" + ) + state.cpu_offload = cpu_offload or CPUOffload() + state.limit_all_gathers = limit_all_gathers + state._use_orig_params = use_orig_params + state.training_state = TrainingState.IDLE + state._is_root = None + state._free_event_queue = _FreeEventQueue() + state._debug_level = dist.get_debug_level() + state._exec_order_data = exec_order_utils._ExecOrderData( + state._debug_level, + backward_prefetch_limit, + forward_prefetch_limit, + ) + state._unshard_event = None + # Mapping from fully sharded module to the handles it is responsible to + # unshard and reshard (see [Note: Fully Sharded Module]) + _fully_sharded_module_to_handle: dict[nn.Module, FlatParamHandle] = {} + state._fully_sharded_module_to_handle = _fully_sharded_module_to_handle + # Invariant: `state.params` contains exactly the `FlatParameter`s of the + # handles in `state._handle` + _handle: Optional[FlatParamHandle] = None + state._handle = _handle + params: list[FlatParameter] = [] + state.params = params + return state + + +@no_type_check +def _init_runtime_state( + state: _FSDPState, +) -> _FSDPState: + _root_pre_forward_handles: list[RemovableHandle] = [] + state._root_pre_forward_handles = _root_pre_forward_handles + _pre_forward_handles: list[RemovableHandle] = [] + state._pre_forward_handles = _pre_forward_handles + _post_forward_handles: list[RemovableHandle] = [] + state._post_forward_handles = _post_forward_handles + state._sync_gradients = True + state._comm_hook = None + state._comm_hook_state = None + # Used to prevent running the pre-backward hook multiple times + return state + + +@no_type_check +def _init_prefetching_state( + state: _FSDPState, + backward_prefetch: BackwardPrefetch, + forward_prefetch: bool, +) -> _FSDPState: + state.backward_prefetch = backward_prefetch + state.forward_prefetch = forward_prefetch + # The data structures use tuples of handles to generalize over the case + # where a module's forward involves multiple handles. + return state + + +@no_type_check +# pyrefly: ignore [bad-function-definition] +def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState: + # TODO: we need to add additional check once we support FSDP + PiPPy. + # This check is currently sufficient, since we only support FSDP + TP. + root_mesh = device_mesh._get_root_mesh() if device_mesh is not None else None + # if a root mesh is not the same as device_mesh, + # meaning the device_mesh is sliced out from the root mesh. + if device_mesh and root_mesh != state._device_mesh: + state._fsdp_extension = DTensorExtensions(state._device_handle) + else: + # We need to explicitly set _fsdp_extension to None. + # Otherwise, we will run into an infinite recursion when getting the attribute. + state._fsdp_extension = None + return state + + +@no_type_check +def _init_state_dict_state(state: _FSDPState) -> _FSDPState: + state._state_dict_type = StateDictType.FULL_STATE_DICT + state_dict_config: StateDictConfig = FullStateDictConfig() + state._optim_state_dict_config = FullOptimStateDictConfig() + state._state_dict_config = state_dict_config + unshard_params_ctx: dict[nn.Module, Generator] = {} + state._unshard_params_ctx = unshard_params_ctx + + return state + + +def _verify_managed_params(module: nn.Module, params: list[nn.Parameter]) -> None: + """ + Verify if the parameters are accepted by FSDP. The only restriction now + is that the parameter cannot be a scalar tensor (param.shape == []). + """ + for param in params: + if len(param.shape) == 0: + param_name = "" + for name, param_ in module.named_parameters(): + if param is param_: + param_name = name + break + if not param_name: + raise AssertionError("Expected param_name to be set") + raise ValueError( + "FSDP doesn't support scalar parameters. " + f"Change {param_name} to a 1D tensor with numel equal to 1." + ) + + +@no_type_check +def _init_param_handle_from_module( + state: _FSDPState, + fully_sharded_module: nn.Module, + device_id: Optional[Union[int, torch.device]], + param_init_fn: Optional[Callable[[nn.Module], None]], + sync_module_states: bool, +) -> _FSDPState: + """Initialize a ``FlatParamHandle`` from a module ``fully_sharded_module``.""" + _check_single_device_module(fully_sharded_module, state._ignored_params, device_id) + device_from_device_id = _get_device_from_device_id( + device_id, state.rank, state._device_handle + ) + is_meta_module, is_torchdistX_deferred_init = _need_to_materialize_module( + fully_sharded_module, state._ignored_params, state._ignored_modules + ) + # Materialize the module if needed + if (is_meta_module or is_torchdistX_deferred_init) and param_init_fn is not None: + _materialize_with_param_init_fn( + fully_sharded_module, param_init_fn, state._ignored_modules + ) + elif is_meta_module: + _materialize_meta_module( + fully_sharded_module, + device_id, + state._ignored_modules, + state._device_handle, + ) + elif is_torchdistX_deferred_init: + deferred_init.materialize_module( + fully_sharded_module, + check_fn=lambda submodule: _get_module_fsdp_state(submodule) is None + and submodule not in state._ignored_modules, + ) + + ignored_buffers = { + buffer + for ignored_module in state._ignored_modules + for buffer in ignored_module.buffers() + } + + _move_module_to_device( + fully_sharded_module, + state._ignored_params, + ignored_buffers, + device_from_device_id, + ) + state.compute_device = _get_compute_device( + fully_sharded_module, + state._ignored_params, + device_from_device_id, + state.rank, + state._device_handle, + ) + + managed_params = list(_get_orig_params(fully_sharded_module, state._ignored_params)) + _verify_managed_params(fully_sharded_module, managed_params) + if sync_module_states: + _sync_module_params_and_buffers( + fully_sharded_module, managed_params, state.process_group + ) + if state.sharding_strategy in HYBRID_SHARDING_STRATEGIES: + _sync_module_params_and_buffers( + fully_sharded_module, managed_params, state._inter_node_pg + ) + _init_param_handle_from_params(state, managed_params, fully_sharded_module) + return state + + +@no_type_check +def _init_param_handle_from_params( + state: _FSDPState, + params: list[nn.Parameter], + fully_sharded_module: nn.Module, +): + if len(params) == 0: + return + handle = FlatParamHandle( + params, + fully_sharded_module, + state.compute_device, + SHARDING_STRATEGY_MAP[state.sharding_strategy], + state.cpu_offload.offload_params, + state.mixed_precision.param_dtype, + state.mixed_precision.reduce_dtype, + state.mixed_precision.keep_low_precision_grads, + state.process_group, + state._use_orig_params, + fsdp_extension=state._fsdp_extension, + ) + handle.shard() + if state._handle: + raise AssertionError("Expected state._handle to be None") + state.params.append(handle.flat_param) + state._handle = handle + state._fully_sharded_module_to_handle[handle._fully_sharded_module] = handle + cpu_device = torch.device("cpu") + if state.cpu_offload.offload_params and handle.flat_param.device != cpu_device: + handle.flat_param_to(cpu_device) + + +def _get_ignored_modules( + root_module: nn.Module, + _ignored_modules: Optional[Iterable[torch.nn.Module]], +) -> set[nn.Module]: + """ + Check that ``_ignored_modules`` is an iterable of ``nn.Module`` s without any FSDP instances. + + Return the modules contained in their module + subtrees as a :class:`set`. Nested FSDP instances are excluded, but their + already-computed ignored modules are included. + + ``_ignored_modules`` represents the argument passed by the user to FSDP. + """ + msg_prefix = "`ignored_modules` should be an iterable of `torch.nn.Module`s " + try: + ignored_root_modules = ( + set(_ignored_modules) if _ignored_modules is not None else set() + ) + except TypeError as e: + raise TypeError(msg_prefix + f"but got {type(_ignored_modules)}") from e + for module in ignored_root_modules: + if not isinstance(module, torch.nn.Module): + raise TypeError(msg_prefix + f"but got an iterable with {type(module)}") + if _get_module_fsdp_state(module): + # TODO: We may relax this by taking the FSDP instance's wrapped + # module to provide more flexibility to the user. + raise ValueError("`ignored_modules` should not include FSDP modules") + # Treat modules that cannot compose with `fully_shard` as ignored modules, + # meaning that their subtrees are ignored + for module in root_module.modules(): + if not traversal_utils._composable(module): + ignored_root_modules.add(module) + # NOTE: Even if `ignored_root_modules` is empty, do not return early so + # that this FSDP instance can get any ignored modules from its children. + + # Include child modules and exclude nested FSDP modules themselves + ignored_modules = { + child + for module in ignored_root_modules + for child in module.modules() + if not isinstance(child, fsdp_file.FullyShardedDataParallel) + } + if root_module in ignored_modules: + warnings.warn( + "Trying to ignore the top-level module passed into the FSDP " + "constructor itself will result in all parameters being " + f"ignored and is not well-supported: {module}", + stacklevel=2, + ) + # Include nested FSDP modules' ignored modules + for submodule in root_module.modules(): + optional_fsdp_state = _get_module_fsdp_state(submodule) + if optional_fsdp_state is not None: + if not hasattr(optional_fsdp_state, "_ignored_modules"): + raise AssertionError( + "Expected optional_fsdp_state to have _ignored_modules attribute" + ) + ignored_modules.update(optional_fsdp_state._ignored_modules) + return ignored_modules + + +def _get_ignored_params( + root_module: torch.nn.Module, + ignored_modules: set[torch.nn.Module], + ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None, +) -> set[torch.nn.Parameter]: + """ + Return the parameters of the modules in ``ignored_modules`` and the parameters in ``ignored_parameters``. + + :class:`FlatParameter` s are excluded from the result. + """ + all_ignored_params: set[torch.nn.Parameter] = set() + + params_in_ignored_modules = { + p for m in ignored_modules for p in m.parameters() if not _is_fsdp_flattened(p) + } + + all_ignored_params.update(params_in_ignored_modules) + + if ignored_parameters is not None: + params_in_ignored_parameters = { + p for p in ignored_parameters if not _is_fsdp_flattened(p) + } + all_ignored_params.update(params_in_ignored_parameters) + + # Always include nested FSDP modules' ignored parameters + for submodule in root_module.modules(): + optional_fsdp_state = _get_module_fsdp_state(submodule) + if optional_fsdp_state is not None: + if not hasattr(optional_fsdp_state, "_ignored_params"): + raise AssertionError( + "Expected optional_fsdp_state to have _ignored_params attribute" + ) + all_ignored_params.update(optional_fsdp_state._ignored_params) + + return all_ignored_params + + +def _get_ignored_buffer_names( + root_module: torch.nn.Module, + ignored_modules: set[torch.nn.Module], +) -> set[str]: + """Return the cleaned buffer FQNs in ``ignored_modules``.""" + all_ignored_buffer_names: set[str] = set() + + buffers_in_ignored_modules = { + buffer for m in ignored_modules for buffer in m.buffers() + } + + all_ignored_buffer_names.update( + { + clean_tensor_name(buffer_name) + for buffer_name, buffer in root_module.named_buffers() + if buffer in buffers_in_ignored_modules + } + ) + + # Always include nested FSDP modules' ignored buffer names + for submodule in root_module.modules(): + optional_fsdp_state = _get_module_fsdp_state(submodule) + if optional_fsdp_state is not None: + if not hasattr(optional_fsdp_state, "_ignored_buffer_names"): + raise AssertionError( + "Expected optional_fsdp_state to have _ignored_buffer_names attribute" + ) + all_ignored_buffer_names.update(optional_fsdp_state._ignored_buffer_names) + + return all_ignored_buffer_names + + +def _get_buffer_names(root_module: nn.Module) -> set[str]: + """Return the fully prefixed names of all buffers in the module hierarchy rooted at ``root_module`` as a class:`set`.""" + return { + clean_tensor_name(buffer_name) for buffer_name, _ in root_module.named_buffers() + } + + +def _check_single_device_module( + module: nn.Module, + ignored_params: set[nn.Parameter], + device_id: Optional[Union[int, torch.device]], +) -> None: + """ + Raise an error if ``module`` has original parameters on multiple devices, ignoring the parameters in ``ignored_params``. + + Thus, after this method, the + module must be either fully on the CPU or fully on a non-CPU device. + """ + devices = {param.device for param in _get_orig_params(module, ignored_params)} + # We allow module to be partially on CPU and partially on GPU if device_id is not + # None, since the device_id arg will result in the CPU portion being moved to + # GPU. This is useful in cases where part of the module may be parallelized + # by another algorithm and may already be on GPU. We'd like to enforce device_id + # to not be None, otherwise we'd flatten parameters in a mixed module which is + # not supported. + if len(devices) == 2 and torch.device("cpu") in devices: + if device_id is None: + raise RuntimeError( + "To support a module with both CPU and GPU params, " + "please pass in device_id argument." + ) + elif len(devices) > 1: + raise RuntimeError( + f"FSDP only supports single device modules but got params on {devices}" + ) + + +def _get_device_from_device_id( + device_id: Optional[Union[int, torch.device]], + rank: int, + device_handle: _FSDPDeviceHandle, +) -> Optional[torch.device]: + """ + Return a ``torch.device`` for the specified ``device_id``. + + Processes ``device_id`` and returns either the corresponding device or + ``None`` if ``device_id`` is ``None``. + """ + if device_id is None: + return None + device = ( + device_id if isinstance(device_id, torch.device) else torch.device(device_id) + ) + if device.type != "cpu" and device.index is None: + warnings.warn( + f"FSDP got the argument `device_id` {device_id} on rank " + f"{rank}, which does not have an explicit index. " + f"FSDP will use the current device {device_handle.current_device()}. " + f"If this is incorrect, please explicitly call `torch.{device.type}.set_device()` " + "before FSDP initialization or pass in the explicit device " + "index as the `device_id` argument.", + stacklevel=2, + ) + device = torch.device(device_handle.current_device()) + return device + + +def _need_to_materialize_module( + module: nn.Module, + ignored_params: set[nn.Parameter], + ignored_modules: set[nn.Module], +) -> tuple[bool, bool]: + """ + Return if ``module`` has parameters on meta device and if ``module`` is using torchdistX deferred initialization. + + At most of the returned bools can + be ``True``. If either is ``True``, then ``module`` needs to be + materialized. + """ + managed_params = list(_get_orig_params(module, ignored_params)) + is_meta_module = any(param.is_meta for param in managed_params) + # TODO: We need to establish a contract for FSDP and buffers. For now, we + # skip checking for meta buffers from ignored modules. We should consider + # refactoring the initialization holistically to avoid so many traversals. + for submodule in module.modules(): + if submodule in ignored_modules: + continue + for buf in submodule.buffers(recurse=False): + is_meta_module |= buf.is_meta + is_torchdistX_deferred_init = ( + not is_meta_module + and _TORCHDISTX_AVAIL + and any(fake.is_fake(param) for param in managed_params) + ) + return is_meta_module, is_torchdistX_deferred_init + + +def _materialize_with_param_init_fn( + root_module: nn.Module, + param_init_fn: Callable[[nn.Module], None], + ignored_modules: set[nn.Module], +) -> None: + if not callable(param_init_fn): + raise ValueError( + f"Expected {param_init_fn} to be callable but got {type(param_init_fn)}" + ) + modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules) + for module in modules_to_materialize: + param_init_fn(module) + + +def _materialize_meta_module( + root_module: nn.Module, + device_from_device_id: Optional[torch.device], + ignored_modules: set[nn.Module], + device_handle: _FSDPDeviceHandle, +): + # Run default meta device initialization + materialization_device = device_from_device_id or torch.device( + device_handle.current_device() + ) + modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules) + module = None + try: + # Assume that each module's `reset_parameters()` only initializes its + # own parameters and not those of its children + with torch.no_grad(): + for module in modules_to_materialize: + # As a contract to the user, only call `reset_parameters()` if + # the module has directly managed parameters/buffers + module_state_iter = itertools.chain( + module.parameters(recurse=False), + # pyrefly: ignore [bad-argument-type] + module.buffers(recurse=False), + ) + has_module_states = len(list(module_state_iter)) > 0 + if has_module_states: + module.to_empty(device=materialization_device, recurse=False) + module.reset_parameters() # type: ignore[operator] + except BaseException as e: + warnings.warn( + "Unable to call `reset_parameters()` for module on meta " + f"device with error {str(e)}. Please ensure that your module of" + f"type {type(module)} implements a `reset_parameters()` method.", + stacklevel=2, # type: ignore[possibly-undefined] + ) + raise e + + +def _get_modules_to_materialize( + root_module: nn.Module, ignored_modules: set[nn.Module] +) -> list[nn.Module]: + # Run BFS to collect the modules to materialize via `reset_parameters()`, + # stopping at any module with FSDP already applied or at ignored modules. + modules_to_materialize: list[nn.Module] = [] + queue = collections.deque([root_module]) + visited_modules: set[nn.Module] = {root_module} + while queue: + module = queue.popleft() + modules_to_materialize.append(module) + for child_module in module.children(): + if ( + child_module not in visited_modules + and _get_module_fsdp_state(child_module) is None + and child_module not in ignored_modules + ): + visited_modules.add(child_module) + queue.append(child_module) + return modules_to_materialize + + +def _move_module_to_device( + module: nn.Module, + ignored_params: set[nn.Parameter], + ignored_buffers: set[torch.Tensor], + device_from_device_id: Optional[torch.device], +) -> None: + """ + Move ``module`` depending on ``device_from_device_id`` and its current device. + + This includes moving ignored modules' parameters. + + - If ``device_from_device_id`` is not ``None``, then this moves + ``module`` to the device. + - If ``device_from_device_id`` is ``None``, then this does not move + ``module`` but warns the user if it is on CPU. + + Precondition: ``_check_single_device_module()``. + """ + cpu_device = torch.device("cpu") + if device_from_device_id is not None: + # BFS from `module` without traversing any nested FSDP instances to + # collect the parameters/buffers that have not yet been managed + queue: collections.deque[nn.Module] = collections.deque() + queue.append(module) + params: list[nn.Parameter] = [] + buffers: list[torch.Tensor] = [] + while queue: + curr_module = queue.popleft() + # NOTE: We include a check to only move parameters/buffers that are + # on CPU device. If they are on a CUDA device different from the + # one specified by `device_id`, then this does NOT move them. This + # is so that we can raise an error in `_get_compute_device()`. + params.extend( + param + for param in curr_module.parameters(recurse=False) + if param.device == cpu_device + ) + buffers.extend( + buffer + for buffer in curr_module.buffers(recurse=False) + if buffer.device == cpu_device + ) + for submodule in curr_module.children(): + if not isinstance(submodule, fsdp_file.FullyShardedDataParallel): + queue.append(submodule) + params_to_move = [p for p in params if p not in ignored_params] + bufs_to_move = [p for p in buffers if p not in ignored_buffers] + _move_states_to_device(params_to_move, bufs_to_move, device_from_device_id) + return + param = next(_get_orig_params(module, ignored_params), None) + if param is not None and param.device == cpu_device: + _warn_cpu_init() + + +def _move_states_to_device( + params: list[nn.Parameter], + buffers: list[torch.Tensor], + device_from_device_id: Optional[torch.device], +) -> None: + """ + Move states to the specified device. + + Precondition: ``_check_single_device_module()`` and module's parameters and + buffers have been materialized if needed. + """ + if len(params) == 0 and len(buffers) == 0: + return + if len(params) > 0: + current_device = params[0].device + elif len(buffers) > 0: + current_device = buffers[0].device + cpu_device = torch.device("cpu") + if device_from_device_id is not None: + # Move the parameters and buffers like the `.data` code path in + # `nn.Module._apply()`, which underlies `nn.Module.to()` + for param in params: + with torch.no_grad(): + param.data = param.to(device_from_device_id) + if param.grad is not None: + param.grad.data = param.grad.to(device_from_device_id) + for buffer in buffers: + buffer.data = buffer.to(device_from_device_id) + elif current_device == cpu_device: # type: ignore[possibly-undefined] + _warn_cpu_init() + + +def _warn_cpu_init(): + warnings.warn( + "The passed-in `module` is on CPU and will thus have FSDP's sharding " + "initialization run on CPU, which may be slower than on GPU. We " + "recommend passing in the `device_id` argument for FSDP to move " + "`module` to GPU for the sharding initialization. `module` must also " + "be on GPU device to work with the `sync_module_states=True` flag " + "since that requires GPU communication.", + stacklevel=2, + ) + + +def _get_compute_device( + module: nn.Module, + ignored_params: set[nn.Parameter], + device_from_device_id: Optional[torch.device], + rank: int, + device_handle: _FSDPDeviceHandle, +) -> torch.device: + """ + Determine and return this FSDP instance's compute device. + + If the module is already on a non-CPU device, then the compute device is that non-CPU + device. If the module is on CPU, then the compute device is the current + device. + + Since this method should be called after materializing the module, any + non-CPU device should not be meta device. For now, the compute device is + always a CUDA or CUDA-like device with its explicit index. + + Precondition: ``_check_single_device_module()`` and + ``_move_module_to_device()``. + """ + param = next(_get_orig_params(module, ignored_params), None) + if param is not None and param.device.type != "cpu": + compute_device = param.device # Determined by model param placement + else: + compute_device = torch.device(device_handle.current_device()) + if device_from_device_id is not None and compute_device != device_from_device_id: + raise ValueError( + f"Inconsistent compute device and `device_id` on rank {rank}: " + f"{compute_device} vs {device_from_device_id}" + ) + return compute_device + + +# TODO: See how to deprecate! +def _sync_module_params_and_buffers( + module: nn.Module, + params: list[nn.Parameter], + process_group: dist.ProcessGroup, +) -> None: + """ + Synchronize module states (i.e. parameters ``params`` and all not-yet-synced buffers) by broadcasting from rank 0 to all ranks. + + Precondition: ``sync_module_states == True`` and ``self.process_group`` has + been set. + """ + module_states: list[torch.Tensor] = [] + for buffer in module.buffers(): + # Avoid re-synchronizing buffers in case of nested wrapping + if not getattr(buffer, FSDP_SYNCED, False): + setattr(buffer, FSDP_SYNCED, True) + detached_buffer = buffer.detach() + if is_traceable_wrapper_subclass(detached_buffer): + # NOTE: Here we assume no nested subclasses, at most one level of subclass + # in both model's buffers and params + attrs, _ = detached_buffer.__tensor_flatten__() # type: ignore[attr-defined] + inner_buffers = [getattr(detached_buffer, attr) for attr in attrs] + module_states.extend(inner_buffers) + else: + module_states.append(detached_buffer) + + for param in params: + detached_param = param.detach() + if is_traceable_wrapper_subclass(detached_param): + attrs, _ = detached_param.__tensor_flatten__() # type: ignore[attr-defined] + inner_params = [getattr(detached_param, attr) for attr in attrs] + module_states.extend(inner_params) + else: + module_states.append(detached_param) + + _check_module_states_for_sync_module_states(module_states) + _sync_params_and_buffers( + process_group, + module_states, + PARAM_BROADCAST_BUCKET_SIZE, + src=0, + ) + + +def _check_module_states_for_sync_module_states( + module_states: list[torch.Tensor], +) -> None: + if module_states and any( + tensor.device == torch.device("cpu") for tensor in module_states + ): + raise ValueError( + "The module has CPU parameters or buffers when `sync_module_states=True`, " + "which requires them to be on GPU. Please specify the `device_id` argument " + "or move the module to GPU before passing it to FSDP." + ) + + +def _get_orig_params( + module: nn.Module, + ignored_params: set[nn.Parameter], +) -> Iterator[nn.Parameter]: + """ + Return an iterator over the original parameters in ``module``. + + The iterator does not return + the parameters in ``ignored_params``, any ``FlatParameter`` s (which may be + present due to nested FSDP wrapping), or any original parameters already + flattened (only relevant when ``use_orig_params=True``). + """ + param_gen = module.parameters() + try: + while True: + param = next(param_gen) + if param not in ignored_params and not _is_fsdp_flattened(param): + yield param + except StopIteration: + pass + + +def _check_orig_params_flattened( + fsdp_module, + ignored_params: set[nn.Parameter], +) -> None: + """ + Check that original parameters in ``fsdp_module`` have been flattened. + + The flattened parameters are made + invisible to ``named_parameters()`` for the module hierarchy rooted at + ``fsdp_module``. This should be called as a sanity check after flattening + the wrapped module's parameters. + """ + for param_name, param in _named_parameters_with_duplicates(fsdp_module): + if param not in ignored_params and not _is_fsdp_flattened(param): + raise RuntimeError( + f"Found an unflattened parameter: {param_name}; " + f"{param.size()} {param.__class__}" + ) + + +def _get_default_comm_hook(sharding_strategy: ShardingStrategy): + return ( + default_hooks.allreduce_hook + if sharding_strategy == ShardingStrategy.NO_SHARD + else default_hooks.reduce_scatter_hook + ) + + +def _get_default_comm_hook_state( + process_group: dist.ProcessGroup, +) -> default_hooks.DefaultState: + return default_hooks.DefaultState(process_group=process_group) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_limiter_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_limiter_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f9b190585342ee267716abace19add022b4d6b3e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_limiter_utils.py @@ -0,0 +1,33 @@ +import collections +from typing import Optional + +import torch + + +class _FreeEventQueue: + """ + This tracks all pending frees corresponding to inflight all-gathers. The + queueing pattern is iterative enqueues with a single dequeue per iteration + once the limit ``_max_num_inflight_all_gathers`` is reached. + """ + + def __init__(self) -> None: + self._queue: collections.deque[torch.Event] = collections.deque() + self._max_num_inflight_all_gathers = 2 # empirically chosen + + def enqueue(self, free_event: torch.Event) -> None: + """Enqueues a free event.""" + self._queue.append(free_event) + + def dequeue_if_needed(self) -> Optional[torch.Event]: + """Dequeues a single event if the limit is reached.""" + if len(self._queue) >= self._max_num_inflight_all_gathers: + return self._dequeue() + return None + + def _dequeue(self) -> Optional[torch.Event]: + """Dequeues a free event if possible.""" + if self._queue: + event = self._queue.popleft() + return event + return None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_optim_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_optim_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..564cfeece48ee1e656ea4e06628c36c0d01c0af8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_optim_utils.py @@ -0,0 +1,2139 @@ +# mypy: allow-untyped-defs +import copy +import functools +import logging +import warnings +from collections.abc import Iterable, Iterator, Sequence +from contextlib import ExitStack +from dataclasses import dataclass, field +from itertools import chain +from typing import Any, cast, NamedTuple, no_type_check, Optional, TYPE_CHECKING, Union + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._traversal_utils as traversal_utils +import torch.nn as nn +from torch.distributed._state_dict_utils import _gather_state_dict +from torch.distributed.distributed_c10d import _get_pg_default_device +from torch.distributed.fsdp._common_utils import ( + _apply_to_modules, + _FSDPState, + _get_module_fsdp_state_if_fully_sharded_module, + _get_param_to_fqns, + _module_handle, + _named_parameters_with_duplicates, + clean_tensor_name, +) +from torch.distributed.fsdp._debug_utils import SimpleProfiler +from torch.distributed.fsdp._flat_param import FlatParameter, FlatParamHandle +from torch.distributed.fsdp._fsdp_extensions import ( + _ext_chunk_dtensor, + _ext_chunk_tensor, +) +from torch.distributed.fsdp._runtime_utils import ( + _lazy_init, + _reset_flat_param_grad_info_if_needed, +) +from torch.distributed.fsdp.api import ( + ShardingStrategy, + StateDictSettings, + StateDictType, +) +from torch.distributed.tensor import DTensor, Replicate +from torch.utils._pytree import tree_map_only + + +if TYPE_CHECKING: + from torch.distributed._shard.sharded_tensor import ShardedTensor + + +logger = logging.getLogger(__name__) + + +@dataclass +class FSDPParamInfo: + state: _FSDPState + handle: FlatParamHandle + param_indices: dict[str, int] + param_requires_grad: list[bool] + + +def sorted_items(dictionary: dict[str, Any]) -> Iterator[tuple[str, Any]]: + keys = sorted(dictionary.keys()) + for k in keys: + yield k, dictionary[k] + + +@dataclass +class _ConsolidatedOptimState: + """ + This holds the consolidated optimizer state on the target rank. Positive- + dimension tensor state is communicated across ranks, while zero-dimension + tensor state and non-tensor state is taken directly from the target rank. + + PyTorch version 1.12 moved to using zero-dimension tensors for scalar + values, but user implemented optimizers may still use float (i.e. a + non-tensor). Thus, we support both and handle them identically. + + Attributes: + tensor_state (Dict[str, torch.Tensor]): Mapping from positive-dimension + tensor state name to the unsharded flat tensor representing the + state. + zero_dim_tensor_state (Dict[str, torch.Tensor]): Mapping from zero- + dimension tensor state name to its value. + non_tensor_state (Dict[str, Any]): Mapping from non-tensor state + name to its value. + """ + + tensor_state: dict[str, torch.Tensor] = field(default_factory=dict) + zero_dim_tensor_state: dict[str, torch.Tensor] = field(default_factory=dict) + non_tensor_state: dict[str, Any] = field(default_factory=dict) + + +class _PosDimTensorInfo(NamedTuple): + """ + Metadata for positive-dimension tensors used internally for + :meth:`scatter_full_optim_state_dict`. + + Attributes: + shape (torch.Size): Sharded tensor shape (which is equal to the + unsharded tensor shape if the tensor is optimizer state for a + non-FSDP parameter and is hence not sharded). + dtype (torch.dtype): Data type of the tensor. + """ + + shape: torch.Size + dtype: torch.dtype + + +class _OptimStateKey(NamedTuple): + """ + This represents an optimizer state key that may be used commonly across + ranks. It is based on the unflattened parameter names rather than parameter + IDs to make it independent of each rank's own optimizer construction. + """ + + unflat_param_names: tuple[str, ...] + is_fsdp_managed: bool + + +def _unflatten_optim_state( + fsdp_param_info: FSDPParamInfo, + flat_param_state: dict[str, Any], + to_save: bool, + shard_state: bool, + cpu_offload: bool, +) -> list[dict[str, Any]]: + """ + Unflattens the optimizer state, consisting of the "state" part and the + "param_groups" part. Unflattening the "state" part involves consolidating + the state on the target rank and remapping from flattened to unflattened + parameter IDs, and the "param_groups" part only involves remapping from + flattened to unflattened parameter IDs. + + Args: + fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a + mapping from FQN to original parameter index. + flat_param_state (Dict[str, Any]): Entry for the flat parameter in the + "state" part of the optimizer state dict. + to_save (bool): Whether to save the state on this rank. + + Returns: + List[Dict[str, Any]]: A :class:`list` holding the entries in the + "state" part of the optimizer state dict corresponding to the + unflattened parameters comprising the flat parameter if on the target + rank or an empty :class:`list` otherwise. The final optimizer state + dict will need to map these entries using the proper unflattened + parameter IDs. + """ + if shard_state and not to_save: + raise AssertionError("If ``shard_state`` is True, ``to_save`` has to be True.") + consolidated_state = _communicate_optim_state( + fsdp_param_info, + flat_param_state, + ) + if to_save: + unflat_param_state = _unflatten_communicated_optim_state( + fsdp_param_info, + consolidated_state, + shard_state, + ) + for optim_state in unflat_param_state: + # We can't use .items() below cuz we'd run into a concurrent modification error + if cpu_offload: + for key in list(optim_state.keys()): + state = optim_state[key] + if not isinstance(state, torch.Tensor): + continue + optim_state[key] = state.cpu() + return unflat_param_state + else: + return [] + + +def _is_zero_dim_tensor(x: Any) -> bool: + return torch.is_tensor(x) and x.dim() == 0 + + +def _communicate_optim_state( + fsdp_param_info: FSDPParamInfo, + flat_param_state: dict[str, Any], +) -> _ConsolidatedOptimState: + """ + Communicates the optimizer state for a flat parameter across ranks. All + ranks will hold the entire non-sharded optimizer state on GPU. + + If ``N`` is the number of tensor optimizer states in the optimizer state + dict, then the communication complexity is 0 if ``N = 0`` and ``N + 1`` + otherwise (where the plus 1 comes from all-gathering the padding per rank). + + Args: + fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a + mapping from FQN to original parameter index. + flat_param_state (Dict[str, Any]): The entry in the "state" part of the + optimizer state dict corresponding to the flat parameter. + + Returns: + ConsolidatedOptimState: Consolidated optimizer state for the target + flat parameter. + """ + fsdp_state = fsdp_param_info.state + flat_param = fsdp_param_info.handle.flat_param + state = _ConsolidatedOptimState() + tensor_state, zero_dim_tensor_state, non_tensor_state = ( + state.tensor_state, + state.zero_dim_tensor_state, + state.non_tensor_state, + ) + + for state_name, value in sorted_items(flat_param_state): + # Positive-dimension tensor state: communicate across ranks + if torch.is_tensor(value) and value.dim() > 0: + # If the parameter is not sharded, then neither is the + # positive-dimension tensor state, so no need to communicate it -- + # we take the target rank's value + if ( + fsdp_state.world_size == 1 + or fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD + ): + tensor_state[state_name] = value + continue + if fsdp_state.compute_device is None: + raise AssertionError("compute_device has not been initialized") + if value.device.type != fsdp_state.compute_device.type: + value = value.to(fsdp_state.compute_device) + # Assume that positive-dimension tensor optimizer state + # has the same shape as the sharded flat parameter + buffer_size = flat_param._full_param_padded.size() # type: ignore[attr-defined] + tensor_buffer = value.new_zeros(*buffer_size) + dist.all_gather_into_tensor( + tensor_buffer, value, group=fsdp_state.process_group + ) + fsdp_state._device_handle.synchronize() + unpadded_numel = cast( + nn.Parameter, flat_param._unpadded_unsharded_size + ).numel() + tensor_state[state_name] = tensor_buffer[:unpadded_numel] + # Zero-dimension tensor state and non-tensor state: take this rank's + # value directly + else: + if _is_zero_dim_tensor(value): + zero_dim_tensor_state[state_name] = value.detach().clone() + else: + non_tensor_state[state_name] = value + return state + + +def _unflatten_communicated_optim_state( + fsdp_param_info: FSDPParamInfo, + state: _ConsolidatedOptimState, + shard_state: bool, +) -> list[dict[str, Any]]: + """ + Unflattens the communicated optimizer state (given by ``tensor_state``, + ``non_tensor_state``, and ``zero_dim_tensor_state``) for a single flat + parameter. This should only be called on the target rank. + + Args: + fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a + mapping from FQN to original parameter index. + state (_ConsolidatedOptimState): Consolidated optimizer state. + + Returns: + List[Dict[str, Any]]: A :class:`list` holding the entries in the + "state" part of the optimizer state dict corresponding to the + unflattened parameters comprising the flat parameter. The final + optimizer state dict will need to map these entries using the proper + unflattened parameter IDs. + """ + fsdp_state = fsdp_param_info.state + handle = fsdp_param_info.handle + flat_param = handle.flat_param + unflat_param_state: list[dict[str, Any]] = [] + flat_param_views: dict[str, Iterator] = {} + num_unflat_params = flat_param._num_params + tensor_state, zero_dim_tensor_state, non_tensor_state = ( + state.tensor_state, + state.zero_dim_tensor_state, + state.non_tensor_state, + ) + + for _ in range(num_unflat_params): + unflat_state_param = {} + # Add positive-dimension tensor state: unflatten with views + for state_name, flat_tensor in sorted_items(tensor_state): + views_generated = state_name in flat_param_views + if not views_generated: + views = handle._get_unflat_views(flat_tensor) + flat_param_views[state_name] = views + else: + views = flat_param_views[state_name] + optim_state: Union[torch.Tensor, ShardedTensor, DTensor] = next(views) + if shard_state: + osd_config = fsdp_state._optim_state_dict_config + if getattr(osd_config, "_use_dtensor", False): + if fsdp_state._device_mesh is None: + raise AssertionError( + f"Expected _device_mesh to be not None, got {fsdp_state._device_mesh}" + ) + optim_state = _ext_chunk_dtensor( + optim_state, + fsdp_state.rank, + fsdp_state._device_mesh, + fsdp_state._fsdp_extension, + ) + else: + if fsdp_state.process_group is None: + raise AssertionError( + f"Expected process_group to be not None, got {fsdp_state.process_group}" + ) + optim_state = _ext_chunk_tensor( + optim_state, + fsdp_state.rank, + fsdp_state.world_size, + fsdp_state._device_handle.device_count(), + fsdp_state.process_group, + fsdp_state._fsdp_extension, + ) + unflat_state_param[state_name] = optim_state + + # Add zero-dimension tensor state: take the target rank's value + unflat_state_param.update(sorted_items(zero_dim_tensor_state)) + # Add non-tensor state: take the target rank's value + unflat_state_param.update(sorted_items(non_tensor_state)) + unflat_param_state.append(unflat_state_param) + return unflat_param_state + + +def _broadcast_processed_state( + fsdp_state: _FSDPState, + optim_state: dict[str, Any], + group: Optional[dist.ProcessGroup], +) -> dict[str, Any]: + objects: list[Any] = [None] + if dist.get_rank(group) == 0: + objects[0] = tree_map_only( + torch.Tensor, + lambda v: v.cpu() if v.dim() == 0 else _PosDimTensorInfo(v.shape, v.dtype), # type: ignore[union-attr] + optim_state, + ) + dist.broadcast_object_list(objects, src=0, group=group) + if dist.get_rank(group) == 0: + return optim_state + else: + return objects[0] + + +def _broadcast_state( + fsdp_state: _FSDPState, state: Any, group: Optional[dist.ProcessGroup] +) -> Any: + if dist.get_rank(group) == 0: + if not isinstance(state, torch.Tensor) or state.dim() == 0: + return state + tensor = state.to(fsdp_state.compute_device) + else: + if isinstance(state, torch.Tensor): + if state.dim() != 0: + raise AssertionError( + "For non-zero ranks, a tensor state should have zero dimension, " + f"but got the state with shape {state.shape}." + ) + return state + elif not isinstance(state, _PosDimTensorInfo): + return state + tensor = torch.zeros( + state.shape, dtype=state.dtype, device=fsdp_state.compute_device + ) + dist.broadcast(tensor, src=0, group=group) + return tensor + + +def _shard_orig_param_state( + fsdp_param_info: FSDPParamInfo, + fqn: str, + optim_state: dict[str, Any], +) -> dict[str, Any]: + """ + Shard the optimizer state for the original parameter with the name ``fqn``. + This API should only be used when ``use_orig_params`` is True. + """ + if not optim_state: + return {} + fsdp_state = fsdp_param_info.state + flat_param = fsdp_param_info.handle.flat_param + param_idx = fsdp_param_info.param_indices[fqn] + shard_param_info = flat_param._shard_param_infos[param_idx] # type: ignore[attr-defined] + optim_state = _gather_state_dict( + optim_state, pg=fsdp_state.process_group, device=fsdp_state.compute_device + ) + if not shard_param_info.in_shard: + return {} + # Flatten and shard the state. + new_optim_state: dict[str, Any] = {} + intra_param_start_idx = shard_param_info.intra_param_start_idx + intra_param_end_idx = shard_param_info.intra_param_end_idx + for state_name, value in optim_state.items(): + if ( + torch.is_tensor(value) + and value.dim() > 0 + and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD + ): + value = value.flatten()[ + intra_param_start_idx : intra_param_end_idx # type: ignore[operator] + + 1 + ].clone() + new_optim_state[state_name] = value + return new_optim_state + + +def _flatten_optim_state_dict( + optim_state_dict: dict[str, Any], + model: nn.Module, + use_orig_params: bool = False, + optim: Optional[torch.optim.Optimizer] = None, + rank0_only: bool = False, + group: Optional[dist.ProcessGroup] = None, +) -> dict[str, Any]: + """ + Flattens the full optimizer state dict, still keying by unflattened parameter + names. + + If ``use_orig_params`` is True, each rank will have all FSDP-managed + parameters but some of these parameters may be empty due to the sharding. + For a regular optim.Optimizer, states for those empty parameters will + not be initialized. So, when aggregating the FQNs across ranks, no assert + will be raised on a rank even if it does not have all the states -- it is + valid and FSDP know how to aggregate them. However, FSDP has to ignore + handling those parameters that are not managed by FSDP and do not exist on + the local rank -- it is managed by other parallelism and FSDP does not + know ho to handle/aggregate them. + + Note that ``_flatten_tensor_optim_state`` does not need ``optim`` to + flatten/shard the state. However, NamedOptimizer and KeyedOptimizer require + all the states even if the corresponding parameters are empty. To this end, + ``optim`` will be used to get the initial state of the empty parameters. + ``optim`` should only be non-None if the ``optim` is KeyedOptimizer or + NamedOptimizer. + + Returns: + Dict[str, Any]: The flattened optimizer state dict. + """ + SimpleProfiler.reset() + + unflat_osd = optim_state_dict + if "state" not in unflat_osd and not rank0_only: + raise ValueError( + '`optim_state_dict` must have the keys "state"' + "to be a valid optimizer state dict" + ) + param_to_fqns = _get_param_to_fqns(model) + fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model) + fsdp_state = next(iter(fqn_to_fsdp_param_info.values())).state + + # Broadcast unflat_osd without non-scalar tensor if rank0_only is True. + if rank0_only: + unflat_osd = _broadcast_processed_state(fsdp_state, unflat_osd, group=group) + + # Construct the "state" part + flat_osd_state: dict[Union[_OptimStateKey, str], Any] = {} + unflat_osd_state = unflat_osd["state"] + all_state_keys = set(unflat_osd_state.keys()) + + for param, fqns in param_to_fqns.items(): + fqn = fqns[0] + if fqn not in unflat_osd_state: + continue + all_state_keys.difference_update(fqns) + + if rank0_only: + for fqn in fqns: + if not unflat_osd_state[fqn]: + continue + for state_name in unflat_osd_state[fqn]: + unflat_osd_state[fqn][state_name] = _broadcast_state( + fsdp_state, unflat_osd_state[fqn][state_name], group=group + ) + fqn = fqns[0] + if fqn in fqn_to_fsdp_param_info: + fsdp_param_info = fqn_to_fsdp_param_info[fqn] + if use_orig_params: + with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING): + flat_state = _shard_orig_param_state( + fsdp_param_info, + fqn, + unflat_osd_state[fqn], + ) + else: + flat_state = _flatten_optim_state( + fsdp_param_info, + unflat_osd_state, + fqns, + ) + key = _OptimStateKey(tuple(fqns), True) + # Only include non-empty states since as expected by + # `torch.optim.Optimizer` s unless the optimizer is KeyedOptimizer + # or NamedOptimizer. + if flat_state: + flat_osd_state[key] = flat_state + elif use_orig_params: + if len(fqns) != 1: + raise AssertionError( + f"use_orig_params is True but there are multiple FQNs, {fqns}." + ) + if optim is not None: # NamedOptimizer or KeyedOptimizer case. + state = optim.state.get(param, None) # type: ignore[call-overload] + if state is not None: + flat_osd_state[key] = copy.deepcopy(state) + else: + warnings.warn( + f"optim_state[{key}] is not on rank{fsdp_state.rank}.", + stacklevel=2, + ) + + else: + raise RuntimeError( + f"The state of {key} is empty. This should happen when " + "use_orig_params=True." + ) + else: # do not flatten non-FSDP parameters' states + if len(fqns) != 1: + raise AssertionError(f"Expected len(fqns) == 1, got {len(fqns)}") + key = _OptimStateKey(tuple(fqns), False) + flat_osd_state[key] = copy.copy(unflat_osd_state[fqn]) + + if rank0_only: + for fqn in fqns: + if not unflat_osd_state[fqn]: + continue + for state_name, param_state in list(unflat_osd_state[fqn].items()): + if fsdp_state.rank > 0: + # Deference the tensor so that PyTorch can collect the memory. + del unflat_osd_state[fqn][state_name] + else: + # Move the tensor in the original osd back to CPU to make the + # original osd unaffected. + unflat_osd_state[fqn][state_name] = param_state.cpu() + + # Handle user-defined state, states that are not associated with parameters. + for key in all_state_keys: + user_state = unflat_osd_state[key] + if isinstance(user_state, torch.Tensor) and rank0_only and use_orig_params: + user_state = _broadcast_state(fsdp_state, user_state, group=group) + flat_osd_state[key] = copy.copy(user_state) + + SimpleProfiler.dump_and_reset("FSDP _flatten_optim_state_dict() profiling: ") + # Construct the "param_groups" part -- copy as is since it will be + # rekeyed later according to the target rank's optimizer + # Only copy param_groups if it exists in unflat_osd + if "param_groups" in unflat_osd: + flat_osd_param_groups = copy.deepcopy(unflat_osd["param_groups"]) + return {"state": flat_osd_state, "param_groups": flat_osd_param_groups} + else: + return {"state": flat_osd_state} + + +def _flatten_optim_state( + fsdp_param_info: FSDPParamInfo, + unflat_osd_state: dict[str, dict[str, Any]], + unflat_param_names: list[str], +) -> dict[str, Any]: + """ + Flattens the optimizer state in ``full_optim_state_dict`` for a single + flat parameter in ``fsdp_param_info`` corresponding to the unflattened + parameter names in ``unflat_param_names``. + + Args: + fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a + mapping from FQN to original parameter index. + unflat_osd_state (Dict[str, Dict[str, Any]]): The "state" part of the + optimizer state dict corresponding to the unflattened parameters. + unflat_param_names (List[str]): A :class:`list` of unflattened + parameter names corresponding to the flat parameter ``flat_param``. + + Returns: + Dict[str, Any]: A :class:`dict` mapping state names to their values for + a particular flat parameter. The sharded optimizer state dict's "state" + part will map a key to this returned value. + """ + fsdp_state = fsdp_param_info.state + handle = fsdp_param_info.handle + flat_param = handle.flat_param + num_unflat_params = len(unflat_param_names) + if num_unflat_params <= 0: + raise AssertionError( + "Expects at least one unflattened parameter corresponding to the flat parameter" + ) + unflat_param_shapes = flat_param._shapes + num_unflat_param_shapes = len(unflat_param_shapes) + if num_unflat_params != num_unflat_param_shapes: + raise AssertionError( + f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}" + ) + + # Check if these unflattened parameters have any optimizer state + has_state = [ + bool(unflat_param_name in unflat_osd_state) + for unflat_param_name in unflat_param_names + ] + # If none of the unflattened parameters comprising this flat parameter have + # any state, then we do not want an entry in the optimizer state dict + if not any(has_state): + return {} # no need to flatten any state + # There may still be some unflattened parameters with state and some + # without + unflat_param_states = [ + _gather_state_dict( + unflat_osd_state[unflat_param_name], + pg=fsdp_state.process_group, + device=fsdp_state.compute_device, + ) + if unflat_param_name in unflat_osd_state + else None + for unflat_param_name in unflat_param_names + ] + # Check that the unflattened parameters have the same state names + state_names = None + # pyrefly: ignore [bad-assignment] + for unflat_param_state in unflat_param_states: + if unflat_param_state is None: + continue + if state_names is None: + state_names = set(unflat_param_state.keys()) + else: + if state_names != set(unflat_param_state.keys()): + raise ValueError( + "Differing optimizer state names for the unflattened " + f"parameters: {unflat_param_names}" + ) + if state_names is None: + raise AssertionError(f"Expected state_names to be not None, got {state_names}") + + # Flatten the state + flat_state: dict[str, Optional[torch.Tensor]] = {} + for state_name in state_names: + state_values = [ + unflat_param_state[state_name] if unflat_param_state is not None else None + for unflat_param_state in unflat_param_states + ] + non_none_state_values = [v for v in state_values if v is not None] + # If all ranks have None, this is a None value + if not non_none_state_values: + flat_state[state_name] = None + continue + are_pos_dim_tensors = are_zero_dim_tensors = are_non_tensors = True + for v in non_none_state_values: + are_pos_dim_tensors &= torch.is_tensor(v) and v.dim() > 0 + are_zero_dim_tensors &= _is_zero_dim_tensor(v) + are_non_tensors &= not torch.is_tensor(v) + types = {type(v) for v in non_none_state_values} + if len(types) != 1 or not ( + are_pos_dim_tensors or are_zero_dim_tensors or are_non_tensors + ): + raise ValueError( + f"Differing optimizer state types for state {state_name}, " + f"values {non_none_state_values}, and unflattened parameter " + f"names {unflat_param_names}" + ) + if are_pos_dim_tensors: + flat_tensor = _flatten_tensor_optim_state( + state_name, + state_values, # type: ignore[arg-type] + unflat_param_names, + unflat_param_shapes, + handle, + ) + # Shard the flattened tensor immediately to minimize max memory + # usage + if ( + fsdp_state.world_size != 1 + and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD + ): + sharded_flat_tensor, _ = FlatParamHandle._get_shard( + flat_tensor, + fsdp_state.rank, + fsdp_state.world_size, + ) + else: + sharded_flat_tensor = flat_tensor + flat_state[state_name] = sharded_flat_tensor + elif are_zero_dim_tensors: + flat_state[state_name] = _flatten_zero_dim_tensor_optim_state( + state_name, + state_values, # type: ignore[arg-type] + unflat_param_names, + ) + else: + if not are_non_tensors: + raise AssertionError( + f"Expected are_non_tensors to be True, got {are_non_tensors}" + ) + flat_state[state_name] = _flatten_non_tensor_optim_state( + state_name, + state_values, + unflat_param_names, + ) + + return flat_state + + +def _flatten_tensor_optim_state( + state_name: str, + pos_dim_tensors: list[torch.Tensor], + unflat_param_names: list[str], + unflat_param_shapes: Sequence[torch.Size], + handle: FlatParamHandle, +) -> torch.Tensor: + """ + Flattens the positive-dimension tensor optimizer state given by the values + ``tensors`` for the state ``state_name`` for a single flat parameter + from ``handle`` corresponding to the unflattened parameter names + ``unflat_param_names`` and unflatted parameter shapes + ``unflat_param_shapes``. This flattens each unflattened parameter's tensor + state into one tensor. + + NOTE: We use zero tensors for any unflattened parameters without state + since some value is required to fill those entries. This assumes that the + zero tensor is mathematically equivalent to having no state, which is true + for Adam's "exp_avg" and "exp_avg_sq" but may not be true for all + optimizers. + + Args: + state_name (str): Optimizer state name. + pos_dim_tensors (List[torch.Tensor]): Positive-dimension tensor + optimizer state values for the unflattened parameters corresponding + to the single flat parameter. + unflat_param_names (List[str]): A :class:`list` of unflattened + parameter names corresponding to the single flat parameter. + unflat_param_shapes (List[torch.Size]): Unflattened parameter shapes + corresponding to the single flat parameter. + handle (FlatParamHandle): The flat parameter's handle. + + Returns: + torch.Tensor: A flat tensor containing the optimizer state + corresponding to ``state_name`` constructed by concatenating the + unflattened parameter tensor states in ``pos_dim_tensors`` (using zero + tensors for any unflattened parameters without the state). + """ + flat_param = handle.flat_param + non_none_tensors = [t for t in pos_dim_tensors if t is not None] + # Check that all are tensors with the same dtype + dtypes = {t.dtype for t in non_none_tensors} + if len(dtypes) != 1: + raise ValueError( + "All unflattened parameters comprising a single flat " + "parameter must have positive-dimension tensor state with the " + f"same dtype but got dtypes {dtypes} for state {state_name} and " + f"unflattened parameter names {unflat_param_names}" + ) + dtype = next(iter(dtypes)) + # Check that each tensor state matches its parameter's shape + for tensor, shape in zip(pos_dim_tensors, unflat_param_shapes): + if tensor is None and len(shape) == 0: + raise ValueError("Flattening a zero-dimension parameter is not supported") + elif tensor is not None and tensor.shape != shape: + raise ValueError( + "Tensor optimizer state does not have same shape as its " + f"parameter: {tensor.shape} {shape}" + ) + # Flatten the tensor states: we do not need to add any right-hand-side + # padding since the flat optimizer state tensor is sharded via + # `_get_shard()`, which pads the shard as needed (just like for the flat + # parameter) + cpu_device = torch.device("cpu") + tensors_to_flatten = [ + torch.flatten(state_value.to(cpu_device)) + if state_value is not None + else torch.flatten( + torch.zeros( + size=shape, + dtype=dtype, + device=cpu_device, + ) + ) + for state_value, shape in zip(pos_dim_tensors, unflat_param_shapes) + ] + flat_tensor = handle.flatten_tensors(tensors_to_flatten, handle._aligned_numel) + flat_param_shape = flat_param._unpadded_unsharded_size # type: ignore[attr-defined] + if flat_tensor.shape != flat_param_shape: + raise AssertionError( + f"tensor optim state: {flat_tensor.shape} flat parameter: {flat_param_shape}" + ) + return flat_tensor + + +def _flatten_zero_dim_tensor_optim_state( + state_name: str, + zero_dim_tensors: list[torch.Tensor], + unflat_param_names: list[str], +) -> torch.Tensor: + """ + Flattens the zero-dimension tensor optimizer state given by the values + ``zero_dim_tensors`` for the state ``state_name`` for a single flat + parameter corresponding to the unflattened parameter names + ``unflat_param_names`` by enforcing that all tensors are the same and using + that common value. + + NOTE: The requirement that the tensors are the same across all unflattened + parameters comprising the flat parameter is needed to maintain the + invariant that FSDP performs the same computation as its non-sharded + equivalent. This means that none of the unflattened parameters can be + missing this state since imposing a value may differ from having no value. + For example, for Adam's "step", no value means maximum bias correction, + while having some positive value means less bias correction. + + Args: + state_name (str): Optimizer state name. + zero_dim_tensors (List[torch.Tensor]): Zero-dimension optimizer state + for the unflattened parameters corresponding to the single + flat parameter. + unflat_param_names (List[str]): A :class:`list` of unflattened + parameter names corresponding to the single flat parameter. + + Returns: + torch.Tensor: A zero-dimensional tensor giving the value of the state + ``state_name`` for all unflattened parameters corresponding to the + names ``unflat_param_names``. + """ + non_none_tensors = [t for t in zero_dim_tensors if t is not None] + # Enforce that all have the same value and dtype + values_set = {t.item() if t is not None else None for t in zero_dim_tensors} + dtypes = {t.dtype if t is not None else None for t in zero_dim_tensors} + if ( + len(non_none_tensors) != len(zero_dim_tensors) + or len(values_set) != 1 + or len(dtypes) != 1 + ): + raise ValueError( + "All unflattened parameters comprising a single flat " + "parameter must have scalar state with the same value and dtype " + f"but got values {values_set} and dtypes {dtypes} for state " + f"{state_name} and unflattened parameter names " + f"{unflat_param_names}" + ) + value = next(iter(values_set)) + dtype = next(iter(dtypes)) + return torch.tensor(value, dtype=dtype, device=torch.device("cpu")) + + +def _flatten_non_tensor_optim_state( + state_name: str, + non_tensors: list[Any], + unflat_param_names: list[str], +) -> Any: + """ + Flattens the non-tensor optimizer state given by the values ``non_tensors`` + for the state ``state_name`` for a single flat parameter corresponding + to the unflattened parameter names ``unflat_param_names`` by enforcing that + all values are the same and using that common value. + + See the note in :func:`_flatten_zero_dim_tensor_optim_state`. + + Args: + state_name (str): Optimizer state name. + non_tensors (List[Any]): Non-tensor optimizer state for the unflattened + parameters corresponding to the single flat parameter. + unflat_param_names (List[str]): A :class:`list` of unflattened + parameter names corresponding to the single flat parameter. + + Returns: + Any: A non-tensor giving the value of the state ``state_name`` for all + unflattened parameters corresponding to the names + ``unflat_param_names``. + """ + non_none_non_tensors = [nt for nt in non_tensors if nt is not None] + # Enforce that all have the same value (same type already checked) + non_tensor_set = set(non_tensors) + if len(non_none_non_tensors) != len(non_tensors) or len(non_tensor_set) != 1: + raise ValueError( + "All unflattened parameters comprising a single flat " + "parameter must have scalar state with the same value and dtype " + f"but got values {non_tensor_set} for state {state_name} and " + f"unflattened parameter names {unflat_param_names}" + ) + non_tensor = next(iter(non_tensor_set)) + return non_tensor + + +def _rekey_sharded_optim_state_dict( + sharded_osd: dict[str, Any], + model: nn.Module, + optim: torch.optim.Optimizer, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[nn.Parameter], + ] + ], + using_optim_input: bool, + is_named_optimizer: bool = False, +) -> dict[str, Any]: + """ + Rekeys the optimizer state dict from unflattened parameter names to flat + parameter IDs according to the calling rank's ``optim``, which may be + different across ranks. In particular, the unflattened parameter names are + represented as :class:`_OptimStateKey` s. + """ + param_to_fqns = _get_param_to_fqns(model) + flat_param_to_fqn = _get_flat_param_to_fqn(model) + param_to_param_key: dict[nn.Parameter, Union[int, str]] = cast( + dict[nn.Parameter, Union[int, str]], + ( + _get_param_to_param_id_from_optim_input(model, optim_input) + if using_optim_input + else _get_param_to_param_key( + optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn + ) + ), + ) + # All parameter keys in `param_to_param_key` should be in + # `param_to_fqns` -- strict inequality follows when not all parameters are + # passed to the optimizer + if len(param_to_param_key) > len(param_to_fqns): + raise AssertionError( + f"Expected len(param_to_param_key) <= len(param_to_fqns), got {len(param_to_param_key)} > {len(param_to_fqns)}" + ) + + unflat_param_names_to_flat_param_key: dict[ + tuple[str, ...], Union[int, str] + ] = {} # for "state" + unflat_param_name_to_flat_param_key: dict[ + str, Union[int, str] + ] = {} # for "param_groups" + for param, unflat_param_names in param_to_fqns.items(): + if param not in param_to_param_key: + # This parameter was not passed to the optimizer + continue + flat_param_key = param_to_param_key[param] + unflat_param_names_to_flat_param_key[tuple(unflat_param_names)] = flat_param_key + for unflat_param_name in unflat_param_names: + unflat_param_name_to_flat_param_key[unflat_param_name] = flat_param_key + + sharded_osd_state = sharded_osd["state"] + rekeyed_osd_state: dict[Union[str, int], Any] = {} + for key, param_state in sharded_osd_state.items(): + if isinstance(key, str): + rekeyed_osd_state[key] = param_state + continue + flat_param_key = unflat_param_names_to_flat_param_key.get( + key.unflat_param_names, key.unflat_param_names + ) + # pyrefly: ignore [unsupported-operation] + rekeyed_osd_state[flat_param_key] = param_state + + # Only process param_groups if it exists in sharded_osd + if "param_groups" in sharded_osd: + rekeyed_osd_param_groups: list[dict[str, Any]] = [] + for unflat_param_group in sharded_osd["param_groups"]: + flat_param_group = copy.deepcopy(unflat_param_group) + flat_param_keys = sorted( + { + unflat_param_name_to_flat_param_key[unflat_param_name] + for unflat_param_name in unflat_param_group["params"] + } + ) + flat_param_group["params"] = flat_param_keys + rekeyed_osd_param_groups.append(flat_param_group) + return {"state": rekeyed_osd_state, "param_groups": rekeyed_osd_param_groups} + else: + return {"state": rekeyed_osd_state} + + +def _get_param_id_to_param_from_optim_input( + model: nn.Module, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[nn.Parameter], + ] + ] = None, +) -> dict[int, nn.Parameter]: + """ + Constructs a mapping from parameter IDs to parameters. This may be used + both for models with ``FlatParameter`` s and without. + + NOTE: This method is only preserved for backward compatibility. The method + :meth:`_get_param_key_to_param` is the preferred code path that does not + rely on ``optim_input``. + + NOTE: We critically assume that, whether the optimizer input is a list of + parameters or a list of parameter groups, :class:`torch.optim.Optimizer` + enumerates the parameter IDs in order. In other words, for a parameter list + input, the parameter IDs should be in that list order, and for a parameter + groups input, the parameter IDs should be in order within each parameter + group and in order across parameter groups. + + Args: + model (nn.Module): Model whose parameters are passed into the + optimizer. + optim_input (Optional[Union[List[Dict[str, Any]], + Iterable[nn.Parameter]]]): Input passed into the optimizer + representing either a :class:`list` of parameter groups or an + iterable of parameters; if ``None``, then this method assumes the + input was ``model.parameters()``. (Default: ``None``) + + Returns: + List[nn.Parameter]: Mapping from parameter IDs to parameters, + where the parameter ID is implicitly the index in the :class:`list`. + """ + # Assume the standard case of passing `model.parameters()` to the optimizer + # if `optim_input` is not specified + if optim_input is None: + return dict(enumerate(model.parameters())) + try: + # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [redundant-cast] + params = cast(list[nn.Parameter], list(optim_input)) + except TypeError as e: + raise TypeError( + "Optimizer input should be an iterable of Tensors or dicts, " + f"but got {optim_input}" + ) from e + if len(params) == 0: + raise ValueError("Optimizer input should not be empty") + + # Check if the optimizer input represents tensors or parameter groups + all_tensors = True + all_dicts = True + for param in params: + all_tensors &= isinstance(param, torch.Tensor) + all_dicts &= isinstance(param, dict) + if not all_tensors and not all_dicts: + raise TypeError("Optimizer input should be an iterable of Tensors or dicts") + if all_tensors: + return dict(enumerate(params)) + if not all_dicts: + raise AssertionError(f"Expected all_dicts to be True, got {all_dicts}") + param_id_to_param: list[nn.Parameter] = [] + for param_group in params: + has_params_key = "params" in param_group # type: ignore[operator] + if not has_params_key: + raise AssertionError( + 'A parameter group should map "params" to a list of the parameters in the group' + ) + # Implicitly map `flat_param_id` (current length of the list) to + # `param` + param_id_to_param.extend(param_group["params"]) # type: ignore[index] + return dict(enumerate(param_id_to_param)) + + +def _get_flat_param_to_fqn(model: torch.nn.Module) -> dict[FlatParameter, str]: + """ + Constructs a mapping from ``FlatParameter`` to a cleaned (devoid of prefixes + from wrappers) fully qualified name (FQN). Note that this FQN is "non-canonical" + because ``FlatParameter`` s do not come from the original module but are + registered only after FSDP has been applied. This function returns the FSDP-given + name for the ``FlatParameter`` (usually module._flat_param) as opposed to the + canonical FQNs returned for ``FlatParameter`` s in ``_common_utils._get_param_to_fqns(...)``). + + Consequently, this function will only return a non-empty mapping if FSDP was + applied with ``use_orig_params=False`` as, otherwise, the original parameters + are used within the module and there would be no ``FlatParameter`` s in the module. + + """ + + def module_fn(module, prefix, tree_level, flat_param_to_fqn): + for param_name, param in _named_parameters_with_duplicates( + module, recurse=False + ): + if not isinstance(param, FlatParameter): + continue + fqn = clean_tensor_name(prefix + param_name) + flat_param_to_fqn[param] = fqn + + def return_fn(flat_param_to_fqn): + return flat_param_to_fqn + + flat_param_to_fqn_ret: dict[FlatParameter, str] = {} + return _apply_to_modules( + model, + module_fn, + return_fn, + [fqn for fqn, _ in _named_parameters_with_duplicates(model)], + flat_param_to_fqn_ret, + ) + + +def _get_param_key_to_param( + optim: torch.optim.Optimizer, + model: Optional[nn.Module] = None, + is_named_optimizer: bool = False, + param_to_fqns: Optional[dict[nn.Parameter, list[str]]] = None, + flat_param_to_fqn: Optional[dict[FlatParameter, str]] = None, +) -> dict[Union[int, str], nn.Parameter]: + """ + Constructs a mapping from parameter keys to parameters. For the regular + optimizers, the keys are parameter IDs. For NamedOptimizer, the keys + are FQNs. This API may be used both for models with ``FlatParameter`` s and + without. + """ + clean_fqn_to_curr_fqn: dict[str, str] = {} + if is_named_optimizer: + if param_to_fqns is None or flat_param_to_fqn is None: + raise AssertionError( + "The optimizer is a NamedOptimizer, `param_to_fqns` must not be None." + ) + if model is None: + raise AssertionError(f"Expected model to be not None, got {model}") + for key, _ in _named_parameters_with_duplicates(model): + clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key + + param_key_to_param: dict[Union[str, int], nn.Parameter] = {} + pid = 0 + for param_group in optim.param_groups: + if is_named_optimizer: + for param in param_group["params"]: + if flat_param_to_fqn is None: + raise AssertionError( + f"Expected flat_param_to_fqn to be not None, got {flat_param_to_fqn}" + ) + if param in flat_param_to_fqn: + # FlatParameter case + key = flat_param_to_fqn[param] + else: + if param_to_fqns is None: + raise AssertionError( + f"Expected param_to_fqns to be not None, got {param_to_fqns}" + ) + # use_orig_params case + if len(param_to_fqns[param]) != 1: + raise AssertionError( + f"Expected len(param_to_fqns[param]) == 1, got {len(param_to_fqns[param])}" + ) + key = param_to_fqns[param][0] + try: + key = clean_fqn_to_curr_fqn[key] + except KeyError as e: + raise KeyError( + f"Can't find {key} from {list(clean_fqn_to_curr_fqn.keys())}." + ) from e + param_key_to_param[key] = param + else: + for param in param_group["params"]: + param_key_to_param[pid] = param + pid += 1 + + return param_key_to_param + + +def _get_param_to_param_key( + optim: torch.optim.Optimizer, + model: Optional[nn.Module] = None, + is_named_optimizer: bool = False, + param_to_fqns: Optional[dict[nn.Parameter, list[str]]] = None, + flat_param_to_fqn: Optional[dict[FlatParameter, str]] = None, +) -> dict[nn.Parameter, Union[int, str]]: + """ + Constructs the inverse mapping of :func:`_get_param_key_to_param`. This API + only supports the case where `optim` is a regular optimizer, not NamedOptimizer. + So the parameter keys will be parameter ids. + """ + param_id_to_param = _get_param_key_to_param( + optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn + ) + return {param: param_id for param_id, param in param_id_to_param.items()} + + +def _get_param_to_param_id_from_optim_input( + model: nn.Module, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[nn.Parameter], + ] + ] = None, +) -> dict[nn.Parameter, int]: + """Constructs the inverse mapping of :func:`_get_param_id_to_param_from_optim_input`.""" + param_id_to_param = _get_param_id_to_param_from_optim_input(model, optim_input) + return {param: param_id for param_id, param in param_id_to_param.items()} + + +def _check_missing_keys_on_rank( + r0_optim_state_keys: list[_OptimStateKey], + optim_state_key_to_param_key: dict[_OptimStateKey, Union[str, int]], + param_key_to_param: dict[Union[str, int], nn.Parameter], + group: Optional[dist.ProcessGroup], +) -> None: + # Ensure that all ranks have at least the optimizer states needed by + # rank 0's optimizer + missing_keys: list[_OptimStateKey] = [] + for r0_optim_state_key in r0_optim_state_keys: + if r0_optim_state_key not in optim_state_key_to_param_key: + # A parameter from rank 0's optimizer does not exist for this + # rank's optimizer + missing_keys.append(r0_optim_state_key) + continue + param_key = optim_state_key_to_param_key[r0_optim_state_key] + if isinstance(param_key, int): + if not (param_key >= 0 and param_key < len(param_key_to_param)): + raise AssertionError("Check the `param_key_to_param` construction") + # We cannot use FSDPState.compute_device as this API is a global view. + device = _get_pg_default_device(group) + num_missing = torch.tensor([len(missing_keys)], dtype=torch.int32, device=device) + dist.all_reduce(num_missing, group=group) + if num_missing.item() > 0: + obj_list = [None for _ in range(dist.get_world_size(group))] + dist.all_gather_object(obj_list, missing_keys, group=group) + error_msg = ( + "FSDP currently requires each rank to have at least the " + "optimizer states needed by rank 0's optimizer but some ranks " + "are missing some of those states" + ) + for rank, keys in enumerate(obj_list): + keys = cast(list[_OptimStateKey], keys) + if len(keys) > 0: + error_msg += ( + f"\nRank {rank} is missing states for the parameters: " + f"{[key.unflat_param_names for key in keys]}" + ) + raise RuntimeError(error_msg) + + +def _map_param_key_to_optim_keys( + optim_state_dict: dict[str, Any], + group: Optional[dist.ProcessGroup], + param_key_to_param: dict[Union[int, str], nn.Parameter], + param_to_fqns: dict[nn.Parameter, list[str]], + fqn_to_fsdp_param_info: dict[str, FSDPParamInfo], + merge_keys: bool = False, +) -> tuple[list[_OptimStateKey], dict[_OptimStateKey, Union[int, str]]]: + """ + Construct the local mapping between the ``_OptimStateKey`` and parameter keys + and all the ``_OptimStateKey`` across ranks. If ``merge_keys`` is False, rank0 + must contain all the ``_OptimStateKey``, an exception will be raised otherwise. + Note that ``merge_keys`` should equal to ``use_orig_params``. + """ + rank = dist.get_rank(group) + optim_state_key_to_param_key: dict[_OptimStateKey, Union[int, str]] = {} # local + all_optim_state_keys: list[_OptimStateKey] = [] + + for param_key, param in param_key_to_param.items(): + # Do not include parameters without state to avoid empty mappings + # just like in normal `torch.optim.Optimizer.state_dict()` + if param_key not in optim_state_dict["state"]: + continue + fqns = param_to_fqns[param] + is_fsdp_managed = isinstance(param, FlatParameter) + if is_fsdp_managed: + if fqns[0] not in fqn_to_fsdp_param_info: + raise AssertionError( + f"Expected {fqns[0]} to be in fqn_to_fsdp_param_info, got keys: {list(fqn_to_fsdp_param_info.keys())}" + ) + is_fsdp_managed = fqns[0] in fqn_to_fsdp_param_info + optim_state_key = _OptimStateKey( + unflat_param_names=tuple(fqns), + is_fsdp_managed=is_fsdp_managed, + ) + if rank == 0 or merge_keys: + all_optim_state_keys.append(optim_state_key) + optim_state_key_to_param_key[optim_state_key] = param_key + + if merge_keys: + all_keys: list[list[_OptimStateKey]] = [ + [] for _ in range(dist.get_world_size(group)) + ] + dist.all_gather_object(all_keys, all_optim_state_keys, group=group) + merge_all_optim_state_keys = [*chain.from_iterable(all_keys)] + all_optim_state_keys = sorted(set(merge_all_optim_state_keys)) + else: + key_obj_list: list[Optional[list[_OptimStateKey]]] = ( + [all_optim_state_keys] if rank == 0 else [None] + ) + dist.broadcast_object_list(key_obj_list, src=0, group=group) + if key_obj_list[0] is None: + raise AssertionError( + f"Expected key_obj_list[0] to be not None, got {key_obj_list[0]}" + ) + all_optim_state_keys = key_obj_list[0] + _check_missing_keys_on_rank( + all_optim_state_keys, + optim_state_key_to_param_key, + param_key_to_param, + group, + ) + + return all_optim_state_keys, optim_state_key_to_param_key + + +def _unflatten_param_groups( + state_dict: dict[str, Any], + param_key_to_param: dict[Union[int, str], nn.Parameter], + param_to_fqns: dict[nn.Parameter, list[str]], +) -> list[dict[str, Any]]: + param_groups: list[dict[str, Any]] = [] + for flat_param_group in state_dict["param_groups"]: + unflat_param_group = copy.deepcopy(flat_param_group) + param_group_params = [ + param_key_to_param[flat_param_key] + for flat_param_key in flat_param_group["params"] + ] + nested_unflat_param_names = [ + param_to_fqns[param] for param in param_group_params + ] + unflat_param_group["params"] = [ + *chain.from_iterable(nested_unflat_param_names) + ] # flatten the list of lists + param_groups.append(unflat_param_group) + return param_groups + + +def _is_named_optimizer(optim_state_dict: dict[str, Any]) -> bool: + """ + Returns whether the state_dict is from a NamedOptimizer. + This function checks that the keys in the state_dict['state'] are strings + (which usually are FQNs) versus integers (which usually refer to param_ids + from a vanilla torch.optim.Optimizer). + """ + state = optim_state_dict.get("state") + if not state: + # If we cannot find a state, assume it is not NamedOptimizer as + # NamedOptimizer has eager initialization. + return False + try: + key = next(iter(state.keys())) + except Exception as e: + raise Exception(optim_state_dict) from e # noqa: TRY002 + return isinstance(key, str) + + +@dataclass +class StateInfo: + # The key of these dictionaries are the state name, e.g., `exp_avg`. + tensors: dict[str, _PosDimTensorInfo] + scalar_tensors: dict[str, torch.Tensor] + non_tensors: dict[str, Any] + + +def _allgather_state_info( + fsdp_state: _FSDPState, + input_states: dict[str, Any], +) -> list[dict[str, StateInfo]]: + """ + Given the ``input_states``, allgather StateInfo for each state. The function + uses all_gather_object to gather StateInfo so no GPU tensors are sent. + """ + + processed_state_dict: dict[str, StateInfo] = {} + gathered_state_info: list[dict[str, StateInfo]] = [ + {} for _ in range(fsdp_state.world_size) + ] + + for fqn, optim_state in input_states.items(): + # Allgather the scalar tensor state, non-tensor states and tensors metadata. + processed_state = StateInfo({}, {}, {}) + for state_name, value in sorted_items(optim_state): + if torch.is_tensor(value): + if value.dim() == 0: + # Ensure that `step` is on CPU. + processed_state.scalar_tensors[state_name] = value.cpu() + else: + processed_state.tensors[state_name] = _PosDimTensorInfo( + value.shape, value.dtype + ) + else: + processed_state.non_tensors[state_name] = value + processed_state_dict[fqn] = processed_state + dist.all_gather_object( + gathered_state_info, + processed_state_dict, + group=fsdp_state.process_group, + ) + return gathered_state_info + + +def _convert_all_state_info( + fsdp_param_info: FSDPParamInfo, + gathered_state_info: list[dict[str, StateInfo]], + input_states: dict[str, Any], + output_states: dict[str, dict[str, Any]], +) -> tuple[Optional[torch.dtype], dict[str, list[Optional[torch.Tensor]]]]: + """ + Given the ``gathered_state_info`` and ``input_states``, the API converted + the StateInfo into the original state if the state is not a non-scalar + tensor. For a multi-dimensional tensor, the local state will be stored in + ``state_buffer`` in a correct order for later allgather purpose. + """ + + state_buffers: dict[str, list[Optional[torch.Tensor]]] = {} + + for fqn, gathered_state in output_states.items(): + state_info = [s[fqn] for s in gathered_state_info] + all_tensor_states = sorted({n for state in state_info for n in state.tensors}) + empty_ranks: set[int] = set() + dtype: Optional[torch.dtype] = None + # First check all the non-scalar states and get the information of + # states on each rank. + for state_name in all_tensor_states: + numels = [] + _empty_ranks: set[int] = set() + for rank, object_state in enumerate(state_info): + numels.append(0) + info = object_state.tensors.get(state_name, None) + if info is not None: + numels[-1] = info.shape.numel() + if not dtype: + dtype = info.dtype + else: + if dtype != info.dtype: + raise AssertionError( + f"Expected dtype == info.dtype, got {dtype} != {info.dtype}" + ) + if numels[-1] == 0: + _empty_ranks.add(rank) + + if not (not empty_ranks or empty_ranks == _empty_ranks): + raise AssertionError( + f"Expected empty_ranks to be empty or equal to _empty_ranks, got {empty_ranks} vs {_empty_ranks}" + ) + empty_ranks = _empty_ranks + if state_name not in state_buffers: + state_buffers[state_name] = [ + None for _ in fsdp_param_info.param_indices + ] + local_state = input_states[fqn].get(state_name, None) + # N.B. We need to move the state to compute_device. The reason is + # not yet clear and we need to figure out why the state may be on a + # different device. + if local_state is not None: + local_state = local_state.to(fsdp_param_info.state.compute_device) + state_buffers[state_name][fsdp_param_info.param_indices[fqn]] = local_state + + # Restoring the scalar and non-tensor states. If the corresponding + # non-scalar states do not exist on the rank, we also skip the scalar + # non-tensor states on that rank. + for rank, object_state in enumerate(state_info): + if rank in empty_ranks: + continue + for name, non_tensor_value in object_state.non_tensors.items(): + curr_non_tensor_value = gathered_state.get(name, None) + if not ( + curr_non_tensor_value is None + or curr_non_tensor_value == non_tensor_value + ): + raise AssertionError( + f"Rank {rank} has different values for {name}: {non_tensor_value}." + + f" Other ranks: {curr_non_tensor_value}" + ) + gathered_state[name] = non_tensor_value + + for name, scalar_tensor_value in object_state.scalar_tensors.items(): + curr_scalar_tensor_value = gathered_state.get(name, None) + if not ( + curr_scalar_tensor_value is None + or torch.equal(scalar_tensor_value, curr_scalar_tensor_value) + ): + raise AssertionError( + f"Rank {rank} has different values for {name}: {scalar_tensor_value}." + + f" Other ranks: {curr_scalar_tensor_value}" + ) + gathered_state[name] = scalar_tensor_value + + return dtype, state_buffers # type: ignore[possibly-undefined] + + +def _unflatten_orig_param_states( + fsdp_param_info: FSDPParamInfo, + output_states: dict[str, dict[str, Any]], + state_name: str, + shard_state: bool, + to_save: bool, + cpu_offload: bool, +) -> None: + """ + Given a output state dict, ``output_states``, which the keys are FQNs to the + original parameters (not FlatParameters nor parameter ID), and the values + are gathered states, unflatten the states to the original dimensions. + + This function performs the unflattening process in-place. + """ + if not to_save: + return + flat_param = fsdp_param_info.handle.flat_param + fsdp_state = fsdp_param_info.state + for fqn, gathered_state in output_states.items(): + value = gathered_state[state_name] + param_idx = fsdp_param_info.param_indices[fqn] + + # TODO: This solution is not general and only apply to PTD TP solution. + if isinstance(value, DTensor): + placement = value.placements[0] + # If gathered state is a DTensor and its TP placement is not Replicate(), we need to + # gather the tensor on its TP dimension before chunking them into DTensor again. + if placement != Replicate(): + placement_dim = placement.dim # type: ignore[attr-defined] + value.redistribute(placements=(Replicate(),)) + reshape_size = list(flat_param._shapes[param_idx]) + reshape_size[placement_dim] *= value.device_mesh.size(0) + reshape_size = torch.Size(reshape_size) + value = value.reshape(reshape_size) + # If gathered state is a replicate DTensor, we directly reshape it. + else: + value = value.reshape(flat_param._shapes[param_idx]) + else: + # If gathered state is a tensor, we directly reshape it into unflatten state. + value = value.reshape(flat_param._shapes[param_idx]) + + if shard_state: + osd_config = fsdp_state._optim_state_dict_config + if getattr(osd_config, "_use_dtensor", False): + if fsdp_state._device_mesh is None: + raise AssertionError( + f"Expected _device_mesh to be not None, got {fsdp_state._device_mesh}" + ) + value = _ext_chunk_dtensor( + value, + fsdp_state.rank, + fsdp_state._device_mesh, + fsdp_state._fsdp_extension, + ) + else: + if fsdp_state.process_group is None: + raise AssertionError( + f"Expected process_group to be not None, got {fsdp_state.process_group}" + ) + value = _ext_chunk_tensor( + value, + fsdp_state.rank, + fsdp_state.world_size, + fsdp_state._device_handle.device_count(), + fsdp_state.process_group, + fsdp_state._fsdp_extension, + ) + elif not cpu_offload: + with SimpleProfiler.profile("clone"): + value = value.detach().clone() + + if cpu_offload: + with SimpleProfiler.profile(SimpleProfiler.Type.D2H): + value = value.cpu() + gathered_state[state_name] = value + + +def _allgather_orig_param_states( + fsdp_param_info: FSDPParamInfo, + gathered_state_info: list[dict[str, StateInfo]], + input_states: dict[str, Any], + shard_state: bool, + to_save: bool, + cpu_offload: bool, +) -> dict[str, dict[str, Any]]: + """ + Given the ``gathered_state_info`` and ``input_states``, the API allgathers + all tensor states and restore non-tensor states from ``gathered_state_info``. + """ + fsdp_state = fsdp_param_info.state + if fsdp_state.rank == 0 and dist.get_debug_level() == dist.DebugLevel.DETAIL: + logger.info( + "Memory Summary before calling to _allgather_orig_param_states %s", + fsdp_state._device_handle.memory_summary(), + ) + + output_states: dict[str, dict[str, Any]] = {fqn: {} for fqn in input_states} + + dtype, state_buffers = _convert_all_state_info( + fsdp_param_info, gathered_state_info, input_states, output_states + ) + + if len(state_buffers) == 0: + return output_states + + has_state_params: list[bool] = [ + fqn in output_states for fqn, idx in fsdp_param_info.param_indices.items() + ] + + # Loop through the ``state_buffers`` and construct the flattened, concatenated, + # sharded states. The size of the constructed state will be the same size as + # flat_param (also sharded). + # Then we perform an allgather_into_tensor to get the full flat_param state. + # The full flat_param state is the result of concatenation of multiple states + # the order of of flat_param._fqns. + # The final step is to split the flat_param state into original param states + # and return the result. + flat_param = fsdp_param_info.handle.flat_param + empty_func = functools.partial( + torch.empty, dtype=dtype, device=fsdp_state.compute_device + ) + gathered_tensor = empty_func(flat_param._padded_unsharded_size) + # Synchronize can be slow but this will be easier for us to debug. + fsdp_state._device_handle.synchronize() + for state_name, buffers in state_buffers.items(): + local_buffers: list[torch.Tensor] = [] + begin = fsdp_state.rank * flat_param._sharded_size.numel() + # End is inclusive. + end = begin + flat_param._sharded_size.numel() - 1 + # param_idx corresponds to the parameter index in the FlatParameter. + mem_offset, param_idx = 0, 0 + for numel, is_padding in zip( + flat_param._numels_with_padding, flat_param._is_padding_mask + ): + frozen_and_no_state = not is_padding and ( + not fsdp_param_info.param_requires_grad[param_idx] + and not has_state_params[param_idx] + ) + + if is_padding or frozen_and_no_state: + # This memory range is a padding or the param is frozen and does + # not require gradient. For the later case, we treat it as a + # padding and add empty values to the local_buffers. + + padding_begin, padding_end = mem_offset, mem_offset + numel - 1 + if padding_begin <= begin <= padding_end: + # The range is an align padding before the first parameter in + # the shard. The shard includes parts of this align padding. + padding_len = ( + padding_end - begin + 1 + if end >= padding_end + else end - begin + 1 + ) + elif padding_begin <= end <= padding_end: + # The range is an align padding after the last parameter in + # the shard. The shard includes parts of this align padding. + padding_len = ( + end - padding_begin + 1 + if begin <= padding_begin + else end - begin + 1 + ) + elif begin < padding_begin <= padding_end < end: + # The range is an align padding that is completely in the + # shard. + padding_len = numel + else: + padding_len = 0 + if padding_len: + local_buffers.append(empty_func(padding_len)) + + if not is_padding: + # This memory range is a parameter in FlatParameter. So there + # should be an corresponding state in the optimizer unless the + # parameter is frozen, which we treat it as a padding above. + + # We need to check if this rank owns the buffer. If this is None: + # 1.) the rank does not own any part of the original parameter. + # As a result, there is no corresponding optimizer state on + # the rank as well. + # 2.) the parameter is frozen AND no optimizer state for the + # parameter. If a parameter is frozen, there can still be + # optimizer state if the parameter is not frozen in the + # previous steps. + if buffers[param_idx] is not None: + local_buffers.append(cast(torch.Tensor, buffers[param_idx])) + param_idx += 1 + + mem_offset += numel + + shard_numel_padded = flat_param._sharded_size.numel() - ( + sum(t.numel() for t in local_buffers) + ) + + if flat_param._shard_numel_padded != shard_numel_padded: + raise AssertionError( + "Manually calculated _sharded_numel_padded is incorrect. " + f"_shard_numel_padded={flat_param._shard_numel_padded}, " + f"shard_numel_padded={shard_numel_padded}, " + f"_sharded_size.numel={flat_param._sharded_size.numel()}, " + f"_numels_with_padding={flat_param._numels_with_padding}, " + f"begin={begin}, end={end}," + ) + if shard_numel_padded > 0: + # Add right-handed padding. + local_buffers.append(empty_func(shard_numel_padded)) + local_shard = torch.cat(local_buffers) + if local_shard.numel() * fsdp_state.world_size != gathered_tensor.numel(): + raise AssertionError( + "The size of local shard times the world size should equal to the " + "gathered tensor size. The inconsistency may be from a bug of " + "FlatParameter's metadata or the reconstruction logic in optimizer " + "state dict." + ) + fsdp_state._device_handle.synchronize() + with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER): + dist.all_gather_into_tensor( + gathered_tensor, local_shard, group=fsdp_state.process_group + ) + # Synchronize can be slow but this will be easier for us to debug. + fsdp_state._device_handle.synchronize() + + unpadded_tensor = gathered_tensor[: flat_param._unpadded_unsharded_size.numel()] + flat_param_handle = fsdp_param_info.handle + orig_states = flat_param_handle._get_unflat_views_aligned(unpadded_tensor) + if len(orig_states) != len(fsdp_param_info.param_indices): + raise AssertionError( + "The number of parameters from FlatParameter is not consistent to " + "the number of states used by optimizer state dict reconstruction " + "logic." + ) + for fqn, idx in fsdp_param_info.param_indices.items(): + if fsdp_param_info.param_requires_grad[idx] or fqn in output_states: + output_states[fqn][state_name] = orig_states[idx] + + _unflatten_orig_param_states( + fsdp_param_info, + output_states, + state_name, + shard_state, + to_save, + cpu_offload, + ) + + del gathered_tensor + return output_states + + +def _gather_all_orig_param_state( + fsdp_param_info: FSDPParamInfo, + input_states: dict[str, Any], + shard_state: bool, + to_save: bool, + cpu_offload: bool, +) -> dict[str, Any]: + """ + Given a optimizer state dict, ``input_states``, which the keys are FQNs to the + original parameters (not FlatParameters nor parameter ID), gather all the + states and unflatten them to the original dimensions. Note that all the + params referred by the ``input_states`` must be managed by FSDP. + """ + fsdp_state = fsdp_param_info.state + if ( + fsdp_state.world_size == 1 + or fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD + ): + return input_states if to_save else {} + + with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING): + with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER_OBJ): + gathered_state_info = _allgather_state_info(fsdp_state, input_states) + output_states = _allgather_orig_param_states( + fsdp_param_info, + gathered_state_info, + input_states, + shard_state, + to_save, + cpu_offload, + ) + if to_save: + for key, idx in fsdp_param_info.param_indices.items(): + if key in output_states: + continue + if not fsdp_param_info.param_requires_grad[idx]: + continue + + raise RuntimeError( + f"{key} is not in the output state. " + "The FSDPParamInfo has the param keys " + f"{sorted(fsdp_param_info.param_indices.keys())} while " + "the output_states has the param keys " + f"{sorted(output_states.keys())}." + ) + return output_states + else: + return {} + + +def _convert_state_with_orig_params( + all_optim_state_keys: list[_OptimStateKey], + optim_state_key_to_param_key: dict[_OptimStateKey, Union[int, str]], + fqn_to_fsdp_param_info: dict[str, FSDPParamInfo], + optim_state_dict: dict[Union[str, int], Any], + to_save: bool, + shard_state: bool, + cpu_offload: bool = True, +) -> dict[str, Any]: + fsdp_osd_state: dict[str, Any] = {} + # This variable is used to deduplicate the FSDPParamInfo as one FSDPParamInfo + # usually corresponds to multiple parameters. We could not use FSDPParamInfo + # as the key because FSDPParamInfo is not hashable. As a result, we fall back + # to `id(FSDPParamInfo)`, which the type is an integer. + all_states: dict[int, dict[str, Any]] = {} + # Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers + # across ranks + for optim_state_key in all_optim_state_keys: + param_key: Union[str, int, None] = optim_state_key_to_param_key.get( + optim_state_key + ) + + if param_key is None and not optim_state_key.is_fsdp_managed: + continue + + if optim_state_key.is_fsdp_managed: + fqn = optim_state_key.unflat_param_names[0] + fsdp_param_info = fqn_to_fsdp_param_info.get(fqn) + if fsdp_param_info is None: + # This can happen if the not all FSDP instances have all the + # parameters. This can happen with FSDP + some MPMD style + # parallelism. + + # TODO: it is unclear if we need to do the same check with + # non-FSDP managed keys. + continue + state = {} if param_key is None else optim_state_dict[param_key] + if id(fsdp_param_info) not in all_states: + all_states[id(fsdp_param_info)] = {} + all_states[id(fsdp_param_info)][fqn] = state + + elif to_save: + if len(optim_state_key.unflat_param_names) != 1: + raise AssertionError( + f"Expected len(optim_state_key.unflat_param_names) == 1, got {len(optim_state_key.unflat_param_names)}" + ) + unflat_param_name = optim_state_key.unflat_param_names[0] + with SimpleProfiler.profile("none_fsdp_managed_copy"): + param_key = cast(Union[str, int], param_key) + fsdp_osd_state[unflat_param_name] = copy.copy( + optim_state_dict[param_key] + ) + if cpu_offload: + for state_name, value in sorted_items( + fsdp_osd_state[unflat_param_name] + ): + if not torch.is_tensor(value): + continue + fsdp_osd_state[unflat_param_name][state_name] = value.cpu() + + # Instead of gathering the state of each parameter individually, we perform + # the gathering all at once to speed up the process. + for _all_states in all_states.values(): + fqn = next(iter(_all_states.keys())) + fsdp_param_info = fqn_to_fsdp_param_info[fqn] + if len(fsdp_param_info.param_requires_grad) <= 0: + raise AssertionError( + "With use_orig_params, FSDPParamInfo should have requires_grad " + "information. However, the length is zero." + ) + for key, idx in fsdp_param_info.param_indices.items(): + if key in _all_states: + continue + if not fsdp_param_info.param_requires_grad[idx]: + continue + raise RuntimeError( + f"{key} is not in the optimizer state. " + "The FSDPParamInfo has the param keys " + f"{sorted(fsdp_param_info.param_indices.keys())} while " + "the optimizer has the param keys " + f"{sorted(_all_states.keys())}." + ) + fsdp_osd_state.update( + _gather_all_orig_param_state( + fsdp_param_info, + _all_states, + shard_state, + to_save, + cpu_offload, + ) + ) + + return fsdp_osd_state + + +def _convert_state_with_flat_params( + all_optim_state_keys: list[_OptimStateKey], + optim_state_key_to_param_key: dict[_OptimStateKey, Union[int, str]], + fqn_to_fsdp_param_info: dict[str, FSDPParamInfo], + optim_state_dict: dict[Union[str, int], Any], + to_save: bool, + shard_state: bool, + cpu_offload: bool = True, +) -> dict[str, Any]: + fsdp_osd_state: dict[str, Any] = {} + # Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers + # across ranks + for optim_state_key in all_optim_state_keys: + param_key: Union[str, int, None] = optim_state_key_to_param_key.get( + optim_state_key + ) + + if param_key is None: + raise AssertionError( + "If use_orig_params is False, we must be able to find the " + f"corresponding param id. {optim_state_key} {param_key}" + ) + + if optim_state_key.is_fsdp_managed: + # If there are multiple unflat_param_names (not use_orig_params), + # they share the same FSDPParamInfo. So the first unflat_param_name + # is sufficient to fetch the FSDPParamInfo. + fqn = optim_state_key.unflat_param_names[0] + fsdp_param_info = fqn_to_fsdp_param_info[fqn] + unflat_state = _unflatten_optim_state( + fsdp_param_info, + optim_state_dict[param_key], + to_save, + shard_state, + cpu_offload, + ) + if to_save: + if len(unflat_state) != len(optim_state_key.unflat_param_names): + raise AssertionError( + f"Expected len(unflat_state) == len(optim_state_key.unflat_param_names), " + f"got {len(unflat_state)} != {len(optim_state_key.unflat_param_names)}" + ) + fsdp_osd_state.update( + zip( + optim_state_key.unflat_param_names, + unflat_state, + ) + ) + elif to_save: + if len(optim_state_key.unflat_param_names) != 1: + raise AssertionError( + f"Expected len(optim_state_key.unflat_param_names) == 1, got {len(optim_state_key.unflat_param_names)}" + ) + unflat_param_name = optim_state_key.unflat_param_names[0] + fsdp_osd_state[unflat_param_name] = copy.copy(optim_state_dict[param_key]) + if cpu_offload: + for state_name, value in sorted_items( + fsdp_osd_state[unflat_param_name] + ): + if not torch.is_tensor(value): + continue + fsdp_osd_state[unflat_param_name][state_name] = value.cpu() + + return fsdp_osd_state + + +@torch.no_grad() +def _optim_state_dict( + model: nn.Module, + optim: torch.optim.Optimizer, + optim_state_dict: dict[str, Any], + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[nn.Parameter], + ] + ], + rank0_only: bool, + shard_state: bool, + group: Optional[dist.ProcessGroup], + using_optim_input: bool, + use_orig_params: bool = False, + cpu_offload: bool = True, +) -> dict[str, Any]: + """ + Consolidates the optimizer state and returns it as a :class:`dict` + following the convention of :meth:`torch.optim.Optimizer.state_dict`, + i.e. with keys ``"state"`` and ``"param_groups"``. + The flat parameters in ``FSDP`` modules contained in ``model`` are mapped + back to their unflattened parameters. + + Parameter keys are not well-defined. For a regular optimizer, the optimizer + state_dict contains a mapping from parameter IDs to parameter states. + Parameter IDs are the order of parameters in ``optim.param_groups()`` across + all the groups. This API also allows user to pass ``optim_input`` for the + mapping between parameters and parameter IDs. Using ``optim_input`` is being + deprecated. + + If the optimizer is a ``NamedOptimizer``, the optimizer state_dict does not + contain parameter IDs mapping but a mapping from parameter FQNs to parameter + states. This API finds the mapping from FQNs to parameters if the optimizer + is a ``NamedOptimizer``. + + If ``use_orig_params`` is True, each rank will have all FSDP-managed + parameters but some of these parameters may be empty due to the sharding. + For a regular optim.Optimizer, states for those empty parameters will + not be initialized. So, when aggregating the FQNs across ranks, no assert + will be raised on a rank even if it does not have all the states -- it is + valid and FSDP knows how to aggregate them. However, FSDP has to ignore + handling those parameters that are not managed by FSDP and do not exist on + the local rank -- those are managed by other parallelisms and FSDP does not + know how to handle/aggregate them. + + Args: + model (nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + were passed into the optimizer ``optim``. + optim (torch.optim.Optimizer): Optimizer for ``model`` 's + parameters. + rank0_only (bool): If ``True``, saves the populated :class:`dict` + only on rank 0; if ``False``, saves it on all ranks. (Default: + ``True``) + shard_state (bool): If ``True``, shard and distribute all + non-zero-dimension states. + + Returns: + Dict[str, Any]: A :class:`dict` containing the optimizer state for + ``model`` 's original unflattened parameters and including keys + "state" and "param_groups" following the convention of + :meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=False``, + then nonzero ranks return an empty :class:`dict`. + """ + SimpleProfiler.reset() + cm = ExitStack() + cm.enter_context(SimpleProfiler.profile(SimpleProfiler.Type.ALL)) + _reset_flat_param_grad_info_if_needed(traversal_utils._get_fsdp_handles(model)) + to_save = not rank0_only or dist.get_rank(group) == 0 or shard_state + + with SimpleProfiler.profile("preprocessing"): + param_to_fqns = _get_param_to_fqns(model) + flat_param_to_fqn = _get_flat_param_to_fqn(model) + is_named_optimizer = _is_named_optimizer(optim_state_dict) + + param_key_to_param = cast( + dict[Union[int, str], nn.Parameter], + ( + _get_param_id_to_param_from_optim_input(model, optim_input) + if using_optim_input + else _get_param_key_to_param( + optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn + ) + ), + ) + fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model) + + with SimpleProfiler.profile("preprocessing_with_comm"): + ( + all_optim_state_keys, + optim_state_key_to_param_key, + ) = _map_param_key_to_optim_keys( + optim_state_dict, + group, + param_key_to_param, + param_to_fqns, + fqn_to_fsdp_param_info, + merge_keys=use_orig_params, + ) + + with SimpleProfiler.profile("state_converting"): + convert_fn = ( + _convert_state_with_orig_params + if use_orig_params + else _convert_state_with_flat_params + ) + fsdp_osd_state = convert_fn( + all_optim_state_keys, + optim_state_key_to_param_key, + fqn_to_fsdp_param_info, + optim_state_dict["state"], + to_save, + shard_state, + cpu_offload, + ) + + # At this point, communication is complete and ranks can return early if nothing + # will be saved on that rank. + if not to_save: + return {} + + fsdp_osd: dict[str, Any] = {"state": fsdp_osd_state} + + flat_param_fqns = set(flat_param_to_fqn.values()) + for key, value in optim_state_dict["state"].items(): + if key in fsdp_osd_state: + continue + if key in flat_param_fqns: + continue + if key in param_key_to_param: + continue + # This key is not recognized by FSDP. It may be a user-defined state + # or some parameters state that FSDP is unable to map from + # ``optim.param_groups``. + warnings.warn( + f"Found a optim state, {key}, that FSDP cannot process. FSDP " + "will directly copy everything to the returned state_dict. In " + "most cases, this is a user-defined state that is not " + "associated with any particular parameter. Another possible " + "case is this state is managed by TorchRec. Otherwise, there may " + " be a mismatched assumption of optim_state_dict of this mode.", + stacklevel=2, + ) + fsdp_osd_state[key] = value + + if "param_groups" in optim_state_dict: + fsdp_osd["param_groups"] = _unflatten_param_groups( + optim_state_dict, param_key_to_param, param_to_fqns + ) + + cm.close() + SimpleProfiler.dump_and_reset("FSDP _optim_state_dict() profiling: ") + + return fsdp_osd + + +def _get_fqn_to_fsdp_param_info(model: nn.Module) -> dict[str, FSDPParamInfo]: + """ + Construct the mapping from a param's fqn to its corresponding ``FSDPParamInfo`` + if the param is managed by FSDP. Shared parameters, or original parameters that + are shared across multiple nn.Modules, are required to belong to one and only + one FSDP instance and thus correspond to one ``FlatParameter``. Within the one + ``FlatParameter``, ``FlatParameter._fqns`` only stores the first FQN of a shared + parameter. Thus, the keys in the mapping are guaranteed to map to unique parameters. + """ + + def module_fn(module, prefix, tree_level, fqn_to_param_info): + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + if fsdp_state is None: + return + _lazy_init(fsdp_state, module) + handle = _module_handle(fsdp_state, module) + if not handle: + return + flat_param = handle.flat_param + fsdp_param_info = FSDPParamInfo(fsdp_state, handle, {}, []) + # NOTE: `idx` indexes into the data structures *without* padding + # elements + for idx, local_fqn in enumerate(flat_param._fqns): + fqn = clean_tensor_name(prefix + local_fqn) + if fqn in fqn_to_param_info: + if fqn_to_param_info[fqn].handle.flat_param is not flat_param: + raise AssertionError( + f"Expected fqn_to_param_info[fqn].handle.flat_param is flat_param for {fqn}" + ) + fqn_to_param_info[fqn] = fsdp_param_info + fsdp_param_info.param_indices[fqn] = idx + if flat_param._params is not None: + fsdp_param_info.param_requires_grad.append( + flat_param._params[idx].requires_grad + ) + + def return_fn(fqn_to_param_info): + return fqn_to_param_info + + fqn_to_param_info: dict[str, FSDPParamInfo] = {} + # FlatParameter._fqns stores the local fqn, starting from the root of the + # FSDP. Using _apply_to_modules() with model (may not be the FSDP root + # module) allows us to construct the global fqn. + return _apply_to_modules( + model, + module_fn, + return_fn, + [fqn for fqn, _ in _named_parameters_with_duplicates(model)], + fqn_to_param_info, + ) + + +@no_type_check +def _set_optim_use_dtensor( + fsdp_state: _FSDPState, + state_dict_settings: StateDictSettings, +) -> None: + # If device_mesh is passed in when initializing FSDP, we automatically turn the + # _use_dtensor flag to be true for ShardedOptimStateDictConfig() if state_dict_type + # has to be set to SHARDED_STATE_DICT. + if getattr(fsdp_state, "_device_mesh", None): + state_dict_type = state_dict_settings.state_dict_type + if state_dict_type == StateDictType.LOCAL_STATE_DICT: + raise RuntimeError( + "Found state_dict_type LOCAL_STATE_DICT.", + "DeviceMesh is not compatible with LOCAL_STATE_DICT.", + "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.", + ) + else: + state_dict_settings.optim_state_dict_config._use_dtensor = True diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_runtime_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_runtime_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eab47412f5d25a3c8a3472141208d6833ec633d1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_runtime_utils.py @@ -0,0 +1,1654 @@ +# mypy: allow-untyped-defs +import functools +import logging +from collections.abc import Callable +from enum import auto, Enum +from typing import Any, no_type_check, Optional + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._traversal_utils as traversal_utils +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from torch.autograd.graph import register_multi_grad_hook +from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS +from torch.distributed.fsdp._common_utils import ( + _assert_in_training_states, + _FSDPState, + _get_module_fsdp_state, + _is_composable, + _log_post_backward_hook, + _no_dispatch_record_stream, + clean_tensor_name, + TrainingState, +) +from torch.distributed.fsdp._flat_param import ( + FlatParameter, + FlatParamHandle, + HandleShardingStrategy, + HandleTrainingState, + RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES, +) +from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES +from torch.distributed.fsdp.api import BackwardPrefetch +from torch.distributed.utils import ( + _apply_to_tensors, + _cast_forward_inputs, + _p_assert, + _to_kwargs, +) +from torch.utils import _pytree as pytree + + +logger = logging.getLogger(__name__) + +# Do not include "process_group" to enable hybrid shard and MoE cases +HOMOGENEOUS_ATTR_NAMES = ( + "_use_orig_params", + "limit_all_gathers", + "_use_full_prec_in_eval", +) + + +class _PrefetchMode(Enum): + BACKWARD = auto() + FORWARD = auto() + + +def _get_fsdp_root_states_with_modules( + module: nn.Module, +) -> tuple[list[_FSDPState], list[nn.Module]]: + """ + Returns a tuple containing: + 1. A list of the root ``_FSDPState`` instances in the module tree rooted at + ``module`` without any duplicates and following the ``module.modules()`` + traversal order (which is assumed to be depth-first). + 2. A corresponding list of the root modules owning the states in the first + list. + + This is similar to :func:`_get_fsdp_states_with_modules` except that we + must call :func:`_is_fsdp_root` to force a lazy initialization to determine + the FSDP root in case lazy initialization has not yet happened. + """ + fsdp_root_states: list[_FSDPState] = [] + fsdp_root_modules: list[nn.Module] = [] + visited_fsdp_states: set[_FSDPState] = set() + # NOTE: This function assumes that `module.modules()` proceeds top-down. + for submodule in module.modules(): + optional_state = _get_module_fsdp_state(submodule) + if ( + optional_state is not None + and optional_state not in visited_fsdp_states + and _is_fsdp_root(optional_state, submodule) + ): + visited_fsdp_states.add(optional_state) + fsdp_root_states.append(optional_state) + fsdp_root_modules.append(submodule) + return fsdp_root_states, fsdp_root_modules + + +def _get_fsdp_root_states(module: nn.Module) -> list[_FSDPState]: + """See :func:`_get_fsdp_root_states_with_modules`.""" + fsdp_root_states, _ = _get_fsdp_root_states_with_modules(module) + return fsdp_root_states + + +def _is_fsdp_root(state: _FSDPState, module: nn.Module) -> bool: + """ + Returns if ``state`` corresponds to that of an FSDP root. + + For the wrapper code path, ``state`` and ``module`` should be the same. For + the non-wrapper code path, ``state`` should be ``module`` 's state. + """ + # Force a lazy initialization to determine the FSDP root + _lazy_init(state, module) + if state._is_root is None: + raise AssertionError("Expected _is_root to be set after lazy init") + return state._is_root + + +@no_type_check +def _lazy_init( + state: _FSDPState, + root_module: nn.Module, +) -> _FSDPState: + """ + Performs initialization lazily, typically right before the first forward + pass. The laziness is needed to ensure that the parameter device/dtype and + the FSDP hierarchy have finalized. This method's actual logic only runs on + the root FSDP instance, which performs initialization for all non-root FSDP + instances to avoid partial initialization. + + For the non-composable code path, ``state`` and ``root_module`` should be + the same, namely the FSDP instance itself. + """ + if state._is_root is not None: + return # no-op: already lazily initialized + if not state._device_handle.is_available(): + # Allow the FSDP constructor to run even without CUDA but check this + # once we start real execution + raise RuntimeError("FSDP does not support CPU only execution") + # The following logic is only run on the root FSDP instance since it will + # set `_is_root=False` for the non-root instances + state._is_root = True + _assert_in_training_states(state, [TrainingState.IDLE]) + _check_flat_params_on_expected_device(state, root_module) + state._all_fsdp_states = traversal_utils._get_fsdp_states(root_module) + _init_streams(state) + buffers, buffer_dtypes = _get_buffers_and_dtypes_for_computation(state, root_module) + _cast_buffers_to_dtype_and_device(buffers, buffer_dtypes, state.compute_device) + state._exec_order_data.init(state, root_module, state.process_group) + _share_state_and_init_handle_attrs(state, root_module) + return state + + +def _check_flat_params_on_expected_device(state: _FSDPState, module: nn.Module): + """ + Checks that all ``FlatParameter``s in ``module`` 's tree managed by + ``state`` are on the expected device for *lazy initialization*. + """ + cpu_device = torch.device("cpu") + for handle in traversal_utils._get_fsdp_handles(module): + if ( + not handle._offload_params + and handle.flat_param.device != state.compute_device + ): + raise RuntimeError( + "An FSDP-managed module unexpectedly has parameters on " + f"{handle.flat_param.device}. Make sure to move the module to " + f"{state.compute_device} before training." + ) + elif handle._offload_params and handle.flat_param.device != cpu_device: + raise RuntimeError( + "An FSDP-managed module with parameter CPU offloading enabled " + f"has parameters on {handle.flat_param.device}. Make sure to " + f"not move the module from CPU when offloading parameters." + ) + + +@no_type_check +def _share_state_and_init_handle_attrs( + root_state: _FSDPState, + root_module: nn.Module, +) -> None: + """ + Shares data structure state from the ``root_state`` to all FSDP states in + ``root_module`` 's module tree, and initializes handle attributes. These + are done together to require a single loop over the states. + """ + handle = root_state._handle + if handle: + handle.init_flat_param_attributes() + attr_name_to_values: dict[str, set[Any]] = {} + for attr_name in HOMOGENEOUS_ATTR_NAMES: + attr_name_to_values[attr_name] = set() + root_state._all_handles = root_state._exec_order_data.all_handles # share reference + # Update _has_optim_in_backward for each handle. + for handle in root_state._all_handles: + flat_param = handle.flat_param + if hasattr(flat_param, "_in_backward_optimizers"): + raise RuntimeError( + "FSDP optimizer in backward only supported with use_orig_params=True!" + ) + handle._has_optim_in_backward = flat_param._params is not None and any( + hasattr(param, "_in_backward_optimizers") for param in flat_param._params + ) + if handle._has_optim_in_backward: + torch._C._log_api_usage_once("fsdp.optimizer_in_backward") + for fsdp_state in root_state._all_fsdp_states: + for attr_name in HOMOGENEOUS_ATTR_NAMES: + _p_assert( + hasattr(fsdp_state, attr_name), + f"FSDP state missing attribute {attr_name}", + ) + attr_name_to_values[attr_name].add(getattr(fsdp_state, attr_name)) + if fsdp_state is root_state: + continue + # Relax the assert for non-root FSDP instances in case the nested + # initialized module is wrapped again in FSDP later (e.g. after + # training to run inference) + _p_assert( + fsdp_state._is_root is None or not fsdp_state._is_root, + "Non-root FSDP instance's `_is_root` should not have been " + "set yet or should have been set to `False`", + ) + fsdp_state._is_root = False + fsdp_state._unshard_stream = root_state._unshard_stream + fsdp_state._post_backward_stream = root_state._post_backward_stream + fsdp_state._pre_unshard_stream = root_state._pre_unshard_stream + fsdp_state._all_reduce_stream = root_state._all_reduce_stream + fsdp_state._default_stream = root_state._default_stream + fsdp_state._exec_order_data = root_state._exec_order_data + fsdp_state._free_event_queue = root_state._free_event_queue + if fsdp_state._fsdp_extension is not None: + fsdp_state._fsdp_extension.compute_stream = root_state._default_stream + handle = fsdp_state._handle + if handle: + handle.init_flat_param_attributes() + for attr_name, attr_values in attr_name_to_values.items(): + if len(attr_values) != 1: + raise ValueError( + f"Expects one homogeneous value for {attr_name} but got {attr_values}" + ) + + +@no_type_check +def _init_streams( + state: _FSDPState, +) -> None: + """ + Initializes CUDA streams for overlapping communication, computation, and + data transfers. The streams should be shared across FSDP instances. + """ + if not state._is_root: + raise AssertionError("Expected state to be root") + if not state._device_handle.is_available(): + raise AssertionError("Expected device handle to be available") + uses_hybrid_sharding = any( + fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES + for fsdp_state in state._all_fsdp_states + ) + # Prioritize all-gathers/reduce-scatters over async all-reduce for HSDP and + # preserve the default priority of 0 otherwise + high_priority = -1 if state.limit_all_gathers and uses_hybrid_sharding else 0 + # Default stream for computation + state._default_stream = state._device_handle.current_stream() + if state._fsdp_extension is not None: + # set the compute stream to the FSDP extension + state._fsdp_extension.compute_stream = state._default_stream + + # Stream for unshard logic, including allocating the all-gather destination + # tensors and the all-gathers themselves + state._unshard_stream = state._device_handle.Stream(priority=high_priority) + # Stream for overlapping gradient reduction with the backward pass gradient + # computation + state._post_backward_stream = state._device_handle.Stream(priority=high_priority) + # Stream for pre-unshard logic, namely allocations and writes for CPU + # offloading (H2D copy) and mixed precision (low precision cast) + state._pre_unshard_stream = state._device_handle.Stream(priority=high_priority) + # Stream to run HSDP's all-reduce as async (if using HSDP) + state._all_reduce_stream = ( + state._device_handle.Stream() if uses_hybrid_sharding else state._default_stream + ) + + +@no_type_check +def _unshard( + state: _FSDPState, + handle: FlatParamHandle, + unshard_stream: torch.Stream, + pre_unshard_stream: torch.Stream, +) -> None: + """ + Unshards the handles in ``handles``. If the handles are in + :meth:`summon_full_params` and are using mixed precision, then they are + forced to full precision. + + Postcondition: handle's ``FlatParameter`` 's data is the padded + unsharded flat parameter on the compute device. + """ + if not handle: + return + with state._device_handle.stream(pre_unshard_stream): + ran_pre_unshard = handle.pre_unshard() + if ran_pre_unshard: + unshard_stream.wait_stream(pre_unshard_stream) + if state.limit_all_gathers: + event = state._free_event_queue.dequeue_if_needed() + if event: + with torch.profiler.record_function( + "FullyShardedDataParallel.rate_limiter" + ): + event.synchronize() + with state._device_handle.stream(unshard_stream): + handle.unshard() + handle.post_unshard() + + +@no_type_check +def _reshard( + state: _FSDPState, + handle: FlatParamHandle, + free_unsharded_flat_param: bool, +): + """ + Reshards the handle. ``free_unsharded_flat_param`` indicates whether to + free the handle's padded unsharded flat parameter. + """ + handle.reshard(free_unsharded_flat_param) + if state.limit_all_gathers and free_unsharded_flat_param: + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + # We don't run a even queue for freeing under torch compile atm + # But maybe we need to? TODO(voz): Look into this + free_event = state._device_handle.Event() + free_event.record() + state._free_event_queue.enqueue(free_event) + handle.post_reshard() + # Flat parameter freed or not, we always have to "unshard" the parameter + # upon next access to get its shape correct. + handle._prefetched = False + + +def _unshard_grads( + handle: Optional[FlatParamHandle], +) -> None: + if handle: + handle.unshard_grad() + + +def _reshard_grads( + handle: Optional[FlatParamHandle], +) -> None: + if handle: + handle.reshard_grad() + + +@no_type_check +def _pre_forward( + state: _FSDPState, + handle: Optional[FlatParamHandle], + unshard_fn: Callable, + module: nn.Module, + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> tuple[tuple[Any, ...], dict[str, Any]]: + """ + Runs the pre-forward logic. This includes an opportunity to unshard + currently sharded parameters such as those for the current forward and + registering post-backward hooks for these current parameters. This function + also converts forward ``args`` and ``kwargs`` to the given precision. + + Args: + handles (List[FlatParamHandle]): Handles giving the parameters used in + the current forward. + unshard_fn (Optional[Callable]): A callable to unshard any currently + sharded parameters or ``None`` to not do any unsharding. + module (nn.Module): Module whose forward this method runs right before; + expected by the hook signature. + args (Tuple[Any, ...]): Module forward ``args``. + kwargs (Dict[str, Any]): Module forward ``kwargs``. + """ + with torch.profiler.record_function("FullyShardedDataParallel._pre_forward"): + # For `fully_shard` + `checkpoint`, skip pre-forward logic in the + # recomputed forward + if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE: + # For both checkpoint implementations, we do not need to re-cast + # inputs here since they will be checkpointed in the low precision + # either by AC or normally by autograd as long as the AC region is + # nested within FSDP + return args, kwargs + state.training_state = TrainingState.FORWARD_BACKWARD + state._exec_order_data.record_pre_forward(handle, module.training) + if handle: + handle._training_state = HandleTrainingState.FORWARD + if unshard_fn is not None: + unshard_fn(state, handle) + # Register post-backward hooks to reshard the parameters and reduce-scatter + # their gradients. They must be re-registered every forward pass in case + # the `grad_fn` is mutated. + _register_post_backward_hook(state, handle) + # We have to reallocate the _cpu_grad if optimizer overlap + # set the grad to None in the backward pass. + if handle and handle._offload_params and handle.flat_param._cpu_grad is None: + handle.flat_param._cpu_grad = torch.zeros_like( + handle.flat_param._local_shard, device=torch.device("cpu") + ).pin_memory() + + should_cast_forward_inputs = ( + state._handle and not state._handle._force_full_precision + ) + + if should_cast_forward_inputs and state.mixed_precision.cast_forward_inputs: + # Recursively convert args and kwargs to specified precision. + input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype + args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs) + _register_post_backward_reshard_only_hook(state, handle, args, kwargs) + return args, kwargs + + +@no_type_check +def _pre_forward_unshard( + state: _FSDPState, + handle: Optional[FlatParamHandle], +) -> None: + """Unshards parameters in the pre-forward.""" + if not handle: + return + # If the handles have been prefetched, then there is no need to call + # `_unshard()` again + if not handle._prefetched: + _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream) + handle._needs_pre_forward_unshard = False + # Don't wait during trace + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + current_stream = state._device_handle.current_stream() + if state._unshard_event is not None: + current_stream.wait_event(state._unshard_event) + state._unshard_event = None + else: + current_stream.wait_stream(state._unshard_stream) + with torch.profiler.record_function( + "FullyShardedDataParallel._pre_forward_prefetch" + ): + _prefetch_handle(state, handle, _PrefetchMode.FORWARD) + + +@no_type_check +def _post_forward( + state: _FSDPState, + handle: Optional[FlatParamHandle], + reshard_fn: Callable, + module: nn.Module, + input: Any, + output: Any, +) -> Any: + """ + Runs the post-forward logic. This includes an opportunity to reshard + currently unsharded parameters such as those used in the current forward + and registering pre-backward hooks on the forward outputs. + + Args: + handles (List[FlatParamHandle]): Handles giving the parameters used in + the current forward. + reshard_fn (Optional[Callable]): A callable to reshard any currently + unsharded parameters (e.g. from the current forward) or ``None`` to + not do any resharding. + module (nn.Module): Module whose forward just ran, which should be a + fully sharded module (see [Note: Fully Sharded Module]); expected + by the hook signature. + input (Any): Unused; expected by the hook signature. + output (Any): Forward pass output; pre-backward hooks are registered on + the tensors that require gradients in this output. + + Postcondition: Each ``FlatParameter`` 's data points to the sharded flat + parameter. + """ + with torch.profiler.record_function("FullyShardedDataParallel._post_forward"): + # For `fully_shard` + `checkpoint`, skip post-forward logic in the + # recomputed forward + if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE: + return output + + state._exec_order_data.record_post_forward(handle) + if reshard_fn is not None: + reshard_fn(state, handle) + # Register pre-backward hooks to unshard the flat parameters for the + # gradient computation (if needed) + output = _register_pre_backward_hooks(state, module, output, handle) + state.training_state = TrainingState.IDLE + if handle: + handle._training_state = HandleTrainingState.IDLE + return output + + +@no_type_check +def _post_forward_reshard( + state: _FSDPState, + handle: FlatParamHandle, +) -> None: + """Reshards parameters in the post-forward.""" + if not handle: + return + # Do not free the root's parameters in the post-forward for `FULL_SHARD` + # with the intention that they are immediately used for backward + # computation (though this may not be true) + free_unsharded_flat_param = ( + not state._is_root + and handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES + ) + _reshard(state, handle, free_unsharded_flat_param) + + +@no_type_check +def _root_pre_forward( + state: _FSDPState, + module: nn.Module, + args, + kwargs, +) -> None: + """ + Runs pre-forward logic specific to the root FSDP instance, which should run + before any individual module's pre-forward. This starts with an attempt at + lazy initialization (which only runs non-vacuously once). Otherwise, if + this is called on a non-root FSDP instance, then it returns directly. + + Args: + module (nn.Module): Module for which this logic tries to run. It may or + may not be the root. If not, then this method does not do anything. + """ + with torch.profiler.record_function("FullyShardedDataParallel._root_pre_forward"): + _lazy_init(state, module) + _p_assert(state._is_root is not None, "Expects a root FSDP to have been set") + if not state._is_root: + # Always cast forward inputs in the root of this local FSDP unit for mixed + # precision, as this is where mixed precision could be configured. + # This is more useful for auto wrapping that is recommended in composable path. + # For manual wrapping, cast forward inputs on each local FSDP unit root will + # increase some overhead, so not turned on for model wrapper path right now where + # manual wrapping is more broadly used. + if _is_composable(state): + return _root_cast_forward_input(state, module, args, kwargs) + return args, kwargs + + # We cast buffers back to full precision if we're forcing full precision. Disjointly, we check if buffers + # are in full precision and if we should cast them back to lower precision, which happens when + # exiting eval() mode. + handle = state._handle + if handle: + should_cast_buffers_to_full_prec = handle._force_full_precision + else: + # If the root has no handle (no managed parameters), then we fall + # back to checking if any child wants to force full precision as a + # workaround + handles = traversal_utils._get_fsdp_handles(module) + should_cast_buffers_to_full_prec = any( + handle._force_full_precision for handle in handles + ) + + if should_cast_buffers_to_full_prec: + _cast_buffers_to_dtype_and_device( + buffers=dict(module.named_buffers()).values(), + buffer_dtypes=list(state._buffer_name_to_orig_dtype.values()), + device=state.compute_device, + ) + # This flag is only set when we cast buffers to full precision, to avoid the + # CPU overhead that can stem from retrieving all buffers and their types in the + # following else branch. + state._needs_buffer_dtype_restore_check = True + elif getattr(state, "_needs_buffer_dtype_restore_check", False): + # Check if buffers are in full precision and we need to cast them + # back down. + ( + buffers, + buffer_dtypes_for_computation, + ) = _get_buffers_and_dtypes_for_computation(state, module) + if len(buffers) > 0 and len(buffer_dtypes_for_computation) > 0: + if any( + buffer.dtype != buffer_dtype_for_computation + for buffer, buffer_dtype_for_computation in zip( + buffers, buffer_dtypes_for_computation + ) + ): + # Assume we have to cast everything if there is one mismatch + _cast_buffers_to_dtype_and_device( + buffers, buffer_dtypes_for_computation, state.compute_device + ) + # We don't have to check this again until we cast buffers to full precision again. + state._needs_buffer_dtype_restore_check = False + + if state.forward_prefetch: + handles = [ + fsdp_state._handle + for fsdp_state in state._all_fsdp_states + if fsdp_state._handle + ] + for handle in handles: + handle._needs_pre_forward_unshard = True + handle._prefetched = False + _wait_for_computation_stream( + state._device_handle.current_stream(), + state._unshard_stream, + state._pre_unshard_stream, + ) + _reset_flat_param_grad_info_if_needed(state._all_handles) + + # Prepares the forward inputs by moving them to ``compute_device`` + # TODO: Do not use the side stream for tensor copies for now; investigate + # the perf with/without it. + with torch.profiler.record_function("FullyShardedDataParallel._to_kwargs"): + args_tuple, kwargs_tuple = _to_kwargs( + args, kwargs, state.compute_device, False + ) + args = args_tuple[0] if args_tuple else tuple() + kwargs = kwargs_tuple[0] if kwargs_tuple else {} + + return _root_cast_forward_input(state, module, args, kwargs) + + +@no_type_check +def _root_cast_forward_input( + state: _FSDPState, module: torch.nn.Module, args, kwargs +) -> tuple[Any, Any]: + if state._handle: + force_full_precision = not state._handle._force_full_precision + else: + force_full_precision = True + + should_cast_forward_inputs = ( + (module.training or not state._use_full_prec_in_eval) and force_full_precision + ) and state.mixed_precision.cast_root_forward_inputs + + if should_cast_forward_inputs: + input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype + args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs) + + return args, kwargs + + +@no_type_check +def _pre_backward_hook( + state: _FSDPState, + module: nn.Module, + handle: FlatParamHandle, + grad, + *unused: Any, +) -> Any: + """ + Prepares ``_handle`` 's ``FlatParameter`` s for gradient computation. + + Args: + module (nn.Module): Fully sharded module (see [Note: Fully Sharded + Module]). + """ + # Only run the pre-backward hook once per group of handles involved in the + # same module forward computation + if ( + handle + and hasattr(handle, "_ran_pre_backward_hook") + and handle._ran_pre_backward_hook + ): + return grad + + with torch.profiler.record_function("FullyShardedDataParallel._pre_backward_hook"): + # Queue the post-backward callback once for the root FSDP instance to + # attach it to the outermost backward graph task so that it is called + # after all backward calls complete + if state._is_root and not state._post_backward_callback_queued: + _register_post_backward_final_callback(state, module) + _reset_flat_param_grad_info_if_needed(state._all_handles) + elif handle: + allowed_states = [TrainingState.IDLE] + if _is_composable(state): + allowed_states.append(TrainingState.FORWARD_BACKWARD) + _assert_in_training_states(state, allowed_states) + state.training_state = TrainingState.FORWARD_BACKWARD + # Queueing the post-backward callback is the only logic that is not + # per-handle in the pre-backward hook, so we can return early here if + # there are no handles. + if not handle: + return grad + handle._training_state = HandleTrainingState.BACKWARD_PRE + + if handle._needs_pre_backward_unshard: + # If the handles have been prefetched, then there is no need to + # call `_unshard()` again + if not handle._prefetched: + _unshard( + state, + handle, + state._unshard_stream, + state._pre_unshard_stream, + ) + # Don't wait during trace + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + state._device_handle.current_stream().wait_stream(state._unshard_stream) + + # Set this to `False` to ensure that a mistargeted prefetch does not + # actually unshard these handles + handle._needs_pre_backward_unshard = False + with torch.profiler.record_function( + "FullyShardedDataParallel._pre_backward_prefetch" + ): + _prefetch_handle(state, handle, _PrefetchMode.BACKWARD) + handle.prepare_gradient_for_backward() + handle._ran_pre_backward_hook = True + return grad + + +@no_type_check +@torch.no_grad() +def _post_backward_hook( + state: _FSDPState, + handle: FlatParamHandle, + flat_param, + *unused: Any, +): + """ + Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``. + + Precondition: The ``FlatParameter`` 's ``.grad`` attribute contains the + unsharded gradient for the local batch. + + Postcondition: + - If using ``NO_SHARD``, then the ``.grad`` attribute is the reduced + unsharded gradient. + - Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded + gradient (accumulating with any existing gradient). + """ + _log_post_backward_hook(state, handle, logger) + flat_param = handle.flat_param + flat_param._post_backward_called = True + with torch.autograd.profiler.record_function( + "FullyShardedDataParallel._post_backward_hook" + ): + _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD]) + # For multiple applications of reentrant AC across submodules sharing + # the same `FlatParameter`, the post-backward hook may run multiple + # times in one backward, in which case we permit the state to already + # be in `BACKWARD_POST`. + _p_assert( + handle._training_state + in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST), + f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}", + ) + handle._training_state = HandleTrainingState.BACKWARD_POST + + if flat_param.grad is None: + return + if flat_param.grad.requires_grad: + raise RuntimeError("FSDP does not support gradients of gradients") + + _post_backward_reshard(state, handle) + if not state._sync_gradients: + if handle._use_orig_params: + handle._use_unsharded_grad_views() + return + + # Wait for all ops in the current stream (e.g. gradient computation) to + # finish before reduce-scattering the gradient + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + state._post_backward_stream.wait_stream( + state._device_handle.current_stream() + ) + + with state._device_handle.stream(state._post_backward_stream): + autograd_computed_grad = flat_param.grad.data + if ( + not _low_precision_hook_enabled(state) + and flat_param.grad.dtype != handle._reduce_dtype + # If we are forcing full precision but communicating grads + # (i.e. model.eval() + full precision in eval was configured), don't downcast gradient. + and not handle._force_full_precision + ): + flat_param.grad.data = flat_param.grad.to(handle._reduce_dtype) + if handle.uses_sharded_strategy: + _reduce_grad(state, handle) + else: + _reduce_grad_no_shard(state, handle) + # Since the unsharded gradient is produced in the computation + # stream and consumed in the post-backward stream, inform the + # caching allocator (before it goes out of scope) + _no_dispatch_record_stream( + autograd_computed_grad, state._post_backward_stream + ) + + +def _post_backward_reshard_only_hook( + state: _FSDPState, + handle: FlatParamHandle, + *unused: Any, +) -> None: + with torch.profiler.record_function( + "FullyShardedDataParallel._post_backward_hook_reshard_only" + ): + # `_pre_backward_hook` may not get executed + # if forward output does not require grad + # overwrite IDLE state for post-backward prefetching + state.training_state = TrainingState.FORWARD_BACKWARD + handle._training_state = HandleTrainingState.BACKWARD_POST + _post_backward_reshard(state, handle) + + +def _post_backward_reshard( + state: _FSDPState, + handle: FlatParamHandle, + *unused: Any, +) -> None: + free_unsharded_flat_param = _should_free_in_backward(state, handle) + _reshard(state, handle, free_unsharded_flat_param) + + # TODO: Post-backward prefetching does not support the multiple handles + # per module case since the post-backward hook runs per handle, not per + # group of handles. + with torch.profiler.record_function( + "FullyShardedDataParallel._post_backward_prefetch" + ): + _prefetch_handle(state, handle, _PrefetchMode.BACKWARD) + + +@no_type_check +def _should_free_in_backward( + state: _FSDPState, + handle: FlatParamHandle, +) -> bool: + """ + Returns whether FSDP should free the unsharded flat parameter in the + post-backward or not. + """ + if not handle.uses_sharded_strategy: + return False + # If not syncing gradients, then we do not free for strategies that do not + # reshard after forward as a *heuristic* to tradeoff higher memory for + # higher throughput. + return ( + state._sync_gradients + or handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES + ) + + +@no_type_check +def _reduce_grad(state: _FSDPState, handle: FlatParamHandle) -> None: + """ + For sharded strategies, this runs gradient reduction, sharded gradient + accumulation if needed, and the post-reduction callback. + """ + flat_param = handle.flat_param + uses_hybrid_sharded_strategy = handle._sharding_strategy in ( + HandleShardingStrategy.HYBRID_SHARD, + HandleShardingStrategy._HYBRID_SHARD_ZERO2, + ) + # We clear `.grad` to permit multiple backwards. This avoids a race where + # the second backward pass computation precedes ahead of the first backward + # pass reduction, which is possible since the reduction is issued in a + # separate stream and is async and would result in reducing the wrong + # gradient. + unsharded_grad = flat_param.grad.data + flat_param.grad = None + padded_unsharded_grad, new_sharded_grad = _get_reduce_scatter_tensors( + state, unsharded_grad + ) + if state._comm_hook is None: # default path + _div_if_needed(padded_unsharded_grad, state._gradient_predivide_factor) + pg = ( + handle._fake_process_group + if handle._use_fake_reduce + else state.process_group + ) + dist.reduce_scatter_tensor( + new_sharded_grad, + padded_unsharded_grad, + group=pg, + ) + if uses_hybrid_sharded_strategy: + # Don't wait during trace + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + state._all_reduce_stream.wait_stream(state._post_backward_stream) + with state._device_handle.stream(state._all_reduce_stream): + # Since the new sharded gradient is produced in the post- + # backward stream and consumed in the all-reduce stream, + # inform the caching allocator + _no_dispatch_record_stream(new_sharded_grad, state._all_reduce_stream) + dist.all_reduce(new_sharded_grad, group=state._inter_node_pg) + _div_if_needed(new_sharded_grad, state._gradient_postdivide_factor) + grad_to_offload = _accumulate_sharded_grad( + state, handle, new_sharded_grad + ) + _post_reduce_grad_callback(state, handle, grad_to_offload) + return + _div_if_needed(new_sharded_grad, state._gradient_postdivide_factor) + else: + state._comm_hook( + state._comm_hook_state, padded_unsharded_grad, new_sharded_grad + ) + # NOTE: HSDP variants do not support communication hook. + grad_to_offload = _accumulate_sharded_grad(state, handle, new_sharded_grad) + _post_reduce_grad_callback(state, handle, grad_to_offload) + + +@no_type_check +def _get_reduce_scatter_tensors( + state: _FSDPState, unsharded_grad: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Returns the input and output tensors to reduce-scatter, respectively. + """ + chunks = list(unsharded_grad.chunk(state.world_size)) + numel_to_pad = state.world_size * chunks[0].numel() - unsharded_grad.numel() + padded_unsharded_grad = ( + F.pad(unsharded_grad, [0, numel_to_pad]) if numel_to_pad > 0 else unsharded_grad + ) + new_sharded_grad = torch.empty_like(chunks[0]) # padded + return padded_unsharded_grad, new_sharded_grad + + +@no_type_check +def _accumulate_sharded_grad( + state: _FSDPState, + handle: FlatParamHandle, + sharded_grad: torch.Tensor, +) -> torch.Tensor: + """ + Accumulates the reduce-scattered sharded gradient with any existing sharded + gradient if needed, returning the gradient to offload (if CPU offloading is + enabled). + """ + flat_param = handle.flat_param + _cast_grad_to_param_dtype(state, sharded_grad, flat_param) + # Save the sharded gradient in `_saved_grad_shard` to support gradient + # accumulation -- for multiple backwards, the gradient reductions may + # happen in arbitrary order + accumulate_grad = hasattr(flat_param, "_saved_grad_shard") + if accumulate_grad: + _check_grad_to_accumulate(sharded_grad, flat_param._saved_grad_shard) + flat_param._saved_grad_shard += sharded_grad + else: + flat_param._saved_grad_shard = sharded_grad + grad_to_offload = flat_param._saved_grad_shard + return grad_to_offload + + +@no_type_check +def _reduce_grad_no_shard(state: _FSDPState, handle: FlatParamHandle) -> None: + """ + For no-shard, this runs gradient reduction (which directly covers any + gradient accumulation implicitly) and the post-reduction callback. + """ + flat_param = handle.flat_param + if state._comm_hook is None: # default path + _div_if_needed(flat_param.grad, state._gradient_predivide_factor) + dist.all_reduce(flat_param.grad, group=state.process_group) + _div_if_needed(flat_param.grad, state._gradient_postdivide_factor) + else: + state._comm_hook(state._comm_hook_state, flat_param.grad) + # For `NO_SHARD`, we can keep the low precision gradients by simply + # omitting the cast altogether + if not handle._keep_low_precision_grads: + _cast_grad_to_param_dtype(state, flat_param.grad, flat_param) + grad_to_offload = flat_param.grad.data + _post_reduce_grad_callback(state, handle, grad_to_offload) + + +@no_type_check +def _post_reduce_grad_callback( + state: _FSDPState, + handle: FlatParamHandle, + # Additional arguments needed for the callback logic + grad_to_offload: torch.Tensor, +): + """ + This callback captures any logic to run after the gradient reduction + finishes. Currently, this offloads the gradient to CPU if CPU offloading is + enabled and uses sharded gradient views if ``use_orig_params=True``. + """ + _offload_grad(state, handle, grad_to_offload) + _post_backward_use_sharded_grad_views(handle) + + +@no_type_check +def _offload_grad( + state: _FSDPState, + handle: FlatParamHandle, + grad_to_offload: torch.Tensor, +): + if not handle._offload_params: + return + # Offload the gradient to CPU to ensure parameters and gradients are on the + # same device as required by the optimizer + # TODO: Investigate why `NO_SHARD` breaks correctness when using + # `non_blocking=True` here. + # TODO (rohan-varma): When CPU offload and optimizer overlap, + # non_blocking=True won't work since the copy may have not finished before + # the optimizer step executes on CPU. If we want to use non-blocking=True + # here, we'll have to synchronize before using result on CPU. + non_blocking = handle.uses_sharded_strategy and not handle._has_optim_in_backward + handle.flat_param._cpu_grad.copy_( + grad_to_offload.detach(), non_blocking=non_blocking + ) # synchronized in the post-backward callback + # Since the gradient being offloaded may have been produced in the + # computation stream and is being consumed here in the post-backward + # stream, inform the caching allocator + _no_dispatch_record_stream(grad_to_offload.data, state._post_backward_stream) + + +@no_type_check +def _post_backward_use_sharded_grad_views(handle: FlatParamHandle): + if not handle._use_orig_params: + return + # Since the handle's `FlatParameter` completed its gradient computation, we + # should reset the gradient noneness mask + handle._reset_is_grad_none() + # Delay using sharded gradient views until after the reduce-scatter instead + # of immediately after resharding + handle._use_sharded_grad_views() + if handle._has_optim_in_backward: + handle.prepare_gradient_for_optim() + for orig_param in handle.flat_param._params: + # Check for `None` gradient to filter parameters not in the rank + if orig_param.grad is not None and hasattr( + orig_param, "_in_backward_optimizers" + ): + # TODO (rohan-varma): For CPU offload, this unfortunately + # operates on CPU because the parameters and gradients have + # already been offloaded. We should run this on GPU after + # refactoring. + for optim in orig_param._in_backward_optimizers: + optim.step() + + optim.zero_grad(set_to_none=True) + handle._reset_flat_param_grad_info_if_needed() + if handle._offload_params: + handle.flat_param._cpu_grad = None + + +def _div_if_needed(tensor: torch.Tensor, div_factor: float) -> None: + if div_factor > 1: + tensor.div_(div_factor) + + +@no_type_check +def _cast_grad_to_param_dtype( + state: _FSDPState, + sharded_grad: torch.Tensor, + param: FlatParameter, +): + """ + Casts ``sharded_grad`` back to the full parameter dtype so that the + optimizer step runs with that dtype. This performs an actual cast if + 1. parameters were in reduced precision during the forward since then + gradients would be in that reduced precision, or + 2. parameters were not in reduced precision but gradients were in + reduced precision for communication. + However, if a low precision communication hook is registered, then this + dtype cast happens in the hook instead. + """ + _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD]) + if not _low_precision_hook_enabled(state) and sharded_grad.dtype != param.dtype: + low_prec_grad_data = sharded_grad.data + sharded_grad.data = sharded_grad.data.to(dtype=param.dtype) + # Since for `NO_SHARD`, the gradient is produced in the computation + # stream and consumed here in the post-backward stream, inform the + # caching allocator; for the sharded strategies, the gradient is + # produced in the post-backward stream, so this `record_stream()` + # should be a no-op + _no_dispatch_record_stream( + low_prec_grad_data, state._device_handle.current_stream() + ) + + +def _check_grad_to_accumulate( + new_sharded_grad: torch.Tensor, + accumulated_grad: torch.Tensor, +) -> None: + _p_assert( + accumulated_grad.shape == new_sharded_grad.shape, + "Shape mismatch when accumulating gradients: " + f"existing gradient shape={accumulated_grad.shape} " + f"new gradient shape={new_sharded_grad.shape}", + ) + _p_assert( + accumulated_grad.device == new_sharded_grad.device, + "Device mismatch when accumulating gradients: " + f"existing gradient device={accumulated_grad.device} " + f"new gradient device={new_sharded_grad.device}", + ) + + +@no_type_check +def _low_precision_hook_enabled(state: _FSDPState) -> bool: + return state._comm_hook in LOW_PRECISION_HOOKS + + +@no_type_check +@torch.no_grad() +def _post_backward_final_callback( + state: _FSDPState, + module: nn.Module, +): + """ + This waits for the post-backward to finish and performs some final cleanup. + This runs at the end of the entire backward pass and should only be called + on the root FSDP instance. + """ + _p_assert( + state._is_root, + "The post-backward callback should only be called on the root FSDP instance", + ) + root_state = state + + if root_state._sync_gradients: + current_stream = state._device_handle.current_stream() + # TODO (rohan-varma): this also waits for the overlapped optimizer step to finish + # since it currently runs in the post-backward stream. That can be + # pushed to the next forward if run in a different stream + current_stream.wait_stream(root_state._post_backward_stream) + if root_state._all_reduce_stream is not current_stream: # uses HSDP + current_stream.wait_stream(root_state._all_reduce_stream) + if root_state.cpu_offload.offload_params: + # Wait for non-blocking GPU -> CPU sharded gradient copies from the + # post-backward hooks to finish explicitly since CPU gradients do + # not automatically synchronize with the GPU + state._device_handle.current_stream().synchronize() + root_state._exec_order_data.next_iter() + + for fsdp_state in state._all_fsdp_states: + _catch_all_reshard(fsdp_state) + _finalize_params(fsdp_state) + fsdp_state.training_state = TrainingState.IDLE + handle = fsdp_state._handle + if handle: + handle._ran_pre_backward_hook = False + handle._needs_pre_backward_unshard = False + handle._post_forward_index = None + handle._training_state = HandleTrainingState.IDLE + handle._prefetched = False + # Reset for cases like one forward and multiple backwards + root_state._post_backward_callback_queued = False + + +@no_type_check +def _catch_all_reshard( + state: _FSDPState, +) -> None: + """ + Reshards the parameters that may not have been resharded in the + post-backward hook. This can happen when a module's output is used in the + forward pass, meaning that its pre-backward hook runs (unsharding the + parameter), but the post-backward hook does not run because the output was + not jused in the loss computation corresponding to this backward pass. + """ + # Wrap with a try-except to provide a more informative traceback if an + # error is raised + try: + if state._handle: + # TODO: This already-resharded check is brittle: + # https://github.com/pytorch/pytorch/issues/83956 + already_resharded = ( + state._handle.flat_param.data_ptr() + == state._handle.flat_param._local_shard.data_ptr() + # If FSDP skipped using sharded views, then the flat parameter + # still points to the sharded data, so we need to reshard to + # use sharded views + and not state._handle._skipped_use_sharded_views + ) + if already_resharded: + return + free_unsharded_flat_param = _should_free_in_backward(state, state._handle) + _reshard(state, state._handle, free_unsharded_flat_param) + except Exception as e: + _p_assert( + False, + f"Got exception in the catch-all reshard for {state}: {str(e)}", + raise_assertion_error=False, + ) + raise e + + +@no_type_check +def _finalize_params( + state: _FSDPState, +) -> None: + """Finalizes the parameters before the next iteration.""" + handle = state._handle + if not handle: + return + flat_param = handle.flat_param + if torch.distributed._functional_collectives.is_torchdynamo_compiling(): + if hasattr(flat_param, "_post_backward_hook_handle"): + pbhs_handle = flat_param._post_backward_hook_handle + pbhs_handle.remove() + del flat_param._post_backward_hook_handle + else: + if hasattr(flat_param, "_post_backward_hook_state"): + post_backward_hook_state_len = len(flat_param._post_backward_hook_state) + expected_post_backward_hook_state_len = int(flat_param.requires_grad) + 1 + _p_assert( + post_backward_hook_state_len == expected_post_backward_hook_state_len, + f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}", + ) + flat_param._post_backward_hook_state[-1].remove() + delattr(flat_param, "_post_backward_hook_state") + if flat_param.requires_grad: + if not state._sync_gradients: + # Preserve the gradient accumulation state if not synchronizing + # gradients: `.grad` remains the unsharded gradient from prior + # `no_sync()` iterations, and `_saved_grad_shard` remains the + # sharded gradient from the last synchronized iteration + return + if not handle._has_optim_in_backward: + handle.prepare_gradient_for_optim() + _p_assert( + hasattr(flat_param, "_post_backward_called"), + "Expects `_post_backward_called` to be set on the `FlatParameter`", + ) + flat_param._post_backward_called = False + + +@no_type_check +def _prefetch_handle( + state: _FSDPState, + current_handle: Optional[FlatParamHandle], + prefetch_mode: _PrefetchMode, +) -> None: + """ + Prefetches the next handles if needed (without synchronization). An empty + handles key cannot prefetch. + """ + if not current_handle: + return + handle = _get_handle_to_prefetch(state, current_handle) + if not handle: + return + # Temporarily emulate the training state while calling `_unshard` to + # ensure the correct `as_params` for `_use_unsharded_views()` + prev_training_state = handle._training_state + if prefetch_mode == _PrefetchMode.BACKWARD: + handle._training_state = HandleTrainingState.BACKWARD_PRE + elif prefetch_mode == _PrefetchMode.FORWARD: + handle._training_state = HandleTrainingState.FORWARD + else: + raise ValueError(f"Invalid prefetch mode on rank {state.rank}: {prefetch_mode}") + # Prefetch the next set of handles without synchronizing to allow + # the sync to happen as late as possible to maximize overlap + _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream) + handle._training_state = prev_training_state + handle._prefetched = True + + +@no_type_check +def _get_handle_to_prefetch( + state: _FSDPState, + current_handle: FlatParamHandle, +) -> FlatParamHandle: + """ + Returns a :class:`list` of the handles keys to prefetch for the next + module(s), where ``current_handle`` represents the current module. + + "Prefetching" refers to running the unshard logic early (without + synchronization), and the "next" modules depend on the recorded execution + order and the current training state. + """ + training_state = _get_training_state(current_handle) + valid_training_states = ( + HandleTrainingState.BACKWARD_PRE, + HandleTrainingState.BACKWARD_POST, + HandleTrainingState.FORWARD, + ) + _p_assert( + training_state in valid_training_states, + f"Prefetching is only supported in {valid_training_states} but " + f"currently in {training_state}", + ) + eod = state._exec_order_data + target_handle: Optional[FlatParamHandle] = None + if ( + training_state == HandleTrainingState.BACKWARD_PRE + and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE + ) or ( + training_state == HandleTrainingState.BACKWARD_POST + and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST + ): + target_handle_candidate = eod.get_handle_to_backward_prefetch(current_handle) + if ( + target_handle_candidate + and target_handle_candidate._needs_pre_backward_unshard + and not target_handle_candidate._prefetched + ): + target_handle = target_handle_candidate + else: + target_handle = None + elif training_state == HandleTrainingState.FORWARD and state.forward_prefetch: + target_handle_candidate = eod.get_handle_to_forward_prefetch(current_handle) + if ( + target_handle_candidate + and target_handle_candidate._needs_pre_forward_unshard + and not target_handle_candidate._prefetched + ): + target_handle = target_handle_candidate + else: + target_handle = None + + return target_handle + + +def _get_training_state( + handle: FlatParamHandle, +) -> HandleTrainingState: + """Returns the training state of the handles in ``handle``.""" + _p_assert(handle, "Expects a non-empty handle") + return handle._training_state + + +@no_type_check +def _register_pre_forward_hook( + state: _FSDPState, + module: nn.Module, +) -> None: + """ + Registers a pre-forward hook on ``module``. + """ + for forward_handle in state._pre_forward_handles: + forward_handle.remove() + state._pre_forward_handles.clear() + module_param_handle = state._fully_sharded_module_to_handle.get(module, None) + hook = functools.partial( + _pre_forward, state, module_param_handle, _pre_forward_unshard + ) + state._pre_forward_handles.append( + module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True) + ) + + +@no_type_check +def _register_post_forward_hook( + state: _FSDPState, + module: nn.Module, +) -> None: + """ + Registers a post-forward hook on ``module``. Even if the module has no + handles, we should register the hook since it will register the module's + pre-backward hook. + """ + for forward_handle in state._post_forward_handles: + forward_handle.remove() + state._post_forward_handles.clear() + module_param_handle = state._fully_sharded_module_to_handle.get(module, None) + hook = functools.partial( + _post_forward, + state, + module_param_handle, + _post_forward_reshard, + ) + state._post_forward_handles.append(module.register_forward_hook(hook)) + + +@no_type_check +def _register_root_pre_forward_hook( + state: _FSDPState, + module: nn.Module, +): + """ + Registers root pre-forward hook on ``module``, which should be the local + FSDP root. + + NOTE: For the current composable FSDP design, we have each application of + ``fully_shard()`` to a module to indicate that that module is the local + FSDP root. We may remove this assumption in the future, in which case we + will need to register this root pre-forward hook on any candidate module + that may be the local FSDP root. + """ + for forward_handle in state._root_pre_forward_handles: + forward_handle.remove() + state._root_pre_forward_handles.clear() + hook = functools.partial(_root_pre_forward, state) + state._root_pre_forward_handles.append( + module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True) + ) + + +@no_type_check +def _register_pre_backward_hooks( + state: _FSDPState, + module: nn.Module, + outputs: Any, + handle: FlatParamHandle, +) -> None: + """ + Registers pre-backward hooks on the tensors that require gradients in the + forward pass outputs ``outputs``, which were computed using the + ``FlatParameter`` s of ``handles``. + + Args: + module (nn.Module): Fully sharded module (see [Note: Fully Sharded + Module]). + + Returns: + Forward pass outputs with pre-backward hooks registered to tensors that + require gradients. + """ + # If there is no gradient computation, then there is no need for + # pre-backward logic + if not torch.is_grad_enabled(): + return outputs + if state._is_root: + state._post_backward_callback_queued = False # only defined on the root + + if handle: + handle._needs_pre_backward_unshard = False + # Since these handles' `FlatParameter`s participated in a forward, we + # conservatively assume that they will be used in the backward + handle._ran_pre_backward_hook = False + + def _register_hook(t: torch.Tensor) -> torch.Tensor: + if t.requires_grad: + t.register_hook( + torch.utils.hooks.unserializable_hook( + functools.partial(_pre_backward_hook, state, module, handle) + ) + ) + if handle: + handle._needs_pre_backward_unshard = True + return t + + return _apply_to_tensors(_register_hook, outputs) + + +def _register_post_backward_hook( + state: _FSDPState, + handle: Optional[FlatParamHandle], +) -> None: + """ + Registers post-backward hooks on the ``FlatParameter`` s' + ``AccumulateGrad`` objects to reshard and to reduce-scatter gradients. + + The ``AccumulateGrad`` object represents the last function that finalizes + the ``FlatParameter`` 's gradient, so it only runs after its entire + gradient computation has finished. + + We register the post-backward hook only once in the *first* forward that a + ``FlatParameter`` participates in. This relies on the ``AccumulateGrad`` + object being preserved through multiple forwards. + + NOTE: We follow this heuristic to prefer the *first* forward to target the + parameter mixed precision case, where there are *separate* + ``AccumulateGrad`` objects across the different forwards. (Without + parameter mixed precision, the ``AccumulateGrad`` objects are the same.) If + we instead prefer the *last* forward, then the hook runs early. + """ + # If there is no gradient computation, then there is no need for + # post-backward logic + if not torch.is_grad_enabled(): + return + if not handle: + return + flat_param = handle.flat_param + + if torch.distributed._functional_collectives.is_torchdynamo_compiling(): + already_registered = hasattr(flat_param, "_post_backward_hook_handle") + if already_registered or not flat_param.requires_grad: + return + hook = functools.partial(_post_backward_hook, state, handle) + hook_handle = flat_param.register_post_accumulate_grad_hook(hook) + flat_param._post_backward_hook_handle = hook_handle # type: ignore[attr-defined] + else: + already_registered = hasattr(flat_param, "_post_backward_hook_state") + if already_registered or not flat_param.requires_grad: + return + # Get the `AccumulateGrad` object + temp_flat_param = flat_param.expand_as(flat_param) + _p_assert( + temp_flat_param.grad_fn is not None, + "The `grad_fn` is needed to access the `AccumulateGrad` and " + "register the post-backward hook", + ) + acc_grad = temp_flat_param.grad_fn.next_functions[0][0] # type: ignore[union-attr] + if acc_grad is None: + raise AssertionError("Expected acc_grad to be set") + hook_handle = acc_grad.register_hook( + functools.partial(_post_backward_hook, state, handle) + ) + flat_param._post_backward_hook_state = (acc_grad, hook_handle) # type: ignore[attr-defined] + + +def _register_post_backward_reshard_only_hook( + state: _FSDPState, + handle: Optional[FlatParamHandle], + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> None: + """ + Registers post-backward hooks to reshard flat parameters that do not + require gradient. We register these using multi-post-grad hooks on the + input activations to ensure that all gradients that may depend on the + parameters have been computed before resharding. + """ + # If there is no gradient computation, then there is no need for + # post-backward logic + if not torch.is_grad_enabled(): + return + # Construct `inp_tensors` lazily to avoid CPU overhead in typical case + # where each flat parameter requires gradient + inp_tensors: Optional[list[torch.Tensor]] = None + if not handle: + return + flat_param = handle.flat_param + + if torch.distributed._functional_collectives.is_torchdynamo_compiling(): + already_registered = hasattr(flat_param, "_post_backward_hook_handle") + else: + already_registered = hasattr(flat_param, "_post_backward_hook_state") + + if already_registered or flat_param.requires_grad: + return + if inp_tensors is None: + args_flat = pytree.arg_tree_leaves(*args, **kwargs) + inp_tensors = [ + obj for obj in args_flat if torch.is_tensor(obj) and obj.requires_grad + ] + if inp_tensors is None: + raise AssertionError("Expected inp_tensors to be set") + hook_handle = register_multi_grad_hook( + inp_tensors, functools.partial(_post_backward_reshard_only_hook, state, handle) + ) + if torch.distributed._functional_collectives.is_torchdynamo_compiling(): + flat_param._post_backward_hook_handle = hook_handle # type: ignore[attr-defined, assignment] + else: + flat_param._post_backward_hook_state = (hook_handle,) # type: ignore[attr-defined, assignment] + + +@no_type_check +def _register_post_backward_final_callback( + state: _FSDPState, module: nn.Module +) -> None: + """ + Registers the post-backward final callback that runs at the end of the + backward pass. This should be called from the root FSDP instance at the + beginning of the pre-backward. + """ + _p_assert( + state._is_root, + "Only the root FSDP instance should register the post-backward callback", + ) + if state._post_backward_callback_queued: + return + _assert_in_training_states(state, [TrainingState.IDLE]) + # Trace does not need this callback + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + state._post_backward_callback_queued = True + Variable._execution_engine.queue_callback( + functools.partial(_post_backward_final_callback, state, module) + ) + + +def _wait_for_computation_stream( + computation_stream: torch.Stream, + unshard_stream: torch.Stream, + pre_unshard_stream: torch.Stream, +): + """ + Has the unshard and pre-unshard streams wait for the computation stream. + For example, this should be called in the FSDP root's pre-forward to + respect optimizer step computation. + """ + # Tracing does not need to wait + if torch.distributed._functional_collectives.is_torchdynamo_compiling(): + return + unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined] + # Having the pre-all-gather stream wait for the current stream even if we + # do not leverage the pre-all-gather stream is tolerable since this only + # runs once per iteration + pre_unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined] + + +def _reset_flat_param_grad_info_if_needed( + handles: list[FlatParamHandle], +): + """ + Clears the original parameters' gradients if needed. This method's CPU + overhead is minimal, so we may call it throughout FSDP methods, which serve + as callsites to free the gradient memory earlier. + """ + if not isinstance(handles, list): + handles = [handles] + for handle in handles: + if handle._use_orig_params: + handle._reset_flat_param_grad_info_if_needed() + + +@no_type_check +def _get_buffers_and_dtypes_for_computation( + state: _FSDPState, + root_module: nn.Module, +) -> tuple[list[torch.Tensor], list[Optional[torch.dtype]]]: + """ + Returns all buffers in the module tree rooted at ``root_module`` and a + corresponding list of the buffer dtypes for computation. Each buffer dtype + is either ``None`` if buffer mixed precision is not enabled or the buffer + low precision dtype otherwise. + """ + _p_assert(state._is_root, "Expects the root to cast buffers") + buffers: list[torch.Tensor] = [] + buffer_dtypes: list[Optional[torch.dtype]] = [] + visited_buffers: set[torch.Tensor] = set() + # Traverse the FSDP states bottom-up so that we prefer the owning FSDP + # instance's mixed precision setting for each buffer + fsdp_states, fsdp_modules = traversal_utils._get_fsdp_states_with_modules( + root_module + ) + for fsdp_state, fsdp_module in zip(reversed(fsdp_states), reversed(fsdp_modules)): + for buffer_name, buffer in fsdp_module.named_buffers(): + if buffer in visited_buffers: + continue + visited_buffers.add(buffer) + if clean_tensor_name(buffer_name) in fsdp_state._ignored_buffer_names: + continue + buffers.append(buffer) + buffer_dtypes.append(fsdp_state.mixed_precision.buffer_dtype) + if len(buffers) != len(buffer_dtypes): + raise AssertionError( + f"Expected buffers and buffer_dtypes to have the same length, got {len(buffers)} and {len(buffer_dtypes)}" + ) + return buffers, buffer_dtypes + + +@no_type_check +def _get_orig_buffer_dtypes( + state: _FSDPState, + buffer_names: list[str], +) -> list[torch.dtype]: + """ + Returns the original buffer types of the given buffer names. + """ + buffer_dtypes: list[torch.dtype] = [] + for buffer_name in buffer_names: + _p_assert( + buffer_name in state._buffer_name_to_orig_dtype, + f"{buffer_name} is missing from pre-computed dict on rank " + f"{state.rank}, which only has keys " + f"{state._buffer_name_to_orig_dtype.keys()}", + ) + buffer_dtypes.append(state._buffer_name_to_orig_dtype[buffer_name]) + return buffer_dtypes + + +def _cast_buffers_to_dtype_and_device( + buffers: list[torch.Tensor], + buffer_dtypes: list[Optional[torch.dtype]], + device: torch.device, +) -> None: + """ + Casts ``buffers`` to the dtypes given by ``buffer_dtypes`` and moves them + to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the + corresponding buffer is only moved to ``device``. + """ + _p_assert( + buffer_dtypes is None or len(buffers) == len(buffer_dtypes), + f"Expects `buffers` and `buffer_dtypes` to have the same length if " + f"`buffer_dtypes` is specified but got {len(buffers)} and " + f"{len(buffer_dtypes)}", + ) + for buffer, buffer_dtype in zip(buffers, buffer_dtypes): + if not torch.is_floating_point(buffer) or buffer_dtype is None: + buffer.data = buffer.to(device=device) + else: + buffer.data = buffer.to(device=device, dtype=buffer_dtype) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_shard_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_shard_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eca5b9bd398749f1f38f50a48969cfbc3758352a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_shard_utils.py @@ -0,0 +1,140 @@ +# mypy: allow-untyped-defs +import copy +import itertools +import math +from typing import Optional + +import torch +import torch.distributed as dist +from torch._utils import _get_device_module +from torch.distributed import distributed_c10d +from torch.distributed._shard.sharded_tensor import ( + Shard, + ShardedTensor, + ShardedTensorMetadata, + TensorProperties, +) +from torch.distributed._shard.sharding_spec import ShardMetadata +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard + + +def _get_remote_device_str(rank, device_type, num_devices_per_node): + if device_type.lower() == "cpu": + return f"rank:{rank}/{device_type}" + elif device_type.lower() == "hpu": + return f"rank:{rank}/{device_type}:{_get_device_module(device_type).current_device()}" + else: + return f"rank:{rank}/{device_type}:{rank % num_devices_per_node}" + + +def _create_chunk_sharded_tensor( + tensor: torch.Tensor, + rank: int, + world_size: int, + num_devices_per_node: int, + pg: dist.ProcessGroup, + device: Optional[torch.device] = None, +) -> ShardedTensor: + """ + Shard a tensor to chunks along the first dimension. The local rank will gets its + corresponding chunk as the local shard to create a ShardedTensor. + """ + chunks = tensor.chunk(world_size, dim=0) + if len(chunks) > rank: + local_shard = chunks[rank].clone() + offsets = [0 for _ in tensor.size()] + offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank + local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)] + else: + local_shards = [] + + # Create a ShardedTensor without invoking communication. + chunk_sizes = [list(chunk.size()) for chunk in chunks] + dim0_offsets = [0] + list( + itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes]) + )[:-1] + offsets = [0] * (len(chunk_sizes[0]) - 1) + chunk_offsets = [[d0] + offsets for d0 in dim0_offsets] + device_type = ( + distributed_c10d._get_pg_default_device(pg).type + if device is None + else device.type + ) + placements = [ + _get_remote_device_str( + dist.get_global_rank(pg, r), + device_type, + num_devices_per_node, + ) + for r in range(len(chunk_sizes)) + ] + if len(chunk_sizes) != len(chunk_offsets) or len(chunk_sizes) != len(placements): + raise AssertionError( + f"Expected chunk_sizes, chunk_offsets, and placements to have the same length, " + f"got {len(chunk_sizes)}, {len(chunk_offsets)}, {len(placements)}" + ) + shard_metadata = [ + ShardMetadata(offset, size, placement) + for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements) + ] + sharded_tensor_metadata = ShardedTensorMetadata( + shards_metadata=shard_metadata, + size=tensor.size(), + tensor_properties=TensorProperties( + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=False, + memory_format=torch.contiguous_format, + pin_memory=tensor.is_pinned(), + ), + ) + return ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg + ) + + +def _create_chunk_dtensor( + tensor: torch.Tensor, + rank: int, + device_mesh: DeviceMesh, +) -> DTensor: + """ + Shard a tensor to chunks along the first dimension. The local rank will gets its + corresponding chunk as the local tensor to create a DTensor. + """ + # We need to explicitly call .detach() to return a new tensor detached from the current graph. + tensor = tensor.detach().clone() + + # FSDP placements: [Shard(0)] + # HSDP placements: [Replicate(), Shard(0)] + replicate_placements = [Replicate() for _ in range(device_mesh.ndim)] + shard_placements = [Replicate() for _ in range(device_mesh.ndim)] + shard_placements[-1] = DShard(0) # type: ignore[call-overload] + + return DTensor.from_local( + tensor, device_mesh, replicate_placements, run_check=False + ).redistribute( + placements=shard_placements, + ) + + +def _all_gather_dtensor( + tensor: DTensor, + root_mesh: Optional[DeviceMesh], +) -> torch.Tensor: + """ + All gather a DTensor in its sharded dimension and return the local tensor. + """ + if root_mesh != tensor.device_mesh: + raise AssertionError("The device mesh of a tensor should be a root mesh.") + + placements = list(copy.deepcopy(tensor.placements)) + # FSDP placements: [Shard(0)] -> [Replicate()] + # HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()] + placements[-1] = Replicate() + tensor = tensor.redistribute( + device_mesh=tensor.device_mesh, + placements=placements, + ) + + return tensor.to_local() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ec648ced837e155018c7002560bb7e297b163c78 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py @@ -0,0 +1,932 @@ +# mypy: allow-untyped-defs +import contextlib +import logging +import math +import warnings +from collections.abc import Callable, Generator, Iterator +from typing import Any, cast, no_type_check + +import torch +import torch.distributed as dist +import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed._shard.sharded_tensor import ( + init_from_local_shards, + Shard, + ShardedTensor, +) +from torch.distributed.fsdp._common_utils import ( + _FSDPState, + _get_module_fsdp_state_if_fully_sharded_module, + _has_fsdp_params, + _is_composable, + _module_handle, + clean_tensor_name, + FSDP_PREFIX, + FSDP_WRAPPED_MODULE, +) +from torch.distributed.fsdp._debug_utils import SimpleProfiler +from torch.distributed.fsdp._runtime_utils import ( + _cast_buffers_to_dtype_and_device, + _get_orig_buffer_dtypes, + _lazy_init, + _reset_flat_param_grad_info_if_needed, +) +from torch.distributed.fsdp.api import ( + FullStateDictConfig, + ShardingStrategy, + StateDictType, +) +from torch.distributed.tensor import DTensor +from torch.distributed.utils import _replace_by_prefix + +from ._fsdp_extensions import ( + _ext_all_gather_dtensor, + _ext_chunk_dtensor, + _ext_chunk_tensor, + _ext_post_unflatten_transform, + _ext_pre_load_state_dict_transform, +) +from ._unshard_param_utils import _unshard_fsdp_state_params, FLAT_PARAM + + +logger = logging.getLogger(__name__) + + +def _should_unshard_params(fsdp_state: _FSDPState) -> bool: + return not ( + fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD + and (_is_composable(fsdp_state) or fsdp_state._use_orig_params) + ) + + +def _convert_to_wrapped_module_name(module_name: str) -> str: + module_name = module_name.replace(f"{FSDP_PREFIX}", "") + module_name = module_name.replace(f"{FSDP_WRAPPED_MODULE}", "") + if module_name: + module_name = f"{module_name}." + # `CheckpointWrapper` adds a prefix that has to be removed as well. + module_name = module_name.replace(checkpoint_wrapper._CHECKPOINT_PREFIX, "") + return module_name + + +def _param_name_infos( + module: nn.Module, fsdp_state: _FSDPState +) -> Iterator[tuple[str, str, str]]: + if not _has_fsdp_params(fsdp_state, module): + return + for param_name, module_name in _module_handle( + fsdp_state, module + ).param_module_names(): + module_name = _convert_to_wrapped_module_name(module_name) + fqn = f"{module_name}{param_name}" + yield fqn, param_name, module_name + + +def _shared_param_name_infos( + module: nn.Module, fsdp_state +) -> Iterator[tuple[str, str, str]]: + for param_name, module_name in _module_handle( + fsdp_state, module + ).shared_param_module_names(): + module_name = _convert_to_wrapped_module_name(module_name) + fqn = f"{module_name}{param_name}" + yield fqn, param_name, module_name + + +@no_type_check +def _enter_unshard_params_ctx( + module: nn.Module, + fsdp_state: _FSDPState, + writeback: bool = False, + rank0_only: bool = False, + offload_to_cpu: bool = False, + with_grads: bool = False, +) -> None: + """ + state_dict hooks cannot use the pure context call as the checkpoint flow + requires to enter the context in the pre-hook but leave the context in the + post-hook. This API enters the context of ``_unshard_fsdp_state_params``. + """ + if module in fsdp_state._unshard_params_ctx: + raise AssertionError( + "Entering the ``_unshard_fsdp_state_params`` context but _unshard_params_ctx[module] " + "is not None." + ) + fsdp_state._unshard_params_ctx[module] = _unshard_fsdp_state_params( + module, + fsdp_state, + writeback=writeback, + rank0_only=rank0_only, + offload_to_cpu=offload_to_cpu, + with_grads=with_grads, + ) + fsdp_state._unshard_params_ctx[module].__enter__() + + +@no_type_check +def _exit_unshard_params_ctx(module: nn.Module, fsdp_state: _FSDPState) -> None: + """A helper function to exit ``_unshard_fsdp_state_params`` context.""" + fsdp_state._unshard_params_ctx[module].__exit__(None, None, None) + fsdp_state._unshard_params_ctx.pop(module) + + +def _common_pre_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, +) -> None: + """Performs the pre-state_dict tasks shared by all state_dict types.""" + if fsdp_state._device_handle.is_available(): + fsdp_state._device_handle.synchronize() + # TODO: need to check if this is always correct for composable FSDP. + _lazy_init(fsdp_state, module) + if fsdp_state._is_root: + _reset_flat_param_grad_info_if_needed(fsdp_state._all_handles) + + +def _common_unshard_pre_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + offload_to_cpu: bool, + rank0_only: bool, +) -> None: + """ + Performs the pre-state_dict tasks shared by all state_dict types that require + ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook. + """ + # For composable `fully_shard`, it does not need to unshard parameters for `NO_SHARD` cases. + if not _should_unshard_params(fsdp_state): + return + _enter_unshard_params_ctx( + module, + fsdp_state, + writeback=False, + offload_to_cpu=offload_to_cpu, + rank0_only=rank0_only, + ) + + +@no_type_check +def _common_unshard_post_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, + param_hook: Callable, +) -> dict[str, Any]: + """ + The post-state_dict flow that shared by all state_dict types that require + ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this + hook. + """ + _replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix) + # Return early for trivial cases + if not state_dict or not _has_fsdp_params(fsdp_state, module): + if _should_unshard_params(fsdp_state): + _exit_unshard_params_ctx(module, fsdp_state) + return state_dict + + # If a rank does not have unsharded parameters(when `rank0_only=True` + # and `rank != 0`), then the rank only needed to participate in the + # all-gather and does not need to save the # state dict. We simply check + # rank0_only to ensure this issue. + rank0_only = ( + fsdp_state._state_dict_type == StateDictType.FULL_STATE_DICT + and cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only + ) + # no_fsdp_return means the state_dict returned by this rank should contain + # only non-FSDP controlled parameters and buffers. + no_fsdp_return = rank0_only and fsdp_state.rank != 0 + if no_fsdp_return and not fsdp_state._use_orig_params: + for clean_key in fsdp_state._buffer_names: + # This is a hack to support activation checkpoint. + clean_key = clean_key.replace( + f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", "" + ) + state_dict.pop(f"{prefix}{clean_key}", None) + # Non-zero ranks have flat_param key when rank0_only=True, because rank0_only=True is + # passed in to unshard context, but nonzero ranks reshard early, causing this flat_param + # to appear in state_dict. + state_dict.pop(f"{prefix}{FLAT_PARAM}") + _exit_unshard_params_ctx(module, fsdp_state) + return state_dict + + # Loop only the parameters saved in this instance's wrapped module to + # avoid processing buffers. + for fqn, param_name, module_name in _param_name_infos(module, fsdp_state): + fqn = f"{prefix}{fqn}" + if no_fsdp_return: + state_dict.pop(fqn) + continue + if fqn not in state_dict: + raise AssertionError( + f"FSDP assumes {fqn} is in the state_dict but the state_dict only " + f"has {state_dict.keys()}. " + f"prefix={prefix}, module_name={module_name}, " + f"param_name={param_name} rank={fsdp_state.rank}." + ) + + param_hook(state_dict, prefix, fqn) + + if _should_unshard_params(fsdp_state): + _exit_unshard_params_ctx(module, fsdp_state) + + cpu_device = torch.device("cpu") + buffer_clean_fqns = [] + buffers = [] + for clean_key in fsdp_state._buffer_names: + # This is a hack to support activation checkpoint. + clean_key = clean_tensor_name(clean_key) + fqn = f"{prefix}{clean_key}" + if fqn not in state_dict: + # A buffer can be registered as non-persistent. + continue + if no_fsdp_return: + state_dict.pop(fqn) + else: + buffer = state_dict[fqn] + if ( + fsdp_state._state_dict_config.offload_to_cpu + and buffer.device != cpu_device + ): + state_dict[fqn] = buffer.to(cpu_device) + # skip upcasting for ignored buffers + if clean_key not in fsdp_state._ignored_buffer_names: + buffer_clean_fqns.append(clean_key) + buffers.append(state_dict[fqn]) + + if buffers: + mixed_precision_enabled_for_buffers = ( + fsdp_state._mixed_precision_enabled_for_buffers() + if not _is_composable(fsdp_state) + else (fsdp_state.mixed_precision.buffer_dtype is not None) + ) + if mixed_precision_enabled_for_buffers: + buffer_dtypes = _get_orig_buffer_dtypes(fsdp_state, buffer_clean_fqns) + _cast_buffers_to_dtype_and_device( + buffers, buffer_dtypes, fsdp_state.compute_device + ) + for buffer, clean_fqn in zip(buffers, buffer_clean_fqns): + fqn = f"{prefix}{clean_fqn}" + logger.info("FSDP is casting the dtype of %s to %s", fqn, buffer.dtype) + state_dict[fqn] = buffer.clone() + return state_dict + + +@no_type_check +def _full_pre_state_dict_hook( + fsdp_state: _FSDPState, + module: nn.Module, + *args, + **kwargs, +) -> None: + """ + Hook that runs before model.state_dict() is called. pre-state_dict hook is + not actually supported by ``nn.Module``. As a result, this API is called + from ``_full_post_state_dict_hook()`` to simulate the case. Once pre-state_dict + is supported in ``nn.Module``, this hook will be registered as a hook in + ``nn.Module``. + """ + if getattr(fsdp_state, "_device_mesh", False): + fsdp_state._device_mesh._get_root_mesh() + + _common_pre_state_dict_hook(module, fsdp_state) + _common_unshard_pre_state_dict_hook( + module, + fsdp_state, + offload_to_cpu=fsdp_state._state_dict_config.offload_to_cpu, + rank0_only=cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only, + ) + + +@no_type_check +def _full_post_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, +) -> dict[str, Any]: + """ + Hook that runs after model.state_dict() is called before returning result to + user. For FSDP, we may have to clone the tensors in state_dict as params go + back to sharded version after _unshard_fsdp_state_params ends, and also remove + the ``FSDP_WRAPPED_MODULE`` prefix. + """ + + def param_hook( + state_dict: dict[str, Any], + prefix: str, + fqn: str, + ) -> None: + clean_key = fqn + clean_prefix = clean_tensor_name(prefix) + # Strip prefix out of key if needed as buffer names and param names + # do not have prefix considered as they are not computed in `state_dict` + # call. + clean_key = clean_key.removeprefix(clean_prefix) + + # Clone parameters before exiting the `_unshard_fsdp_state_params()` context. + if not getattr(state_dict[fqn], "_has_been_cloned", False): + try: + state_dict[fqn] = state_dict[fqn].detach().clone() + state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined] + except BaseException as e: # noqa: B036 + warnings.warn( + f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. " + "This may mean that this state_dict entry could point to invalid " + "memory regions after returning from state_dict() call if this " + "parameter is managed by FSDP. Please check clone " + f"implementation of {fqn}. Error: {str(e)}", + stacklevel=2, + ) + + return _common_unshard_post_state_dict_hook( + module, fsdp_state, state_dict, prefix, param_hook + ) + + +def _full_pre_load_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, +) -> None: + _lazy_init(fsdp_state, module) + if _should_unshard_params(fsdp_state): + with SimpleProfiler.profile("_enter_unshard_params_ctx"): + _enter_unshard_params_ctx(module, fsdp_state, writeback=True) + # Add FSDP_PREFIX only for wrapper-based FSDP. + if not _is_composable(fsdp_state): + _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") + + +def _full_post_load_state_dict_hook( + module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs +) -> None: + if _should_unshard_params(fsdp_state): + with SimpleProfiler.profile("_exit_unshard_params_ctx"): + _exit_unshard_params_ctx(module, fsdp_state) + + +def _local_pre_state_dict_hook( + fsdp_state: _FSDPState, + module: nn.Module, + *args, + **kwargs, +) -> None: + """ + Hook that runs before model.state_dict() is called. Right now, pre-state_dict + hook is not supported by the PyTorch core. So this API is called from + `_local_post_state_dict_hook()` to simulate the case. + """ + if ( + _has_fsdp_params(fsdp_state, module) + and not _module_handle(fsdp_state, module).uses_sharded_strategy + ): + raise RuntimeError( + "``local_state_dict`` can only be used when parameters are flatten " + "and sharded." + ) + _common_pre_state_dict_hook(module, fsdp_state) + + +@no_type_check +def _local_post_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, +) -> dict[str, Any]: + """ + This hook create a ShardedTensor from the local flat_param and replace + the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy + will happen. The underlying storage is the same. + """ + + _replace_by_prefix(state_dict, f"{prefix}{FSDP_PREFIX}", prefix) + if not _has_fsdp_params(fsdp_state, module): + return state_dict + + # state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor + # value as the flat_param but it is a pure Tensor because + # nn.Module.state_dict() will detach the parameter. Therefore, we need + # to get flat_param to get the metadata. + if not _module_handle(fsdp_state, module): + raise AssertionError("Should have returned early") + flat_param = _module_handle(fsdp_state, module).flat_param + # Constructs a ShardedTensor from the flat_param "without" padding. + # Removing the padding allows users to change the number of ranks + # when loading the local_state_dict. + full_numel = flat_param._unpadded_unsharded_size.numel() # type: ignore[attr-defined] + shard_offset = flat_param.numel() * fsdp_state.rank + valid_data_size = flat_param.numel() - flat_param._shard_numel_padded + if valid_data_size > 0: + # If FlatParameter is returned, FlatParameter._local_shard cause a + # pickling issue (can be torch.save but not torch.load). Since there + # is no benefit for state_dict to return the actual FlatParameter class, + # a view (which is a tensor) of the FlatParameter will be returned. + flat_param = flat_param[:valid_data_size].view(valid_data_size) + local_shards = [ + Shard.from_tensor_and_offsets(flat_param, [shard_offset], fsdp_state.rank) + ] + else: + local_shards = [] + sharded_tensor = init_from_local_shards( + local_shards, full_numel, process_group=fsdp_state.process_group + ) # type: ignore[assignment] + # TODO: Add DTensor state_dict support for LOCAL_STATE_DICT. + if fsdp_state._state_dict_config.offload_to_cpu: + sharded_tensor = sharded_tensor.cpu() + state_dict[f"{prefix}{FLAT_PARAM}"] = sharded_tensor + return state_dict + + +def _local_post_load_state_dict_hook( + module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs +) -> None: + pass + + +def _local_pre_load_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, +) -> None: + """ + This hook finds the local flat_param for this FSDP module from the + state_dict. The flat_param should be a ShardedTensor. This hook converts + the ShardedTensor to a tensor. No copy happen unless padding is required. + """ + _lazy_init(fsdp_state, module) + _replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_PREFIX}") + fqn = f"{prefix}{FSDP_PREFIX}{FLAT_PARAM}" + if fqn not in state_dict: + if _has_fsdp_params(fsdp_state, module): + raise AssertionError( + "No `FlatParameter` in `state_dict` for this FSDP instance " + "but it has parameters" + ) + return + load_tensor = state_dict[fqn] + if not isinstance(load_tensor, ShardedTensor): + raise AssertionError("Tensors in local_state_dict should be ShardedTensor.") + + # Convert the ShardedTensor to a Tensor. + flat_param = _module_handle(fsdp_state, module).flat_param + if flat_param is None: + raise AssertionError("Expected flat_param to be set") + valid_data_size = flat_param.numel() - flat_param._shard_numel_padded + shards = load_tensor.local_shards() + if valid_data_size > 0: + if not len(shards): + raise AssertionError( + "load_local_state_dict assume one shard per ShardedTensor." + ) + load_tensor = shards[0].tensor + + # Get the metadata of the flat_param to decide whether to pad the loaded + # tensor. + if flat_param._shard_numel_padded > 0: + if load_tensor.numel() >= flat_param.numel(): + raise AssertionError( + f"Local shard size = {flat_param.numel()} and the tensor in " + f"the state_dict is {load_tensor.numel()}." + ) + load_tensor = F.pad(load_tensor, [0, flat_param._shard_numel_padded]) + else: + load_tensor = flat_param + # TODO: Add DTensor state_dict support for LOCAL_STATE_DICT. + state_dict[fqn] = load_tensor + + +def _sharded_pre_state_dict_hook( + fsdp_state: _FSDPState, + module: nn.Module, + *args, + **kwargs, +) -> None: + """ + Hook that runs before model.state_dict() is called. Check + ``_full_pre_load_state_dict_hook`` for the detail. + """ + if ( + _has_fsdp_params(fsdp_state, module) + and not _module_handle(fsdp_state, module).uses_sharded_strategy + ): + raise RuntimeError( + "``sharded_state_dict`` can only be used when parameters are flatten " + "and sharded." + ) + _common_pre_state_dict_hook(module, fsdp_state) + # Setting offload_to_cpu here does not work even if offload_to_cpu is True. + # We have to create ShardedTensor first then move it to CPU. + _common_unshard_pre_state_dict_hook( + module, + fsdp_state, + offload_to_cpu=False, + rank0_only=False, + ) + + +@no_type_check +def _sharded_post_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, +) -> dict[str, Any]: + """ + The hook replaces the unflattened, unsharded parameter in the state_dict + with a unflattened, sharded parameter (a ShardedTensor). + """ + + def param_hook(state_dict: dict[str, Any], prefix: str, fqn: str): + param = state_dict[fqn] + if not fsdp_state._state_dict_config._use_dtensor: + sharded_tensor = _ext_chunk_tensor( + tensor=param, + rank=fsdp_state.rank, + world_size=fsdp_state.world_size, + num_devices_per_node=fsdp_state._device_handle.device_count(), + pg=fsdp_state.process_group, + fsdp_extension=fsdp_state._fsdp_extension, + ) + else: + sharded_tensor = _ext_chunk_dtensor( + tensor=param, + rank=fsdp_state.rank, + device_mesh=fsdp_state._device_mesh, + fsdp_extension=fsdp_state._fsdp_extension, + ) + if fsdp_state._state_dict_config.offload_to_cpu: + sharded_tensor = sharded_tensor.cpu() + state_dict[fqn] = sharded_tensor + + return _common_unshard_post_state_dict_hook( + module, fsdp_state, state_dict, prefix, param_hook + ) + + +@no_type_check +def _sharded_post_load_state_dict_hook( + module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs +) -> None: + if _has_fsdp_params(fsdp_state, module): + with SimpleProfiler.profile("_exit_unshard_params_ctx"): + _exit_unshard_params_ctx(module, fsdp_state) + + +@no_type_check +def _sharded_pre_load_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, +) -> None: + """ + The hook combines the unflattened, sharded parameters (ShardedTensor) to + a new FlatParameter and shards the new FlatParameter to the local chunk. + """ + _lazy_init(fsdp_state, module) + if not _is_composable(fsdp_state): + _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") + if not _has_fsdp_params(fsdp_state, module): + return + + handle = _module_handle(fsdp_state, module) + if not handle.uses_sharded_strategy: + raise RuntimeError( + "load_sharded_state_dict can only be called when parameters " + "are flattened and sharded." + ) + fqn_to_param_ext = dict( + zip(handle.flat_param._fqns, handle.flat_param._param_extensions) + ) + + for fqn, _, _ in _param_name_infos(module, fsdp_state): + if not _is_composable(fsdp_state): + fqn_from_global_root = f"{prefix}{FSDP_PREFIX}{fqn}" + else: + fqn_from_global_root = f"{prefix}{fqn}" + try: + param = state_dict.pop(fqn_from_global_root) + except KeyError: + logger.warning( + f"Did not find param with FQN {fqn_from_global_root}, skipping it. " # noqa: G004 + "The weight will not be filled if you expect it to be." + ) + continue # TODO: Improve unittesting for state_dict finetuning + # cases: https://github.com/pytorch/pytorch/issues/109134 + + if not fsdp_state._state_dict_config._use_dtensor: + # All-gather the param (ShardedTensor) + param, shards = _ext_pre_load_state_dict_transform( + param, fsdp_state._fsdp_extension + ) + + if len(shards) >= 2: + raise AssertionError( + "Expects 0 or 1 shard per rank " + f"but got {len(shards)} shards on rank {fsdp_state.rank}." + ) + param_numel = param.size().numel() + dim_0_size = param.size()[0] + chunk_size = ( + math.ceil(dim_0_size / fsdp_state.world_size) + * param_numel + // dim_0_size + ) + if len(shards) == 1: + local_tensor = shards[0].tensor.flatten() + with SimpleProfiler.profile(SimpleProfiler.Type.H2D): + local_tensor = local_tensor.to(fsdp_state.compute_device) + num_padding = chunk_size - local_tensor.numel() + if num_padding > 0: + local_tensor = F.pad(local_tensor, [0, num_padding]) + else: + local_tensor = torch.zeros( + chunk_size, dtype=param.dtype, device=fsdp_state.compute_device + ) + tensor = torch.empty( + chunk_size * fsdp_state.world_size, + dtype=local_tensor.dtype, + device=fsdp_state.compute_device, + ) + with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER): + dist.all_gather_into_tensor( + tensor, local_tensor, group=fsdp_state.process_group + ) + tensor = tensor.narrow(0, 0, param_numel).reshape(param.size()) + state_dict[fqn_from_global_root] = tensor + else: + if param.device != fsdp_state._device_mesh.device_type: + param = param.to(fsdp_state._device_mesh.device_type) + + root_mesh = fsdp_state._device_mesh._get_root_mesh() + local_tensor = _ext_all_gather_dtensor( + param, root_mesh, fsdp_state._fsdp_extension + ) + + if fqn_to_param_ext.get(fqn) is not None: + ext = fqn_to_param_ext[fqn] + local_tensor = _ext_post_unflatten_transform( + local_tensor, ext, fsdp_state._fsdp_extension + ) + state_dict[fqn_from_global_root] = local_tensor + + with SimpleProfiler.profile("_enter_unshard_params_ctx"): + _enter_unshard_params_ctx(module, fsdp_state, writeback=True) + + +@contextlib.contextmanager +def _replace_with_full_state_dict_type(fsdp_state: _FSDPState) -> Generator: + old_state_dict_config = fsdp_state._state_dict_config + old_state_dict_type = fsdp_state._state_dict_type + fsdp_state._state_dict_config = FullStateDictConfig() + fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT + yield + fsdp_state._state_dict_config = old_state_dict_config + fsdp_state._state_dict_type = old_state_dict_type + + +@no_type_check +@torch.no_grad() +def _post_state_dict_hook( + module: nn.Module, + state_dict: dict[str, Any], + prefix: str, + *args: Any, +) -> dict[str, Any]: + """ + _post_state_dict_hook() is called after the state_dict() of this + FSDP module is executed. ``fsdp_state._state_dict_type`` is used to decide + what postprocessing will be done. + """ + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: + context = _replace_with_full_state_dict_type(fsdp_state) + warnings.warn( + "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will " + "be returned.", + stacklevel=2, + ) + else: + context = contextlib.nullcontext() + + with context: + _post_state_dict_hook_fn = { + StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook, + } + processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type]( + module, fsdp_state, state_dict, prefix + ) + + if fsdp_state._is_root: + logger.info("FSDP finished processing state_dict(), prefix=%s", prefix) + for key, tensor in sorted(processed_state_dict.items()): + if key.startswith(prefix) and isinstance(tensor, torch.Tensor): + local_shape = tensor.shape + device = None + if isinstance(tensor, ShardedTensor): + local_shape = None + shards = tensor.local_shards() + if shards: + local_shape = shards[0].tensor.shape + device = shards[0].tensor.device + elif isinstance(tensor, DTensor): + local_shape = tensor.to_local().shape + device = tensor.device + else: + device = tensor.device + logger.info( + "FQN=%s: type=%s, shape=%s, local_shape=%s, dtype=%s, device=%s", + key, + type(tensor), + tensor.shape, + local_shape, + tensor.dtype, + device, + ) + + return processed_state_dict + + +@no_type_check +@torch.no_grad() +def _pre_state_dict_hook( + module: nn.Module, + *args, + **kwargs, +) -> None: + """ + This is called before the core state dict saving logic of ``module``. + ``fsdp_state._state_dict_type`` is used to decide what postprocessing will + be done. + """ + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: + context = _replace_with_full_state_dict_type(fsdp_state) + warnings.warn( + "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will " + "be returned.", + stacklevel=2, + ) + else: + _set_use_dtensor(fsdp_state) + context = contextlib.nullcontext() + + with context: + _pre_state_dict_hook_fn = { + StateDictType.FULL_STATE_DICT: _full_pre_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook, + } + _pre_state_dict_hook_fn[fsdp_state._state_dict_type]( + fsdp_state, + module, + *args, + **kwargs, + ) + + +@no_type_check +def _set_use_dtensor(fsdp_state: _FSDPState) -> None: + # If device_mesh is passed in when initializing FSDP, we automatically turn the + # _use_dtensor flag to be true for ShardedStateDictConfig(). + if getattr(fsdp_state, "_device_mesh", None): + state_dict_type = fsdp_state._state_dict_type + if state_dict_type == StateDictType.LOCAL_STATE_DICT: + raise RuntimeError( + "Found state_dict_type LOCAL_STATE_DICT", + "DeviceMesh is not compatible with LOCAL_STATE_DICT.", + "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.", + ) + else: + fsdp_state._state_dict_config._use_dtensor = True + + +@no_type_check +@torch.no_grad() +def _pre_load_state_dict_hook( + module: nn.Module, + state_dict: dict[str, Any], + prefix: str, + *args: Any, +) -> None: + """ + This is called before ``module._load_from_state_dict()``. + ``fsdp_state._state_dict_type`` is used to decide what preprocessing will + be done. + """ + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: + context = _replace_with_full_state_dict_type(fsdp_state) + warnings.warn( + "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will" + "be returned.", + stacklevel=2, + ) + else: + _set_use_dtensor(fsdp_state) + context = contextlib.nullcontext() + + _lazy_init(fsdp_state, module) + if fsdp_state._is_root: + SimpleProfiler.reset() + + with context: + _pre_load_state_dict_hook_fn = { + StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook, + } + # Code that is common for all state_dict impls + if fsdp_state._device_handle.is_available(): + fsdp_state._device_handle.synchronize() + # Dispatch into state_dict specific implementation of pre-hook. + _pre_load_state_dict_hook_fn[fsdp_state._state_dict_type]( + module, fsdp_state, state_dict, prefix + ) + + +@no_type_check +@torch.no_grad() +def _post_load_state_dict_hook( + module: nn.Module, + incompatible_keys: tuple[list[str], list[str]], + *args: Any, +) -> None: + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: + context = _replace_with_full_state_dict_type(fsdp_state) + warnings.warn( + "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will" + "be returned.", + stacklevel=2, + ) + else: + context = contextlib.nullcontext() + + with context: + _post_load_state_dict_hook_fn = { + StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook, + } + # Code that is common for all state_dict impls + # Dispatch into state_dict type specific implementation of post-hook for + # loading state_dict. + _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state) + + # When reporting incompatible keys, trim FSDP prefixes. + missing_keys = incompatible_keys[0] + unexpected_keys = incompatible_keys[1] + for i in range(len(missing_keys)): + missing_keys[i] = clean_tensor_name(missing_keys[i]) + + for i in range(len(unexpected_keys)): + unexpected_keys[i] = clean_tensor_name(unexpected_keys[i]) + + if fsdp_state._is_root: + SimpleProfiler.dump_and_reset("FSDP model load_state_dict profiling: ") + + +def _register_all_state_dict_hooks(state: _FSDPState): + """ + Registers pre-save, post-save, pre-load, and post-load state dict hooks. + """ + for hook_registration_fn_str, hook, hook_registration_fn_kwargs in ( + ("register_state_dict_pre_hook", _pre_state_dict_hook, {}), + ("_register_state_dict_hook", _post_state_dict_hook, {}), + ( + "_register_load_state_dict_pre_hook", + _pre_load_state_dict_hook, + {"with_module": True}, + ), + ("register_load_state_dict_post_hook", _post_load_state_dict_hook, {}), + ): + _register_state_dict_hooks_base( + state, hook_registration_fn_str, hook, hook_registration_fn_kwargs + ) + + +@no_type_check +def _register_state_dict_hooks_base( + state: _FSDPState, + hook_registration_fn_name: str, + hook: Callable, + hook_registration_fn_kwargs: dict[str, Any], +) -> None: + """Registers ``hook`` using ``hook_registration_fn``.""" + if not _is_composable(state): + getattr(state, hook_registration_fn_name)(hook, **hook_registration_fn_kwargs) + else: + handle = state._handle + if handle: + getattr(handle._fully_sharded_module, hook_registration_fn_name)( + hook, **hook_registration_fn_kwargs + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_trace_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_trace_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c4d514c5c6474b3a984424b1cd7563e1656f3f2a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_trace_utils.py @@ -0,0 +1,240 @@ +# mypy: allow-untyped-defs +import functools +from collections.abc import Callable +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Any, NamedTuple, Optional + +import torch +import torch.nn as nn + + +@dataclass +class TracingConfig: + """ + This represents a symbolic tracing configuration. + + Args: + tracer (torch.fx.Tracer): An instance of :class:`torch.fx.Tracer` to + use for symbolic tracing. The default value is the native + :class:`torch.fx.Tracer` constructed with default arguments. + However, the user may want to pass a different value such as the + ``HFTracer`` for models in the HuggingFace Transformers_ library. + .. _Transformers: https://huggingface.co/docs/transformers/index + concrete_args (Optional[Dict[str, Any]]): Concrete arguments that + should not be treated as ``torch.fx.Proxy`` when tracing the + module ``forward()``. Passing ``concrete_args`` allows partially + specializing the forward, e.g. to remove control flow or data + structures. This ``concrete_args`` here is the same argument used + in :meth:`~torch.fx.Tracer.trace`. + """ + + tracer: torch.fx.Tracer = field(default_factory=torch.fx.Tracer) + concrete_args: Optional[dict[str, Any]] = None + + +class _ParamUsageInfo(NamedTuple): + """ + This is used for ``_ExecutionInfo.module_to_param_usage_infos`` to record + execution information. The ``dict`` maps modules to a list of these + ``_ParamUsageInfo`` instances, where each instance represents a group of + parameters used together. + + Specifically, for each module key in the ``dict``, each instance of this + class represents either: + (1) the module and some sublist of its ``named_parameters()`` used + together in execution (see ``_patched_create_proxy()``), or + (2) a submodule and all of ``submodule.named_parameters()`` (see + ``_patched_call_module()``). + + Type (1) corresponds to directly using parameters in ops without calling + ``forward()``, and type (2) corresponds to calling ``forward()``. The + mapped-to lists in the ``dict`` follow the execution order. + """ + + module: nn.Module + named_params: list[tuple[str, nn.Parameter]] + + +class _ExecutionInfo: + """ + This represents the execution order information from the forward pass. + + Attributes: + curr_module (nn.Module): Current module being traced. + module_forward_order (List[nn.Module]): The modules in (pre-)forward + order, i.e. the order in which their ``forward()`` methods are + called. Each call to a module's ``forward()`` corresponds to one + element in the list. + module_to_param_usage_infos (Dict[nn.Module, List[_ParamUsageInfo]]): + Maps a module to a list of module execution infos. See + :class:`_ParamUsageInfo` for details. + param_forward_order (List[nn.Parameter]): The parameters in forward + execution order, where only a parameter's first participation is + included. + visited_params (Set[nn.Parameter]): The parameters visited so far + during the trace. This is only used during tracing for fast + membership check. Invariant: The parameters in + ``param_forward_order`` are exactly those in ``visited_params``. + """ + + def __init__(self, root_module: nn.Module) -> None: + self.curr_module: nn.Module = root_module + self.module_forward_order: list[nn.Module] = [root_module] + self.module_to_param_usage_infos: dict[nn.Module, list[_ParamUsageInfo]] = { + root_module: [] + } + self.param_forward_order: list[nn.Parameter] = [] + self.visited_params: set[nn.Parameter] = set() + + +class _ExecOrderTracer: + def __init__(self) -> None: + self.exec_info: Optional[_ExecutionInfo] = None + + @contextmanager + def patch_tracer(self, tracer: torch.fx.Tracer, root_module: nn.Module): + self.exec_info = _ExecutionInfo(root_module) + orig_call_module = tracer.call_module + orig_create_proxy = tracer.create_proxy + tracer.call_module = functools.partial( # type: ignore[method-assign] + self._patched_call_module, orig_call_module, self.exec_info + ) + fqn_to_param = dict(root_module.named_parameters()) + tracer.create_proxy = functools.partial( # type: ignore[method-assign] + self._patched_create_proxy, + orig_create_proxy, + self.exec_info, + fqn_to_param, + ) + try: + yield + finally: + tracer.call_module = orig_call_module # type: ignore[method-assign] + tracer.create_proxy = orig_create_proxy # type: ignore[method-assign] + + def _patched_call_module( + self, + call_module: Callable, + exec_info: _ExecutionInfo, + # Below are the expected arguments to `call_module()` + module: nn.Module, + forward: Callable, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> Any: + """ + Overrides ``call_module`` to save execution information to + ``exec_info``. Note that ``call_module`` is called during symbolic + tracing for each non-root module. + + Args: + call_module (Callable): Original ``call_module`` to override. + exec_info (_ExecutionInfo): Used to record execution information. + module (nn.Module): Module corresponding to this ``call_module``. + forward (Callable): ``forward()`` method of ``module`` to be called + for this ``call_module``. + args (Tuple[Any, ...]): Positional arguments for ``forward``. + kwargs (Dict[str, Any]): Keyword arguments for ``forward``. + + Returns: + Same return value as ``call_module``. + """ + exec_info.module_forward_order.append(module) + named_params = list(module.named_parameters()) + curr_module = exec_info.curr_module + if named_params: + if curr_module not in exec_info.module_to_param_usage_infos: + raise AssertionError( + "The current module should have already been processed by a patched `call_module`" + ) + exec_info.module_to_param_usage_infos[exec_info.curr_module].append( + _ParamUsageInfo(module, named_params) + ) + prev_curr_module = curr_module + exec_info.curr_module = module + exec_info.module_to_param_usage_infos[module] = [] + output = call_module(module, forward, args, kwargs) + exec_info.curr_module = prev_curr_module + return output + + def _patched_create_proxy( + self, + create_proxy: Callable, + exec_info: _ExecutionInfo, + fqn_to_param: dict[str, nn.Parameter], + # Below are the expected arguments to `create_proxy()` + kind: str, + target: torch.fx.node.Target, + args: tuple[Any, ...], + kwargs: dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None, + ) -> torch.fx.Proxy: + """ + Overrides ``create_proxy`` to save execution information to + ``exec_info``. Note that ``create_proxy`` is called during symbolic + tracing for each leaf function/method/module. + + Args: + create_proxy (Callable): Original ``create_proxy`` to override. + exec_info (_ExecutionInfo): Used to record execution information. + fqn_to_param (Dict[str, nn.Parameter]): ``dict`` version of the + root module's ``named_parameters()`` with FQN as key and + parameter as value. + kind (str): Kind of the target method ('call_function', + 'call_method', 'get_attr', 'call_module', 'placeholder', or + 'output'). See :class:`torch.fx.Graph` for details. This is + passed to ``create_proxy``. + target (torch.fx.node.Target): Contains the string name of the + function/method/module. This is passed to ``create_proxy``. + args (Tuple[Any, ...]): Positional arguments for the function/ + method/module. This is passed to ``create_proxy``. + kwargs (Dict[str, Any]): Keyword arguments for the function/method/ + module. This is passed to ``create_proxy`` + name (Optional[str]): An optional string name for the ``Node`` + created in ``create_proxy``. This is passed to + ``create_proxy``. + type_expr (Optional[Any]): An optional type annotation representing + the Python type that the output of the node has. This is passed + to ``create_proxy``. + proxy_factory_fn (Callable[[torch.fx.Node], torch.fx.Proxy]): + An alternative proxy constructor used in ``create_proxy``. This + is passed to ``create_proxy``. + + Returns: + torch.fx.Proxy: Created ``Node`` wrapped in a ``Proxy`` object. + """ + proxy = create_proxy( + kind, target, args, kwargs, name, type_expr, proxy_factory_fn + ) + curr_module = exec_info.curr_module + if kind in ("call_function", "call_method"): + if args is not None: + named_params: list[tuple[str, nn.Parameter]] = [] + for arg in args: + if ( + isinstance(arg, torch.fx.Proxy) + and arg.node.target in fqn_to_param + ): + param = fqn_to_param[arg.node.target] # type: ignore[index] + named_params.append((arg.node.target, param)) # type: ignore[arg-type] + if param not in exec_info.visited_params: + exec_info.visited_params.add(param) + exec_info.param_forward_order.append(param) + if named_params: + exec_info.module_to_param_usage_infos[curr_module].append( + _ParamUsageInfo(curr_module, named_params) + ) + elif kind == "call_module": + named_params = list(curr_module.named_parameters()) + if named_params: + exec_info.module_to_param_usage_infos[curr_module].append( + _ParamUsageInfo(curr_module, named_params) + ) + for _, param in named_params: + if param not in exec_info.visited_params: + exec_info.visited_params.add(param) + exec_info.param_forward_order.append(param) + return proxy diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_traversal_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_traversal_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51140d3b0a8d3d16ab50226b414e651f22772648 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_traversal_utils.py @@ -0,0 +1,112 @@ +""" +NOTE: This file must be imported like +``import torch.distributed.fsdp._traversal_utils`` and not like +``from torch.distributed.fsdp._traversal_utils import ...`` to avoid circular +imports. For brevity, we may import the file as ``traversal_utils``. +""" + +import collections + +import torch.nn as nn +from torch.distributed._composable.contract import _get_registry +from torch.distributed.fsdp._common_utils import _FSDPState, _get_module_fsdp_state + + +""" +[Note: FSDP State Traversal] +For the wrapper code path, ``_FSDPState`` is the ``FullyShardedDataParallel`` +module wrapping a fully sharded module, and for the non-wrapper code path, +``_FSDPState`` is an object that gets embedded on a fully sharded module. +See [Note: Fully Sharded Module] for the definition. + +There are three common traversal idioms: Given a root module, +- ``_get_fsdp_states()`` returns all ``_FSDPState`` s in the tree. +- ``get_fsdp_root_states()`` returns all local root ``_FSDPState`` s in the +tree (i.e. those with ``_is_root == True``). +- ``_get_fsdp_handles()``returns all ``FlatParamHandle`` s in the tree. + +All of these methods must take in the root module (i.e. an ``nn.Module``) and +not a general ``_FSDPState`` because ``_FSDPState`` does not support a graph +traversal, whereas ``nn.Module`` has ``nn.Module.modules()`` for traversal. +""" + + +def _composable(module: nn.Module) -> bool: + """ + Returns if ``module`` can compose with ``fully_shard``. + """ + # TODO: Add any other composable APIs that are mutually exclusive. + registry = _get_registry(module) + if registry is None: + return True + return "replicate" not in registry + + +# TODO (awgu): We may be able to remove this function if we retired the +# `use_orig_params=False` code path since so far we only need the module for +# `FlatParameter` registration, which is not needed for `use_orig_params=True`. +def _get_fsdp_states_with_modules( + module: nn.Module, +) -> tuple[list[_FSDPState], list[nn.Module]]: + """ + Returns a tuple containing: + 1. A list of the ``_FSDPState`` instances in the module tree rooted at + ``module`` without any duplicates and following the ``module.modules()`` + traversal order (which is assumed to be depth-first). + 2. A corresponding list of the modules owning the states in the first list. + + For the wrapper code path, both returned lists are the same, each + containing all ``FullyShardedDataParallel`` instances. For the composable + code path, this returns a list of all composable state instances and a list + of the corresponding fully sharded modules. See [Note: Fully Sharded + Module]. + + NOTE: The traversal does not proceed into any module annotated by an + incompatible API (e.g. ``replicate``). + """ + fsdp_states: list[_FSDPState] = [] + fsdp_modules: list[nn.Module] = [] + # Track the visited FSDP states since multiple modules may share the same + # one and we want to return a de-duplicated list + visited_fsdp_states: set[_FSDPState] = set() + # Track the visited modules in case of shared modules, which implies the + # module graph is no longer a tree + visited_modules: set[nn.Module] = set() + + # Perform depth-first search from `module` to ensure that we do not + # traverse into an incompatible API's subtree (use DFS instead of BFS to + # match `.modules()` order) + deque: collections.deque[nn.Module] = collections.deque([module]) + while deque: + submodule = deque.popleft() + visited_modules.add(submodule) + if not _composable(submodule): + continue + for child_module in reversed(list(submodule.children())): + if child_module not in visited_modules: + deque.appendleft(child_module) + optional_state = _get_module_fsdp_state(submodule) + if optional_state is not None and optional_state not in visited_fsdp_states: + visited_fsdp_states.add(optional_state) + fsdp_states.append(optional_state) + fsdp_modules.append(submodule) + return fsdp_states, fsdp_modules + + +def _get_fsdp_states(module: nn.Module) -> list[_FSDPState]: + """See :func:`_get_fsdp_states_with_modules`.""" + fsdp_states, _ = _get_fsdp_states_with_modules(module) + return fsdp_states + + +def _get_fsdp_handles(module: nn.Module) -> list: + """ + Returns all ``FlatParamHandle`` s in the module tree rooted at ``module`` + following the rules in :func:`_get_fsdp_state`. + """ + handles = [ + fsdp_state._handle + for fsdp_state in _get_fsdp_states(module) + if fsdp_state._handle is not None + ] + return handles diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_unshard_param_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_unshard_param_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..71dc1a9f4e28c7101fc0acdae2582be89e954013 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_unshard_param_utils.py @@ -0,0 +1,340 @@ +# mypy: allow-untyped-defs +import contextlib +import warnings +from collections.abc import Generator +from typing import cast + +import torch +import torch.distributed.fsdp._traversal_utils as traversal_utils +import torch.nn as nn +from torch.distributed.fsdp._common_utils import ( + _FSDPState, + _get_module_fsdp_state, + _has_fsdp_params, + _module_handle, + HandleTrainingState, + TrainingState, +) +from torch.distributed.fsdp._runtime_utils import ( + _lazy_init, + _reset_flat_param_grad_info_if_needed, + _reshard, + _reshard_grads, + _unshard, + _unshard_grads, +) +from torch.distributed.utils import _p_assert + +from ._flat_param import FlatParamHandle + + +FLAT_PARAM = "_flat_param" + + +@torch.no_grad() +def _writeback_to_local_shard( + handle: FlatParamHandle, + writeback_grad: bool, +): + """ + For the handle, writes back the this rank's shard of the unsharded + flattened parameter to the sharded flattened parameter. If + ``writeback_grad=True``, then writes back to the sharded gradient as + well. + + Precondition: The handle's ``FlatParameter`` 's data points to the + padded unsharded flattened parameter. + """ + + def _get_shard(flat_param_or_grad: torch.Tensor) -> torch.Tensor: + if handle.uses_sharded_strategy: + # For sharded strategies, get the *unpadded* shard instead of + # the *padded* shard to persist user changes to the padding + # (though FSDP does not explicitly support this) + shard, _ = FlatParamHandle._get_unpadded_shard( + flat_param_or_grad, + handle.rank, + handle.world_size, + ) + return shard + # For `NO_SHARD`, the `flat_param` or its gradient may be modified, + # so we write it back directly + return flat_param_or_grad + + param_shard = _get_shard(handle.flat_param) + handle.flat_param._local_shard[: param_shard.numel()].copy_(param_shard) # type: ignore[attr-defined] + if writeback_grad: + existing_grad = handle.sharded_grad + if existing_grad is not None: + if handle.flat_param.grad is None: + raise AssertionError("Expected handle.flat_param.grad to not be None") + grad_shard = _get_shard(handle.flat_param.grad) + existing_grad[: grad_shard.numel()].copy_(grad_shard) + + +def _deregister_flat_param(state: _FSDPState, module: nn.Module) -> None: + """ + De-registers the flattened parameter from the wrapped module, hiding it + from ``nn.Module`` methods. + + We do not use ``del`` because we want ``FLAT_PARAM`` to always be an + attribute but dynamically change whether it is visible to ``nn.Module`` + methods. + """ + if _has_fsdp_params(state, module): + # TODO: figure out the case for the composable APIs. + cast(nn.Module, module.module)._parameters.pop(FLAT_PARAM, None) + + +def _register_flat_param(state: _FSDPState, module: nn.Module) -> None: + """ + Registers the flattened parameter to the wrapped module, making it + visible to ``nn.Module`` methods. + + We do not use :meth:`nn.Module.register_parameter` because we want + ``FLAT_PARAM`` to always be an attribute but dynamically change whether + it is visible to ``nn.Module`` methods. + """ + handle = _module_handle(state, module) + if _has_fsdp_params(state, module): + # TODO: figure out the case for the composable APIs. + cast(nn.Module, module.module)._parameters[FLAT_PARAM] = handle.flat_param + + +@contextlib.contextmanager +def _unflatten_as_params(state: _FSDPState, module: nn.Module) -> Generator: + """ + Assumes that the flattened parameter is unsharded. When in the context, + de-registers the flattened parameter and unflattens the original + parameters as ``nn.Parameter`` views into the flattened parameter. + After the context, re-registers the flattened parameter and restores + the original parameters as ``Tensor`` views into the flattened + parameter. + """ + handle = _module_handle(state, module) + if not handle: + yield + else: + _deregister_flat_param(state, module) + try: + with handle.unflatten_as_params(): + yield + finally: + if not handle._use_orig_params: + _register_flat_param(state, module) + + +def _validate_unshard_params_args( + state: _FSDPState, + writeback: bool, + rank0_only: bool, + offload_to_cpu: bool, + with_grads: bool, +) -> None: + if with_grads and (offload_to_cpu or not state._use_orig_params): + raise NotImplementedError( + f"with_grads={with_grads}, " + f"use_orig_params={state._use_orig_params}, " + f"offload_to_cpu={offload_to_cpu} " + f"is not supported yet" + ) + if offload_to_cpu and state._handle and (not state._handle.uses_sharded_strategy): + raise NotImplementedError( + "offload_to_cpu=True and NO_SHARD is not supported yet" + ) + if writeback and rank0_only: + # TODO: Rank 0 can broadcast the `FlatParameter` to allow all ranks to + # persist the changes. + raise NotImplementedError( + "writeback=True and rank0_only=True is not supported yet" + ) + if offload_to_cpu and not rank0_only: + warnings.warn( + "offload_to_cpu=True and rank0_only=False may result in the" + "unsharded parameters being redundantly copied to CPU memory for " + "GPUs sharing the same CPU memory, which risks CPU OOM. We " + "recommend using offload_to_cpu=True with rank0_only=True.", + stacklevel=2, + ) + + +@contextlib.contextmanager +def _unshard_fsdp_state_params( + module: nn.Module, + state: _FSDPState, + writeback: bool, + rank0_only: bool, + offload_to_cpu: bool, + with_grads: bool, +): + """ + This unshards the parameters for a single FSDP state ``state`` that + corresponds to ``module``. + """ + _validate_unshard_params_args( + state, writeback, rank0_only, offload_to_cpu, with_grads + ) + state._device_handle.synchronize() + # If handles are shared by other module(s), the handle may be already unsharded. + maybe_handle = _module_handle(state, module) + handle = None + if ( + maybe_handle + and maybe_handle._training_state != HandleTrainingState.SUMMON_FULL_PARAMS + ): + handle = maybe_handle + if not handle: + yield + return + + if handle._training_state != HandleTrainingState.IDLE: + raise AssertionError( + f"Expects the handle training to be IDLE but got {handle._training_state}" + ) + + handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS + + _reset_flat_param_grad_info_if_needed(handle) + free_unsharded_flat_param = handle.needs_unshard() + # No need to call `wait_stream()` since we unshard in the computation + # stream directly + computation_stream = state._device_handle.current_stream() + _unshard(state, handle, computation_stream, computation_stream) + if with_grads: + _unshard_grads(handle) + + if rank0_only and state.rank != 0: + # Free the unsharded flattened parameter early + _reshard(state, handle, free_unsharded_flat_param) + if with_grads: + _reshard_grads(handle) + try: + yield + finally: + handle._training_state = HandleTrainingState.IDLE + else: + # Unflatten the unsharded flattened parameters + with contextlib.ExitStack() as stack: + # Invariant: rank == 0 or !rank0_only + if offload_to_cpu and handle.uses_sharded_strategy: + stack.enter_context(handle.to_cpu()) + # NOTE: Since PyTorch enforces that a parameter and its + # gradients need to match metadata (e.g. device), we must + # move gradients to CPU *after* we move parameters. + # NOTE: This assumes 1 `FlatParameter` + if not state._use_orig_params: + stack.enter_context(_unflatten_as_params(state, module)) + try: + yield + finally: + stack.close() + if writeback: + _writeback_to_local_shard(handle, with_grads) + _reshard(state, handle, free_unsharded_flat_param) + if with_grads: + _reshard_grads(handle) + handle._training_state = HandleTrainingState.IDLE + + +@contextlib.contextmanager +def _unshard_params_for_summon( + module: nn.Module, + state: _FSDPState, + writeback: bool, + rank0_only: bool, + offload_to_cpu: bool, + with_grads: bool, +): + _validate_unshard_params_args( + state, writeback, rank0_only, offload_to_cpu, with_grads + ) + _lazy_init(state, module) + if state.training_state == TrainingState.FORWARD_BACKWARD: + raise AssertionError( + "Cannot manually unshard parameters during forward/backward" + ) + elif state.training_state == TrainingState.SUMMON_FULL_PARAMS: + raise AssertionError( + "Cannot manually unshard parameters when already unsharding parameters" + ) + with _unshard_fsdp_state_params( + module=module, + state=state, + writeback=writeback, + rank0_only=rank0_only, + offload_to_cpu=offload_to_cpu, + with_grads=with_grads, + ): + try: + state.training_state = TrainingState.SUMMON_FULL_PARAMS + yield + finally: + state.training_state = TrainingState.IDLE + + +@contextlib.contextmanager +def _unshard_params( + module: nn.Module, + recurse: bool, + writeback: bool, + rank0_only: bool, + offload_to_cpu: bool, + with_grads: bool, +): + """ + This unshards FSDP-managed parameters for all modules with FSDP applied in + the module tree rooted at ``module``. + """ + if not recurse: + optional_state = _get_module_fsdp_state(module) + if optional_state is None: + with contextlib.nullcontext(): + yield + return + states_and_modules = ([optional_state], [module]) + else: + states_and_modules = traversal_utils._get_fsdp_states_with_modules(module) + with contextlib.ExitStack() as stack: + for state, module in zip(*states_and_modules): + stack.enter_context( + _unshard_params_for_summon( + module=module, + state=state, + writeback=writeback, + rank0_only=rank0_only, + offload_to_cpu=offload_to_cpu, + with_grads=with_grads, + ) + ) + yield + + +def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None: + """ + Deregisters the original parameters; registers the ``FlatParameter``. + """ + handle = _module_handle(state, module) + if not handle: + return + _p_assert( + handle._use_orig_params, + f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} " + f"handle: {handle._use_orig_params}", + ) + handle._deregister_orig_params() + _register_flat_param(state, module) + + +def _register_orig_params(state: _FSDPState, module: nn.Module) -> None: + """ + Deregisters the ``FlatParameter``; registers the original parameters. + """ + handle = _module_handle(state, module) + if not handle: + return + _deregister_flat_param(state, module) + if handle.is_sharded(handle.flat_param): + handle._use_sharded_views() + handle._use_sharded_grad_views() + else: + handle._use_unsharded_views(as_params=True) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_wrap_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_wrap_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..41dc4d8575198875b8403c2a41c7b2f547a1b742 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/_wrap_utils.py @@ -0,0 +1,264 @@ +# mypy: allow-untyped-defs +import collections +import functools +import inspect +import warnings +from collections.abc import Callable +from functools import partial +from typing import Any, Union + +import torch.nn as nn +from torch.distributed.fsdp._common_utils import ( + _get_module_fsdp_state, + _override_module_mixed_precision, +) +from torch.distributed.fsdp.wrap import ( + _construct_wrap_fn, + _or_policy, + _Policy, + _post_order_apply, + _recursive_wrap, + _run_mixed_precision_override_policy, + _wrap_module_cls_individually, +) + + +def _auto_wrap( + root_module: nn.Module, + policy: Union[Callable, _Policy], + ignored_modules: set[nn.Module], + ignored_params: set[nn.Parameter], + root_kwargs: dict[str, Any], + fsdp_fn: Callable, # e.g. `FullyShardedDataParallel` or `fully_shard` +): + """ + Auto wraps modules in ``root_module`` 's tree according to ``policy`` + following a post-order traversal. + + Precondition: ``root_kwargs`` should contain all arguments except + ``module``. This function accepts the kwargs dict directly since it gets + forwarded into the post-order traversal function. + """ + mixed_precision = root_kwargs["mixed_precision"] + is_wrapper = inspect.isclass(fsdp_fn) + # TODO: We may relax this no-nested-wrapping constraint to support manual + # wrapping followed by auto wrapping. + _check_nested_wrapping(root_module) + + if isinstance(policy, _Policy): + root_kwargs["auto_wrap_policy" if is_wrapper else "policy"] = None + target_module_to_kwargs = policy._run_policy( + root_module, ignored_modules, root_kwargs + ) + if mixed_precision is not None: + target_module_to_kwargs = _run_mixed_precision_override_policy( + root_module, + mixed_precision._module_classes_to_ignore, + ignored_modules, + root_kwargs, + target_module_to_kwargs, + ) + overridden_module_classes = _override_module_mixed_precision( + root_module, mixed_precision._module_classes_to_ignore + ) + _warn_on_overridden_mixed_precision(overridden_module_classes) + use_orig_params = root_kwargs.get("use_orig_params", False) + _validate_frozen_params( + root_module, + set(target_module_to_kwargs.keys()), + ignored_params, + use_orig_params, + ) + wrap_fn = _construct_wrap_fn(root_module, target_module_to_kwargs, fsdp_fn) + _post_order_apply(root_module, wrap_fn) + return + + recursive_wrap_kwargs = { + "module": root_module, + "auto_wrap_policy": policy, + "wrapper_cls": fsdp_fn, + "ignored_modules": ignored_modules, + "ignored_params": ignored_params, + "only_wrap_children": True, + } + if mixed_precision is not None: + # Wrap modules of the ignored types separately and register forward + # hooks to cast to fp32 and back to the original dtype, respectively + overridden_module_classes = _override_module_mixed_precision( + root_module, mixed_precision._module_classes_to_ignore + ) + policy = functools.partial( + _or_policy, + policies=[ + policy, + partial( + _wrap_module_cls_individually, + module_classes=mixed_precision._module_classes_to_ignore, + ), + ], + ) + recursive_wrap_kwargs["auto_wrap_policy"] = policy + _warn_on_overridden_mixed_precision(overridden_module_classes) + _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type] + + +def _check_nested_wrapping(root_module: nn.Module): + for module_name, module in root_module.named_modules(): + if _get_module_fsdp_state(module) is not None: + raise ValueError( + "FSDP auto wrapping requires modules to not already have " + f"FSDP applied but found {module_name} in\n{root_module}" + ) + + +def _warn_on_overridden_mixed_precision( + overridden_module_classes: set[type[nn.Module]], +): + if len(overridden_module_classes) == 0: + return + warnings.warn( + "Both mixed precision and an auto_wrap_policy were specified to FSDP, " + f"where the wrapped module has submodules of type:\n{overridden_module_classes}\n" + "These modules will be wrapped as separate FSDP instacnes with mixed " + "precision disabled.", + stacklevel=2, + ) + + +def _validate_frozen_params( + root_module: nn.Module, + modules_to_wrap: set[nn.Module], + ignored_params: set[nn.Parameter], + use_orig_params: bool, +): + """ + This checks that, given ``modules_to_wrap``, each module would manage + parameters that are uniformly frozen or non-frozen. This uniformity + requirement is strict for ``use_orig_params=False`` (hard error) and highly + recommended for ``use_orig_params=True`` (user warning). + """ + post_order_named_modules = _get_post_order_named_modules(root_module) + visited_modules: set[nn.Module] = set() + for module_name, module in post_order_named_modules: + if module in modules_to_wrap: + param_to_fqn = _get_managed_param_to_fqn( + module, ignored_params, visited_modules, module_name + ) + frozen_param_fqns: list[str] = [] + frozen_param_numel = 0 + nonfrozen_param_fqns: list[str] = [] + nonfrozen_param_numel = 0 + for param, fqn in param_to_fqn.items(): + if param.requires_grad: + nonfrozen_param_fqns.append(fqn) + nonfrozen_param_numel += param.numel() + else: + frozen_param_fqns.append(fqn) + frozen_param_numel += param.numel() + if len(frozen_param_fqns) > 0 and len(nonfrozen_param_fqns) > 0: + msg = f"{module_name} has both parameters with requires_grad=True and False." + if use_orig_params: + total_param_numel = frozen_param_numel + nonfrozen_param_numel + msg += ( + " We do not recommend wrapping such modules since " + "the gradient memory usage will be higher than expected " + f"({total_param_numel} numel instead of {nonfrozen_param_numel} numel " + "before sharding via reduce-scatter). " + ) + else: + msg += " FSDP does not support wrapping such modules when use_orig_params=False. " + msg += "If possible, wrap the frozen parameters with FSDP separately.\n" + msg += ( + f"The following parameters have requires_grad=True:\n{nonfrozen_param_fqns}\n" + f"The following parameters have requires_grad=False:\n{frozen_param_fqns}" + ) + if use_orig_params: + warnings.warn(msg, stacklevel=2) + else: + raise ValueError(msg) + + +def _get_post_order_named_modules( + root_module: nn.Module, +) -> list[tuple[str, nn.Module]]: + """ + This returns the named modules following a post-order traversal, which is a + valid reverse topological sort. We achieve this using the reverse of a + stack-based DFS order instead of reversing ``root_module.named_modules()`` + since the former gives the modules in registration order at each level in + the module tree (as opposed to the reverse), which allows us to error/warn + on the first registered module that violates the condition. + + For example, consider the following module structure: + M( + S1(), + S2( + SS1(), + SS2(), + ), + S3(), + ) + The reverse DFS order is [S1, SS1, SS2, S2, S3, M], while the reverse + ``named_modules()`` order is [S3, SS2, SS1, S2, S1, M]. + """ + visited_modules = {root_module} + stack = [("", root_module)] + # Append and reverse at the end for linear-time algorithm + reverse_post_order_named_modules: list[tuple[str, nn.Module]] = [] + while stack: + module_name, module = stack.pop() + reverse_post_order_named_modules.append((module_name, module)) + for child_module_name, child_module in module.named_children(): + if child_module is None: # only for overrides of `named_children()` + continue + if child_module not in visited_modules: + visited_modules.add(child_module) + if module_name != "": + child_module_name = module_name + "." + child_module_name + stack.append((child_module_name, child_module)) + post_order_named_modules = list(reversed(reverse_post_order_named_modules)) + return post_order_named_modules + + +def _get_managed_param_to_fqn( + module_to_wrap: nn.Module, + ignored_params: set[nn.Parameter], + visited_modules: set[nn.Module], + root_prefix: str, +) -> dict[nn.Parameter, str]: + """ + This returns a dict that maps managed parameter to its FQN for the given + ``module_to_wrap``. The dict's keys are exactly the parameters that would + be managed by the module, where this is achieved by calling this function + on the modules to wrap in reverse topological order, destructively updating + ``visited_modules``, and not traversing into those modules. The FQNs are + prefixed from the root (via ``root_prefix``) to be more informative. + + NOTE: This function is meant to be called pre-wrapping and iteratively in + reverse topological order to cover the full module tree. This differs from + the ``_get_param_to_fqn()`` function meant to be called post-wrapping and + on the full module tree in one shot. Given those differences, we do not try + to unify the two. + """ + param_to_fqn: dict[nn.Parameter, str] = {} + # Run BFS (or any tree traversal works) + queue = collections.deque([(module_to_wrap, root_prefix)]) + visited_modules.add(module_to_wrap) + while queue: + module, prefix = queue.popleft() + for param_name, param in module.named_parameters(recurse=False): + if param not in ignored_params: + fqn = param_name if prefix == "" else prefix + "." + param_name + param_to_fqn[param] = fqn + for child_module_name, child_module in module.named_children(): + if child_module is None: # only for overrides of `named_children()` + continue + if child_module not in visited_modules: + visited_modules.add(child_module) + child_prefix = ( + child_module_name + if prefix == "" + else prefix + "." + child_module_name + ) + queue.append((child_module, child_prefix)) + return param_to_fqn diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/api.py new file mode 100644 index 0000000000000000000000000000000000000000..17ed0483f1c26248673fe888bc5489e099b1313b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/api.py @@ -0,0 +1,417 @@ +""" +This file includes public APIs for FSDP such as the classes used for the +constructor arguments. +""" + +from collections.abc import Sequence +from dataclasses import dataclass +from enum import auto, Enum +from typing import Optional + +import torch +from torch.nn.modules.batchnorm import _BatchNorm + + +__all__ = [ + "ShardingStrategy", + "BackwardPrefetch", + "MixedPrecision", + "CPUOffload", + "StateDictType", + "StateDictConfig", + "FullStateDictConfig", + "LocalStateDictConfig", + "ShardedStateDictConfig", + "OptimStateDictConfig", + "FullOptimStateDictConfig", + "LocalOptimStateDictConfig", + "ShardedOptimStateDictConfig", + "StateDictSettings", +] + + +class ShardingStrategy(Enum): + """ + This specifies the sharding strategy to be used for distributed training by + :class:`FullyShardedDataParallel`. + + - ``FULL_SHARD``: Parameters, gradients, and optimizer states are sharded. + For the parameters, this strategy unshards (via all-gather) before the + forward, reshards after the forward, unshards before the backward + computation, and reshards after the backward computation. For gradients, + it synchronizes and shards them (via reduce-scatter) after the backward + computation. The sharded optimizer states are updated locally per rank. + - ``SHARD_GRAD_OP``: Gradients and optimizer states are sharded during + computation, and additionally, parameters are sharded outside + computation. For the parameters, this strategy unshards before the + forward, does not reshard them after the forward, and only reshards them + after the backward computation. The sharded optimizer states are updated + locally per rank. Inside ``no_sync()``, the parameters are not resharded + after the backward computation. + - ``NO_SHARD``: Parameters, gradients, and optimizer states are not sharded + but instead replicated across ranks similar to PyTorch's + :class:`DistributedDataParallel` API. For gradients, this strategy + synchronizes them (via all-reduce) after the backward computation. The + unsharded optimizer states are updated locally per rank. + - ``HYBRID_SHARD``: Apply ``FULL_SHARD`` within a node, and replicate parameters across + nodes. This results in reduced communication volume as expensive all-gathers and + reduce-scatters are only done within a node, which can be more performant for medium + -sized models. + - ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across + nodes. This is like ``HYBRID_SHARD``, except this may provide even higher throughput + since the unsharded parameters are not freed after the forward pass, saving the + all-gathers in the pre-backward. + """ + + FULL_SHARD = auto() + SHARD_GRAD_OP = auto() + NO_SHARD = auto() + HYBRID_SHARD = auto() + _HYBRID_SHARD_ZERO2 = auto() + + +class BackwardPrefetch(Enum): + """ + This configures explicit backward prefetching, which improves throughput by + enabling communication and computation overlap in the backward pass at the + cost of slightly increased memory usage. + + - ``BACKWARD_PRE``: This enables the most overlap but increases memory + usage the most. This prefetches the next set of parameters *before* the + current set of parameters' gradient computation. This overlaps the *next + all-gather* and the *current gradient computation*, and at the peak, it + holds the current set of parameters, next set of parameters, and current + set of gradients in memory. + - ``BACKWARD_POST``: This enables less overlap but requires less memory + usage. This prefetches the next set of parameters *after* the current + set of parameters' gradient computation. This overlaps the *current + reduce-scatter* and the *next gradient computation*, and it frees the + current set of parameters before allocating memory for the next set of + parameters, only holding the next set of parameters and current set of + gradients in memory at the peak. + - FSDP's ``backward_prefetch`` argument accepts ``None``, which disables + the backward prefetching altogether. This has no overlap and does not + increase memory usage. In general, we do not recommend this setting since + it may degrade throughput significantly. + + For more technical context: For a single process group using NCCL backend, + any collectives, even if issued from different streams, contend for the + same per-device NCCL stream, which implies that the relative order in which + the collectives are issued matters for overlapping. The two backward + prefetching values correspond to different issue orders. + """ + + # NOTE: For both modes, the ordering that defines "current" and "next" is + # not always exact in the current implementation. A mistargeted prefetch + # simply means that the parameter memory is allocated earlier than needed, + # possibly increasing peak memory usage, but does not affect correctness. + BACKWARD_PRE = auto() + BACKWARD_POST = auto() + + +@dataclass +class MixedPrecision: + """ + This configures FSDP-native mixed precision training. + + Attributes: + param_dtype (Optional[torch.dtype]): This specifies the dtype for model + parameters during forward and backward and thus the dtype for + forward and backward computation. Outside forward and backward, the + *sharded* parameters are kept in full precision (e.g. for the + optimizer step), and for model checkpointing, the parameters are + always saved in full precision. (Default: ``None``) + reduce_dtype (Optional[torch.dtype]): This specifies the dtype for + gradient reduction (i.e. reduce-scatter or all-reduce). If this is + ``None`` but ``param_dtype`` is not ``None``, then this takes on + the ``param_dtype`` value, still running gradient reduction in low + precision. This is permitted to differ from ``param_dtype``, e.g. + to force gradient reduction to run in full precision. (Default: + ``None``) + buffer_dtype (Optional[torch.dtype]): This specifies the dtype for + buffers. FSDP does not shard buffers. Rather, FSDP casts them to + ``buffer_dtype`` in the first forward pass and keeps them in that + dtype thereafter. For model checkpointing, the buffers are saved + in full precision except for ``LOCAL_STATE_DICT``. (Default: + ``None``) + keep_low_precision_grads (bool): If ``False``, then FSDP upcasts + gradients to full precision after the backward pass in preparation + for the optimizer step. If ``True``, then FSDP keeps the gradients + in the dtype used for gradient reduction, which can save memory if + using a custom optimizer that supports running in low precision. + (Default: ``False``) + cast_forward_inputs (bool): If ``True``, then this FSDP module casts + its forward args and kwargs to ``param_dtype``. This is to ensure + that parameter and input dtypes match for forward computation, as + required by many ops. This may need to be set to ``True`` when only + applying mixed precision to some but not all FSDP modules, in which + case a mixed-precision FSDP submodule needs to recast its inputs. + (Default: ``False``) + cast_root_forward_inputs (bool): If ``True``, then the root FSDP module + casts its forward args and kwargs to ``param_dtype``, overriding + the value of ``cast_forward_inputs``. For non-root FSDP modules, + this does not do anything. (Default: ``True``) + _module_classes_to_ignore: (Sequence[Type[nn.Module]]): This specifies + module classes to ignore for mixed precision when using an + ``auto_wrap_policy``: Modules of these classes will have FSDP + applied to them separately with mixed precision disabled (meaning + that the final FSDP construction would deviate from the specified + policy). If ``auto_wrap_policy`` is not specified, then this does + not do anything. This API is experimental and subject to change. + (Default: ``(_BatchNorm,)``) + + .. note:: This API is experimental and subject to change. + + .. note:: Only floating point tensors are cast to their specified dtypes. + + .. note:: In ``summon_full_params``, parameters are forced to full + precision, but buffers are not. + + .. note:: Layer norm and batch norm accumulate in ``float32`` even when + their inputs are in a low precision like ``float16`` or ``bfloat16``. + Disabling FSDP's mixed precision for those norm modules only means that + the affine parameters are kept in ``float32``. However, this incurs + separate all-gathers and reduce-scatters for those norm modules, which + may be inefficient, so if the workload permits, the user should prefer + to still apply mixed precision to those modules. + + .. note:: By default, if the user passes a model with any ``_BatchNorm`` + modules and specifies an ``auto_wrap_policy``, then the batch norm + modules will have FSDP applied to them separately with mixed precision + disabled. See the ``_module_classes_to_ignore`` argument. + + .. note:: ``MixedPrecision`` has ``cast_root_forward_inputs=True`` and + ``cast_forward_inputs=False`` by default. For the root FSDP instance, + its ``cast_root_forward_inputs`` takes precedence over its + ``cast_forward_inputs``. For non-root FSDP instances, their + ``cast_root_forward_inputs`` values are ignored. The default setting is + sufficient for the typical case where each FSDP instance has the same + ``MixedPrecision`` configuration and only needs to cast inputs to the + ``param_dtype`` at the beginning of the model's forward pass. + + .. note:: For nested FSDP instances with different ``MixedPrecision`` + configurations, we recommend setting individual ``cast_forward_inputs`` + values to configure casting inputs or not before each instance's + forward. In such a case, since the casts happen before each FSDP + instance's forward, a parent FSDP instance should have its non-FSDP + submodules run before its FSDP submodules to avoid the activation dtype + being changed due to a different ``MixedPrecision`` configuration. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) + >>> model[1] = FSDP( + >>> model[1], + >>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), + >>> ) + >>> model = FSDP( + >>> model, + >>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), + >>> ) + + The above shows a working example. On the other hand, if ``model[1]`` + were replaced with ``model[0]``, meaning that the submodule using + different ``MixedPrecision`` ran its forward first, then ``model[1]`` + would incorrectly see ``float16`` activations instead of ``bfloat16`` + ones. + + """ + + param_dtype: Optional[torch.dtype] = None + reduce_dtype: Optional[torch.dtype] = None + buffer_dtype: Optional[torch.dtype] = None + keep_low_precision_grads: bool = False + cast_forward_inputs: bool = False + cast_root_forward_inputs: bool = True + _module_classes_to_ignore: Sequence[type[torch.nn.Module]] = (_BatchNorm,) + + +@dataclass +class CPUOffload: + """ + This configures CPU offloading. + + Attributes: + offload_params (bool): This specifies whether to offload parameters to + CPU when not involved in computation. If ``True``, then this + offloads gradients to CPU as well, meaning that the optimizer step + runs on CPU. + """ + + offload_params: bool = False + + +class StateDictType(Enum): + """ + This enum indicates that which type of ``state_dict`` the FSDP module is + currently processing (returning or loading). + The default value is FULL_STATE_DICT to comply the PyTorch convention. + + .. note:: + FSDP currently supports three types of ``state_dict``: + 1. ``state_dict/load_state_dict`: this pair of APIs return and load + the non-sharded, unflattened parameters. The semantics is the + same as using DDP. + 2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return + and load local sharded, flattened parameters. The values returned + by ``_local_state_dict`` can be directly used by FSDP and is only + meaningful to FSDP (because parameters are flattened). Note that + these APIs are meant for use via the :func:`state_dict_type` + context manager as follows: + >>> # xdoctest: +SKIP("undefined variables") + >>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT): + ... state = fsdp.state_dict() # loads local state dict + 3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs + return and load sharded, unflattened parameters. The ``state_dict`` + return by ``sharded_state_dict`` can be used by all other parallel + schemes (resharding may be required). + """ + + FULL_STATE_DICT = auto() + LOCAL_STATE_DICT = auto() + SHARDED_STATE_DICT = auto() + + +@dataclass +class StateDictConfig: + """ + ``StateDictConfig`` is the base class for all ``state_dict`` configuration + classes. Users should instantiate a child class (e.g. + ``FullStateDictConfig``) in order to configure settings for the + corresponding ``state_dict`` type supported by FSDP. + + Attributes: + offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict + values to CPU, and if ``False``, then FSDP keeps them on GPU. + (Default: ``False``) + """ + + offload_to_cpu: bool = False + + +@dataclass +class FullStateDictConfig(StateDictConfig): + """ + ``FullStateDictConfig`` is a config class meant to be used with + ``StateDictType.FULL_STATE_DICT``. We recommend enabling both + ``offload_to_cpu=True`` and ``rank0_only=True`` when saving full state + dicts to save GPU memory and CPU memory, respectively. This config class + is meant to be used via the :func:`state_dict_type` context manager as + follows: + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> fsdp = FSDP(model, auto_wrap_policy=...) + >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): + >>> state = fsdp.state_dict() + >>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0. + >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: + >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP + >>> if dist.get_rank() == 0: + >>> # Load checkpoint only on rank 0 to avoid memory redundancy + >>> state_dict = torch.load("my_checkpoint.pt") + >>> model.load_state_dict(state_dict) + >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument + >>> # communicates loaded checkpoint states from rank 0 to rest of the world. + >>> fsdp = FSDP( + ... model, + ... device_id=torch.cuda.current_device(), + ... auto_wrap_policy=..., + ... sync_module_states=True, + ... ) + >>> # After this point, all ranks have FSDP model with loaded checkpoint. + + Attributes: + rank0_only (bool): If ``True``, then only rank 0 saves the full state + dict, and nonzero ranks save an empty dict. If ``False``, then all + ranks save the full state dict. (Default: ``False``) + """ + + rank0_only: bool = False + + +@dataclass +class LocalStateDictConfig(StateDictConfig): + pass + + +@dataclass +class ShardedStateDictConfig(StateDictConfig): + """ + ``ShardedStateDictConfig`` is a config class meant to be used with + ``StateDictType.SHARDED_STATE_DICT``. + + Attributes: + _use_dtensor (bool): If ``True``, then FSDP saves the state dict values + as ``DTensor``, and if ``False``, then FSDP saves them as + ``ShardedTensor``. (Default: ``False``) + + .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedStateDictConfig` + and it is used by FSDP to determine the type of state dict values. Users should not + manually modify ``_use_dtensor``. + """ + + _use_dtensor: bool = False + + +@dataclass +class OptimStateDictConfig: + """ + ``OptimStateDictConfig`` is the base class for all ``optim_state_dict`` + configuration classes. Users should instantiate a child class (e.g. + ``FullOptimStateDictConfig``) in order to configure settings for the + corresponding ``optim_state_dict`` type supported by FSDP. + + Attributes: + offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict's + tensor values to CPU, and if ``False``, then FSDP keeps them on the + original device (which is GPU unless parameter CPU offloading is + enabled). (Default: ``True``) + """ + + offload_to_cpu: bool = True + + +@dataclass +class FullOptimStateDictConfig(OptimStateDictConfig): + """ + Attributes: + rank0_only (bool): If ``True``, then only rank 0 saves the full state + dict, and nonzero ranks save an empty dict. If ``False``, then all + ranks save the full state dict. (Default: ``False``) + """ + + rank0_only: bool = False + + +@dataclass +class LocalOptimStateDictConfig(OptimStateDictConfig): + offload_to_cpu: bool = False + + +@dataclass +class ShardedOptimStateDictConfig(OptimStateDictConfig): + """ + ``ShardedOptimStateDictConfig`` is a config class meant to be used with + ``StateDictType.SHARDED_STATE_DICT``. + + Attributes: + _use_dtensor (bool): If ``True``, then FSDP saves the state dict values + as ``DTensor``, and if ``False``, then FSDP saves them as + ``ShardedTensor``. (Default: ``False``) + + .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedOptimStateDictConfig` + and it is used by FSDP to determine the type of state dict values. Users should not + manually modify ``_use_dtensor``. + """ + + _use_dtensor: bool = False + + +@dataclass +class StateDictSettings: + state_dict_type: StateDictType + state_dict_config: StateDictConfig + optim_state_dict_config: OptimStateDictConfig diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..cdc5ef424e7052a41ddb986da07e1edb389bed27 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -0,0 +1,2199 @@ +# mypy: ignore-errors + +import contextlib +import copy +import functools +import math +import traceback +import warnings +from collections.abc import Callable, Generator, Iterable, Iterator +from contextlib import contextmanager +from enum import auto, Enum +from typing import Any, Optional, Union + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._traversal_utils as traversal_utils +import torch.nn as nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + _CHECKPOINT_WRAPPED_MODULE, + ActivationWrapper, +) +from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS +from torch.distributed.fsdp._common_utils import ( + _FSDPState, + _get_param_to_fqns, + FSDP_PREFIX, + FSDP_WRAPPED_MODULE, + HandleTrainingState, + TrainingState, +) +from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo +from torch.distributed.fsdp._init_utils import ( + _check_orig_params_flattened, + _init_buffer_state, + _init_core_state, + _init_device_handle, + _init_extension, + _init_ignored_module_states, + _init_param_handle_from_module, + _init_prefetching_state, + _init_process_group_state, + _init_runtime_state, + _init_state_dict_state, + HYBRID_SHARDING_STRATEGIES, + ProcessGroupType, +) +from torch.distributed.fsdp._runtime_utils import ( + _get_fsdp_root_states, + _is_fsdp_root, + _lazy_init, + _post_forward, + _post_forward_reshard, + _pre_forward, + _pre_forward_unshard, + _root_pre_forward, + _unshard, + _wait_for_computation_stream, +) +from torch.distributed.fsdp._wrap_utils import _auto_wrap +from torch.distributed.fsdp.api import ( + BackwardPrefetch, + CPUOffload, + FullOptimStateDictConfig, + FullStateDictConfig, + LocalOptimStateDictConfig, + LocalStateDictConfig, + MixedPrecision, + OptimStateDictConfig, + ShardedOptimStateDictConfig, + ShardedStateDictConfig, + ShardingStrategy, + StateDictConfig, + StateDictSettings, + StateDictType, +) +from torch.distributed.tensor import DeviceMesh +from torch.distributed.utils import _p_assert + +from ._flat_param import FlatParameter, FlatParamHandle +from ._optim_utils import ( + _flatten_optim_state_dict, + _get_param_id_to_param_from_optim_input, + _get_param_key_to_param, + _get_param_to_param_id_from_optim_input, + _get_param_to_param_key, + _optim_state_dict, + _rekey_sharded_optim_state_dict, + _set_optim_use_dtensor, +) +from ._state_dict_utils import _register_all_state_dict_hooks +from ._unshard_param_utils import ( + _deregister_orig_params, + _register_flat_param, + _register_orig_params, + _unshard_params, + _unshard_params_for_summon, +) +from .wrap import CustomPolicy, ModuleWrapPolicy + + +__all__ = [ + "FullyShardedDataParallel", + "OptimStateKeyType", +] + + +FLAT_PARAM = "_flat_param" + + +class OptimStateKeyType(Enum): + """Represents the type of key in an optimizer state-dict.""" + + PARAM_NAME = auto() + PARAM_ID = auto() + + +class FullyShardedDataParallel(nn.Module, _FSDPState): + """A wrapper for sharding module parameters across data parallel workers. + + This is inspired by `Xu et al. `_ as + well as the ZeRO Stage 3 from `DeepSpeed `_. + FullyShardedDataParallel is commonly shortened to FSDP. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> import torch + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> torch.cuda.set_device(device_id) + >>> sharded_module = FSDP(my_module) + >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) + >>> x = sharded_module(x, y=3, z=torch.Tensor([1])) + >>> loss = x.sum() + >>> loss.backward() + >>> optim.step() + + Using FSDP involves wrapping your module and then initializing your + optimizer after. This is required since FSDP changes the parameter + variables. + + When setting up FSDP, you need to consider the destination CUDA + device. If the device has an ID (``dev_id``), you have three options: + + * Place the module on that device + * Set the device using ``torch.cuda.set_device(dev_id)`` + * Pass ``dev_id`` into the ``device_id`` constructor argument. + + This ensures that the FSDP instance's compute device is the + destination device. For option 1 and 3, the FSDP initialization + always occurs on GPU. For option 2, the FSDP initialization + happens on module's current device, which may be a CPU. + + If you're using the ``sync_module_states=True`` flag, you need to + ensure that the module is on a GPU or use the ``device_id`` + argument to specify a CUDA device that FSDP will move the module + to in the FSDP constructor. This is necessary because + ``sync_module_states=True`` requires GPU communication. + + FSDP also takes care of moving input tensors to the forward method + to the GPU compute device, so you don't need to manually move them + from CPU. + + For ``use_orig_params=True``, + ``ShardingStrategy.SHARD_GRAD_OP`` exposes the unsharded + parameters, not the sharded parameters after forward, unlike + ``ShardingStrategy.FULL_SHARD``. If you want + to inspect the gradients, you can use the ``summon_full_params`` + method with ``with_grads=True``. + + With ``limit_all_gathers=True``, you may see a gap in the FSDP + pre-forward where the CPU thread is not issuing any kernels. This is + intentional and shows the rate limiter in effect. Synchronizing the CPU + thread in that way prevents over-allocating memory for subsequent + all-gathers, and it should not actually delay GPU kernel execution. + + FSDP replaces managed modules' parameters with ``torch.Tensor`` + views during forward and backward computation for autograd-related + reasons. If your module's forward relies on saved references to + the parameters instead of reacquiring the references each + iteration, then it will not see FSDP's newly created views, + and autograd will not work correctly. + + Finally, when using ``sharding_strategy=ShardingStrategy.HYBRID_SHARD`` + with the sharding process group being intra-node and the + replication process group being inter-node, setting + ``NCCL_CROSS_NIC=1`` can help improve the all-reduce times over + the replication process group for some cluster setups. + + **Limitations** + + There are several limitations to be aware of when using FSDP: + + * FSDP currently does not support gradient accumulation outside + ``no_sync()`` when using CPU offloading. This is because FSDP + uses the newly-reduced gradient instead of accumulating with any + existing gradient, which can lead to incorrect results. + + * FSDP does not support running the forward pass of a submodule + that is contained in an FSDP instance. This is because the + submodule's parameters will be sharded, but the submodule itself + is not an FSDP instance, so its forward pass will not all-gather + the full parameters appropriately. + + * FSDP does not work with double backwards due to the way it + registers backward hooks. + + * FSDP has some constraints when freezing parameters. + For ``use_orig_params=False``, each FSDP instance must manage + parameters that are all frozen or all non-frozen. For + ``use_orig_params=True``, FSDP supports mixing frozen and + non-frozen parameters, but it's recommended to avoid doing so to + prevent higher than expected gradient memory usage. + + * As of PyTorch 1.12, FSDP offers limited support for shared + parameters. If enhanced shared parameter support is needed for + your use case, please post in + `this issue `__. + + * You should avoid modifying the parameters between forward and + backward without using the ``summon_full_params`` context, as + the modifications may not persist. + + Args: + module (nn.Module): + This is the module to be wrapped with FSDP. + process_group (Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]): + This is the process group over which the model is sharded and thus + the one used for FSDP's all-gather and reduce-scatter collective + communications. If ``None``, then FSDP uses the default process + group. For hybrid sharding strategies such as + ``ShardingStrategy.HYBRID_SHARD``, users can pass in a tuple of + process groups, representing the groups over which to shard and + replicate, respectively. If ``None``, then FSDP constructs process + groups for the user to shard intra-node and replicate inter-node. + (Default: ``None``) + sharding_strategy (Optional[ShardingStrategy]): + This configures the sharding strategy, which may trade off memory + saving and communication overhead. See :class:`ShardingStrategy` + for details. (Default: ``FULL_SHARD``) + cpu_offload (Optional[CPUOffload]): + This configures CPU offloading. If this is set to ``None``, then + no CPU offloading happens. See :class:`CPUOffload` for details. + (Default: ``None``) + auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]]): + This specifies a policy to apply FSDP to submodules of ``module``, + which is needed for communication and computation overlap and thus + affects performance. If ``None``, then FSDP only applies to + ``module``, and users should manually apply FSDP to parent modules + themselves (proceeding bottom-up). For convenience, this accepts + ``ModuleWrapPolicy`` directly, which allows users to specify the + module classes to wrap (e.g. the transformer block). Otherwise, + this should be a callable that takes in three arguments + ``module: nn.Module``, ``recurse: bool``, and + ``nonwrapped_numel: int`` and should return a ``bool`` specifying + whether the passed-in ``module`` should have FSDP applied if + ``recurse=False`` or if the traversal should continue into the + module's subtree if ``recurse=True``. Users may add additional + arguments to the callable. The ``size_based_auto_wrap_policy`` in + ``torch.distributed.fsdp.wrap.py`` gives an example callable that + applies FSDP to a module if the parameters in its subtree exceed + 100M numel. We recommend printing the model after applying FSDP + and adjusting as needed. + + Example:: + + >>> def custom_auto_wrap_policy( + >>> module: nn.Module, + >>> recurse: bool, + >>> nonwrapped_numel: int, + >>> # Additional custom arguments + >>> min_num_params: int = int(1e8), + >>> ) -> bool: + >>> return nonwrapped_numel >= min_num_params + >>> # Configure a custom `min_num_params` + >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5)) + + backward_prefetch (Optional[BackwardPrefetch]): + This configures explicit backward prefetching of all-gathers. If + ``None``, then FSDP does not backward prefetch, and there is no + communication and computation overlap in the backward pass. See + :class:`BackwardPrefetch` for details. (Default: ``BACKWARD_PRE``) + mixed_precision (Optional[MixedPrecision]): + This configures native mixed precision for FSDP. If this is set to + ``None``, then no mixed precision is used. Otherwise, parameter, + buffer, and gradient reduction dtypes can be set. See + :class:`MixedPrecision` for details. (Default: ``None``) + ignored_modules (Optional[Iterable[torch.nn.Module]]): Modules whose + own parameters and child modules' parameters and buffers are + ignored by this instance. None of the modules directly in + ``ignored_modules`` should be :class:`FullyShardedDataParallel` + instances, and any child modules that are already-constructed + :class:`FullyShardedDataParallel` instances will not be ignored if + they are nested under this instance. This argument may be used to + avoid sharding specific parameters at module granularity when using an + ``auto_wrap_policy`` or if parameters' sharding is not managed by + FSDP. (Default: ``None``) + param_init_fn (Optional[Callable[[nn.Module], None]]): + A ``Callable[torch.nn.Module] -> None`` that + specifies how modules that are currently on the meta device should + be initialized onto an actual device. As of v1.12, FSDP detects + modules with parameters or buffers on meta device via ``is_meta`` + and either applies ``param_init_fn`` if specified or calls + ``nn.Module.reset_parameters()`` otherwise. For both cases, the + implementation should *only* initialize the parameters/buffers of + the module, not those of its submodules. This is to avoid + re-initialization. In addition, FSDP also supports deferred + initialization via torchdistX's (https://github.com/pytorch/torchdistX) + ``deferred_init()`` API, where the deferred modules are initialized + by calling ``param_init_fn`` if specified or torchdistX's default + ``materialize_module()`` otherwise. If ``param_init_fn`` is + specified, then it is applied to all meta-device modules, meaning + that it should probably case on the module type. FSDP calls the + initialization function before parameter flattening and sharding. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> module = MyModule(device="meta") + >>> def my_init_fn(module: nn.Module): + >>> # E.g. initialize depending on the module type + >>> ... + >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) + >>> print(next(fsdp_model.parameters()).device) # current CUDA device + >>> # With torchdistX + >>> module = deferred_init.deferred_init(MyModule, device="cuda") + >>> # Will initialize via deferred_init.materialize_module(). + >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy) + + device_id (Optional[Union[int, torch.device]]): An ``int`` or + ``torch.device`` giving the CUDA device on which FSDP + initialization takes place, including the module initialization + if needed and the parameter sharding. This should be specified to + improve initialization speed if ``module`` is on CPU. If the + default CUDA device was set (e.g. via ``torch.cuda.set_device``), + then the user may pass ``torch.cuda.current_device`` to this. + (Default: ``None``) + sync_module_states (bool): If ``True``, then each FSDP module will + broadcast module parameters and buffers from rank 0 to ensure that + they are replicated across ranks (adding communication overhead to + this constructor). This can help load ``state_dict`` checkpoints + via ``load_state_dict`` in a memory efficient way. See + :class:`FullStateDictConfig` for an example of this. (Default: + ``False``) + forward_prefetch (bool): If ``True``, then FSDP *explicitly* prefetches + the next forward-pass all-gather before the current forward + computation. This is only useful for CPU-bound workloads, in which + case issuing the next all-gather earlier may improve overlap. This + should only be used for static-graph models since the prefetching + follows the first iteration's execution order. (Default: ``False``) + limit_all_gathers (bool): If ``True``, then FSDP explicitly + synchronizes the CPU thread to ensure GPU memory usage from only + *two* consecutive FSDP instances (the current instance running + computation and the next instance whose all-gather is prefetched). + If ``False``, then FSDP allows the CPU thread to issue all-gathers + without any extra synchronization. (Default: ``True``) We often + refer to this feature as the "rate limiter". This flag should only + be set to ``False`` for specific CPU-bound workloads with low + memory pressure in which case the CPU thread can aggressively issue + all kernels without concern for the GPU memory usage. + use_orig_params (bool): Setting this to ``True`` has FSDP use + ``module`` 's original parameters. FSDP exposes those original + parameters to the user via :meth:`nn.Module.named_parameters` + instead of FSDP's internal :class:`FlatParameter` s. This means + that the optimizer step runs on the original parameters, enabling + per-original-parameter hyperparameters. FSDP preserves the original + parameter variables and manipulates their data between unsharded + and sharded forms, where they are always views into the underlying + unsharded or sharded :class:`FlatParameter`, respectively. With the + current algorithm, the sharded form is always 1D, losing the + original tensor structure. An original parameter may have all, + some, or none of its data present for a given rank. In the none + case, its data will be like a size-0 empty tensor. Users should not + author programs relying on what data is present for a given + original parameter in its sharded form. ``True`` is required to + use ``torch.compile()``. Setting this to ``False`` exposes FSDP's + internal :class:`FlatParameter` s to the user via + :meth:`nn.Module.named_parameters`. (Default: ``False``) + ignored_states (Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]): + Ignored parameters or modules that will not be managed by this FSDP + instance, meaning that the parameters are not sharded and their + gradients are not reduced across ranks. This argument unifies with + the existing ``ignored_modules`` argument, and we may deprecate + ``ignored_modules`` soon. For backward compatibility, we keep both + ``ignored_states`` and `ignored_modules``, but FSDP only allows one + of them to be specified as not ``None``. + device_mesh (Optional[DeviceMesh]): DeviceMesh can be used as an alternative to + process_group. When device_mesh is passed, FSDP will use the underlying process + groups for all-gather and reduce-scatter collective communications. Therefore, + these two args need to be mutually exclusive. For hybrid sharding strategies such as + ``ShardingStrategy.HYBRID_SHARD``, users can pass in a 2D DeviceMesh instead + of a tuple of process groups. For 2D FSDP + TP, users are required to pass in + device_mesh instead of process_group. For more DeviceMesh info, please visit: + https://pytorch.org/tutorials/recipes/distributed_device_mesh.html + """ + + def __init__( + self, + module: nn.Module, + process_group: ProcessGroupType = None, + sharding_strategy: Optional[ShardingStrategy] = None, + cpu_offload: Optional[CPUOffload] = None, + auto_wrap_policy: Optional[ + Union[Callable, ModuleWrapPolicy, CustomPolicy] + ] = None, + backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE, + mixed_precision: Optional[MixedPrecision] = None, + ignored_modules: Optional[Iterable[torch.nn.Module]] = None, + param_init_fn: Optional[Callable[[nn.Module], None]] = None, + device_id: Optional[Union[int, torch.device]] = None, + sync_module_states: bool = False, + forward_prefetch: bool = False, + limit_all_gathers: bool = True, + use_orig_params: bool = False, + ignored_states: Union[ + Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]] + ] = None, + device_mesh: Optional[DeviceMesh] = None, + ): + torch._C._log_api_usage_once("torch.distributed.fsdp") + super().__init__() + if isinstance(module, (nn.ModuleList, nn.ModuleDict)): + warnings.warn( + "FSDP will not all-gather parameters for containers that do " + f"not implement forward: {module}", + stacklevel=2, + ) + _init_ignored_module_states(self, module, ignored_modules, ignored_states) + _init_device_handle(self, module, self._ignored_params, device_id) + + # Add module annotations for Dynamo support (see function for details) + _annotate_modules_for_dynamo(module, self._ignored_modules, use_orig_params) + + # Initializes self.process_group, along with rank and world size. This will + # also set another attribute, _inter_node_pg, to control the process group + # over which sharding occurs, if sharding_strategy is {HYBRID_SHARD, _HYBRID_SHARD_ZERO2}. + # Note that this is done before auto_wrapping, so that child FSDP modules simply pick up + # the same process group state as the root FSDP module. + self._device_mesh = device_mesh + _init_process_group_state( + self, + process_group, + sharding_strategy, + auto_wrap_policy, + device_mesh, + ) + if auto_wrap_policy is not None: + root_kwargs = { + "process_group": process_group, + "sharding_strategy": sharding_strategy, + "cpu_offload": cpu_offload, + "backward_prefetch": backward_prefetch, + "mixed_precision": mixed_precision, + "param_init_fn": param_init_fn, + "device_id": device_id, + "sync_module_states": sync_module_states, + "forward_prefetch": forward_prefetch, + "limit_all_gathers": limit_all_gathers, + "use_orig_params": use_orig_params, + "ignored_states": self._ignored_params, + "device_mesh": device_mesh, + } + if sharding_strategy in HYBRID_SHARDING_STRATEGIES and device_mesh is None: + # Share root process groups with children to maintain + # the invariant that all FSDP modules will have the same + # process groups. + root_kwargs["process_group"] = (self.process_group, self._inter_node_pg) + + _auto_wrap( + module, + auto_wrap_policy, + self._ignored_modules, + self._ignored_params, + root_kwargs, + FullyShardedDataParallel, + ) + + backward_prefetch_limit = 1 + forward_prefetch_limit = 1 + _init_core_state( + self, + sharding_strategy, + mixed_precision, + cpu_offload, + limit_all_gathers, + use_orig_params, + backward_prefetch_limit, + forward_prefetch_limit, + ) + _init_runtime_state(self) + _init_prefetching_state(self, backward_prefetch, forward_prefetch) + _init_buffer_state(self, module) + # extension needs to be set before `_init_param_handle_from_module()` + _init_extension(self, device_mesh) + _init_param_handle_from_module( + self, + module, + device_id, + param_init_fn, + sync_module_states, + ) + self._fsdp_wrapped_module = module + if not use_orig_params: + _check_orig_params_flattened(self, self._ignored_params) + _register_flat_param(self, self) + + # `_state_dict_type` controls the `state_dict()` behavior, which is + # implemented using post-save and pre-load hooks + _init_state_dict_state(self) + _register_all_state_dict_hooks(self) + self._zero_scalar = None + + @property + def module(self) -> nn.Module: + """Return the wrapped module.""" + # FSDP's `.module` must refer to the innermost wrapped module when + # composing with other module wrappers in order for state dict to work + if isinstance(self._fsdp_wrapped_module, ActivationWrapper): + return getattr(self._fsdp_wrapped_module, _CHECKPOINT_WRAPPED_MODULE) + return self._fsdp_wrapped_module + + @property + def _has_params(self) -> bool: + """Returns whether this FSDP instance manages any parameters.""" + return hasattr(self, "_handle") and self._handle is not None + + @property + def _flat_param(self) -> Optional[FlatParameter]: + return self._handle.flat_param if self._handle else None + + def __getattr__(self, name: str) -> Any: + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self._fsdp_wrapped_module, name) + + def __getitem__(self, key: int) -> Any: + """Forward indexing calls in case the module is an ``nn.Sequential``.""" + if hasattr(self, FSDP_WRAPPED_MODULE): + return self._fsdp_wrapped_module.__getitem__(key) # type: ignore[operator] + return super().__getitem__(key) + + def check_is_root(self) -> bool: + """Check if this instance is a root FSDP module.""" + return _is_fsdp_root(self, self) + + @staticmethod + def fsdp_modules( + module: nn.Module, + root_only: bool = False, + ) -> list["FullyShardedDataParallel"]: + """Return all nested FSDP instances. + + This possibly includes ``module`` itself and only includes FSDP root modules if ``root_only=True``. + + Args: + module (torch.nn.Module): Root module, which may or may not be an + ``FSDP`` module. + root_only (bool): Whether to return only FSDP root modules. + (Default: ``False``) + + Returns: + List[FullyShardedDataParallel]: FSDP modules that are nested in + the input ``module``. + """ + if root_only: + return _get_fsdp_root_states(module) + return traversal_utils._get_fsdp_states(module) + + def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel": + r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. + + Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). + + Compared to ``torch.nn.Module.apply``, this version additionally gathers + the full parameters before applying ``fn``. It should not be called from + within another ``summon_full_params`` context. + + Args: + fn (:class:`Module` -> None): function to be applied to each submodule + + Returns: + Module: self + """ + uninitialized = self._is_root is None + self._assert_state(TrainingState.IDLE) + # Use `_unshard_params_for_summon()` with `recurse=False` instead of + # `_unshard_fsdp_state_params()` directly to perform lazy + # initialization, which is needed to initialize `FlatParameter` + # parameter attributes as required by the unshard logic + with _unshard_params_for_summon( + self, + self, + writeback=True, + rank0_only=False, + offload_to_cpu=False, + with_grads=False, + ): + ret = super().apply(fn) + + # Reset lazy init called in `_unshard_params_for_summon()` since + # `apply()` may have been called on FSDP instance that is not truly a + # root, in which case it will be incorrectly marked as one. + if uninitialized and self._is_root: + for module in traversal_utils._get_fsdp_states(self): + module._reset_lazy_init() + + return ret + + def _mixed_precision_enabled_for_buffers(self) -> bool: + """Return whether the user explicitly enabled buffer mixed precision. + + NOTE: Unlike parameters and gradient reduction, buffer mixed precision + is applied at the FSDP instance level, not the ``FlatParameter`` level, + which may be different for the composable code path. + """ + return self.mixed_precision.buffer_dtype is not None + + def _low_precision_hook_enabled(self) -> bool: + """Whether a low precision hook is registered or not.""" + return self._comm_hook is not None and self._comm_hook in LOW_PRECISION_HOOKS + + def _reset_lazy_init(self) -> None: + """Reset instance so :func:`_lazy_init` will run on the next forward.""" + self._is_root: Optional[bool] = None + + @staticmethod + def set_state_dict_type( + module: nn.Module, + state_dict_type: StateDictType, + state_dict_config: Optional[StateDictConfig] = None, + optim_state_dict_config: Optional[OptimStateDictConfig] = None, + ) -> StateDictSettings: + """Set the ``state_dict_type`` of all the descendant FSDP modules of the target module. + + Also takes (optional) configuration for the model's and optimizer's state dict. + The target module does not have to be a FSDP module. If the target + module is a FSDP module, its ``state_dict_type`` will also be changed. + + .. note:: This API should be called for only the top-level (root) + module. + + .. note:: This API enables users to transparently use the conventional + ``state_dict`` API to take model checkpoints in cases where the + root FSDP module is wrapped by another ``nn.Module``. For example, + the following will ensure ``state_dict`` is called on all non-FSDP + instances, while dispatching into `sharded_state_dict` implementation + for FSDP: + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> model = DDP(FSDP(...)) + >>> FSDP.set_state_dict_type( + >>> model, + >>> StateDictType.SHARDED_STATE_DICT, + >>> state_dict_config = ShardedStateDictConfig(offload_to_cpu=True), + >>> optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True), + >>> ) + >>> param_state_dict = model.state_dict() + >>> optim_state_dict = FSDP.optim_state_dict(model, optim) + + Args: + module (torch.nn.Module): Root module. + state_dict_type (StateDictType): the desired ``state_dict_type`` to set. + state_dict_config (Optional[StateDictConfig]): the configuration for the + target ``state_dict_type``. + optim_state_dict_config (Optional[OptimStateDictConfig]): the configuration + for the optimizer state dict. + + Returns: + A StateDictSettings that include the previous state_dict type and + configuration for the module. + """ + warnings.warn( + "FSDP.state_dict_type() and FSDP.set_state_dict_type() are being " + "deprecated. Please use APIs, get_state_dict() and set_state_dict(), " + "which can support different parallelisms, FSDP1, FSDP2, DDP. " + "API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html" + "#torch.distributed.checkpoint.state_dict.get_state_dict ." + "Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .", + FutureWarning, + stacklevel=2, + ) + _state_dict_type_to_config = { + StateDictType.FULL_STATE_DICT: FullStateDictConfig, + StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig, + StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig, + } + _optim_state_dict_type_to_config = { + StateDictType.FULL_STATE_DICT: FullOptimStateDictConfig, + StateDictType.LOCAL_STATE_DICT: LocalOptimStateDictConfig, + StateDictType.SHARDED_STATE_DICT: ShardedOptimStateDictConfig, + } + + # Use the default config if a state_dict config is not set. + state_dict_config_type = _state_dict_type_to_config[state_dict_type] + optim_state_dict_config_type = _optim_state_dict_type_to_config[state_dict_type] + if state_dict_config is None: + state_dict_config = state_dict_config_type() + if optim_state_dict_config is None: + optim_state_dict_config = optim_state_dict_config_type() + if state_dict_config_type is not type(state_dict_config): + raise RuntimeError( + f"Expected state_dict_config of type {state_dict_config_type} " + f"but got {type(state_dict_config)}" + ) + if optim_state_dict_config_type is not type(optim_state_dict_config): + raise RuntimeError( + f"Expected optim_state_dict_config of type {optim_state_dict_config_type} " + f"but got {type(optim_state_dict_config)}" + ) + + # Set the state_dict type and configurations. + prev_state_dict_type = None + prev_state_dict_config = None + prev_optim_state_dict_config = None + for submodule in traversal_utils._get_fsdp_states(module): + if prev_state_dict_type is None: + prev_state_dict_type = submodule._state_dict_type + else: + if prev_state_dict_type != submodule._state_dict_type: + raise AssertionError( + "All FSDP modules should have the same state_dict_type." + ) + if prev_state_dict_config is None: + prev_state_dict_config = submodule._state_dict_config + else: + if not isinstance( + submodule._state_dict_config, type(prev_state_dict_config) + ): + raise AssertionError( + "All FSDP modules must have the same type of state_dict_config." + ) + if prev_optim_state_dict_config is None: + prev_optim_state_dict_config = submodule._optim_state_dict_config + else: + if not isinstance( + submodule._optim_state_dict_config, + type(prev_optim_state_dict_config), + ): + raise AssertionError( + "All FSDP modules must have the same type of optim_state_dict_config." + ) + + submodule._state_dict_type = state_dict_type + submodule._state_dict_config = state_dict_config + submodule._optim_state_dict_config = optim_state_dict_config + + return StateDictSettings( + prev_state_dict_type, prev_state_dict_config, prev_optim_state_dict_config + ) + + @staticmethod + def get_state_dict_type(module: nn.Module) -> StateDictSettings: + """Get the state_dict_type and the corresponding configurations for the FSDP modules rooted at ``module``. + + The target module does not have to be an FSDP module. + + Returns: + A ``StateDictSettings`` containing the state_dict_type and + state_dict / optim_state_dict configs that are currently set. + + Raises: + ``AssertionError`` if the ``StateDictSettings`` for different + FSDP submodules differ. + """ + state_dict_settings: Optional[StateDictSettings] = None + for submodule in FullyShardedDataParallel.fsdp_modules(module): + if state_dict_settings is None: + state_dict_settings = StateDictSettings( + state_dict_type=submodule._state_dict_type, + state_dict_config=submodule._state_dict_config, + optim_state_dict_config=submodule._optim_state_dict_config, + ) + _set_optim_use_dtensor(submodule, state_dict_settings) + else: + submodule_settings = StateDictSettings( + submodule._state_dict_type, + submodule._state_dict_config, + submodule._optim_state_dict_config, + ) + if state_dict_settings != submodule_settings: + raise AssertionError( + "All FSDP modules must have the same state dict settings." + f"Got {submodule_settings} and {state_dict_settings}." + ) + _set_optim_use_dtensor(submodule, submodule_settings) + return state_dict_settings + + @staticmethod + @contextlib.contextmanager + def state_dict_type( + module: nn.Module, + state_dict_type: StateDictType, + state_dict_config: Optional[StateDictConfig] = None, + optim_state_dict_config: Optional[OptimStateDictConfig] = None, + ) -> Generator: + """Set the ``state_dict_type`` of all the descendant FSDP modules of the target module. + + This context manager has the same functions as :meth:`set_state_dict_type`. Read the document of + :meth:`set_state_dict_type` for the detail. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> model = DDP(FSDP(...)) + >>> with FSDP.state_dict_type( + >>> model, + >>> StateDictType.SHARDED_STATE_DICT, + >>> ): + >>> checkpoint = model.state_dict() + + Args: + module (torch.nn.Module): Root module. + state_dict_type (StateDictType): the desired ``state_dict_type`` to set. + state_dict_config (Optional[StateDictConfig]): the model ``state_dict`` + configuration for the target ``state_dict_type``. + optim_state_dict_config (Optional[OptimStateDictConfig]): the optimizer + ``state_dict`` configuration for the target ``state_dict_type``. + """ + prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type( + module, + state_dict_type, + state_dict_config, + optim_state_dict_config, + ) + yield + FullyShardedDataParallel.set_state_dict_type( + module, + prev_state_dict_settings.state_dict_type, + prev_state_dict_settings.state_dict_config, + prev_state_dict_settings.optim_state_dict_config, + ) + + def forward(self, *args: Any, **kwargs: Any) -> Any: + """Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic.""" + handle = self._handle + with torch.autograd.profiler.record_function( + "FullyShardedDataParallel.forward" + ): + args, kwargs = _root_pre_forward(self, self, args, kwargs) + unused = None + args, kwargs = _pre_forward( + self, + handle, + _pre_forward_unshard, + self._fsdp_wrapped_module, + args, + kwargs, + ) + if handle: + _p_assert( + handle.flat_param.device == self.compute_device, + "Expected `FlatParameter` to be on the compute device " + f"{self.compute_device} but got {handle.flat_param.device}", + ) + output = self._fsdp_wrapped_module(*args, **kwargs) + return _post_forward( + self, handle, _post_forward_reshard, self, unused, output + ) + + @staticmethod + @contextlib.contextmanager + def summon_full_params( + module: nn.Module, + recurse: bool = True, + writeback: bool = True, + rank0_only: bool = False, + offload_to_cpu: bool = False, + with_grads: bool = False, + ) -> Generator: + r"""Expose full params for FSDP instances with this context manager. + + Can be useful *after* forward/backward for a model to get + the params for additional processing or checking. It can take a non-FSDP + module and will summon full params for all contained FSDP modules as + well as their children, depending on the ``recurse`` argument. + + .. note:: This can be used on inner FSDPs. + .. note:: This can *not* be used within a forward or backward pass. Nor + can forward and backward be started from within this context. + .. note:: Parameters will revert to their local shards after the context + manager exits, storage behavior is the same as forward. + .. note:: The full parameters can be modified, but only the portion + corresponding to the local param shard will persist after the + context manager exits (unless ``writeback=False``, in which case + changes will be discarded). In the case where FSDP does not shard + the parameters, currently only when ``world_size == 1``, or ``NO_SHARD`` + config, the modification is persisted regardless of ``writeback``. + .. note:: This method works on modules which are not FSDP themselves but + may contain multiple independent FSDP units. In that case, the given + arguments will apply to all contained FSDP units. + + .. warning:: Note that ``rank0_only=True`` in conjunction with + ``writeback=True`` is not currently supported and will raise an + error. This is because model parameter shapes would be different + across ranks within the context, and writing to them can lead to + inconsistency across ranks when the context is exited. + + .. warning:: Note that ``offload_to_cpu`` and ``rank0_only=False`` will + result in full parameters being redundantly copied to CPU memory for + GPUs that reside on the same machine, which may incur the risk of + CPU OOM. It is recommended to use ``offload_to_cpu`` with + ``rank0_only=True``. + + Args: + recurse (bool, Optional): recursively summon all params for nested + FSDP instances (default: True). + writeback (bool, Optional): if ``False``, modifications to params are + discarded after the context manager exits; + disabling this can be slightly more efficient (default: True) + rank0_only (bool, Optional): if ``True``, full parameters are + materialized on only global rank 0. This means that within the + context, only rank 0 will have full parameters and the other + ranks will have sharded parameters. Note that setting + ``rank0_only=True`` with ``writeback=True`` is not supported, + as model parameter shapes will be different across ranks + within the context, and writing to them can lead to + inconsistency across ranks when the context is exited. + offload_to_cpu (bool, Optional): If ``True``, full parameters are + offloaded to CPU. Note that this offloading currently only + occurs if the parameter is sharded (which is only not the case + for world_size = 1 or ``NO_SHARD`` config). It is recommended + to use ``offload_to_cpu`` with ``rank0_only=True`` to avoid + redundant copies of model parameters being offloaded to the same CPU memory. + with_grads (bool, Optional): If ``True``, gradients are also + unsharded with the parameters. Currently, this is only + supported when passing ``use_orig_params=True`` to the FSDP + constructor and ``offload_to_cpu=False`` to this method. + (Default: ``False``) + """ + with _unshard_params( + module, recurse, writeback, rank0_only, offload_to_cpu, with_grads + ): + yield + + @contextlib.contextmanager + def _deregister_orig_params_ctx(self): + """Deregister the original parameters and expose the :class:`FlatParameter`. + + If a :class:`FlatParameter` is sharded, then + this refreshes the sharded views before exiting. This method should + only be called when using the original parameters. + """ + _p_assert( + self._use_orig_params, + "`_deregister_orig_params_ctx()` should only be called when " + "`_use_orig_params=True`", + ) + for fsdp_module in traversal_utils._get_fsdp_states(self): + _deregister_orig_params(fsdp_module, fsdp_module) + try: + yield + finally: + for fsdp_module in traversal_utils._get_fsdp_states(self): + _register_orig_params(fsdp_module, fsdp_module) + + def _apply(self, *args, **kwargs): + """Deregister the original parameters and expose the :class:`FlatParameter` s before calling ``_apply()``.""" + # When using the original parameters: Since (1) the `FlatParameter`s + # own the storage and (2) `_apply()` is the subroutine underlying the + # most common storage-changing ops like `to()` and `cuda()`, we + # override `_apply()` to have the storage change directly performed on + # the `FlatParameter`s instead of applying to the original parameters + # and then writing back to the `FlatParameter`s. + context = ( + self._deregister_orig_params_ctx() + if self._use_orig_params + else contextlib.nullcontext() + ) + with context: + return super()._apply(*args, **kwargs) + + def named_buffers( + self, + *args, + **kwargs, + ) -> Iterator[tuple[str, torch.Tensor]]: + """Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself. + + Intercepts buffer names and removes all occurrences of the FSDP-specific flattened buffer prefix + when inside the :meth:`summon_full_params` context manager. + """ + should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS + for buffer_name, buffer in super().named_buffers(*args, **kwargs): + if should_clean_name: + # Remove any instances of the FSDP-specific prefix; there can + # be multiple in the case of nested FSDP modules + buffer_name = buffer_name.replace(FSDP_PREFIX, "") + yield (buffer_name, buffer) + + def named_parameters( + self, + *args, + **kwargs, + ) -> Iterator[tuple[str, torch.nn.Parameter]]: + """Return an iterator over module parameters, yielding both the name of the parameter and the parameter itself. + + Intercepts parameter names and removes all occurrences of the FSDP-specific flattened parameter prefix + when inside the :meth:`summon_full_params` context manager. + """ + should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS + for param_name, param in super().named_parameters(*args, **kwargs): + if should_clean_name: + # Remove any instances of the FSDP-specific prefix; there can + # be multiple in the case of nested FSDP modules + param_name = param_name.replace(FSDP_PREFIX, "") + yield (param_name, param) + + def _assert_state(self, state: Union[TrainingState, list[TrainingState]]) -> None: + """Assert we are in the given state.""" + # Since assert can be turned off and this error checking + # is really important, we use explicit error checking + # and raise a ValueError if needed. + if isinstance(state, TrainingState): + state = [state] + if self.training_state not in state: + msg = ( + f"expected to be in states {state} but current state " + f"is {self.training_state}" + ) + # In case we are failing in the context of autograd hook, asserting + # may not generate useful msg. So, let's print it to be sure. + if self.rank == 0: + print(f"Asserting FSDP instance is: {self}") + print(f"ERROR: {msg}") + traceback.print_stack() + raise ValueError(msg) + + @contextmanager + def no_sync(self) -> Generator: + """Disable gradient synchronizations across FSDP instances. + + Within this context, gradients will be accumulated in module + variables, which will later be synchronized in the first + forward-backward pass after exiting the context. This should only be + used on the root FSDP instance and will recursively apply to all + children FSDP instances. + + .. note:: This likely results in higher memory usage because FSDP will + accumulate the full model gradients (instead of gradient shards) + until the eventual sync. + + .. note:: When used with CPU offloading, the gradients will not be + offloaded to CPU when inside the context manager. Instead, they + will only be offloaded right after the eventual sync. + """ + _lazy_init(self, self) + if not self._is_root: + raise RuntimeError( + "`no_sync()` on inner FSDP instances is not supported. Please call `no_sync()` on root FSDP module." + ) + self._assert_state(TrainingState.IDLE) + old_flags = [] + for m in self.modules(): + if isinstance(m, FullyShardedDataParallel): + old_flags.append((m, m._sync_gradients)) + m._sync_gradients = False + try: + yield + finally: + for m, old_flag in old_flags: + if m._sync_gradients: + raise AssertionError( + "`_sync_gradients` was incorrectly set to " + "`True` while in the `no_sync()` context manager" + ) + m._sync_gradients = old_flag + + @torch.no_grad() + def clip_grad_norm_( + self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0 + ) -> torch.Tensor: + """Clip the gradient norm of all parameters. + + The norm is computed over all parameters' gradients as viewed as a single vector, and the + gradients are modified in-place. + + Args: + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` + for infinity norm. + + Returns: + Total norm of the parameters (viewed as a single vector). + + If every FSDP instance uses ``NO_SHARD``, meaning that no + gradients are sharded across ranks, then you may directly use + :func:`torch.nn.utils.clip_grad_norm_`. + + If at least some FSDP instance uses a sharded strategy (i.e. + one other than ``NO_SHARD``), then you should use this method + instead of :func:`torch.nn.utils.clip_grad_norm_` since this method + handles the fact that gradients are sharded across ranks. + + The total norm returned will have the "largest" dtype across + all parameters/gradients as defined by PyTorch's type promotion + semantics. For example, if *all* parameters/gradients use a low + precision dtype, then the returned norm's dtype will be that low + precision dtype, but if there exists at least one parameter/ + gradient using FP32, then the returned norm's dtype will be FP32. + + .. warning:: This needs to be called on all ranks since it uses + collective communications. + """ + _lazy_init(self, self) + if not self._is_root: + raise RuntimeError( + "`clip_grad_norm_()` should only be called on the root FSDP instance" + ) + if self._zero_scalar is None: + self._zero_scalar = torch.tensor(0.0, device=self.compute_device) + self._assert_state(TrainingState.IDLE) + # If every FSDP instance uses `NO_SHARD`, then we can directly use + # the normal `nn.utils` one targeting local gradients + all_no_shard = all( + not handle.uses_sharded_strategy for handle in self._all_handles + ) + if all_no_shard: + return torch.nn.utils.clip_grad_norm_( + self.parameters(), max_norm, norm_type + ) + # Otherwise, there exists some FSDP instance using a sharded strategy, + # where sharded and non-sharded parameters must be handled separately + max_norm = float(max_norm) + norm_type = float(norm_type) + sharded_params_set = set() + nonsharded_params_set = set() # `NO_SHARD` or not FSDP-managed + # Make sure to compute the local norm using lists for deterministic + # iteration order and hence deterministic total norm computation + sharded_params = [] + nonsharded_params = [] + grads: list[torch.Tensor] = [] + for handle in self._all_handles: + if handle.uses_sharded_strategy: + target_set = sharded_params_set + target_list = sharded_params + else: + target_set = nonsharded_params_set + target_list = nonsharded_params + if handle._use_orig_params: + for param in handle.flat_param._params: + if param not in target_set: + target_set.add(param) + target_list.append(param) + if param.grad is not None: + grads.append(param.grad) + else: + if handle.flat_param not in target_set: + target_set.add(handle.flat_param) + target_list.append(handle.flat_param) + if handle.flat_param.grad is not None: + grads.append(handle.flat_param.grad) + for param in self.parameters(): + not_fsdp_managed = ( + param not in sharded_params_set and param not in nonsharded_params_set + ) + if not_fsdp_managed: + nonsharded_params_set.add(param) + nonsharded_params.append(param) + if param.grad is not None: + grads.append(param.grad) + # Compute local norms (forced to be in FP32) + local_sharded_norm = _get_grad_norm( + sharded_params, norm_type, self._zero_scalar, self.compute_device + ) + local_nonsharded_norm = ( + _get_grad_norm( + nonsharded_params, norm_type, self._zero_scalar, self.compute_device + ) + if nonsharded_params + else None + ) + # Reconstruct the total gradient norm depending on the norm type + if norm_type == math.inf: + total_norm = ( + torch.maximum(local_sharded_norm, local_nonsharded_norm) + if local_nonsharded_norm is not None + else local_sharded_norm + ) + dist.all_reduce( + total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group + ) + else: + total_norm = local_sharded_norm**norm_type + dist.all_reduce(total_norm, group=self.process_group) + # All-reducing the local non-sharded norm would count it an extra + # world-size-many times + if local_nonsharded_norm is not None: + total_norm += local_nonsharded_norm**norm_type + total_norm = total_norm ** (1.0 / norm_type) + if self.cpu_offload.offload_params: + total_norm = total_norm.cpu() + + clip_coef = max_norm / (total_norm + 1e-6) + # Multiplying by the clamped coefficient is meaningless when it is + # equal to 1, but it avoids the host-device sync that would result from + # `if clip_coef < 1` + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for grad in grads: + grad.mul_(clip_coef_clamped.to(grad.device, grad.dtype)) + # Use the "largest" dtype by type promotion semantics to use the same + # dtype as if we did not force local norm computation to be in FP32 + if len(grads) == 0: + # If this rank has no gradients, then we must default to FP32 + # unless we use additional communication, which we prefer to avoid + # since `clip_grad_norm_()` is called in the training loop + warnings.warn( + f"Called FSDP.clip_grad_norm_() on rank {self.rank} with no " + "gradients -- returning the total norm in the default dtype " + f"{total_norm.dtype}", + stacklevel=2, + ) # warn since this is generally unexpected + return total_norm + total_norm_dtype = functools.reduce( + torch.promote_types, + [grad.dtype for grad in grads], + ) + return total_norm.to(total_norm_dtype) + + @staticmethod + def _warn_optim_input(optim_input, *, stacklevel: int = 1): + if optim_input is not None: + warnings.warn( + "The `optim_input` argument is deprecated and will be removed after PyTorch 1.13. " + "You may remove it from your code without changing its functionality.", + FutureWarning, + stacklevel=stacklevel + 1, + ) + + @staticmethod + def _is_using_optim_input(optim_input, optim) -> bool: + if optim_input is None and optim is None: + # Use the default behavior of `optim_input`` + return True + if optim_input is not None: + # Use the `optim_input` code path + return True + # Use the `optim` code path + return False + + @staticmethod + def _warn_legacy_optim_state_dict(curr: str, new: str, *, stacklevel: int = 1): + warnings.warn( + f"``FullyShardedDataParallel.{curr}``is being deprecated and is " + f"replaced by ``FullyShardedDataParallel.{new}``. " + f"``FullyShardedDataParallel.{curr}`` may be removed after PyTorch 2.2.", + FutureWarning, + stacklevel=stacklevel + 1, + ) + + @staticmethod + def _optim_state_dict_impl( + model: torch.nn.Module, + optim: torch.optim.Optimizer, + optim_state_dict: dict[str, Any], + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + rank0_only: bool = True, + full_state_dict: bool = True, + group: Optional[dist.ProcessGroup] = None, + cpu_offload: bool = True, + *, + _stacklevel: int = 1, + ) -> dict[str, Any]: + """Transform the state-dict of an optimizer corresponding to a sharded model. + + This is the internal API that is used by all the optim_state_dict implementations. + Given model, optim, the original optim_state_dict, this API removes the + FSDP internal information and internal sharding from the optim_state_dict. + """ + if full_state_dict: + FullyShardedDataParallel._warn_optim_input( + optim_input, stacklevel=_stacklevel + 1 + ) + using_optim_input = FullyShardedDataParallel._is_using_optim_input( + optim_input, + optim, + ) + else: + using_optim_input = False + if optim_input is not None or rank0_only: + raise AssertionError( + f"Expected optim_input to be None and rank0_only to be False, " + f"got optim_input={optim_input}, rank0_only={rank0_only}" + ) + + use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[ + 0 + ]._use_orig_params + if not all( + use_orig_params == m._use_orig_params + for m in FullyShardedDataParallel.fsdp_modules(model) + ): + raise AssertionError( + "Not all FSDP modules have the same _use_orig_params value" + ) + + return _optim_state_dict( + model=model, + optim=optim, + optim_state_dict=optim_state_dict, + optim_input=optim_input, + rank0_only=rank0_only, + shard_state=not full_state_dict, + group=group, + using_optim_input=using_optim_input, + use_orig_params=use_orig_params, + cpu_offload=cpu_offload, + ) + + @staticmethod + def _optim_state_dict_to_load_impl( + optim_state_dict: dict[str, Any], + model: torch.nn.Module, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + full_state_dict: bool = True, + rank0_only: bool = False, + is_named_optimizer: bool = False, + group: Optional[dist.ProcessGroup] = None, + ) -> dict[str, Any]: + """ + Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model. + + This is the internal API that is used by all the load optim_state_dict implementations. + Given model, optim, and the saved optim_state_dict, this API adds the FSDP + internal information and internal sharding to the optim_state_dict. + """ + if full_state_dict: + FullyShardedDataParallel._warn_optim_input(optim_input) + using_optim_input = FullyShardedDataParallel._is_using_optim_input( + optim_input, + optim, + ) + else: + using_optim_input = False + if optim_input is not None or rank0_only: + raise AssertionError( + f"Expected optim_input to be None and rank0_only to be False, " + f"got optim_input={optim_input}, rank0_only={rank0_only}" + ) + + use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[ + 0 + ]._use_orig_params + if not all( + use_orig_params == m._use_orig_params + for m in FullyShardedDataParallel.fsdp_modules(model) + ): + raise AssertionError( + "Not all FSDP modules have the same _use_orig_params value" + ) + + if rank0_only and dist.get_rank(group) > 0: + optim_state_dict = {} + sharded_osd = _flatten_optim_state_dict( + optim_state_dict, + model=model, + use_orig_params=use_orig_params, + optim=(optim if is_named_optimizer else None), + rank0_only=rank0_only, + group=group, + ) + return _rekey_sharded_optim_state_dict( + sharded_osd, + model=model, + optim=optim, + optim_input=optim_input, + using_optim_input=using_optim_input, + is_named_optimizer=is_named_optimizer, + ) + + @staticmethod + def full_optim_state_dict( + model: torch.nn.Module, + optim: torch.optim.Optimizer, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + rank0_only: bool = True, + group: Optional[dist.ProcessGroup] = None, + ) -> dict[str, Any]: + """Return the full optimizer state-dict. + + Consolidates the full optimizer state on rank 0 and returns it + as a :class:`dict` following the convention of + :meth:`torch.optim.Optimizer.state_dict`, i.e. with keys ``"state"`` + and ``"param_groups"``. The flattened parameters in ``FSDP`` modules + contained in ``model`` are mapped back to their unflattened parameters. + + This needs to be called on all ranks since it uses + collective communications. However, if ``rank0_only=True``, then + the state dict is only populated on rank 0, and all other ranks + return an empty :class:`dict`. + + Unlike ``torch.optim.Optimizer.state_dict()``, this method + uses full parameter names as keys instead of parameter IDs. + + Like in :meth:`torch.optim.Optimizer.state_dict`, the tensors + contained in the optimizer state dict are not cloned, so there may + be aliasing surprises. For best practices, consider saving the + returned optimizer state dict immediately, e.g. using + ``torch.save()``. + + Args: + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + were passed into the optimizer ``optim``. + optim (torch.optim.Optimizer): Optimizer for ``model`` 's + parameters. + optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): + Input passed into the optimizer ``optim`` representing either a + :class:`list` of parameter groups or an iterable of parameters; + if ``None``, then this method assumes the input was + ``model.parameters()``. This argument is deprecated, and there + is no need to pass it in anymore. (Default: ``None``) + rank0_only (bool): If ``True``, saves the populated :class:`dict` + only on rank 0; if ``False``, saves it on all ranks. (Default: + ``True``) + group (dist.ProcessGroup): Model's process group or ``None`` if using + the default process group. (Default: ``None``) + + Returns: + Dict[str, Any]: A :class:`dict` containing the optimizer state for + ``model`` 's original unflattened parameters and including keys + "state" and "param_groups" following the convention of + :meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=True``, + then nonzero ranks return an empty :class:`dict`. + """ + FullyShardedDataParallel._warn_legacy_optim_state_dict( + "full_optim_state_dict", + "optim_state_dict", + stacklevel=2, + ) + return FullyShardedDataParallel._optim_state_dict_impl( + model=model, + optim=optim, + optim_state_dict=optim.state_dict(), + optim_input=optim_input, + rank0_only=rank0_only, + group=group, + full_state_dict=True, + _stacklevel=2, + ) + + @staticmethod + def sharded_optim_state_dict( + model: torch.nn.Module, + optim: torch.optim.Optimizer, + group: Optional[dist.ProcessGroup] = None, + ) -> dict[str, Any]: + """Return the optimizer state-dict in its sharded form. + + The API is similar to :meth:`full_optim_state_dict` but this API chunks + all non-zero-dimension states to :class:`ShardedTensor` to save memory. + This API should only be used when the model ``state_dict`` is derived + with the context manager ``with state_dict_type(SHARDED_STATE_DICT):``. + + For the detailed usage, refer to :meth:`full_optim_state_dict`. + + .. warning:: The returned state dict contains ``ShardedTensor`` and + cannot be directly used by the regular ``optim.load_state_dict``. + """ + FullyShardedDataParallel._warn_legacy_optim_state_dict( + "sharded_optim_state_dict", + "optim_state_dict", + stacklevel=2, + ) + return FullyShardedDataParallel._optim_state_dict_impl( + model=model, + optim=optim, + optim_state_dict=optim.state_dict(), + optim_input=None, + rank0_only=False, + full_state_dict=False, + group=group, + _stacklevel=2, + ) + + @staticmethod + def shard_full_optim_state_dict( + full_optim_state_dict: dict[str, Any], + model: torch.nn.Module, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + ) -> dict[str, Any]: + """Shard a full optimizer state-dict. + + Remaps the state in ``full_optim_state_dict`` to flattened parameters instead of unflattened + parameters and restricts to only this rank's part of the optimizer state. + The first argument should be the return value of :meth:`full_optim_state_dict`. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> model, optim = ... + >>> full_osd = FSDP.full_optim_state_dict(model, optim) + >>> torch.save(full_osd, PATH) + >>> # Define new model with possibly different world size + >>> new_model, new_optim = ... + >>> full_osd = torch.load(PATH) + >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model) + >>> new_optim.load_state_dict(sharded_osd) + + .. note:: Both :meth:`shard_full_optim_state_dict` and + :meth:`scatter_full_optim_state_dict` may be used to get the + sharded optimizer state dict to load. Assuming that the full + optimizer state dict resides in CPU memory, the former requires + each rank to have the full dict in CPU memory, where each rank + individually shards the dict without any communication, while the + latter requires only rank 0 to have the full dict in CPU memory, + where rank 0 moves each shard to GPU memory (for NCCL) and + communicates it to ranks appropriately. Hence, the former has + higher aggregate CPU memory cost, while the latter has higher + communication cost. + + Args: + full_optim_state_dict (Dict[str, Any]): Optimizer state dict + corresponding to the unflattened parameters and holding the + full non-sharded optimizer state. + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + correspond to the optimizer state in ``full_optim_state_dict``. + optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): + Input passed into the optimizer representing either a + :class:`list` of parameter groups or an iterable of parameters; + if ``None``, then this method assumes the input was + ``model.parameters()``. This argument is deprecated, and there + is no need to pass it in anymore. (Default: ``None``) + optim (Optional[torch.optim.Optimizer]): Optimizer that will load + the state dict returned by this method. This is the preferred + argument to use over ``optim_input``. (Default: ``None``) + + Returns: + Dict[str, Any]: The full optimizer state dict now remapped to + flattened parameters instead of unflattened parameters and + restricted to only include this rank's part of the optimizer state. + """ + FullyShardedDataParallel._warn_legacy_optim_state_dict( + "shard_full_optim_state_dict", + "optim_state_dict_to_load", + stacklevel=2, + ) + return FullyShardedDataParallel._optim_state_dict_to_load_impl( + optim_state_dict=full_optim_state_dict, + model=model, + optim_input=optim_input, + optim=optim, + full_state_dict=True, + is_named_optimizer=False, + ) + + @staticmethod + def flatten_sharded_optim_state_dict( + sharded_optim_state_dict: dict[str, Any], + model: torch.nn.Module, + optim: torch.optim.Optimizer, + ) -> dict[str, Any]: + """Flatten a sharded optimizer state-dict. + + The API is similar to :meth:`shard_full_optim_state_dict`. The only + difference is that the input ``sharded_optim_state_dict`` should be + returned from :meth:`sharded_optim_state_dict`. Therefore, there will + be all-gather calls on each rank to gather ``ShardedTensor`` s. + + Args: + sharded_optim_state_dict (Dict[str, Any]): Optimizer state dict + corresponding to the unflattened parameters and holding the + sharded optimizer state. + model (torch.nn.Module): + Refer to :meth:`shard_full_optim_state_dict`. + optim (torch.optim.Optimizer): Optimizer for ``model`` 's + parameters. + + Returns: + Refer to :meth:`shard_full_optim_state_dict`. + """ + FullyShardedDataParallel._warn_legacy_optim_state_dict( + "flatten_sharded_optim_state_dict", + "optim_state_dict_to_load", + stacklevel=2, + ) + return FullyShardedDataParallel._optim_state_dict_to_load_impl( + optim_state_dict=sharded_optim_state_dict, + model=model, + optim_input=None, + optim=optim, + full_state_dict=False, + is_named_optimizer=False, + ) + + @staticmethod + def scatter_full_optim_state_dict( + full_optim_state_dict: Optional[dict[str, Any]], + model: torch.nn.Module, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + group: Optional[Any] = None, + ) -> dict[str, Any]: + """Scatter the full optimizer state dict from rank 0 to all other ranks. + + Returns the sharded optimizer state dict on each rank. + The return value is the same as :meth:`shard_full_optim_state_dict`, and on rank + 0, the first argument should be the return value of + :meth:`full_optim_state_dict`. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> model, optim = ... + >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 + >>> # Define new model with possibly different world size + >>> new_model, new_optim, new_group = ... + >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) + >>> new_optim.load_state_dict(sharded_osd) + + .. note:: Both :meth:`shard_full_optim_state_dict` and + :meth:`scatter_full_optim_state_dict` may be used to get the + sharded optimizer state dict to load. Assuming that the full + optimizer state dict resides in CPU memory, the former requires + each rank to have the full dict in CPU memory, where each rank + individually shards the dict without any communication, while the + latter requires only rank 0 to have the full dict in CPU memory, + where rank 0 moves each shard to GPU memory (for NCCL) and + communicates it to ranks appropriately. Hence, the former has + higher aggregate CPU memory cost, while the latter has higher + communication cost. + + Args: + full_optim_state_dict (Optional[Dict[str, Any]]): Optimizer state + dict corresponding to the unflattened parameters and holding + the full non-sharded optimizer state if on rank 0; the argument + is ignored on nonzero ranks. + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + correspond to the optimizer state in ``full_optim_state_dict``. + optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): + Input passed into the optimizer representing either a + :class:`list` of parameter groups or an iterable of parameters; + if ``None``, then this method assumes the input was + ``model.parameters()``. This argument is deprecated, and there + is no need to pass it in anymore. (Default: ``None``) + optim (Optional[torch.optim.Optimizer]): Optimizer that will load + the state dict returned by this method. This is the preferred + argument to use over ``optim_input``. (Default: ``None``) + group (dist.ProcessGroup): Model's process group or ``None`` if + using the default process group. (Default: ``None``) + + Returns: + Dict[str, Any]: The full optimizer state dict now remapped to + flattened parameters instead of unflattened parameters and + restricted to only include this rank's part of the optimizer state. + """ + FullyShardedDataParallel._warn_legacy_optim_state_dict( + "scatter_full_optim_state_dict", + "optim_state_dict_to_load", + stacklevel=2, + ) + return FullyShardedDataParallel._optim_state_dict_to_load_impl( + optim_state_dict=full_optim_state_dict, + model=model, + optim_input=optim_input, + optim=optim, + full_state_dict=True, + rank0_only=True, + is_named_optimizer=False, + group=group, + ) + + @staticmethod + def rekey_optim_state_dict( + optim_state_dict: dict[str, Any], + optim_state_key_type: OptimStateKeyType, + model: torch.nn.Module, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + ) -> dict[str, Any]: + """Re-keys the optimizer state dict ``optim_state_dict`` to use the key type ``optim_state_key_type``. + + This can be used to achieve compatibility between optimizer state dicts from models with FSDP + instances and ones without. + + To re-key an FSDP full optimizer state dict (i.e. from + :meth:`full_optim_state_dict`) to use parameter IDs and be loadable to + a non-wrapped model:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> wrapped_model, wrapped_optim = ... + >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) + >>> nonwrapped_model, nonwrapped_optim = ... + >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) + >>> nonwrapped_optim.load_state_dict(rekeyed_osd) + + To re-key a normal optimizer state dict from a non-wrapped model to be + loadable to a wrapped model:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> nonwrapped_model, nonwrapped_optim = ... + >>> osd = nonwrapped_optim.state_dict() + >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) + >>> wrapped_model, wrapped_optim = ... + >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) + >>> wrapped_optim.load_state_dict(sharded_osd) + + Returns: + Dict[str, Any]: The optimizer state dict re-keyed using the + parameter keys specified by ``optim_state_key_type``. + """ + FullyShardedDataParallel._warn_optim_input(optim_input) + using_optim_input = FullyShardedDataParallel._is_using_optim_input( + optim_input, + optim, + ) + if optim_state_key_type not in ( + OptimStateKeyType.PARAM_NAME, + OptimStateKeyType.PARAM_ID, + ): + raise AssertionError( + f"Expected optim_state_key_type to be PARAM_NAME or PARAM_ID, got {optim_state_key_type}" + ) + osd = optim_state_dict # alias + # Validate that the existing parameter keys are uniformly typed + uses_param_name_mask = [type(param_key) is str for param_key in osd["state"]] + uses_param_id_mask = [type(param_key) is int for param_key in osd["state"]] + if (any(uses_param_name_mask) and not all(uses_param_name_mask)) or ( + any(uses_param_id_mask) and not all(uses_param_id_mask) + ): + error_msg = f"Invalid parameter keys: {osd['state'].keys()}" + raise ValueError(error_msg) + # Return directly if the existing key type matches the target key type + if ( + optim_state_key_type == OptimStateKeyType.PARAM_NAME + and all(uses_param_name_mask) + ) or ( + optim_state_key_type == OptimStateKeyType.PARAM_ID + and all(uses_param_id_mask) + ): + return osd + # Otherwise, actually perform the re-keying + new_osd = {} + if optim_state_key_type == OptimStateKeyType.PARAM_NAME: # ID -> name + param_id_to_param = ( + _get_param_id_to_param_from_optim_input(model, optim_input) + if using_optim_input + else _get_param_key_to_param(optim) + ) + param_to_param_name = _get_param_to_fqn(model) + param_id_to_param_name: list[str] = [ + param_to_param_name[param] for param in param_id_to_param.values() + ] + new_osd["state"] = { + param_id_to_param_name[param_id]: param_state + for param_id, param_state in osd["state"].items() + } + new_osd["param_groups"] = copy.deepcopy(osd["param_groups"]) + for param_group in new_osd["param_groups"]: + param_group["params"] = sorted( + [ + param_id_to_param_name[param_id] + for param_id in param_group["params"] + ] + ) + return new_osd + elif optim_state_key_type == OptimStateKeyType.PARAM_ID: # name -> ID + param_name_to_param = _get_fqn_to_param(model) + param_to_param_id = ( + _get_param_to_param_id_from_optim_input(model, optim_input) + if using_optim_input + else _get_param_to_param_key(optim) + ) + # Because not all model parameters may be passed as the optimizer + # input, we may need to drop some parameters from this mapping + param_name_to_param_id = { + param_name: param_to_param_id[param] + for param_name, param in param_name_to_param.items() + if param in param_to_param_id + } + new_osd["state"] = { + param_name_to_param_id[param_name]: param_state + for param_name, param_state in osd["state"].items() + } + new_osd["param_groups"] = copy.deepcopy(osd["param_groups"]) + for param_group in new_osd["param_groups"]: + param_group["params"] = sorted( + [ + param_name_to_param_id[param_name] + for param_name in param_group["params"] + ] + ) + return new_osd + return new_osd # should never reach here + + @staticmethod + def optim_state_dict( + model: torch.nn.Module, + optim: torch.optim.Optimizer, + optim_state_dict: Optional[dict[str, Any]] = None, + group: Optional[dist.ProcessGroup] = None, + ) -> dict[str, Any]: + """ + Transform the state-dict of an optimizer corresponding to a sharded model. + + The given state-dict can be transformed to one of three types: + 1) full optimizer state_dict, 2) sharded optimizer state_dict, 3) local optimizer state_dict. + + For full optimizer state_dict, all states are unflattened and not sharded. + Rank0 only and CPU only can be specified via :meth:`state_dict_type` to + avoid OOM. + + For sharded optimizer state_dict, all states are unflattened but sharded. + CPU only can be specified via :meth:`state_dict_type` to further save + memory. + + For local state_dict, no transformation will be performed. But a state + will be converted from nn.Tensor to ShardedTensor to represent its sharding + nature (this is not supported yet). + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> from torch.distributed.fsdp import StateDictType + >>> from torch.distributed.fsdp import FullStateDictConfig + >>> from torch.distributed.fsdp import FullOptimStateDictConfig + >>> # Save a checkpoint + >>> model, optim = ... + >>> FSDP.set_state_dict_type( + >>> model, + >>> StateDictType.FULL_STATE_DICT, + >>> FullStateDictConfig(rank0_only=False), + >>> FullOptimStateDictConfig(rank0_only=False), + >>> ) + >>> state_dict = model.state_dict() + >>> optim_state_dict = FSDP.optim_state_dict(model, optim) + >>> save_a_checkpoint(state_dict, optim_state_dict) + >>> # Load a checkpoint + >>> model, optim = ... + >>> state_dict, optim_state_dict = load_a_checkpoint() + >>> FSDP.set_state_dict_type( + >>> model, + >>> StateDictType.FULL_STATE_DICT, + >>> FullStateDictConfig(rank0_only=False), + >>> FullOptimStateDictConfig(rank0_only=False), + >>> ) + >>> model.load_state_dict(state_dict) + >>> optim_state_dict = FSDP.optim_state_dict_to_load( + >>> model, optim, optim_state_dict + >>> ) + >>> optim.load_state_dict(optim_state_dict) + + Args: + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + were passed into the optimizer ``optim``. + optim (torch.optim.Optimizer): Optimizer for ``model`` 's + parameters. + optim_state_dict (Dict[str, Any]): the target optimizer state_dict to + transform. If the value is None, optim.state_dict() will be used. ( + Default: ``None``) + group (dist.ProcessGroup): Model's process group across which parameters + are sharded or ``None`` if using the default process group. ( + Default: ``None``) + + Returns: + Dict[str, Any]: A :class:`dict` containing the optimizer state for + ``model``. The sharding of the optimizer state is based on + ``state_dict_type``. + """ + state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model) + if optim_state_dict is None: + optim_state_dict = optim.state_dict() + return FullyShardedDataParallel._optim_state_dict_impl( + model=model, + optim=optim, + optim_state_dict=optim_state_dict, + optim_input=None, + rank0_only=getattr( + state_dict_settings.optim_state_dict_config, "rank0_only", False + ), + full_state_dict=state_dict_settings.state_dict_type + == StateDictType.FULL_STATE_DICT, + group=group, + cpu_offload=getattr( + state_dict_settings.optim_state_dict_config, "offload_to_cpu", True + ), + _stacklevel=2, + ) + + @staticmethod + def optim_state_dict_to_load( + model: torch.nn.Module, + optim: torch.optim.Optimizer, + optim_state_dict: dict[str, Any], + is_named_optimizer: bool = False, + load_directly: bool = False, + group: Optional[dist.ProcessGroup] = None, + ) -> dict[str, Any]: + """ + Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model. + + Given a ``optim_state_dict`` that is transformed through + :meth:`optim_state_dict`, it gets converted to the flattened optimizer + state_dict that can be loaded to ``optim`` which is the optimizer for + ``model``. ``model`` must be sharded by FullyShardedDataParallel. + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> from torch.distributed.fsdp import StateDictType + >>> from torch.distributed.fsdp import FullStateDictConfig + >>> from torch.distributed.fsdp import FullOptimStateDictConfig + >>> # Save a checkpoint + >>> model, optim = ... + >>> FSDP.set_state_dict_type( + >>> model, + >>> StateDictType.FULL_STATE_DICT, + >>> FullStateDictConfig(rank0_only=False), + >>> FullOptimStateDictConfig(rank0_only=False), + >>> ) + >>> state_dict = model.state_dict() + >>> original_osd = optim.state_dict() + >>> optim_state_dict = FSDP.optim_state_dict( + >>> model, + >>> optim, + >>> optim_state_dict=original_osd + >>> ) + >>> save_a_checkpoint(state_dict, optim_state_dict) + >>> # Load a checkpoint + >>> model, optim = ... + >>> state_dict, optim_state_dict = load_a_checkpoint() + >>> FSDP.set_state_dict_type( + >>> model, + >>> StateDictType.FULL_STATE_DICT, + >>> FullStateDictConfig(rank0_only=False), + >>> FullOptimStateDictConfig(rank0_only=False), + >>> ) + >>> model.load_state_dict(state_dict) + >>> optim_state_dict = FSDP.optim_state_dict_to_load( + >>> model, optim, optim_state_dict + >>> ) + >>> optim.load_state_dict(optim_state_dict) + + Args: + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + were passed into the optimizer ``optim``. + optim (torch.optim.Optimizer): Optimizer for ``model`` 's + parameters. + optim_state_dict (Dict[str, Any]): The optimizer states to be loaded. + is_named_optimizer (bool): Is this optimizer a NamedOptimizer or + KeyedOptimizer. Only set to True if ``optim`` is TorchRec's + KeyedOptimizer or torch.distributed's NamedOptimizer. + load_directly (bool): If this is set to True, this API will also + call optim.load_state_dict(result) before returning the result. + Otherwise, users are responsible to call ``optim.load_state_dict()`` + (Default: ``False``) + group (dist.ProcessGroup): Model's process group across which parameters + are sharded or ``None`` if using the default process group. ( + Default: ``None``) + """ + state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model) + result = FullyShardedDataParallel._optim_state_dict_to_load_impl( + optim_state_dict=optim_state_dict, + model=model, + optim_input=None, + optim=optim, + full_state_dict=( + state_dict_settings.state_dict_type == StateDictType.FULL_STATE_DICT + ), + rank0_only=getattr( + state_dict_settings.optim_state_dict_config, "rank0_only", False + ), + is_named_optimizer=is_named_optimizer, + group=group, + ) + if load_directly: + optim.load_state_dict(result) + return result + + def register_comm_hook(self, state: object, hook: callable): + """Register a communication hook. + + This is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregates + gradients across multiple workers. + This hook can be used to implement several algorithms like + `GossipGrad `_ and gradient compression + which involve different communication strategies for + parameter syncs while training with :class:`FullyShardedDataParallel`. + + .. warning :: + FSDP communication hook should be registered before running an initial forward pass + and only once. + + Args: + state (object): Passed to the hook to maintain any state information during the training process. + Examples include error feedback in gradient compression, + peers to communicate with next in `GossipGrad `_, etc. + It is locally stored by each worker + and shared by all the gradient tensors on the worker. + hook (Callable): Callable, which has one of the following signatures: + 1) ``hook: Callable[torch.Tensor] -> None``: + This function takes in a Python tensor, which represents + the full, flattened, unsharded gradient with respect to all variables + corresponding to the model this FSDP unit is wrapping + (that are not wrapped by other FSDP sub-units). + It then performs all necessary processing and returns ``None``; + 2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``: + This function takes in two Python tensors, the first one represents + the full, flattened, unsharded gradient with respect to all variables + corresponding to the model this FSDP unit is wrapping + (that are not wrapped by other FSDP sub-units). The latter + represents a pre-sized tensor to store a chunk of a sharded gradient after + reduction. + In both cases, callable performs all necessary processing and returns ``None``. + Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case. + Callables with signature 2 are expected to handle gradient communication for sharded cases. + + """ + if not self.check_is_root(): + raise AssertionError( + "register_comm_hook can only be called on a root instance." + ) + for fsdp_state in traversal_utils._get_fsdp_states(self): + if fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES: + raise AssertionError( + f"Communication hook is not supported for hybrid strategies: {fsdp_state.sharding_strategy}" + ) + if fsdp_state._comm_hook is not None: + raise AssertionError("A communication hook is already registered") + if not callable(hook): + raise ValueError( + f"The communication hook must be callable but got {hook}" + ) + fsdp_state._comm_hook = hook + fsdp_state._comm_hook_state = state + + def _unshard(self, async_op: bool = False): + class UnshardHandle: + def __init__( + self, + flat_param_handle: Optional[FlatParamHandle], + unshard_event: torch.Event, + ): + self._flat_param_handle = flat_param_handle + self._unshard_event = unshard_event + + def wait(self): + if self._flat_param_handle is not None: + current_stream = ( + self._flat_param_handle._device_handle.current_stream() + ) + current_stream.wait_event(self._unshard_event) + self._flat_param_handle = None + + if self._handle: + with self._use_training_state( + TrainingState.FORWARD_BACKWARD, HandleTrainingState.FORWARD + ): + _unshard( + self, self._handle, self._unshard_stream, self._pre_unshard_stream + ) + self._unshard_event = self._unshard_stream.record_event() + self._handle._prefetched = True + unshard_handle = UnshardHandle(self._handle, self._unshard_stream) + if async_op: + return unshard_handle + unshard_handle.wait() + return None + + def _wait_unshard_streams_on_current_stream(self): + _wait_for_computation_stream( + self._device_handle.current_stream(), + self._unshard_stream, + self._pre_unshard_stream, + ) + + @contextlib.contextmanager + def _use_training_state( + self, training_state: TrainingState, handle_training_state: HandleTrainingState + ): + prev_training_state = self.training_state + self.training_state = training_state + if self._handle: + prev_handle_training_state = self._handle._training_state + self._handle._training_state = handle_training_state + try: + yield + finally: + self.training_state = prev_training_state + if self._handle: + self._handle._training_state = prev_handle_training_state + + +def _get_grad_norm( + params: Iterable[nn.Parameter], + norm_type: float, + zero: torch.Tensor, + device: torch.device, +) -> torch.Tensor: + """ + Return the gradient norm of parameters ``param`` s, where the gradients are viewed as a single vector. + + The returned norm is in FP32 even if parameters/gradients are in a low precision. This is because the downstream + use of this return value is a reduction across ranks. + """ + params_with_grad = [param for param in params if param.grad is not None] + if len(params_with_grad) == 0: + # Reuse a tensor for zero to avoid a GPU sync + return zero + grads = [param.grad for param in params_with_grad] + grad_dtypes = {grad.dtype for grad in grads} + if len(grad_dtypes) != 1: + raise ValueError( + f"Requires uniform dtype across all gradients but got {grad_dtypes}" + ) + # Compute the gradient norm in FP32, where we treat the gradients as a + # single vector + grad_norm = torch.linalg.vector_norm( + torch.stack( + [ + torch.linalg.vector_norm(grad.detach(), norm_type, dtype=torch.float32) + for grad in grads + ], + ), + norm_type, + dtype=torch.float32, + ) + return grad_norm.to(device=device) + + +def _get_param_to_fqn( + model: torch.nn.Module, +) -> dict[torch.nn.Parameter, str]: + """ + Construct a mapping from parameters to their parameter names. + + The ``model`` should not contain any :class:`FullyShardedDataParallel` instances, which + means that none of the parameters should be ``FlatParameter`` s. As a + result, compared to :meth:`_get_param_to_fqns`, the mapped + values may be flattened from singleton :class:`list` s to the contained + names themselves. + + Args: + model (torch.nn.Module): Root module, which should not contain any + :class:`FullyShardedDataParallel` instances. + """ + param_to_param_names = _get_param_to_fqns(model) + for param_names in param_to_param_names.values(): + if len(param_names) == 0: + raise AssertionError( + "`_get_param_to_fqns()` should not construct empty lists" + ) + if len(param_names) > 1: + raise RuntimeError( + "Each parameter should only map to one parameter name but got " + f"{len(param_names)}: {param_names}" + ) + param_to_param_name = { + param: param_names[0] for param, param_names in param_to_param_names.items() + } + return param_to_param_name + + +def _get_fqn_to_param( + model: torch.nn.Module, +) -> dict[str, torch.nn.Parameter]: + """Construct the inverse mapping of :meth:`_get_param_to_fqn`.""" + param_to_param_name = _get_param_to_fqn(model) + return dict(zip(param_to_param_name.values(), param_to_param_name.keys())) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/sharded_grad_scaler.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/sharded_grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..3986d733328c80f12e6eed138386a9e8aafe6a3a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/sharded_grad_scaler.py @@ -0,0 +1,377 @@ +# mypy: allow-untyped-defs +import logging +from collections import abc, defaultdict +from collections.abc import Iterable +from typing import Any, Optional, overload, Union + +import torch +import torch.distributed as dist +from torch.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState +from torch.distributed.distributed_c10d import ProcessGroup + + +logger = logging.getLogger(__name__) + + +def _refresh_per_optimizer_state() -> dict[str, Any]: + return {"stage": OptState.READY, "found_inf_per_device": {}} + + +def _is_supported_device(tensor: torch.Tensor) -> bool: + return tensor.is_cuda or tensor.device.type in ( + "xla", + "cpu", + "hpu", + "mtia", + "xpu", + torch._C._get_privateuse1_backend_name(), + ) + + +class _GeneralMultiDeviceReplicator(_MultiDeviceReplicator): + """ + Lazily serves tensor to request device. This class extends + _MultiDeviceReplicator to allow support for "cpu" as a device. + """ + + def __init__(self, master_tensor: torch.Tensor) -> None: + if not _is_supported_device(master_tensor): + raise AssertionError( + f"Expected supported device, got {master_tensor.device}" + ) + self.master = master_tensor + self._per_device_tensors: dict[torch.device, torch.Tensor] = {} + + +class ShardedGradScaler(GradScaler): + """ + ShardedGradScaler helps perform gradient scaling in a shard aware manner. It extends + functionality from GradScaler: + * Supports Pytorch DDP and FSDP implementations + * Support CPU offloaded tensors (as used in fully sharded data parallel[FSDP]) + * Supports the custom Mixed Precision loss dtype (fp16, bf16) that FSDP returns + * Sync inf/nan for scaled gradient tensors on any torch.device (where tensors are placed) across + nodes + + Example:: + + # Creates a ShardedGradScaler once at the beginning of training. + scaler = ShardedGradScaler() + + for epoch in epochs: + for input, target in data: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + + # Scales loss. Calls backward() on scaled loss to create scaled gradients. + scaler.scale(loss).backward() + + # scaler.step() first unscales gradients of the optimizer's params. + # If gradients don't contain infs/NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) + + # Updates the scale for next iteration. + scaler.update() + + See :class:`GradScaler` for explanation of scaling/unscaling and more use cases. + + Args: + init_scale (float, optional, default=2.**16): Initial scale factor. + growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during + :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. + backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during + :meth:`update` if inf/NaN gradients occur in an iteration. + growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients + that must occur for the scale to be multiplied by ``growth_factor``. + enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply + invokes the underlying ``optimizer.step()``, and other methods become no-ops. + Default: ``True`` + process_group (ProcessGroup, optional, default=torch.distributed.group.WORLD): + process group for sharding + """ + + def __init__( + self, + device: str = "cuda", + init_scale: float = 2.0**16, + backoff_factor: float = 0.5, + growth_factor: float = 2.0, + growth_interval: int = 2000, + enabled: bool = True, + process_group: Optional[ProcessGroup] = dist.group.WORLD, + ) -> None: + super().__init__( + device, + init_scale=init_scale, + backoff_factor=backoff_factor, + growth_factor=growth_factor, + growth_interval=growth_interval, + enabled=enabled, + ) + if self._enabled: + self.process_group = process_group + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + @overload + def scale(self, outputs: torch.Tensor) -> torch.Tensor: ... + + @overload + def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]: ... + + @overload + def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: ... + + @overload + def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: ... + + def scale( + self, outputs: Union[torch.Tensor, Iterable[torch.Tensor]] + ) -> Union[torch.Tensor, Iterable[torch.Tensor]]: + if not self._enabled: + return outputs + + if isinstance(outputs, torch.Tensor): + if not _is_supported_device(outputs): + raise AssertionError(f"Expected supported device, got {outputs.device}") + if self._scale is None: + self._lazy_init_scale_growth_tracker(outputs.device) + if self._scale is None: + raise AssertionError("Expected _scale to be initialized, got None") + scaled_output = outputs * self._scale.to( + device=outputs.device, non_blocking=True + ) + # Here we ensure the return dtype is the same as the outputs dtype. + # For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision + # format (fp16, bf16) and so the scaled loss should be of the same dtype. + return scaled_output.type(outputs.dtype) + + stash: list[_GeneralMultiDeviceReplicator] = [] + + def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]): + if isinstance(val, torch.Tensor): + if not _is_supported_device(val): + raise AssertionError(f"Expected supported device, got {val.device}") + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + if self._scale is None: + raise AssertionError( + "Expected _scale to be initialized, got None" + ) + stash.append(_GeneralMultiDeviceReplicator(self._scale)) + scaled_val = val * stash[0].get(val.device) + # Here we ensure the return dtype is the same as the outputs dtype. + # For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision + # format (fp16, bf16) and so the scaled loss should be of the same dtype. + return scaled_val.type(val.dtype) + if isinstance(val, abc.Iterable): + iterator = map(apply_scale, val) + if isinstance(val, (list, tuple)): + return type(val)(iterator) + return iterator + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) + + def _unscale_grads_( + self, + optimizer: torch.optim.Optimizer, + inv_scale: torch.Tensor, + found_inf: torch.Tensor, + allow_fp16: bool = True, + ) -> dict[torch.device, torch.Tensor]: + per_device_inv_scale = _GeneralMultiDeviceReplicator(inv_scale) + per_device_found_inf = _GeneralMultiDeviceReplicator(found_inf) + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be thousands of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] + with torch.no_grad(): + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue + if (not allow_fp16) and param.grad.dtype == torch.float16: + raise ValueError("Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is torch.float16: + # coalesce is not supported in torch.float16 + param_grad_fp32 = param.grad.type(torch.float32).coalesce() + param.grad = param_grad_fp32.type(torch.float16) + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + per_device_and_dtype_grads[to_unscale.device][ + to_unscale.dtype + ].append(to_unscale) + + for device, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._amp_foreach_non_finite_check_and_unscale_( + grads, + per_device_found_inf.get(device), + per_device_inv_scale.get(device), + ) + # There exist contexts (e.g. w/ `use_orig_params=True`) wherein some + # ranks may have no (non-zero sized) parameter shards, necessitating the + # initialization of `per_device_found_inf._per_device_tensors` here + if not per_device_found_inf._per_device_tensors: + if self._scale is None: + raise AssertionError("Expected _scale to be initialized, got None") + per_device_found_inf.get(self._scale.device) + return per_device_found_inf._per_device_tensors + + def unscale_(self, optimizer: torch.optim.Optimizer) -> None: + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.UNSCALED: + raise RuntimeError( + "unscale_() has already been called on this optimizer since the last update()." + ) + elif optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + if self._scale is None: + raise AssertionError("Expected _scale to be initialized, got None") + inv_scale = self._scale.double().reciprocal().float() + found_inf = torch.full( + (1,), 0.0, dtype=torch.float32, device=self._scale.device + ) + + optimizer_state["found_inf_per_device"] = self._unscale_grads_( + optimizer, inv_scale, found_inf, True + ) + optimizer_state["stage"] = OptState.UNSCALED + + # Synchronize the detected inf across the ranks + optimizer_state = self._per_optimizer_states[id(optimizer)] + works = [] + found_inf_on_cpus = [] + found_inf_on_devices = [] + + for found_inf in optimizer_state["found_inf_per_device"].values(): + if self._device != "cpu" and found_inf.device.type == "cpu": + found_inf_on_cpus.append(found_inf) + found_inf_on_device = found_inf.to(self._device) + found_inf_on_devices.append(found_inf_on_device) + works.append( + dist.all_reduce( + found_inf_on_device, async_op=True, group=self.process_group + ) + ) + else: + works.append( + dist.all_reduce(found_inf, async_op=True, group=self.process_group) + ) + for work in works: + work.wait() + if found_inf_on_cpus: + torch._foreach_copy_(found_inf_on_cpus, found_inf_on_devices) + + def _amp_update_scale_cpu_(self, found_inf: torch.Tensor) -> None: + """ + If found_inf is 1.0 (True), then scale is multiplied by backoff_factor and growth_tracker is set to zero. + Otherwise, scale is multiplied by the growth factor when the growth interval is reached. + """ + if self._scale is None or self._growth_tracker is None: + raise AssertionError( + "Expected _scale and _growth_tracker to be initialized, got None" + ) + + if found_inf.item() >= 1.0: + self._scale *= self._backoff_factor + self._growth_tracker.fill_(0) + else: + successful = self._growth_tracker + 1 + if successful == self._growth_interval: + self._scale *= self._growth_factor + self._growth_tracker.fill_(0) + else: + self._growth_tracker = successful + + def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None: + """ + Updates the scale factor. + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + Args: + new_scale (float or :class:`torch.Tensor`, optional, default=None): New scale factor. + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + """ + + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") # type: ignore[var-annotated] + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) # type: ignore[union-attr] + else: + reason = ( + "new_scale should be a float or a 1-element torch.cuda.FloatTensor or " + "torch.FloatTensor with requires_grad=False." + ) + if new_scale.device.type != self._device: + raise AssertionError(reason) + if new_scale.numel() != 1: + raise AssertionError(reason) + if new_scale.requires_grad is not False: + raise AssertionError(reason) + self._scale.copy_(new_scale) # type: ignore[union-attr] + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [ + found_inf.to(device=_scale.device, non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values() + ] + + if len(found_infs) == 0: + raise AssertionError("No inf checks were recorded prior to update.") + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + if _scale.device.type == "cpu": + self._amp_update_scale_cpu_(found_inf_combined) + else: + torch._amp_update_scale_( + self._scale, # type: ignore[arg-type] + self._growth_tracker, # type: ignore[arg-type] + found_inf_combined, + self._growth_factor, # type: ignore[arg-type] + self._backoff_factor, # type: ignore[arg-type] + self._growth_interval, # type: ignore[arg-type] + ) + + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py new file mode 100644 index 0000000000000000000000000000000000000000..f731854dab2eb475e7c8321738552fed205db70d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py @@ -0,0 +1,608 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import copy +from abc import ABC, abstractmethod +from collections.abc import Callable, Generator, Iterable, Sequence +from typing import Any, cast, Optional, Union + +import torch.nn as nn + + +__all__ = [ + "always_wrap_policy", + "lambda_auto_wrap_policy", + "transformer_auto_wrap_policy", + "size_based_auto_wrap_policy", + "enable_wrap", + "wrap", + "CustomPolicy", + "ModuleWrapPolicy", +] + + +# NOTE: We intentionally keep this function simple and isolate the complexity +# to `fn` to enable using this function generically. We may move this to a +# non-FSDP-specific folder and/or make it public in the future. +def _post_order_apply( + root_module: nn.Module, + fn: Callable[[nn.Module], Optional[nn.Module]], +): + """ + This applies ``fn`` to every module in the module tree of ``root_module`` + following a post-order traversal. If ``fn`` returns an :class:`nn.Module`, + then this replaces the original module with the newly returned one in the + tree. Otherwise, ``fn`` should return ``None``, in which case the module is + not changed. + """ + # Track visited modules to avoid visiting shared modules multiple times + visited_modules: set[nn.Module] = {root_module} + + def _post_order_apply_inner( + module: nn.Module, + module_name: str, + parent_module: Optional[nn.Module], + ): + for child_module_name, child_module in module.named_children(): + if child_module not in visited_modules: + visited_modules.add(child_module) + _post_order_apply_inner(child_module, child_module_name, module) + optional_module = fn(module) + if optional_module is not None: + if not isinstance(parent_module, nn.Module): + raise AssertionError( + "Non-root modules should have their parent module set but got " + f"{parent_module} for {module}" + ) + if not module_name: + raise AssertionError( + "Non-root modules should have their module name set but got " + f"an empty module name for {module}" + ) + if not isinstance(optional_module, nn.Module): + raise AssertionError( + f"fn should return None or an nn.Module but got {optional_module}" + ) + setattr(parent_module, module_name, optional_module) + + _post_order_apply_inner(root_module, "", None) + + +def _construct_wrap_fn( + root_module: nn.Module, + target_module_to_kwargs: dict[nn.Module, dict[str, Any]], + fsdp_fn: Callable, +) -> Callable[[nn.Module], Optional[nn.Module]]: + """ + This constructs the "wrap" function to pass to :func:`_post_order_apply` + based on ``target_module_to_kwargs``, which should be constructed from the + wrapping policy. + """ + + def fn(module: nn.Module) -> Optional[nn.Module]: + # Explicitly avoid wrapping the root module since for FSDP, it is + # handled by the caller + if module in target_module_to_kwargs and module is not root_module: + kwargs = target_module_to_kwargs[module] + return fsdp_fn(module, **kwargs) + return None + + return fn + + +def _run_mixed_precision_override_policy( + root_module: nn.Module, + module_classes: Iterable[type[nn.Module]], + ignored_modules: set[nn.Module], + root_kwargs: dict[str, Any], + target_module_to_kwargs: dict[nn.Module, dict[str, Any]], +): + module_classes_tuple = tuple(set(module_classes)) + for module in root_module.modules(): + if module in ignored_modules: + continue + elif isinstance(module, module_classes_tuple): + # This policy overrides any existing policy + if module not in target_module_to_kwargs: + # Only inherit from the root kwargs if not already specified + target_module_to_kwargs[module] = root_kwargs + target_module_to_kwargs[module]["mixed_precision"] = None + return target_module_to_kwargs + + +def always_wrap_policy(*args, **kwargs) -> bool: + """ + A simple recursive wrap policy that always returns ``True``. This means + that every submodule is wrapped by the wrapper class in + :func:`_recursive_wrap`. + """ + return True + + +class _Policy(ABC): + """ + This defines an abstract base class that represents a policy for applying + a module-level API. + """ + + @abstractmethod + def _run_policy( + self, + root_module: nn.Module, + ignored_modules: set[nn.Module], + root_kwargs: dict[str, Any], + ) -> dict[nn.Module, dict[str, Any]]: + """ + This should return a dict ``target_module_to_kwargs`` that maps from + each target module to wrap to its kwargs. + """ + ... + + +def _module_wrap_policy( + module: nn.Module, + recurse: bool, + nonwrapped_numel: int, + module_classes: set[type[nn.Module]], +) -> bool: + """ + This auto wrap policy wraps every module that is an instance of any type in + ``module_classes`` as its own FSDP instance. The root module given by + ``module`` is always wrapped as an FSDP instance regardless. Since the + wrapping proceeds bottom up, each FSDP instance manages the parameters in + its subtree excluding any already managed by a child FSDP instance. + + Args: + module (nn.Module): Current module being considered. + recurse (bool): If ``False``, then this function must decide whether + ``module`` should be wrapped as an FSDP instance or not. If + ``True``, then the function is still recursing down the module + tree as a part of the DFS. + nonwrapped_numel (int): Parameter numel not yet wrapped. + module_classes (Set[Type[nn.Module]]): Set of module classes that are + wrapped as FSDP instances. + + Returns: + ``True`` if ``recurse=True``, and whether ``module`` should be wrapped + if ``recurse=False``. + """ + if recurse: + return True # always recurse + return isinstance(module, tuple(module_classes)) + + +class ModuleWrapPolicy(_Policy): + """ + This policy applies to every module of the specified module classes, + passing in the kwargs given to the root. + """ + + def __init__(self, module_classes: Iterable[type[nn.Module]]): + module_classes_set = set(module_classes) + self._module_classes = module_classes_set + self._module_classes_str = str(module_classes_set) + + def _run_policy( + self, + root_module: nn.Module, + ignored_modules: set[nn.Module], + root_kwargs: dict[str, Any], + ) -> dict[nn.Module, dict[str, Any]]: + module_classes = tuple(self._module_classes) + target_module_to_kwargs: dict[nn.Module, dict[str, Any]] = {} + for module in root_module.modules(): + if module in ignored_modules: + continue + elif isinstance(module, module_classes): + # Shallow copy to avoid coupling changes across modules + target_module_to_kwargs[module] = copy.copy(root_kwargs) + return target_module_to_kwargs + + def __call__(self, module, recurse, *args, **kwargs): + # nonwrapped_numel is not used. + return _module_wrap_policy( + module, recurse, nonwrapped_numel=-1, module_classes=self._module_classes + ) + + def __repr__(self) -> str: + return super().__repr__() + f"({self._module_classes_str})" + + +class CustomPolicy(_Policy): + """ + This policy takes in a lambda function that maps a given ``nn.Module`` to + either ``False``, ``True``, or a kwarg dictionary. + - If the function returns ``False`` or an empty dictionary, then the module + does not have the API applied. + - If the function returns ``True``, then the module has the API applied + with the root's kwargs. + - If the function returns a non-empty dictionary, then the module has the + API applied, and the dictionary overrides the root's kwargs. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> model = init_transformer_model(...) + >>> def lambda_fn(module: nn.Module): + >>> if module is model.lm_head: + >>> return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP} + >>> elif isinstance(module, TransformerBlock): + >>> return True + >>> return False + >>> policy = CustomPolicy(lambda_fn) + >>> fsdp_model = FSDP(model, auto_wrap_policy=policy) + """ + + def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, dict[str, Any]]]): + self._lambda_fn = lambda_fn + + def _run_policy( + self, + root_module: nn.Module, + ignored_modules: set[nn.Module], + root_kwargs: dict[str, Any], + ) -> dict[nn.Module, dict[str, Any]]: + target_module_to_kwargs: dict[nn.Module, dict[str, Any]] = {} + for module in root_module.modules(): + if module in ignored_modules: + continue + res = self._lambda_fn(module) + if not isinstance(res, (dict, bool)): + raise ValueError( + "The lambda_fn passed to CustomPolicy should return " + f"False/True or a kwarg dict, but it returned {res}" + ) + if not res: + continue + kwargs = copy.copy(root_kwargs) + if isinstance(res, dict): + # Override the root kwargs with the ones specified by the + # lambda function + kwargs.update(res) + target_module_to_kwargs[module] = kwargs + return target_module_to_kwargs + + +def lambda_auto_wrap_policy( + module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable +) -> bool: + """ + A convenient auto wrap policy to wrap submodules based on an arbitrary user + function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as + a `wrapper_cls` unit. + + Return if a module should be wrapped during auto wrapping. + + The first three parameters are required by :func:`_recursive_wrap`. + + Args: + module (nn.Module): Current module being considered. + recurse (bool): If ``False``, then this function must decide whether + ``module`` should be wrapped as an FSDP instance or not. If + ``True``, then the function is still recursing down the module + tree as a part of the DFS. + nonwrapped_numel (int): Parameter numel not yet wrapped. + + lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then + this module will be wrapped. + """ + if recurse: + return True # always recurse + return lambda_fn(module) + + +def transformer_auto_wrap_policy( + module: nn.Module, + recurse: bool, + nonwrapped_numel: int, + transformer_layer_cls: set[type[nn.Module]], +) -> bool: + """ + See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the + same as ``module_classes``. Note that shared parameters must be wrapped in + the same FSDP instance, so this auto wrap policy can help wrap shared + embeddings into the same FSDP instance for transformer models. + """ + return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls) + + +def _wrap_module_cls_individually( + module: nn.Module, module_classes: Sequence[type], recurse: bool, *args, **kwargs +): + if recurse: + # always recurse + return True + else: + # if not recursing, decide whether we should wrap based on whether the type of module + # is in `module_classes`. + return isinstance(module, tuple(module_classes)) + + +def _or_policy( + module: nn.Module, + recurse: bool, + nonwrapped_numel: int, + policies, +) -> bool: + """ + A policy that wraps ``module`` if any policy in the passed in iterable of + ``policies`` returns ``True``. + """ + return any( + policy(module=module, recurse=recurse, nonwrapped_numel=nonwrapped_numel) + for policy in policies + ) + + +def size_based_auto_wrap_policy( + module: nn.Module, + recurse: bool, + nonwrapped_numel: int, + # Additional custom arguments + min_num_params: int = int(1e8), + force_leaf_modules: Optional[set[type[nn.Module]]] = None, + exclude_wrap_modules: Optional[set[type[nn.Module]]] = None, +) -> bool: + """ + A size-based auto wrap policy. + + Args: + module (nn.Module): Current module being considered. + recurse (bool): If ``False``, then this function must decide whether + ``module`` should be wrapped as an FSDP instance or not. If + ``True``, then the function is still recursing down the module + tree as a part of the DFS. + nonwrapped_numel (int): Parameter numel not yet wrapped. + + min_num_params (int): Customizable policy input that controls the size + threshold over which a module is ready to be wrapped. This is in + units of numel. + force_leaf_modules (Optional[set[type[nn.Module]]]): Set of module types to keep + as leaves, i.e. their children will never be wrapped. + exclude_wrap_modules (Optional[set[type[nn.Module]]]): Set of module types to be + excluded in wrapping. + + Returns: + Whether ``module`` should be wrapped. + """ + force_leaf_modules = ( + size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined] + if force_leaf_modules is None + else force_leaf_modules + ) + exclude_wrap_modules = ( + size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore[attr-defined] + if exclude_wrap_modules is None + else exclude_wrap_modules + ) + + # Keep the argument `min_num_params` for BC for now, but it represents the + # minimum non-wrapped *numel* before triggering a wrapping + min_nonwrapped_numel = min_num_params + is_large = nonwrapped_numel >= min_nonwrapped_numel + if recurse: + # We should recurse if the module is big enough but not in force_leaf_modules list. + return is_large and not isinstance(module, tuple(force_leaf_modules)) + else: + # If we are not recursing, determine if we should wrap. + return is_large and not isinstance(module, tuple(exclude_wrap_modules)) + + +# Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported. +size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore[attr-defined] +size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore[attr-defined] + + +@contextlib.contextmanager +def enable_wrap( + *, wrapper_cls: Any, **wrapper_kwargs: Any +) -> Generator[None, None, None]: + """ + Context manager to wrap modules using a wrapper. + + Useful for when you'd like to apply the same configuration arguments to all + child modules that you wrap. A particularly important use case is wrapping + large layers so that they get sharded (in-place) during initialization, to + avoid running out of system memory. Large layers can indicate that they + should be sharded via the ``wrap`` annotation and this context manager can + provide the exact configuration for these nested instances. + + Usage:: + + with enable_wrap(wrapper_cls, **params): + # Wraps layer in FSDP by default if within context + self.l1 = wrap(torch.nn.Linear(5, 5)) + + Args: + wrapper_cls: + Class that `wrap` annotation will `wrap` modules with, such as + `FullyShardedDataParallel`. + **wrapper_kwargs: + Configuration settings that will be passed to all ``wrap`` + instances inside the context + """ + kwargs = { + "wrapper_cls": wrapper_cls, + **wrapper_kwargs, + } + with _ConfigAutoWrap(**kwargs): + yield + + +def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module: + """ + Annotate that a module should be wrapped. Annotated modules will only be + wrapped if inside of an :func:`enable_wrap` context manager. This allows + a module to be initialized both with and without a wrapper without code + change. + + The class that this function wraps the passed in ``nn.Module`` with is the + passed in ``wrapper_cls`` argument into ``enable_wrap``. Both + ``enable_wrap`` and ``wrap`` can take in kwargs specifying how to construct + the ``wrapper_cls`` instance. In the case of duplicate kwargs in + ``enable_wrap`` and ``wrap``, the argument passed into ``wrap`` will be + respected. + + Usage:: + + with enable_wrap(wrapper_cls=FSDP, **fsdp_config): + # Wraps layer in FSDP by default if within context + self.l1 = wrap(torch.nn.Linear(5, 5)) + + Args: + module (nn.Module): module to wrap (if in :func:`enable_wrap` context) + **wrap_overrides: configuration overrides that will take priority over + the values provided by the :func:`enable_wrap` context + """ + if _ConfigAutoWrap.in_autowrap_context: + if _ConfigAutoWrap.wrapper_cls is None: + raise AssertionError("Expected _ConfigAutoWrap.wrapper_cls to be set") + + wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides} + return _wrap( + module, + _ConfigAutoWrap.wrapper_cls, + **wrap_overrides, + ) + return module + + +def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module: + if wrapper_cls is None: + raise AssertionError("Expected wrapper_cls to be set") + if hasattr(module, "_wrap_overrides"): + # If module has a _wrap_overrides attribute, we force overriding the + # FSDP config with these attributes for this module. Currently this + # is only used to disable mixed precision for BatchNorm when + # auto_wrapping. + overrides = {**kwargs, **module._wrap_overrides} # type: ignore[arg-type, dict-item] + return wrapper_cls(module, **overrides) + + return wrapper_cls(module, **kwargs) + + +def _recursive_wrap( + module: nn.Module, + auto_wrap_policy: Callable, + wrapper_cls: Callable, + ignored_modules: set[nn.Module], + ignored_params: set[nn.Parameter], + only_wrap_children: bool = False, + **kwargs: Any, +) -> tuple[nn.Module, int]: + """ + Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns + ``True`` with ``wrapper_cls``. + + Args: + module (nn.Module): Module to recursively wrap. + auto_wrap_policy (Callable): A callable representing a policy that + determines which modules to recursively wrap with ``wrapper_cls``. + ignored_modules (set[torch.nn.Module]): Modules to ignore when + wrapping. + ignored_params (set[torch.nn.Parameter]): Parameters to ignore when + wrapping; these should be the parameters contained in the modules + in ``ignored_modules``. + Returns: + (nn.Module, int): + ``module`` after wrapping and the numel recursively wrapped. + """ + if auto_wrap_policy is None: + raise AssertionError("Must specify auto_wrap_policy.") + if wrapper_cls is None: + raise AssertionError("Must specify wrapper_cls") + # Make sure no child is already wrapped. + for _, child in module.named_modules(): + if child in ignored_modules: + continue + try: + if isinstance(child, cast(type, wrapper_cls)): + raise AssertionError( + f"Child module {child} is already wrapped by {wrapper_cls}" + ) + except TypeError: + # wrapper_cls is a function as opposed to a class type, just bypass above check. + pass + + # We count all params, assuming none of them are already wrapped. + nonwrapped_numel = sum( + p.numel() for p in module.parameters() if p not in ignored_params + ) + + if auto_wrap_policy is None: + raise AssertionError("Expected auto_wrap_policy to be set") + if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel): + total_wrapped_numel = 0 + # Iterate through the children, recursively wrap if necessary + for name, child in module.named_children(): + if child in ignored_modules: + continue + wrapped_child, num_wrapped_params = _recursive_wrap( + module=child, + auto_wrap_policy=auto_wrap_policy, + wrapper_cls=wrapper_cls, + ignored_modules=ignored_modules, + ignored_params=ignored_params, + **kwargs, + ) + setattr(module, name, wrapped_child) + # Keep track of how many parameters have been wrapped + total_wrapped_numel += num_wrapped_params + # decide if we need to wrap the current module, + # since the left over parameters exceed the number of params to wrap + remainder = nonwrapped_numel - total_wrapped_numel + if not only_wrap_children and auto_wrap_policy( + module=module, recurse=False, nonwrapped_numel=remainder + ): + # Leaf node or final wrapping of the remainder both happen here. + return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel + else: + return module, total_wrapped_numel + return module, 0 + + +class _ConfigAutoWrap: + """ + Helper class to wrap modules based on default config args via a context manager. + See :func:`enable_wrap` for more information. + """ + + in_autowrap_context: bool = False # Context flag + wrapper_cls: Optional[Callable] = None # The wrapper class + kwargs: dict[str, Any] = {} # Wrapper's args + + def __init__(self, **kwargs: dict[str, Any]): + self.kwargs = kwargs + + @staticmethod + def enable_autowrap_context(kwargs: Any) -> None: + if _ConfigAutoWrap.in_autowrap_context: + raise NotImplementedError( + "You are already within an autowrap context and we currently do not supported nested autowrap." + ) + _ConfigAutoWrap.in_autowrap_context = True + # Get and save the wrapper cls for the context. + if "wrapper_cls" not in kwargs: + raise AssertionError( + "Expected to pass in wrapper_cls arg into _ConfigAutoWrap." + ) + _ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"]) + del kwargs["wrapper_cls"] + # Save the rest. + _ConfigAutoWrap.kwargs = kwargs + + @staticmethod + def disable_autowrap_context() -> None: + _ConfigAutoWrap.in_autowrap_context = False + _ConfigAutoWrap.wrapper_cls = None + _ConfigAutoWrap.kwargs = {} + + def __enter__(self) -> None: + self.enable_autowrap_context(self.kwargs) + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.disable_autowrap_context() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/launcher/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/launcher/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb744a2b93615b703eb0dafb7c8e6c71bc1ad5d2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/launcher/__init__.py @@ -0,0 +1,14 @@ +#!/usr/bin/env/python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from torch.distributed.launcher.api import ( # noqa: F401 + elastic_launch, + launch_agent, + LaunchConfig, +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/launcher/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/launcher/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8eb02c017e9560b5af818e1c892f6046acc56fa2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/launcher/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/launcher/__pycache__/api.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/launcher/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fad50f238fffaa3e1b6df279c5396d72c1f2e5ff Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/launcher/__pycache__/api.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/launcher/api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/launcher/api.py new file mode 100644 index 0000000000000000000000000000000000000000..2adf5549fecf13560d0c8637085872688c9454a4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/launcher/api.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +import sys +import uuid +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + +import torch +import torch.distributed.elastic.rendezvous.registry as rdzv_registry +from torch._utils_internal import get_default_numa_options +from torch.distributed.elastic import events, metrics +from torch.distributed.elastic.agent.server.api import WorkerSpec +from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent +from torch.distributed.elastic.multiprocessing import ( + DefaultLogsSpecs, + LogsSpecs, + SignalException, +) +from torch.distributed.elastic.multiprocessing.errors import ChildFailedError +from torch.distributed.elastic.rendezvous import RendezvousParameters +from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint +from torch.distributed.elastic.utils.logging import get_logger +from torch.numa.binding import NumaOptions + + +__all__ = ["LaunchConfig", "elastic_launch", "launch_agent"] + +logger = get_logger(__name__) + + +@dataclass +class LaunchConfig: + """ + Creates a rendezvous config. + + Args: + min_nodes: Minimum amount of nodes that the user function will + be launched on. Elastic agent ensures that the user + function start only when the min_nodes amount enters + the rendezvous. + max_nodes: Maximum amount of nodes that the user function + will be launched on. + nproc_per_node: On each node the elastic agent will launch + this amount of workers that will execute user + defined function. + rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd). + rdzv_endpoint: The endpoint of the rdzv sync. storage. + rdzv_configs: Key, value pair that specifies rendezvous specific configuration. + rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going + to be removed in future versions, see the note below. The default timeout is 900 seconds. + run_id: The unique run id of the job (if not passed a unique one will be + deduced from run environment - flow workflow id in flow - or auto generated). + role: User defined role of the worker (defaults to "trainer"). + max_restarts: The maximum amount of restarts that elastic agent will conduct + on workers before failure. + monitor_interval: The interval in seconds that is used by the elastic_agent + as a period of monitoring workers. + start_method: The method is used by the elastic agent to start the + workers (spawn, fork, forkserver). + metrics_cfg: configuration to initialize metrics. + local_addr: address of the local node if any. If not set, a lookup on the local + machine's FQDN will be performed. + local_ranks_filter: ranks for which to show logs in console. If not set, show from all. + event_log_handler: name of the event logging handler as registered in + `elastic/events/handlers.py `_. + duplicate_stdout_filters: If non-empty, duplicates stdout to a file containing only lines + that match _any_ of the filter strings. + duplicate_stderr_filters: If non-empty, duplicates stderr to a file containing only lines + that match _any_ of the filter strings. + virtual_local_rank: Enable virtual local rank mode for workers (defaults to False). + When enabled, LOCAL_RANK is set to 0 for all workers and + CUDA_VISIBLE_DEVICES is adjusted so each worker accesses its + assigned GPU at device index 0. + + + .. note:: + `rdzv_timeout` is a legacy argument that will be removed in future. + Set the timeout via `rdzv_configs['timeout']` + + """ + + min_nodes: int + max_nodes: int + nproc_per_node: int + logs_specs: LogsSpecs | None = None + run_id: str = "" + role: str = "default_role" + rdzv_endpoint: str = "" + rdzv_backend: str = "etcd" + rdzv_configs: dict[str, Any] = field(default_factory=dict) + rdzv_timeout: int = -1 + max_restarts: int = 3 + monitor_interval: float = 0.1 + start_method: str = "spawn" + log_line_prefix_template: str | None = None + metrics_cfg: dict[str, str] = field(default_factory=dict) + local_addr: str | None = None + event_log_handler: str = "null" + numa_options: NumaOptions | None = None + signals_to_handle: str = "SIGTERM,SIGINT,SIGHUP,SIGQUIT" + duplicate_stdout_filters: list[str] | None = None + duplicate_stderr_filters: list[str] | None = None + virtual_local_rank: bool = False + + def __post_init__(self): + default_timeout = 900 + if self.rdzv_timeout != -1: + self.rdzv_configs["timeout"] = self.rdzv_timeout + elif "timeout" not in self.rdzv_configs: + self.rdzv_configs["timeout"] = default_timeout + + # Post-processing to enable refactoring to introduce logs_specs due to non-torchrun API usage + if self.logs_specs is None: + self.logs_specs = DefaultLogsSpecs() + + if ( + self.numa_options is None + and torch.cuda.is_available() + # We assume local_rank n uses cuda device n. + and torch.cuda.device_count() == self.nproc_per_node + ): + self.numa_options = get_default_numa_options() + logger.info("Using default numa options = %r", self.numa_options) + + +class elastic_launch: + """ + Launches an torchelastic agent on the container that invoked the entrypoint. + + 1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/ + ``entrypoint`` can be a function or a command. + 2. The return value is a map of each worker's output mapped + by their respective global rank. + + Usage + + :: + + def worker_fn(foo): + # ... + + def main(): + # entrypoint is a function. + outputs = elastic_launch(LaunchConfig, worker_fn)(foo) + # return rank 0's output + return outputs[0] + + # entrypoint is a command and ``script.py`` is the python module. + outputs = elastic_launch(LaunchConfig, "script.py")(args) + outputs = elastic_launch(LaunchConfig, "python")("script.py") + """ + + def __init__( + self, + config: LaunchConfig, + entrypoint: Callable | str | None, + ): + self._config = config + self._entrypoint = entrypoint + + def __call__(self, *args): + return launch_agent(self._config, self._entrypoint, list(args)) + + +def _get_entrypoint_name(entrypoint: Callable | str | None, args: list[Any]) -> str: + """Retrieve entrypoint name with the rule: + 1. If entrypoint is a function, use ``entrypoint.__qualname__``. + 2. If entrypoint is a string, check its value: + 2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args`` + which does not start with hifen letter (for example, "-u" will be skipped). + 2.2 otherwise, use ``entrypoint`` value. + 3. Otherwise, return empty string. + """ + if isinstance(entrypoint, Callable): # type: ignore[arg-type] + return entrypoint.__name__ # type: ignore[union-attr] + elif isinstance(entrypoint, str): + if entrypoint == sys.executable: + return next((arg for arg in args if arg[0] != "-"), "") + else: + return entrypoint + else: + return "" + + +def _get_addr_and_port( + rdzv_parameters: RendezvousParameters, +) -> tuple[str | None, int | None]: + if rdzv_parameters.backend != "static": + return (None, None) + endpoint = rdzv_parameters.endpoint + endpoint = endpoint.strip() + if not endpoint: + raise ValueError( + "Endpoint is missing in endpoint. Try to add --master-addr and --master-port" + ) + master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1) + if master_port == -1: + raise ValueError( + f"port is missing in endpoint: {endpoint}. Try to specify --master-port" + ) + return (master_addr, master_port) + + +def launch_agent( + config: LaunchConfig, + entrypoint: Callable | str | None, + args: list[Any], +) -> dict[int, Any]: + if not config.run_id: + run_id = str(uuid.uuid4().int) + logger.warning("config has no run_id, generated a random run_id: %s", run_id) + config.run_id = run_id + + entrypoint_name = _get_entrypoint_name(entrypoint, args) + + logger.info( + "Starting elastic_operator with launch configs:\n" + " entrypoint : %(entrypoint)s\n" + " min_nodes : %(min_nodes)s\n" + " max_nodes : %(max_nodes)s\n" + " nproc_per_node : %(nproc_per_node)s\n" + " run_id : %(run_id)s\n" + " rdzv_backend : %(rdzv_backend)s\n" + " rdzv_endpoint : %(rdzv_endpoint)s\n" + " rdzv_configs : %(rdzv_configs)s\n" + " max_restarts : %(max_restarts)s\n" + " monitor_interval : %(monitor_interval)s\n" + " log_dir : %(log_dir)s\n" + " metrics_cfg : %(metrics_cfg)s\n" + " event_log_handler : %(event_log_handler)s\n" + " numa_options : %(numa_options)s\n" + " signals_to_handle : %(signals_to_handle)s\n" + " duplicate_stdout_filters : %(duplicate_stdout_filters)s\n" + " duplicate_stderr_filters : %(duplicate_stderr_filters)s\n", + { + "entrypoint": entrypoint_name, + "min_nodes": config.min_nodes, + "max_nodes": config.max_nodes, + "nproc_per_node": config.nproc_per_node, + "run_id": config.run_id, + "rdzv_backend": config.rdzv_backend, + "rdzv_endpoint": config.rdzv_endpoint, + "rdzv_configs": config.rdzv_configs, + "max_restarts": config.max_restarts, + "monitor_interval": config.monitor_interval, + "log_dir": config.logs_specs.root_log_dir, # type: ignore[union-attr] + "metrics_cfg": config.metrics_cfg, + "event_log_handler": config.event_log_handler, + "numa_options": config.numa_options, + "signals_to_handle": config.signals_to_handle, + "duplicate_stdout_filters": config.duplicate_stdout_filters, + "duplicate_stderr_filters": config.duplicate_stderr_filters, + }, + ) + + rdzv_parameters = RendezvousParameters( + backend=config.rdzv_backend, + endpoint=config.rdzv_endpoint, + run_id=config.run_id, + min_nodes=config.min_nodes, + max_nodes=config.max_nodes, + local_addr=config.local_addr, + **config.rdzv_configs, + ) + + master_addr, master_port = _get_addr_and_port(rdzv_parameters) + + # Set the signals to handle in the environment variable + os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = config.signals_to_handle + + spec = WorkerSpec( + role=config.role, + local_world_size=config.nproc_per_node, + entrypoint=entrypoint, + args=tuple(args), + rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters), + max_restarts=config.max_restarts, + monitor_interval=config.monitor_interval, + master_addr=master_addr, + master_port=master_port, + local_addr=config.local_addr, + event_log_handler=config.event_log_handler, + numa_options=config.numa_options, + duplicate_stdout_filters=config.duplicate_stdout_filters, + duplicate_stderr_filters=config.duplicate_stderr_filters, + virtual_local_rank=config.virtual_local_rank, + ) + + agent = LocalElasticAgent( + spec=spec, + logs_specs=config.logs_specs, # type: ignore[arg-type] + start_method=config.start_method, + log_line_prefix_template=config.log_line_prefix_template, + ) + + shutdown_rdzv = True + try: + metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg)) + + result = agent.run() + # records that agent.run() has succeeded NOT that workers have succeeded + events.record(agent.get_event_succeeded(), config.event_log_handler) + + if result.is_failed(): + # ChildFailedError is treated specially by @record + # if the error files for the failed children exist + # @record will copy the first error (root cause) + # to the error file of the launcher process. + raise ChildFailedError( + name=entrypoint_name, + failures=result.failures, + ) + + return result.return_values + except ChildFailedError: + raise + except SignalException: + # when the agent dies with a signal do NOT shutdown the rdzv_handler + # since this closes the rendezvous on this rdzv_id permanently and + # prevents any additional scaling events + shutdown_rdzv = False + events.record(agent.get_event_failed(), config.event_log_handler) + raise + except Exception: + events.record(agent.get_event_failed(), config.event_log_handler) + raise + finally: + if shutdown_rdzv: + spec.rdzv_handler.shutdown() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e15fb517052e4aefeb7377d1f0ca63cf2b2da753 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/__init__.py @@ -0,0 +1,7 @@ +import torch + +from .functional import * # noqa: F403 + + +if torch.distributed.rpc.is_available(): + from .api.remote_module import RemoteModule diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21539a4eea224f93612f99ae1e6371296ccc72cf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/__pycache__/functional.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/__pycache__/functional.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fd46a58ea7acbadb62f84a2b8a76d4b961f2abf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/__pycache__/functional.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/api/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/api/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/api/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..446ead37cbdb24eb51df2bb7ff6c44cc43580f4f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/api/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/api/__pycache__/remote_module.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/api/__pycache__/remote_module.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56bb603a3d281c3aad607220a3b111933e88030d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/api/__pycache__/remote_module.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/api/remote_module.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/api/remote_module.py new file mode 100644 index 0000000000000000000000000000000000000000..728bf9c0288a2002c6c81d5347cc0c44d27957da --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/api/remote_module.py @@ -0,0 +1,771 @@ +#!/usr/bin/python3 +# mypy: allow-untyped-defs +import collections +import io +import sys +import types +from collections.abc import Callable, Iterator, Mapping +from typing import Any, TypeVar, Union +from typing_extensions import Self + +import torch +import torch.distributed.rpc as rpc +from torch import device, dtype, nn, Tensor +from torch.distributed import _remote_device +from torch.distributed.nn.jit import instantiator +from torch.distributed.rpc.internal import _internal_rpc_pickler +from torch.nn import Module +from torch.nn.parameter import Parameter +from torch.utils.hooks import RemovableHandle + + +__all__ = ["RemoteModule"] + +_grad_t = Union[tuple[Tensor, ...], Tensor] +# See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use +# of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be +# the type of the subclass, not the looser type of `Module`. +T = TypeVar("T", bound="Module") + +_NON_SCRIPTABLE_REMOTE_MODULE_MODULE = ( + instantiator.instantiate_non_scriptable_remote_module_template() +) + +_REMOTE_MODULE_PICKLED_ATTRIBUTES = ( + "on", + "device", + "is_device_map_set", + "is_scriptable", + "generated_methods", + "module_rref", +) + +_SerializedRemoteModule = collections.namedtuple( # type: ignore[misc] + "_SerializedRemoteModule", + _REMOTE_MODULE_PICKLED_ATTRIBUTES, +) + +# These attributes are mostly from RemoteModule's parent class and are intentionally not pickled. +# A new attribute of RemoteModule should be either in _REMOTE_MODULE_PICKLED_ATTRIBUTES +# or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING. +# Otherwise, it will not be pickled. +_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING = ( + "training", + "_parameters", + "_buffers", + "_non_persistent_buffers_set", + "_backward_hooks", + "_backward_pre_hooks", + "_is_full_backward_hook", + "_forward_hooks", + "_forward_hooks_with_kwargs", + "_forward_hooks_always_called", + "_forward_pre_hooks", + "_forward_pre_hooks_with_kwargs", + "_state_dict_hooks", + "_state_dict_pre_hooks", + "_load_state_dict_pre_hooks", + "_load_state_dict_post_hooks", + "_state_dict_pre_hooks", + "_modules", + # The two attributes below are generated methods, not available at pickling time. + "forward_async", + "forward", +) + + +# RPC handler. +def _instantiate_template(module_interface_cls, enable_moving_cpu_tensors_to_cuda): + instantiator.instantiate_scriptable_remote_module_template( + module_interface_cls, enable_moving_cpu_tensors_to_cuda + ) + + +def _create_module(module_cls, args, kwargs, device): + module = module_cls(*args, **kwargs) + if not isinstance(module, nn.Module): + raise ValueError( + "Expect `module_cls(*args, **kwargs)` returns an instance of , " + f"but it returns an instance of {type(module)}." + ) + module.to(device) + return module + + +def _create_module_with_interface( + module_cls, args, kwargs, device, module_interface_cls +): + module = _create_module(module_cls, args, kwargs, device) + if module_interface_cls is not None: + module = torch.jit.script(module) + return rpc.RRef(module, module_interface_cls) + + +def _param_rrefs(module_rref, recurse) -> list[rpc.RRef[Parameter]]: + ret: list[rpc.RRef[Parameter]] = [ + rpc.RRef(param) for param in module_rref.local_value().parameters(recurse) + ] + return ret + + +def _raise_not_supported(name: str) -> None: + raise ValueError(f"Method ``{name}`` not supported for RemoteModule") + + +class _RemoteModule(nn.Module): + def __new__(cls, *args, **kwargs): + # Use __new__ for logging purposes. + torch._C._log_api_usage_once("torch.distributed.nn.api.remote_module") + return super().__new__(cls) + + def __init__( + self, + remote_device: str, + module_cls: type[nn.Module], + args: tuple | None = None, + kwargs: dict[str, Any] | None = None, + _module_interface_cls: Any = None, + ): + """ + RemoteModule instance can only be created after RPC initialization. + + It creates a user-specified module on a specified remote node. + It behaves like a regular ``nn.Module`` except that the ``forward`` method is + executed on the remote node. + It takes care of autograd recording to ensure the backward pass propagates + gradients back to the corresponding remote module. + It can be shared across processors using `RPC framework `__, + without incurring any overheads of copying the actual module, + which is equivalent to an :class:`~torch.distributed.rpc.RRef` + pointing to the remote module. + + The arguments of ``forward_async`` and ``forward`` are the same as + the ``forward`` method of the module returned by the ``module_cls``. + + Apart from ``forward_async`` and ``forward``, no other methods are supported from nn.Module for now. + + Particularly, to create a hybrid model, typically the local modules should be + created outside of remote modules, rather than as submodules of any remote module (by calling ``add_module``). + Hybrid Example: + >>> class HybridModel(nn.Module): + >>> def __init__(self) -> None: + >>> nn.Module.__init__(self) + >>> self.remote_embedding = RemoteModule(...) + >>> self.local_linear = nn.Linear(...) + + For example, if ``module_cls`` returns an instance of ``nn.Linear``, + that has ``forward`` method signature, ``def forward(input: Tensor) -> Tensor:``, + the generated ``RemoteModule`` will have 2 methods in signature of + ``def forward(input: Tensor) -> Tensor:`` and + ``def forward_async(input: Tensor) -> Future[Tensor]:``. + + .. note:: + If the remote module is placed on a cuda device, + any input CPU tensors will be automatically moved to the same cuda device, + and GPU tensors are returned over the wire according to the device map of the remote worker on TensorPipe RPC backend. + + Args: + remote_device (str): Device on the destination worker where we'd like to place this module. + The device can be a local device or a remote device specified by one of the following remote + formats: + + 1. "rank:/" (ex: "rank:0/cuda:0"). + 2. "/" (ex: "trainer0/cuda:0"). + + In addition, the device field can be optional and the default value is "cpu". + module_cls (nn.Module): For example, + >>> class MyModule(nn.Module): + >>> def forward(input): + >>> return input + 1 + >>> + >>> module_cls = MyModule + args (Sequence, optional): args to be passed to ``module_cls``. + kwargs (Dict, optional): kwargs to be passed to ``module_cls``. + _module_interface_cls (type, optional): The TorchScript interface type for the module + to be created. The type object should be decorated by @torch.jit.interface. + If not provided, the generated RemoteModule is not torchscript-able. + Warning, this is an experimental API and susceptible to frequent changes. + + Returns: + A remote module instance which wraps the :class:`~nn.Module` created by the + user-provided ``module_cls``, it has a blocking ``forward`` method and an + asynchronous ``forward_async`` method that returns a future of the ``forward`` call + on the user-provided module on the remote side. + + Example:: + Run the following code in two different processes: + + >>> # xdoctest: +SKIP("distributed") + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> from torch import nn, Tensor + >>> from torch.distributed.nn.api.remote_module import RemoteModule + >>> + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> remote_linear_module = RemoteModule( + >>> "worker1/cpu", nn.Linear, args=(20, 30), + >>> ) + >>> input = torch.randn(128, 20) + >>> ret_fut = remote_linear_module.forward_async(input) + >>> ret = ret_fut.wait() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + """ + super().__init__() + + enable_moving_cpu_tensors_to_cuda = self._prepare_init(remote_device) + + # Default arguments preparation. + args = args if args is not None else () + kwargs = kwargs if kwargs is not None else {} + + if _module_interface_cls is not None: + # Users reply on this field to know if this generated RemoteModule is TorchScript-able. + self.is_scriptable = True + + # Instantiate template on remote side. + fut = rpc.rpc_async( + self.on, + _instantiate_template, + (_module_interface_cls, enable_moving_cpu_tensors_to_cuda), + ) + + self._init_template( + _module_interface_cls, enable_moving_cpu_tensors_to_cuda + ) + + # Instantiate template on remote side. + fut = rpc.rpc_async( + self.on, + _instantiate_template, + (_module_interface_cls, enable_moving_cpu_tensors_to_cuda), + ) + + # Create the module on the remote side. + fut.wait() # Ensure remote_module_cls is available on remote side. + + # TODO: We need to change this to rpc.remote, and make it async (see the else branch below). + # For that we need to be able to apply _module_interface_cls to the RRef returned by rpc.remote + # See https://github.com/pytorch/pytorch/issues/58098 for more context. + self.module_rref = rpc.rpc_sync( + self.on, + _create_module_with_interface, + (module_cls, args, kwargs, self.device, _module_interface_cls), + ) + else: + self.is_scriptable = False + self.generated_methods = ( + _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods + ) + # Create the module on the remote side. + self.module_rref = rpc.remote( + self.on, + _create_module, + (module_cls, args, kwargs, self.device), + ) + + self._install_generated_methods() + self._check_attribute_picklability() + + def remote_parameters(self, recurse: bool = True) -> list[rpc.RRef[Parameter]]: + """ + Return a list of :class:`~torch.distributed.rpc.RRef` pointing to the remote module's parameters. + + This can typically be used in conjunction + with :class:`~torch.distributed.optim.DistributedOptimizer`. + + Args: + recurse (bool): if True, then returns parameters of the remote + module and all submodules of the remote module. Otherwise, + returns only parameters that are direct members of the + remote module. + + Returns: + A list of :class:`~torch.distributed.rpc.RRef` (``List[RRef[nn.Parameter]]``) + to remote module's parameters. + """ + return rpc.rpc_sync(self.on, _param_rrefs, args=(self.module_rref, recurse)) + + def get_module_rref(self) -> rpc.RRef[nn.Module]: + """Return an :class:`~torch.distributed.rpc.RRef` (``RRef[nn.Module]``) pointing to the remote module.""" + return self.module_rref + + @torch.jit.export + def __getstate__(self): + raise RuntimeError( + "Cannot pickle RemoteModule in python pickler. RemoteModule can only be pickled when using RPC" + ) + + @torch.jit.export + def __setstate__(self, state): + raise RuntimeError( + "Cannot unpickle RemoteModule in python pickler. RemoteModule can only be unpickled when using RPC" + ) + + def register_buffer( + self, name: str, tensor: Tensor | None, persistent: bool = True + ) -> None: + _raise_not_supported(self.register_buffer.__name__) + + def register_parameter(self, name: str, param: Parameter | None) -> None: + _raise_not_supported(self.register_parameter.__name__) + + def add_module(self, name: str, module: Module | None) -> None: + _raise_not_supported(self.add_module.__name__) + + def apply(self, fn: Callable[[Module], None]) -> Self: # type: ignore[return] + _raise_not_supported(self.apply.__name__) + + def cuda(self, device: int | device | None = None) -> Self: # type: ignore[return] + _raise_not_supported(self.cuda.__name__) + + def ipu(self, device: int | device | None = None) -> Self: # type: ignore[return] + _raise_not_supported(self.ipu.__name__) + + def xpu(self, device: int | device | None = None) -> Self: # type: ignore[return] + _raise_not_supported(self.xpu.__name__) + + def cpu(self) -> Self: # type: ignore[return] + _raise_not_supported(self.cpu.__name__) + + def type(self, dst_type: dtype | str) -> Self: # type: ignore[return] + _raise_not_supported(self.type.__name__) + + def float(self) -> Self: # type: ignore[return] + _raise_not_supported(self.float.__name__) + + def double(self) -> Self: # type: ignore[return] + _raise_not_supported(self.double.__name__) + + def half(self) -> Self: # type: ignore[return] + _raise_not_supported(self.half.__name__) + + def bfloat16(self) -> Self: # type: ignore[return] + _raise_not_supported(self.bfloat16.__name__) + + def to(self, *args, **kwargs) -> T: # type: ignore[misc, return, type-var] + _raise_not_supported(self.to.__name__) + + def register_backward_hook( # type: ignore[return] + self, + hook: Callable[[Module, _grad_t, _grad_t], None | _grad_t], + # pyrefly: ignore [bad-return] + ) -> RemovableHandle: + _raise_not_supported(self.register_backward_hook.__name__) + + def register_forward_pre_hook( # type: ignore[return] + self, + hook: Callable[[T, tuple[Any, ...]], Any | None] + | Callable[ + [T, tuple[Any, ...], dict[str, Any]], tuple[Any, dict[str, Any]] | None + ], + prepend: bool = False, + with_kwargs: bool = False, + # pyrefly: ignore [bad-return] + ) -> RemovableHandle: + _raise_not_supported(self.register_forward_pre_hook.__name__) + + def register_forward_hook( # type: ignore[return, override] + self, + hook: Callable[[T, tuple[Any, ...], Any], Any | None] + | Callable[[T, tuple[Any, ...], dict[str, Any], Any], Any | None], + prepend: bool = False, + with_kwargs: bool = False, + # pyrefly: ignore [bad-return] + ) -> RemovableHandle: + _raise_not_supported(self.register_forward_hook.__name__) + + def state_dict(self, *args, **kwargs): + _raise_not_supported(self.state_dict.__name__) + + def load_state_dict( + self, + state_dict: Mapping[str, Any], + strict: bool = True, + assign: bool = False, + ): + _raise_not_supported(self.load_state_dict.__name__) + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + raise ValueError( + "Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead." + ) + + def named_parameters( # type: ignore[return] + self, + prefix: str = "", + recurse: bool = True, + remove_duplicate: bool = True, + # pyrefly: ignore [bad-return] + ) -> Iterator[tuple[str, Parameter]]: + _raise_not_supported(self.named_parameters.__name__) + + def buffers(self, recurse: bool = True) -> Iterator[Tensor]: # type: ignore[return] + _raise_not_supported(self.buffers.__name__) + + def named_buffers( # type: ignore[return] + self, + prefix: str = "", + recurse: bool = True, + remove_duplicate: bool = True, + # pyrefly: ignore [bad-return] + ) -> Iterator[tuple[str, Tensor]]: + _raise_not_supported(self.named_buffers.__name__) + + def children(self) -> Iterator[Module]: # type: ignore[return] + _raise_not_supported(self.children.__name__) + + def named_children(self) -> Iterator[tuple[str, Module]]: # type: ignore[return] + _raise_not_supported(self.named_children.__name__) + + def modules(self) -> Iterator[Module]: # type: ignore[return] + _raise_not_supported(self.modules.__name__) + + def named_modules( + self, + memo: set[Module] | None = None, + prefix: str = "", + remove_duplicate: bool = True, + ): + _raise_not_supported(self.named_modules.__name__) + + def train(self, mode: bool = True) -> Self: + return self.module_rref.rpc_sync().train() # type: ignore[operator, union-attr] + + def eval(self) -> Self: + return self.module_rref.rpc_sync().eval() # type: ignore[operator, union-attr] + + def requires_grad_(self, requires_grad: bool = True) -> Self: # type: ignore[return] + _raise_not_supported(self.requires_grad_.__name__) + + def zero_grad(self, set_to_none: bool = True) -> None: + _raise_not_supported(self.zero_grad.__name__) + + def share_memory(self) -> Self: # type: ignore[return] + _raise_not_supported(self.share_memory.__name__) + + def extra_repr(self) -> str: # type: ignore[return] + _raise_not_supported(self.extra_repr.__name__) + + def _prepare_init(self, remote_device_str: str) -> bool: + """Prepare the initialization and returns whether to enable automatically moving CPU tensors to CUDA devices.""" + # Sanity check. + assert rpc._is_current_rpc_agent_set(), "RemoteModule only works in RPC." + + remote_device = _remote_device(remote_device_str) + self.on = ( + remote_device.worker_name() + if remote_device.worker_name() is not None + else remote_device.rank() + ) + self.device = str(remote_device.device()) + agent = rpc._get_current_rpc_agent() + # If the device map of the remote worker is set, + # then enable moving any input CPU tensors to the same cuda device. + self.is_device_map_set = bool( + agent._get_device_map(agent.get_worker_info(self.on)) # type: ignore[arg-type] + ) + # ``enable_moving_cpu_tensors_to_cuda`` is less strict than ``is_device_map_set``: + # If ``enable_moving_cpu_tensors_to_cuda`` is true, but the device map is not set, + # then any CPU tensors can still be moved to a cuda device to run forward, + # but the output must be moved back to CPU before being sent over the wire. + enable_moving_cpu_tensors_to_cuda = torch.device(self.device).type == "cuda" + return enable_moving_cpu_tensors_to_cuda + + def _init_template(self, module_interface_cls, enable_moving_cpu_tensors_to_cuda): + """Instantiate template on local side.""" + generated_module = instantiator.instantiate_scriptable_remote_module_template( + module_interface_cls, enable_moving_cpu_tensors_to_cuda + ) + self.generated_methods = generated_module._generated_methods + + def _check_attribute_picklability(self): + """Check if all the attribute has explicitly defined whether to be pickled (i.e., picklability).""" + for k in self.__dict__: + if ( + k not in _REMOTE_MODULE_PICKLED_ATTRIBUTES + and k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING + ): + raise AttributeError( + f"Attribute {k} must be either in ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` or " + "``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``." + ) + + def _install_generated_methods(self): + for method in self.generated_methods: + method_name = method.__name__ + method = torch.jit.export(method) + setattr(self, method_name, types.MethodType(method, self)) + + @staticmethod + def init_from_module_rref( + remote_device: str, + module_rref: rpc.RRef[nn.Module], + _module_interface_cls: Any = None, + ): + """ + Besides the constructor, a RemoteModule instance can also be initialized given a module RRef. + + This alternate initialization method can be particularly useful if we want to create multiple + RemoteModule instances that share the same underlying module and reduce memory consumption. + + Moreover, this also provides a workaround for passing script RemoteModule over RPC, + which is not supported. The recommended way is as follows: + + 1. the sender creates a RemoteModule; + 2. the sender sends its ``module_rref`` over RPC; + 3. the receiver calls this method to initialize another RemoteModule using the same ``module_rref``. + + Example:: + Run the following code in two different processes: + + >>> # xdoctest: +SKIP("distributed") + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> from torch import nn, Tensor + >>> from torch.distributed.nn.api.remote_module import RemoteModule + >>> + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> remote_module = RemoteModule( + >>> "worker1/cpu", nn.Linear, args=(20, 30), + >>> ) + >>> + >>> remote_module1 = rpc.rpc_sync( + >>> "worker1/cpu", + >>> RemoteModule.init_from_module_rref, + >>> ("worker1/cpu", remote_module1.get_module_rref()), + >>> ) + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + Args: + remote_device (str): Device on the destination worker where we'd like to place this module. + The device can be a local device or a remote device specified by one of the following remote + formats: + + 1. "rank:/" (ex: "rank:0/cuda:0"). + 2. "/" (ex: "trainer0/cuda:0"). + + In addition, the device field can be optional and the default value is "cpu". + module_rref (RRef[nn.Module]): The module reference shared by both the caller and + the created remote module. + _module_interface_cls (type, optional): The TorchScript interface type for the module + to be created. The type object should be decorated by @torch.jit.interface. + If not provided, the generated RemoteModule is not torchscript-able. + Warning, this is an experimental API and susceptible to frequent changes. + + Returns: + A remote module instance which wraps the :class:`~nn.Module` created by the + user-provided ``module_rref``, it has a blocking ``forward`` method and an + asynchronous ``forward_async`` method that returns a future of the ``forward`` call + on the user-provided module on the remote side. + """ + # NOTE: if a new attribute is added to this class, also need to add it + # to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` for pickling/unpickling. + + remote_module = object.__new__(RemoteModule) + + # pyrefly: ignore [missing-attribute] + enable_moving_cpu_tensors_to_cuda = remote_module._prepare_init(remote_device) + + if _module_interface_cls is not None: + # Users reply on this field to know if this generated RemoteModule is TorchScript-able. + # pyrefly: ignore [missing-attribute] + remote_module.is_scriptable = True + + # pyrefly: ignore [missing-attribute] + remote_module._init_template( + _module_interface_cls, enable_moving_cpu_tensors_to_cuda + ) + else: + # pyrefly: ignore [missing-attribute] + remote_module.is_scriptable = False + # pyrefly: ignore [missing-attribute] + remote_module.generated_methods = ( + _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods + ) + # pyrefly: ignore [missing-attribute] + remote_module.module_rref = module_rref + + # pyrefly: ignore [missing-attribute] + remote_module._install_generated_methods() + # pyrefly: ignore [missing-attribute] + remote_module._check_attribute_picklability() + + return remote_module + + +class RemoteModule(_RemoteModule): + """ + A RemoteModule instance can only be created after RPC initialization. + + It creates a user-specified module on a specified remote node. + It behaves like a regular ``nn.Module`` except that the ``forward`` method is + executed on the remote node. + It takes care of autograd recording to ensure the backward pass propagates + gradients back to the corresponding remote module. + + It generates two methods ``forward_async`` and ``forward`` based on the + signature of the ``forward`` method of ``module_cls``. ``forward_async`` + runs asynchronously and returns a Future. The arguments of ``forward_async`` + and ``forward`` are the same as the ``forward`` method of the module + returned by the ``module_cls``. + + For example, if ``module_cls`` returns an instance of ``nn.Linear``, + that has ``forward`` method signature: ``def forward(input: Tensor) -> Tensor:``, + the generated ``RemoteModule`` will have 2 methods with the signatures: + + | ``def forward(input: Tensor) -> Tensor:`` + | ``def forward_async(input: Tensor) -> Future[Tensor]:`` + + Args: + remote_device (str): Device on the destination worker where we'd like to place this module. + The format should be "/", where the device field can be parsed as torch.device type. + E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + In addition, the device field can be optional and the default value is "cpu". + module_cls (nn.Module): Class for the module to be created remotely. For example, + + >>> class MyModule(nn.Module): + >>> def forward(input): + >>> return input + 1 + >>> + >>> module_cls = MyModule + + args (Sequence, optional): args to be passed to ``module_cls``. + kwargs (Dict, optional): kwargs to be passed to ``module_cls``. + + Returns: + A remote module instance which wraps the :class:`~nn.Module` created by the + user-provided ``module_cls``, it has a blocking ``forward`` method and an + asynchronous ``forward_async`` method that returns a future of the ``forward`` call + on the user-provided module on the remote side. + + Example:: + Run the following code in two different processes: + + >>> # xdoctest: +SKIP("distributed") + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> from torch import nn, Tensor + >>> from torch.distributed.nn.api.remote_module import RemoteModule + >>> + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> remote_linear_module = RemoteModule( + >>> "worker1/cpu", nn.Linear, args=(20, 30), + >>> ) + >>> input = torch.randn(128, 20) + >>> ret_fut = remote_linear_module.forward_async(input) + >>> ret = ret_fut.wait() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + Furthermore, a more practical example that is combined with + `DistributedDataParallel `__ (DDP) + can be found in this `tutorial `__. + """ + + def __init__( + self, + remote_device: str, + module_cls: type[nn.Module], + args: tuple | None = None, + kwargs: dict[str, Any] | None = None, + ): + super().__init__(remote_device, module_cls, args, kwargs) + + +def _remote_module_receiver( + *remote_module_pickled_attrs, +): + """Deserializes a RemoteModule.""" + serialized_remote_module = _SerializedRemoteModule._make( + remote_module_pickled_attrs + ) + m = object.__new__(RemoteModule) + m.__dict__.update(serialized_remote_module._asdict()) + + # Unpickling the attribute `module_rref` must invoke RRef's `_deserialize()` method. + # pyrefly: ignore [missing-attribute] + m.module_rref = rpc.PyRRef._deserialize(m.module_rref) + + # Install generated methods when unpickled. + # pyrefly: ignore [missing-attribute] + for method in m.generated_methods: + method_name = method.__name__ + method = torch.jit.export(method) + setattr(m, method_name, types.MethodType(method, m)) + + return m + + +def _remote_module_reducer(remote_module): + """Serialize a RemoteModule.""" + pickled_attrs = {} + for k, v in remote_module.__dict__.items(): + # Pickling the attribute `module_rref` must invoke RRef's `_serialize()` method. + if k == "module_rref": + pickled_attrs[k] = v._serialize() + elif k in _REMOTE_MODULE_PICKLED_ATTRIBUTES: + pickled_attrs[k] = v + # Check if unpickled attributes are all in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING. + elif k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING: + print( + f"The new attribute ``{k}`` of RemoteModule is ignored during RPC pickling. " + "To pickle this attribute, please add it to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES``. " + "Otherwise, please explicitly add it to ``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``.", + file=sys.stderr, + ) + + return ( + _remote_module_receiver, + tuple(pickled_attrs.values()), + ) + + +def _recursive_script_module_receiver( + recursive_script_module_serialized, +): + """Deserializes a RecursiveScriptModule that does not contain a script RemoteModule.""" + f = io.BytesIO(recursive_script_module_serialized) + m = torch.jit.load(f) + return m + + +def _recursive_script_module_reducer(recursive_script_module): + """Serialize a RecursiveScriptModule that does not contain a script RemoteModule, and raises an error otherwise.""" + if hasattr(recursive_script_module._c, "module_rref"): + raise RuntimeError( + "Passing a script RemoteModule over RPC is not supported. Please create a RemoteModule in the sender, " + "send the `module_rref` to the receiver, and create a new instance on the receiver end by passing this `module_rref`." + ) + + f = io.BytesIO() + torch.jit.save(recursive_script_module, f) + return (_recursive_script_module_receiver, (f.getvalue(),)) + + +_internal_rpc_pickler._register_reducer(RemoteModule, _remote_module_reducer) +_internal_rpc_pickler._register_reducer( + torch.jit.RecursiveScriptModule, _recursive_script_module_reducer +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/functional.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..287775be924a399aff01fcda66f6ebc838c62873 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/functional.py @@ -0,0 +1,469 @@ +# mypy: allow-untyped-defs +import torch +import torch.distributed as dist +from torch.autograd import Function + +# The two imports below are not always available depending on the +# USE_DISTRIBUTED compile flag. Make sure they raise import error +# if we're trying to use them. +from torch.distributed import group, ReduceOp + + +def broadcast(tensor, src, group=group.WORLD): + """ + Broadcasts the tensor to the whole group. + + ``tensor`` must have the same number of elements in all processes + participating in the collective. + + Arguments: + tensor (Tensor): Data to be sent if ``src`` is the rank of current + process. + src (int): Source rank. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Received tensor from the broadcast op. + + """ + return _Broadcast.apply(src, group, tensor) + + +def gather(tensor, dst=0, group=group.WORLD): + """ + Gathers a list of tensors in a single process. + + Arguments: + tensor (Tensor): Input tensor. + dst (int, optional): Destination rank (default is 0). + group (ProcessGroup, optional): The process group to work on. + + Returns: + tuple[Tensor]: List of appropriately-sized tensors with the gathered data. + """ + return _Gather.apply(dst, group, tensor) + + +def scatter(tensors, src=0, group=group.WORLD): + """ + Scatters a list of tensors to all processes in a group. + + Each process will receive exactly one tensor and store its data in the + ``tensor`` argument. + + Arguments: + tensors (list[Tensor]): List of tensors to scatter on the source rank. + Receivers must pass ``None`. + src (int, optional): Source rank (default is 0). + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output tensor from the scatter operation. + + """ + return _Scatter.apply(src, group, *tensors) + + +def reduce(tensor, dst, op=ReduceOp.SUM, group=group.WORLD): + """ + Reduces the tensor data across all machines. + + Only the process with rank ``dst`` is going to receive the final result. + + Arguments: + tensor (Tensor): Input of the collective. + dst (int): Destination rank. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output of the collective. + + """ + return _Reduce.apply(dst, op, group, tensor) + + +def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=group.WORLD): + """ + Reduces, then scatters a list of tensors to all processes in a group. + + Arguments: + output (Tensor): Output tensor. + input_list (list[Tensor]): List of tensors to reduce and scatter. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output of the collective. + + """ + return _Reduce_Scatter.apply(op, group, output, *input_list) + + +def all_gather(tensor, group=group.WORLD): + """ + Gathers tensors from the whole group in a list. + + Arguments: + tensor (Tensor): Tensor to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. + + Returns: + tuple([Tensor]): Output of the collective. + + """ + return _AllGather.apply(group, tensor) + + +def _all_gather_base(output_tensor, input_tensor, group=group.WORLD): + """ + Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor. + + Args: + output_tensor (Tensor): Output tensor. It should contain + correctly-sized tensors to be used for output of the collective. + input_tensor (Tensor): Tensor to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Examples: + >>> # All tensors below are of torch.int64 dtype. + >>> # We have 2 process groups, 2 ranks. + >>> # xdoctest: +SKIP("incorrect want text") + >>> output_tensor = torch.zeros(2, dtype=torch.int64) + >>> output_tensor + [tensor([0, 0])] # Rank 0 and 1 + >>> tensor = torch.arange(1, dtype=torch.int64) + 1 + rank + >>> tensor + tensor([1]) # Rank 0 + tensor([2]) # Rank 1 + >>> dist.all_gather_base(output_tensor, tensor) + >>> output_tensor + tensor([1,2]) # Rank 0 + tensor([1,2]) # Rank 1 + + .. warning:: + `_all_gather_base` is experimental and subject to change. + It is the caller's responsibility to ensure the output_tensor + is correctly sized. + + """ + return _AllGatherBase.apply(output_tensor, input_tensor, group) + + +def all_to_all(output_tensor_list, input_tensor_list, group=group.WORLD): + """ + Each process scatters list of input tensors to all processes in a group and return gathered list of tensors in output list. + + Arguments: + output_tensor_list (list[Tensor]): list of tensors to gather one per rank. + input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. + group (ProcessGroup, optional): The process group to work on. + + Returns: + tuple([Tensor]): Output of the collective. + + """ + return _AlltoAll.apply(group, output_tensor_list, *input_tensor_list) + + +def all_to_all_single( + output, + input, + output_split_sizes=None, + input_split_sizes=None, + group=group.WORLD, +): + """ + Each process splits input tensor and then scatters the split list to all processes in a group. + + Then concatenate the received tensors from all the processes in the group and return single output tensor. + + Arguments: + output (Tensor): Gathered concatenated output tensor. + input (Tensor): Input tensor to scatter. + output_split_sizes: (list[Int], optional): Output split sizes for dim 0 + if specified None or empty, dim 0 of ``output`` tensor must divide + equally by ``world_size``. + input_split_sizes: (list[Int], optional): Input split sizes for dim 0 + if specified None or empty, dim 0 of ``input`` tensor must divide + equally by ``world_size``. + + Returns: + Tensor: Output of the collective. + + """ + return _AlltoAllSingle.apply( + group, output, output_split_sizes, input_split_sizes, input + ) + + +def all_reduce(tensor, op=ReduceOp.SUM, group=group.WORLD): + """ + Reduces the tensor data across all machines in such a way that all get the final result. + + After the call the returned tensor is going to be bitwise + identical in all processes. + + Arguments: + tensor (Tensor): Input of the collective. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output of the collective + + """ + return _AllReduce.apply(op, group, tensor) + + +class _Broadcast(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, src, group, tensor): + ctx.src = src + ctx.group = group + ctx.rank = dist.get_rank(group=group) + # torch.distributed makes all the calls in place + # we allocate new tensors to avoid this + tensor = tensor.clone() + dist.broadcast(tensor, src, group=group) + return tensor + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + gx = _Reduce.apply(ctx.src, ReduceOp.SUM, ctx.group, grad_output) + if ctx.src != ctx.rank: + gx.zero_() + return (None, None, gx) + + +class _Gather(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, dst, group, tensor): + ctx.dst = dst + ctx.group = group + # Need to create a list of tensors here to do the + # aggregation, get it from the group size + # tensor should be correctly sized for the method + # gathering + tensor_list = [ + torch.zeros_like(tensor) for i in range(dist.get_world_size(group=group)) + ] + + tensor = tensor.contiguous() + if dist.get_rank(group=group) == dst: + dist.gather(tensor, tensor_list, dst, group=group) + else: + dist.gather(tensor, None, dst, group=group) + return tuple(tensor_list) + + @staticmethod + def backward(ctx, *grad_outputs): + return (None, None) + (_Scatter.apply(ctx.dst, ctx.group, *grad_outputs),) + + +class _Scatter(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, src, group, *tensors): + ctx.src = src + ctx.group = group + assert all(t.size() == tensors[0].size() for t in tensors) + output = torch.zeros_like(tensors[0]) + if dist.get_rank(group=group) == src: + dist.scatter(output, list(tensors), src, group=group) + else: + dist.scatter(output, None, src, group=group) + return output + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + return (None, None) + _Gather.apply(ctx.src, ctx.group, grad_output) + + +class _Reduce(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, src, op, group, tensor): + ctx.src = src + ctx.group = group + tensor = tensor.clone() + dist.reduce(tensor, src, op=op, group=group) + return tensor + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),) + + +class _Reduce_Scatter(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, op, group, tensor, *input_tensor_list): + ctx.group = group + # Need contiguous tensors for collectives. + tensor = tensor.contiguous() + input_tensor_list = tuple(t.contiguous() for t in input_tensor_list) + dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group) + return tensor + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + return (None, None, None) + _AllGather.apply(ctx.group, grad_output) + + +class _AllGather(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, group, tensor): + # Need contiguous tensors for collectives. + tensor = tensor.contiguous() + + ctx.group = group + out_tensor_list = [ + torch.empty_like(tensor) for _ in range(dist.get_world_size(group=group)) + ] + + dist.all_gather(out_tensor_list, tensor, group=group) + return tuple(out_tensor_list) + + @staticmethod + def backward(ctx, *grad_outputs): + if dist.get_backend(group=ctx.group) is dist.Backend.NCCL: + rank = dist.get_rank(group=ctx.group) + gx = torch.empty_like(grad_outputs[rank]) + gx = _Reduce_Scatter.apply(ReduceOp.SUM, ctx.group, gx, *grad_outputs) + else: + # As many backends doesn't support ReduceScatter, we use AlltoAll with .sum() + # to emulate the ReduceScatter behavior + tensor_list = [torch.empty_like(tensor) for tensor in grad_outputs] + gxs = _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs) + gx = torch.sum(torch.stack(gxs), dim=0) + return (None, gx) + + +class _AllGatherBase(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, output_tensor, input_tensor, group): + ctx.group = group + dist._all_gather_base(output_tensor, input_tensor.contiguous(), group=group) + return output_tensor + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + if dist.get_backend(group=ctx.group) is dist.Backend.NCCL: + world_size = dist.get_world_size(group=ctx.group) + out_size = list(grad_output.size()) + if out_size[0] % world_size != 0: + raise RuntimeError( + f"Tensor with dimensions: {out_size} does " + f"not have first dimension divisible by world_size: {world_size}" + ) + out_size[0] = out_size[0] // dist.get_world_size(group=ctx.group) + gx = torch.empty( + out_size, device=grad_output.device, dtype=grad_output.dtype + ) + dist._reduce_scatter_base(gx, grad_output, ReduceOp.SUM, ctx.group) + else: + raise RuntimeError("Backend not supported!") + return (None, gx, None) + + +class _AlltoAll(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, group, out_tensor_list, *tensors): + ctx.group = group + ctx.input_tensor_size_list = [ + tensors[i].size() for i in range(dist.get_world_size(group=group)) + ] + my_rank = dist.get_rank(group=group) + tensors = tuple(t.contiguous() for t in tensors) + # Implement it on means of scatter/gather, send/recv async operations have issues + if dist.get_backend(group=group) is dist.Backend.GLOO: + for i in range(dist.get_world_size(group=group)): + to_send = None + if i == my_rank: + to_send = list(tensors) + dist.scatter(out_tensor_list[i], to_send, i, group=group) + else: + dist.all_to_all( + out_tensor_list, + list(tensors), + group=group, + ) + return tuple(out_tensor_list) + + @staticmethod + def backward(ctx, *grad_outputs): + tensor_list = [ + torch.empty( + size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype + ) + for size in ctx.input_tensor_size_list + ] + return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs) + + +class _AlltoAllSingle(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, group, output, output_split_sizes, input_split_sizes, input): + ctx.group = group + ctx.input_size = input.size() + ctx.output_split_sizes = input_split_sizes + ctx.input_split_sizes = output_split_sizes + dist.all_to_all_single( + output, + input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + ) + return output + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + tensor = torch.empty( + ctx.input_size, device=grad_output.device, dtype=grad_output.dtype + ) + return (None, None, None, None) + ( + _AlltoAllSingle.apply( + ctx.group, + tensor, + ctx.output_split_sizes, + ctx.input_split_sizes, + grad_output.contiguous(), + ), + ) + + +class _AllReduce(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, op, group, tensor): + ctx.group = group + ctx.op = op + tensor = tensor.clone(memory_format=torch.contiguous_format) + dist.all_reduce(tensor, op=op, group=group) + return tensor + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ffd74dd8f0e0c0b25caa390e2ba609d111ec13d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/__pycache__/instantiator.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/__pycache__/instantiator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f07570776028cb4c55313307546ec85b4cad4c71 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/__pycache__/instantiator.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/instantiator.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/instantiator.py new file mode 100644 index 0000000000000000000000000000000000000000..a6dee7e61ef5731f35915f704d2ffaf0444d168b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/instantiator.py @@ -0,0 +1,141 @@ +#!/usr/bin/python3 +# mypy: allow-untyped-defs +import importlib.abc +import importlib.util +import sys + +import torch +from torch.distributed.nn.jit.templates.remote_module_template import ( + get_remote_module_template, +) + + +_FILE_PREFIX = "_remote_module_" + + +def get_arg_return_types_from_interface(module_interface): + assert getattr(module_interface, "__torch_script_interface__", False), ( + "Expect a TorchScript class interface decorated by @torch.jit.interface." + ) + qualified_name = torch._jit_internal._qualified_name(module_interface) + cu = torch.jit._state._python_cu + module_interface_c = cu.get_interface(qualified_name) + assert "forward" in module_interface_c.getMethodNames(), ( + f"Expect forward in interface methods, while it has {module_interface_c.getMethodNames()}" + ) + method_schema = module_interface_c.getMethod("forward") + + arg_str_list = [] + arg_type_str_list = [] + assert method_schema is not None + for argument in method_schema.arguments: + arg_str_list.append(argument.name) + + if argument.has_default_value(): + default_value_str = f" = {argument.default_value}" + else: + default_value_str = "" + arg_type_str = f"{argument.name}: {argument.type}{default_value_str}" + arg_type_str_list.append(arg_type_str) + + arg_str_list = arg_str_list[1:] # Remove "self". + args_str = ", ".join(arg_str_list) + + arg_type_str_list = arg_type_str_list[1:] # Remove "self". + arg_types_str = ", ".join(arg_type_str_list) + + assert len(method_schema.returns) == 1 + argument = method_schema.returns[0] + return_type_str = str(argument.type) + + return args_str, arg_types_str, return_type_str + + +class _StringLoader(importlib.abc.SourceLoader): + def __init__(self, data): + self.data = data + + def get_source(self, fullname): + return self.data + + def get_data(self, path): + return self.data.encode("utf-8") + + def get_filename(self, fullname): + return fullname + + +def _do_instantiate_remote_module_template( + generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda +): + if generated_module_name in sys.modules: + return sys.modules[generated_module_name] + + loader = _StringLoader( + get_remote_module_template(enable_moving_cpu_tensors_to_cuda).format(**str_dict) + ) + spec = importlib.util.spec_from_loader( + generated_module_name, loader, origin="torch-git" + ) + assert spec is not None + module = importlib.util.module_from_spec(spec) + sys.modules[generated_module_name] = module + loader.exec_module(module) + return module + + +def instantiate_scriptable_remote_module_template( + module_interface_cls, enable_moving_cpu_tensors_to_cuda=True +): + if not getattr(module_interface_cls, "__torch_script_interface__", False): + raise ValueError( + f"module_interface_cls {module_interface_cls} must be a type object decorated by " + "@torch.jit.interface" + ) + + # Generate the template instance name. + module_interface_cls_name = torch._jit_internal._qualified_name( + module_interface_cls + ).replace(".", "_") + generated_module_name = f"{_FILE_PREFIX}{module_interface_cls_name}" + + # Generate type annotation strs. + assign_module_interface_cls_str = ( + f"from {module_interface_cls.__module__} import " + f"{module_interface_cls.__name__} as module_interface_cls" + ) + args_str, arg_types_str, return_type_str = get_arg_return_types_from_interface( + module_interface_cls + ) + kwargs_str = "" + arrow_and_return_type_str = f" -> {return_type_str}" + arrow_and_future_return_type_str = f" -> Future[{return_type_str}]" + + str_dict = dict( + assign_module_interface_cls=assign_module_interface_cls_str, + arg_types=arg_types_str, + arrow_and_return_type=arrow_and_return_type_str, + arrow_and_future_return_type=arrow_and_future_return_type_str, + args=args_str, + kwargs=kwargs_str, + jit_script_decorator="@torch.jit.script", + ) + return _do_instantiate_remote_module_template( + generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda + ) + + +def instantiate_non_scriptable_remote_module_template(): + generated_module_name = f"{_FILE_PREFIX}non_scriptable" + str_dict = dict( + assign_module_interface_cls="module_interface_cls = None", + args="*args", + kwargs="**kwargs", + arg_types="*args, **kwargs", + arrow_and_return_type="", + arrow_and_future_return_type="", + jit_script_decorator="", + ) + # For a non-scriptable template, always enable moving CPU tensors to a cuda device, + # because there is no syntax limitation on the extra handling caused by the script. + return _do_instantiate_remote_module_template(generated_module_name, str_dict, True) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/templates/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/templates/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/templates/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/templates/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccb4be4c1fd938b67a1d8ef0e35a879e967fe328 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/templates/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/templates/__pycache__/remote_module_template.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/templates/__pycache__/remote_module_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ced230d158d415c34c2fe48090d1fc9866f435fe Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/templates/__pycache__/remote_module_template.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/templates/remote_module_template.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/templates/remote_module_template.py new file mode 100644 index 0000000000000000000000000000000000000000..07b055774b36af4835e308c8a1f85afd0ab35f0f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/nn/jit/templates/remote_module_template.py @@ -0,0 +1,108 @@ +#!/usr/bin/python3 +# mypy: allow-untyped-defs + + +def get_remote_module_template(enable_moving_cpu_tensors_to_cuda: bool): + return _TEMPLATE_PREFIX + ( + _REMOTE_FORWARD_TEMPLATE_ENABLE_MOVING_CPU_TENSORS_TO_CUDA + if enable_moving_cpu_tensors_to_cuda + else _REMOTE_FORWARD_TEMPLATE + ) + + +_TEMPLATE_PREFIX = """from typing import * + +import torch +import torch.distributed.rpc as rpc +from torch import Tensor +from torch._jit_internal import Future +from torch.distributed.rpc import RRef +from typing import Tuple # pyre-ignore: unused import + + +{assign_module_interface_cls} + + +def forward_async(self, {arg_types}){arrow_and_future_return_type}: + args = (self.module_rref, self.device, self.is_device_map_set, {args}) + kwargs = {{{kwargs}}} + return rpc.rpc_async( + self.module_rref.owner(), + _remote_forward, + args, + kwargs, + ) + + +def forward(self, {arg_types}){arrow_and_return_type}: + args = (self.module_rref, self.device, self.is_device_map_set, {args}) + kwargs = {{{kwargs}}} + ret_fut = rpc.rpc_async( + self.module_rref.owner(), + _remote_forward, + args, + kwargs, + ) + return ret_fut.wait() + + +_generated_methods = [ + forward_async, + forward, +] + + +{jit_script_decorator} +""" + +# This template may cause typing error (the mismatch between ``Tuple[()]`` and ``Tuple[Any]``) +# even if the code is only used for instantiation but not execution. +# Therefore, only include handling moving CPU tensors to a cuda device if necessary. +# TODO: Merge these two templates together in the future once TorchScript syntax is improved. +_REMOTE_FORWARD_TEMPLATE_ENABLE_MOVING_CPU_TENSORS_TO_CUDA = """ +def _remote_forward( + module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, {arg_types}){arrow_and_return_type}: + module = module_rref.local_value() + device = torch.device(device) + + if device.type != "cuda": + return module.forward({args}, {kwargs}) + + # If the module is on a cuda device, + # move any CPU tensor in args or kwargs to the same cuda device. + # Since torch script does not support generator expression, + # have to use concatenation instead of + # ``tuple(i.to(device) if isinstance(i, Tensor) else i for i in *args)``. + args = ({args},) + out_args: Tuple[()] = () + for arg in args: + arg = (arg.to(device),) if isinstance(arg, Tensor) else (arg,) + out_args = out_args + arg + + kwargs = {{{kwargs}}} + for k, v in kwargs.items(): + if isinstance(v, Tensor): + kwargs[k] = kwargs[k].to(device) + + if is_device_map_set: + return module.forward(*out_args, {kwargs}) + + # If the device map is empty, then only CPU tensors are allowed to send over wire, + # so have to move any GPU tensor to CPU in the output. + # Since torch script does not support generator expression, + # have to use concatenation instead of + # ``tuple(i.cpu() if isinstance(i, Tensor) else i for i in module.forward(*out_args, {kwargs}))``. + ret: Tuple[()] = () + for i in module.forward(*out_args, {kwargs}): + i = (i.cpu(),) if isinstance(i, Tensor) else (i,) + ret = ret + i + return ret +""" + +_REMOTE_FORWARD_TEMPLATE = """ +def _remote_forward( + module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, {arg_types}){arrow_and_return_type}: + module = module_rref.local_value() + + return module.forward({args}, {kwargs}) +""" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..faac68bb632934ba730ba7c5ce3cf7fe934a58cf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__init__.py @@ -0,0 +1,44 @@ +""" +:mod:`torch.distributed.optim` exposes DistributedOptimizer, which takes a list +of remote parameters (:class:`~torch.distributed.rpc.RRef`) and runs the +optimizer locally on the workers where the parameters live. The distributed +optimizer can use any of the local optimizer :ref:`optimizer-algorithms` to +apply the gradients on each worker. +""" + +import warnings + +import torch +from torch import optim + +from .apply_optimizer_in_backward import ( + _apply_optimizer_in_backward, + _get_in_backward_optimizers, +) +from .functional_adadelta import _FunctionalAdadelta +from .functional_adagrad import _FunctionalAdagrad +from .functional_adam import _FunctionalAdam +from .functional_adamax import _FunctionalAdamax +from .functional_adamw import _FunctionalAdamW +from .functional_rmsprop import _FunctionalRMSprop +from .functional_rprop import _FunctionalRprop +from .functional_sgd import _FunctionalSGD +from .named_optimizer import _NamedOptimizer +from .utils import as_functional_optim + + +# DistributedOptimizer imports torch.distributed.rpc names, so gate availability +# based on RPC being available. +if hasattr(torch._C, "_rpc_init"): + from .optimizer import DistributedOptimizer + +from .post_localSGD_optimizer import PostLocalSGDOptimizer +from .zero_redundancy_optimizer import ZeroRedundancyOptimizer + + +__all__ = [ + "as_functional_optim", + "DistributedOptimizer", + "PostLocalSGDOptimizer", + "ZeroRedundancyOptimizer", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b978513c3d1e20177fc8f4e23b5bc5dd3be30b94 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/_deprecation_warning.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/_deprecation_warning.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21082a6f6e5a0398bd82f056f4743e5340c22d88 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/_deprecation_warning.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/apply_optimizer_in_backward.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/apply_optimizer_in_backward.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c781313b0a07e3e2ef01c881955ab5d003439c8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/apply_optimizer_in_backward.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_adadelta.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_adadelta.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9e65a3382781076a6a46228476098adf21013b6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_adadelta.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_adagrad.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_adagrad.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dacb493d1b33fec2aa310160baf56100e6741005 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_adagrad.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_adam.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_adam.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ceaa65c8cd9fdc67e4df59a127942b64b4718d16 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_adam.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_adamax.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_adamax.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f81e552135becdb066824caf1647787c895e627 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_adamax.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_adamw.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_adamw.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddaad9087883457920ad7834b607f22ef47cb907 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_adamw.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_rmsprop.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_rmsprop.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0d77988176635941692515798c9efbb503552fc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_rmsprop.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_rprop.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_rprop.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b59ec3e548acde69250e140838673331a19b4ce Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_rprop.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_sgd.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_sgd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d957fcc1f1acf0e9a2953909442cde98c3cf819 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/functional_sgd.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/named_optimizer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/named_optimizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..078d55ecc5c20081813f31367011cacec19e1587 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/named_optimizer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/optimizer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/optimizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..588667e4d2974046b3692169cfa9f7689b7005a8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/optimizer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/post_localSGD_optimizer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/post_localSGD_optimizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d735a42d3884b1cffa82474aeba3fe1524ad45c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/post_localSGD_optimizer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bee5f6cc03a30fd48f85ede16ad9a199b518617 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/zero_redundancy_optimizer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/zero_redundancy_optimizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35853649c5dd1e42ee998f1734074d0a24af0f80 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/__pycache__/zero_redundancy_optimizer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/_deprecation_warning.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/_deprecation_warning.py new file mode 100644 index 0000000000000000000000000000000000000000..c3434a4cd4f081843295e488c18a67a5c297fcbf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/_deprecation_warning.py @@ -0,0 +1,16 @@ +import warnings + +import torch + + +@torch.jit.ignore # type: ignore[misc] +def _scripted_functional_optimizer_deprecation_warning(stacklevel: int = 0) -> None: + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`TorchScript` support for functional optimizers is deprecated " + "and will be removed in a future PyTorch release. " + "Consider using the `torch.compile` optimizer instead.", + DeprecationWarning, + stacklevel=stacklevel + 2, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/apply_optimizer_in_backward.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/apply_optimizer_in_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..1ff9854793df1aa96a27cb105a1afd1190df942a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/apply_optimizer_in_backward.py @@ -0,0 +1,121 @@ +from collections.abc import Iterable +from typing import Any, no_type_check + +import torch + + +__all__: list[str] = [] + +# WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter +# without changing it's life-time. +# NOTE: Alternative is to add the meta-data as an attribute to the tensor, +# but that will serialize the meta-data if Tensor is serialized. +param_to_optim_hook_handle_map = torch.utils.weak.WeakTensorKeyDictionary() +param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary() + + +@no_type_check +def _apply_optimizer_in_backward( + optimizer_class: type[torch.optim.Optimizer], + params: Iterable[torch.nn.Parameter], + optimizer_kwargs: dict[str, Any], + register_hook: bool = True, +) -> None: + """ + Upon ``backward()``, the optimizer specified for each parameter will fire after + the gradient has been accumulated into the parameter. + + Note - gradients for these parameters will be set to None after ``backward()``. + This means that any other optimizer not specified via `_apply_optimizer_in_backward` + over this parameter will be a no-op. + + Args: + optimizer_class: (Type[torch.optim.Optimizer]): Optimizer to apply to parameter + params: (Iterator[nn.Parameter]): parameters to apply optimizer state to + optimizer_kwargs: (Dict[str, Any]): kwargs to pass to optimizer constructor + register_hook: (bool): whether to register a hook that runs the optimizer + after gradient for this parameter is accumulated. This is the default + way that optimizer in backward is implemented, but specific use cases + (such as DDP) may wish to override this to implement custom behavior. + (Default = True) + + Example:: + params_generator = model.parameters() + param_1 = next(params_generator) + remainder_params = list(params_generator) + + apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": 0.02}) + apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": 0.04}) + + model(...).sum().backward() # after backward, parameters will already + # have their registered optimizer(s) applied. + + """ + torch._C._log_api_usage_once("torch.distributed.optim.apply_optimizer_in_backward") + + @no_type_check + def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None: + # view_as creates a node in autograd graph that allows us access to the + # parameter's AccumulateGrad autograd function object. We register a + # hook on this object to fire the optimizer when the gradient for + # this parameter is ready (has been accumulated into .grad field) + + # Don't create a new acc_grad if we already have one + # i.e. for shared parameters or attaching multiple optimizers to a param. + if param not in param_to_acc_grad_map: + param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[ + 0 + ][0] + + optimizer = optimizer_class([param], **optimizer_kwargs) + + if not hasattr(param, "_in_backward_optimizers"): + param._in_backward_optimizers = [] # type: ignore[attr-defined] + # TODO: Remove these attributes once we have a better way of accessing + # optimizer classes and kwargs for a parameter. + param._optimizer_classes = [] # type: ignore[attr-defined] + param._optimizer_kwargs = [] # type: ignore[attr-defined] + + param._in_backward_optimizers.append(optimizer) # type: ignore[attr-defined] + param._optimizer_classes.append(optimizer_class) # type: ignore[attr-defined] + param._optimizer_kwargs.append(optimizer_kwargs) # type: ignore[attr-defined] + + if not register_hook: + return + + def optimizer_hook(*_unused) -> None: + for opt in param._in_backward_optimizers: # type: ignore[attr-defined] + opt.step() + + param.grad = None + + handle = param_to_acc_grad_map[param].register_hook(optimizer_hook) # type: ignore[attr-defined] + if param not in param_to_optim_hook_handle_map: + param_to_optim_hook_handle_map[param] = [] + param_to_optim_hook_handle_map[param].append(handle) + + for param in params: + _apply_optimizer_in_backward_to_param(param) + + +def _get_in_backward_optimizers(module: torch.nn.Module) -> list[torch.optim.Optimizer]: + """ + Return a list of in-backward optimizers applied to ``module``'s parameters. Note that these + optimizers are not intended to directly have their ``step`` or ``zero_grad`` methods called + by the user and are intended to be used for things like checkpointing. + + Args: + module: (torch.nn.Module): model to retrieve in-backward optimizers for + + Returns: + List[torch.optim.Optimizer]: the in-backward optimizers. + + Example:: + _apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {"lr": 0.01}) + optims = _get_optimizers_in_backward(model) + """ + optims: list[torch.optim.Optimizer] = [] + for param in module.parameters(): + optims.extend(getattr(param, "_in_backward_optimizers", [])) + + return optims diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_adadelta.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_adadelta.py new file mode 100644 index 0000000000000000000000000000000000000000..e8455c5ef5a41613dc15140b6c562ceb3134ca4e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_adadelta.py @@ -0,0 +1,110 @@ +# mypy: allow-untyped-defs + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional Adadelta Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalAdadelta: + def __init__( + self, + params: list[Tensor], + lr: float = 1.0, + rho: float = 0.9, + eps: float = 1e-6, + weight_decay: float = 0.0, + foreach: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + "rho": rho, + "eps": eps, + "weight_decay": weight_decay, + } + self.foreach = foreach + self.maximize = maximize + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + def step(self, gradients: list[Tensor | None]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + square_avgs = [] + acc_deltas = [] + state_steps = [] + lr = self.defaults["lr"] + rho = self.defaults["rho"] + eps = self.defaults["eps"] + weight_decay = self.defaults["weight_decay"] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + has_complex = False + for param, gradient in zip(params, gradients): + if gradient is not None: + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + state["square_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + state["acc_delta"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + square_avgs.append(state["square_avg"]) + acc_deltas.append(state["acc_delta"]) + state_steps.append(state["step"]) + + with torch.no_grad(): + F.adadelta( + params_with_grad, + grads, + square_avgs, + acc_deltas, + state_steps, + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_adagrad.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_adagrad.py new file mode 100644 index 0000000000000000000000000000000000000000..3da4e29b3f0154ab58206c835f80a24ae208a05c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_adagrad.py @@ -0,0 +1,114 @@ +# mypy: allow-untyped-defs + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional Adagrad Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly let the user pass gradients to the `step` function +# this is so that we could separate the gradients and parameters +# and allow multithreaded trainer to update the parameters +# without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalAdagrad: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-2, + lr_decay: float = 0.0, + weight_decay: float = 0.0, + initial_accumulator_value: float = 0.0, + warmup_lr_multiplier: float = 1.0, + warmup_num_iters: float = 0.0, + eps: float = 1e-10, + coalesce_grad: bool = True, + foreach: bool = False, + fused: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + "lr_decay": lr_decay, + "eps": eps, + "weight_decay": weight_decay, + "initial_accumulator_value": initial_accumulator_value, + "warmup_lr_multiplier": warmup_lr_multiplier, + "warmup_num_iters": warmup_num_iters, + } + self.coalesce_grad = coalesce_grad + self.foreach = foreach + self.fused = fused + self.maximize = maximize + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + # TODO: no union or any types in TorchScript, make step a scalar tensor instead + # This is also needed by if we want to share_memory on the step across processes + for p in self.param_group["params"]: + self.state[p] = { + "sum": torch.full_like(p.data, initial_accumulator_value), + "step": torch.tensor(0.0), + } + + def step(self, gradients: list[Tensor | None]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + state_sums = [] + state_steps: list[Tensor] = [] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_sparse_grad, has_complex = False, False + for param, gradient in zip(self.param_group["params"], gradients): + if gradient is not None: + has_sparse_grad |= gradient.is_sparse + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + state = self.state[param] + state_sums.append(state["sum"]) + state_steps.append(state["step"]) + + with torch.no_grad(): + F.adagrad( + params, + grads, + state_sums, + state_steps, + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + lr_decay=self.defaults["lr_decay"], + eps=self.defaults["eps"], + has_sparse_grad=has_sparse_grad, + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_adam.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..1763edd14c9da1c19081fcc1334e267c889472c1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_adam.py @@ -0,0 +1,201 @@ +# mypy: allow-untyped-defs + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional Adam Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalAdam: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + amsgrad: bool = False, + maximize: bool = False, + foreach: bool = False, + fused: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + self.defaults = { + "lr": lr, + "eps": eps, + "beta1": betas[0], + "beta2": betas[1], + "weight_decay": weight_decay, + } + self.amsgrad = amsgrad + self.maximize = maximize + self.foreach = foreach + self.fused = fused + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + def step_param(self, param: Tensor, grad: Tensor | None): + """ + Similar to step, but operates on a single parameter and optionally a + gradient tensor. + """ + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps: list[Tensor] = [] + has_complex = torch.is_complex(param) + if grad is not None: + params_with_grad.append(param) + grads.append(grad) + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + state["exp_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + state["exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + if self.amsgrad: + state["max_exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if self.amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + state_steps.append(state["step"]) + with torch.no_grad(): + F.adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=self.amsgrad, + has_complex=has_complex, + maximize=self.maximize, + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + eps=self.defaults["eps"], + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) + + def step(self, gradients: list[Tensor | None]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps: list[Tensor] = [] + has_complex = False + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + for param, gradient in zip(self.param_group["params"], gradients): + if gradient is not None: + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + if self.amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if self.amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + state_steps.append(state["step"]) + + with torch.no_grad(): + F.adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=self.amsgrad, + has_complex=has_complex, + maximize=self.maximize, + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + eps=self.defaults["eps"], + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_adamax.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_adamax.py new file mode 100644 index 0000000000000000000000000000000000000000..595a5668a78fc0f8451fa9e2a81c03d049bb4b82 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_adamax.py @@ -0,0 +1,122 @@ +# mypy: allow-untyped-defs + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional Adamax Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalAdamax: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + foreach: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + self.defaults = { + "lr": lr, + "eps": eps, + "beta1": betas[0], + "beta2": betas[1], + "weight_decay": weight_decay, + } + self.foreach = foreach + self.maximize = maximize + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + def step(self, gradients: list[Tensor | None]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_infs = [] + state_steps: list[Tensor] = [] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_complex = False + for param, gradient in zip(self.param_group["params"], gradients): + if gradient is not None: + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_inf"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + + exp_avgs.append(state["exp_avg"]) + exp_infs.append(state["exp_inf"]) + state_steps.append(state["step"]) + + with torch.no_grad(): + F.adamax( + params_with_grad, + grads, + exp_avgs, + exp_infs, + state_steps, + eps=self.defaults["eps"], + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_adamw.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..d695ce8b473af8fbf1bde28293e576ff69fe6f04 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_adamw.py @@ -0,0 +1,202 @@ +# mypy: allow-untyped-defs + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional AdamW Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalAdamW: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + amsgrad: bool = False, + maximize: bool = False, + foreach: bool = False, + fused: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + self.defaults = { + "lr": lr, + "eps": eps, + "beta1": betas[0], + "beta2": betas[1], + "weight_decay": weight_decay, + } + self.amsgrad = amsgrad + self.maximize = maximize + self.foreach = foreach + self.fused = fused + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + def step_param(self, param: Tensor, grad: Tensor | None): + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps: list[Tensor] = [] + has_complex = torch.is_complex(param) + if grad is not None: + params_with_grad.append(param) + grads.append(grad) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + if self.amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if self.amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + state_steps.append(state["step"]) + with torch.no_grad(): + F.adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=self.amsgrad, + maximize=self.maximize, + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + eps=self.defaults["eps"], + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + has_complex=has_complex, + ) + + def step(self, gradients: list[Tensor | None]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps: list[Tensor] = [] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_complex = False + for param, gradient in zip(self.param_group["params"], gradients): + if gradient is not None: + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + if self.amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if self.amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + state_steps.append(state["step"]) + + with torch.no_grad(): + F.adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=self.amsgrad, + maximize=self.maximize, + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + eps=self.defaults["eps"], + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + has_complex=has_complex, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_rmsprop.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_rmsprop.py new file mode 100644 index 0000000000000000000000000000000000000000..45341b03237b456419ec181ae8b771dec081d3cb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_rmsprop.py @@ -0,0 +1,129 @@ +# mypy: allow-untyped-defs + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional RMSprop Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalRMSprop: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-2, + alpha: float = 0.99, + eps: float = 1e-8, + weight_decay: float = 0.0, + momentum: float = 0.0, + centered: bool = False, + foreach: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + "alpha": alpha, + "eps": eps, + "weight_decay": weight_decay, + "momentum": momentum, + } + self.centered = centered + self.foreach = foreach + self.maximize = maximize + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + def step(self, gradients: list[Tensor | None]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + square_avgs = [] + grad_avgs = [] + momentum_buffer_list = [] + state_steps = [] + lr = self.defaults["lr"] + alpha = self.defaults["alpha"] + eps = self.defaults["eps"] + momentum = self.defaults["momentum"] + weight_decay = self.defaults["weight_decay"] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_complex = False + for param, gradient in zip(params, gradients): + if gradient is not None: + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + state["square_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + if momentum > 0: + state["momentum_buffer"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + if self.centered: + state["grad_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + square_avgs.append(state["square_avg"]) + if momentum > 0: + momentum_buffer_list.append(state["momentum_buffer"]) + if self.centered: + grad_avgs.append(state["grad_avg"]) + + state_steps.append(state["step"]) + + with torch.no_grad(): + F.rmsprop( + params_with_grad, + grads, + square_avgs, + grad_avgs, + momentum_buffer_list, + state_steps, + lr=lr, + alpha=alpha, + eps=eps, + weight_decay=weight_decay, + momentum=momentum, + centered=self.centered, + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_rprop.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_rprop.py new file mode 100644 index 0000000000000000000000000000000000000000..ffc9c510dabca7871d19890c5e52e0f5eeafcd49 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_rprop.py @@ -0,0 +1,106 @@ +# mypy: allow-untyped-defs + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional Rprop Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalRprop: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-2, + etas: tuple[float, float] = (0.5, 1.2), + step_sizes: tuple[float, float] = (1e-6, 50), + foreach: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + } + self.etas = etas + self.step_sizes = step_sizes + self.foreach = foreach + self.maximize = maximize + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + def step(self, gradients: list[Tensor | None]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + prevs = [] + step_sizes = [] + state_steps = [] + lr = self.defaults["lr"] + etaminus, etaplus = self.etas + step_size_min, step_size_max = self.step_sizes + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_complex = False + for param, gradient in zip(params, gradients): + if gradient is not None: + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + state["prev"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + state["step_size"] = torch.full_like(gradient, lr) + + state = self.state[param] + prevs.append(state["prev"]) + step_sizes.append(state["step_size"]) + state_steps.append(state["step"]) + + with torch.no_grad(): + F.rprop( + params_with_grad, + grads, + prevs, + step_sizes, + state_steps, + step_size_min=step_size_min, + step_size_max=step_size_max, + etaminus=etaminus, + etaplus=etaplus, + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_sgd.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..aed92403e6fb62394e2f755fffbe5b7f323200ff --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/functional_sgd.py @@ -0,0 +1,165 @@ +# mypy: allow-untyped-defs + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional SGD Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalSGD: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-2, + momentum: float = 0.0, + dampening: float = 0.0, + weight_decay: float = 0.0, + nesterov: bool = False, + maximize: bool = False, + foreach: bool = False, + fused: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + "momentum": momentum, + "dampening": dampening, + "weight_decay": weight_decay, + } + self.nesterov = nesterov + self.maximize = maximize + self.foreach = foreach + self.fused = fused + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + def step_param(self, param: Tensor, grad: Tensor | None): + """Similar to self.step, but operates on a single parameter and + its gradient. + """ + # TODO: Once step_param interface is robust, refactor step to call + # step param on each param. + weight_decay = self.defaults["weight_decay"] + momentum = self.defaults["momentum"] + dampening = self.defaults["dampening"] + lr = self.defaults["lr"] + params = [param] + momentum_buffer_list: list[Tensor | None] = [] + grads = [] + + has_sparse_grad = False + if grad is not None: + grads.append(grad) + if grad.is_sparse: + has_sparse_grad = True + if param not in self.state: + self.state[param] = {} + state = self.state[param] + if "momentum_buffer" not in state: + momentum_buffer_list.append(None) + else: + momentum_buffer_list.append(state["momentum_buffer"]) + + with torch.no_grad(): + F.sgd( + params, + grads, + momentum_buffer_list, + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=self.nesterov, + maximize=self.maximize, + has_sparse_grad=has_sparse_grad, + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) + # update momentum_buffer in state + state = self.state[param] + momentum_buffer = momentum_buffer_list[0] + if momentum_buffer is not None: + state["momentum_buffer"] = momentum_buffer + + def step(self, gradients: list[Tensor | None]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + momentum_buffer_list: list[Tensor | None] = [] + lr = self.defaults["lr"] + weight_decay = self.defaults["weight_decay"] + momentum = self.defaults["momentum"] + dampening = self.defaults["dampening"] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_sparse_grad = False + for param, gradient in zip(params, gradients): + if gradient is not None: + params_with_grad.append(param) + grads.append(gradient) + if gradient.is_sparse: + has_sparse_grad = True + + if param not in self.state: + self.state[param] = {} + + state = self.state[param] + if "momentum_buffer" not in state: + momentum_buffer_list.append(None) + else: + momentum_buffer_list.append(state["momentum_buffer"]) + + with torch.no_grad(): + F.sgd( + params_with_grad, + grads, + momentum_buffer_list, + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=self.nesterov, + maximize=self.maximize, + has_sparse_grad=has_sparse_grad, + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) + + # update momentum_buffers in state + for i, p in enumerate(params_with_grad): + state = self.state[p] + momentum_buffer = momentum_buffer_list[i] + if momentum_buffer is not None: + state["momentum_buffer"] = momentum_buffer diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/named_optimizer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/named_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a8432e198a083e194a1e48bf8c0af76ffa6b83a1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/named_optimizer.py @@ -0,0 +1,328 @@ +import logging +import warnings +from collections.abc import Callable, Collection, Mapping +from copy import deepcopy +from typing import Any, overload + +import torch +import torch.nn as nn +from torch import optim +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + + +__all__: list[str] = [] + +logger = logging.getLogger(__name__) + + +class _NamedOptimizer(optim.Optimizer): + """ + ``_NamedOptimizer`` takes a dict of parameters and exposes ``state_dict`` by parameter key. + + We replace the original key (number) in an optim to the + fully qualified name (FQN) string. User can initialize the optim as they + initialize a PyTorch optim, the only difference is that they also need to + pass in the FQN of each parameters. + + Args: + named_parameters (Mapping[str, Union[torch.Tensor, ShardedTensor]]): + Mapping from FQN to parameter. + optimizer_class (optim.Optimizer): + The class of optimizer to instantiate. + param_groups (Collection[Mapping[str, Any]]): + `param_groups` to pass to optimizer if specified. + The key of the inner map needs to be FQNs. + Default: None + module (nn.Module): the module whose parameters to updated + by the optimizer. + args: arguments to pass to the optimizer constructor. + kwargs: arguments to pass to the optimizer constructor. + + Example:: + >>> # xdoctest: +SKIP("distributed") + >>> from torch import optim + >>> from torch.distributed.optim import _NamedOptimizer + >>> + >>> # Define the named optimizer. + >>> m = Model(...) + >>> named_optim = _NamedOptimizer(m.named_parameters(), optim.SGD) + >>> # Forward pass + backward pass. + >>> named_optim.step() + >>> ... + >>> # Call state_dict for the named optimizer returns a FQN state_dict. + >>> named_optim.state_dict() + + Warning: This API is still in development and subject to change. + + TODO: Add tutorial for _NamedOptimizer. + TODO: Add documentation in the docstring for the public attributes + like self.param_groups and self.named_parameters. + """ + + def __init__( + self, + named_parameters: Mapping[str, torch.Tensor | ShardedTensor], + optimizer_class: optim.Optimizer, + param_groups: Collection[Mapping[str, Any]] | None = None, + module: nn.Module | None = None, + *args: tuple[Any, ...], + **kwargs: dict[str, Any], + ) -> None: + torch._C._log_api_usage_once("torch.distributed.optim._NamedOptimizer") + self.param_groups: Collection[Mapping[str, Any]] = param_groups # type: ignore[assignment] + self._param_groups_check() + self.named_parameters = dict(named_parameters) + params_for_optimizer = ( + self.named_parameters.values() if param_groups is None else param_groups + ) + self._optimizer = optimizer_class( # type: ignore[operator] + params_for_optimizer, + *args, + **kwargs, + ) + self.module = module + if param_groups is None: + self.ordered_param_keys = list(self.named_parameters.keys()) + else: + warnings.warn( + "Since we pass in param_groups, we will use param_groups to " + "initialize the optimizer, not all parameters of the module.", + stacklevel=2, + ) + param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type] + ordered_param_keys = [] + for group in param_groups: + for param in group["params"]: + if param not in param_to_key: + raise ValueError( + f"Expect param name {param} found in param group but is missing." + ) + ordered_param_keys.append(param_to_key[param]) + self.ordered_param_keys = ordered_param_keys + # Update param_groups from optimizer. + self.param_groups = self._optimizer.param_groups + + def _param_groups_check(self) -> None: + if self.param_groups is not None: + for param_group in self.param_groups: + assert isinstance(param_group, dict), "param group must be a dict" + assert "params" in param_group, "param group must contain key params" + params = param_group["params"] + if isinstance(params, torch.Tensor): + params = [params] + params = list(params) + for param in params: + if not isinstance(param, torch.Tensor): + raise TypeError( + "optimizer can only optimize Tensors, " + "but one of the params is " + torch.typename(param) + ) + param_group["params"] = params + + def state_dict(self) -> dict[str, Any]: + """ + Return the ``state_dict`` of the optimizer. + + Instead of using number to index + parameters, we will use module fully qualified name (FQN) as the key. + """ + state_dict = self._optimizer.state_dict() + param_groups = state_dict["param_groups"] + + ret_state = { + self.ordered_param_keys[st_key]: state_val + for st_key, state_val in state_dict["state"].items() + } + + ret_groups = [] + for group in param_groups: + param_keys = [self.ordered_param_keys[param] for param in group["params"]] + ret_group = {"params": sorted(param_keys)} + for k, v in group.items(): + if k != "params": + ret_group[k] = deepcopy(v) + ret_groups.append(ret_group) + + return self._post_state_dict({"state": ret_state, "param_groups": ret_groups}) + + @overload + def step(self, closure: None = None) -> None: ... + + @overload + def step(self, closure: Callable[[], float]) -> float: ... + + def step(self, closure: Callable[[], float] | None = None) -> float | None: + """ + Perform a single optimization step. + + This will call :meth:`torch.optim.Optimizer.step` on the wrapped + optimizer. + """ + return self._optimizer.step(closure=closure) + + @property + def state(self) -> Mapping[torch.Tensor, Any]: # type: ignore[override] + return self._optimizer.state + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """ + Define the default behavior to load a state_dict for ``_NamedOptimizer``. + + Sample Code + ``` + my_model = MyModule() + optimizer = _NamedOptimizer(my_model.named_parameters(), Adagrad) + ... + + optim_state_dict = optimizer.state_dict() + ... + ... + + optimizer.load_state_dict(optim_state_dict) + ... + ``` + Args: + state_dict (dict[str, Any]) : A ``state_dict`` to load into the optimizer. + Note that this state dict update is performed in place. + + .. note:: PyTorch is using lazy init to initialize the optim states. + So it is possible that there is no optim state when user call + ``load_state_dict`` and for ``_NamedOptimizer`` we make it stricter + that users can only call ``load_state_dict`` after the state is initialized. + By doing this, we can validate the optim ``state_dict`` to be loaded. + """ + new_state_dict = self._optimizer.state_dict() + state_dict = self._pre_load_state_dict(state_dict) + state = state_dict["state"] + new_state = new_state_dict["state"] + if len(new_state) == 0: + raise ValueError( + "Expects the optim to be initialized before load but found not initialized." + ) + + for idx, param_key in enumerate(self.ordered_param_keys): + # When the conditional training is performed, not all parameters are updated in the optim. + if param_key not in state: + continue + if len(state[param_key]) != len(new_state[idx]): + raise ValueError( + f"Expects equal length as {len(new_state[idx])} for parameter {param_key} but found: {len(state[param_key])}" + ) + # Iterate through all optimizer states. + for state_key, state_val in new_state[idx].items(): + if state_key not in state[param_key]: + raise ValueError( + f"Expects state {state_key} for parameter {param_key} but not found." + ) + + src_state_val = state[param_key][state_key] + if isinstance(state_val, ShardedTensor): + assert isinstance(src_state_val, ShardedTensor) + num_shards = len(state_val.local_shards()) + num_new_shards = len(src_state_val.local_shards()) + if num_shards != num_new_shards: + raise ValueError( + f"Expects equal number of shards as {num_new_shards} but found {num_shards} for {param_key}/{state_key}" + ) + for shard, src_shard in zip( + state_val.local_shards(), src_state_val.local_shards() + ): + shard.tensor.detach().copy_(src_shard.tensor) + elif isinstance(state_val, torch.Tensor): + assert isinstance(src_state_val, torch.Tensor) + state_val.detach().copy_(src_state_val) + else: + new_state[idx][state_key] = deepcopy(src_state_val) + + # Load param_groups of state_dict + src_param_groups = state_dict["param_groups"] + new_param_groups = new_state_dict["param_groups"] + + src_group_map = {} + for group in src_param_groups: + param_keys = list(group["params"]) + src_group_map[_gen_param_group_key(param_keys)] = group + new_group_map = {} + for new_group in new_param_groups: + param_keys = [] + for param_key in new_group["params"]: + param_keys.append(self.ordered_param_keys[param_key]) # type: ignore[call-overload] + new_group_map[_gen_param_group_key(param_keys)] = new_group + for group_key, new_group in new_group_map.items(): + # When not all parameters are used in training or receive gradient, aka., not all parameters + # would be in the param_group. Thus we skip the group_key here. + if group_key not in src_group_map: + continue + src_group = src_group_map[group_key] + if len(src_group) != len(new_group): + raise ValueError( + f"Expects equal param_group size as {len(new_group)} for group {group_key} but found {len(src_group)}." + ) + for k in src_group: + if k not in new_group: + raise ValueError( + f"Expects group key {k} to be in group {group_key} in `state_dict` but is missing." + ) + if k != "params": + new_group[k] = deepcopy(src_group[k]) + + self._optimizer.load_state_dict(new_state_dict) + + def add_param_group(self, param_group: Mapping[str, Any]) -> None: + """ + Add a param group to the :class:`_NamedOptimizer` s `param_groups`. + + Warning: This API is still in development and subject to change. + """ + assert isinstance(param_group, dict), "param group must be a dict" + + params = param_group["params"] + if isinstance(params, torch.Tensor): + param_group["params"] = [params] + else: + param_group["params"] = list(params) + + param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type] + for param in param_group["params"]: + if param not in param_to_key: + raise ValueError("some parameters are not in the module") + self.ordered_param_keys.append(param_to_key[param]) + + self._optimizer.add_param_group(param_group) + # Update param_groups from optimizer. + self.param_groups = self._optimizer.param_groups + + def init_state(self) -> None: + """ + Run a dummy optimizer step, which allows to initialize optimizer state because we do lazy init for most optimizers. + + This allows doing in-place loading of optimizer state from a checkpoint. + """ + for param in self.named_parameters.values(): + if param.requires_grad: + t = torch.zeros_like(param) + param.grad = torch.autograd.Variable(t) + # Calling ``step`` will load the initial state for optimizer states. + self.step(closure=None) + + def _pre_load_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]: + # TODO(chienchin): This API should be FSDP agnostic and should support + # general user hooks. + if isinstance(self.module, FSDP): + return FSDP.optim_state_dict_to_load( + self.module, self._optimizer, state_dict, is_named_optimizer=True + ) + return state_dict + + def _post_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]: + # TODO(chienchin): This API should be FSDP agnostic and should support + # general user hooks. + if isinstance(self.module, FSDP): + FSDP.optim_state_dict(self.module, self._optimizer, state_dict) + return state_dict + + +def _gen_param_group_key(param_keys: list[str]) -> str: + """Concatenate all param keys as a unique identifier for one param group.""" + return "/".join(sorted(param_keys)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/optimizer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..f9477aa414b429e4cb4ca8bf1d1fedf9788d4eaa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/optimizer.py @@ -0,0 +1,254 @@ +# mypy: allow-untyped-defs +import logging +from collections import defaultdict +from threading import Lock + +import torch +import torch.distributed.autograd as dist_autograd +import torch.distributed.rpc as rpc +import torch.jit as jit +import torch.nn as nn +from torch import Tensor +from torch.distributed.rpc import RRef + +from .utils import functional_optim_map + + +__all__ = ["DistributedOptimizer"] + +logger = logging.getLogger(__name__) + + +# XXX: we define a _ScriptModuleOptimizer here to explicitly +# compile the FunctionalOptimizer class into TorchScript +# This is because ScriptClass instance still lives in +# python unless you explicitly compile it as an attribute +# in ScriptModule or pass it to a ScriptFunction +# _ScriptLocalOptimizerInterface serves as a common +# interface type for Optimizer ScriptModules. +# +# TODO (wanchaol): remove this once we added TorchScript +# class reference semantics +@jit.interface +class _ScriptLocalOptimizerInterface: + def step(self, autograd_ctx_id: int) -> None: + pass + + +class _ScriptLocalOptimizer(nn.Module): + # TorchScript does not support multithread concurrent compiling. + # request_callback might invoke concurrent compiling, so we + # serialize the compiling with a lock + compile_lock = Lock() + + def __init__(self, optim_cls, local_params_rref, *args, **kwargs): + super().__init__() + self._local_params = [rref.local_value() for rref in local_params_rref] + self.optim = optim_cls(self._local_params, *args, **kwargs) + + @jit.export + def step(self, autograd_ctx_id: int): + all_local_grads = dist_autograd.get_gradients(autograd_ctx_id) + # apply functional optimizer step with a list of gradients + grads: list[Tensor | None] = [ + all_local_grads[p] if p in all_local_grads else None # noqa: SIM401 + for p in self._local_params + ] + + self.optim.step(grads) + + +# TODO (wanchaol): remove/merge this with ScriptLocalOptimizer once +# we have converted all to functional optimizer in distributed.optim +class _LocalOptimizer: + # Ideally we would only need to share a lock for instances of + # _LocalOptimizer that deal with the same parameters. We are + # making a simplifying assumption here that if there is more + # than one instance of _LocalOptimizer per worker, they will + # be optimizing the same parameters (e.g. each data parallel + # trainer will create its own instance of _LocalOptimizer but + # they will all optimize the same parameters on each worker) + global_lock = Lock() + + def __init__(self, optim_cls, local_params_rref, *args, **kwargs): + self._local_params = [rref.local_value() for rref in local_params_rref] + self.optim = optim_cls(self._local_params, *args, **kwargs) + + def step(self, autograd_ctx_id): + all_local_grads = dist_autograd.get_gradients(autograd_ctx_id) + + with _LocalOptimizer.global_lock: + for param, grad in all_local_grads.items(): + param.grad = grad + self.optim.step() + + +def _new_local_optimizer(optim_cls, local_params_rref, *args, **kwargs): + return rpc.RRef(_LocalOptimizer(optim_cls, local_params_rref, *args, **kwargs)) + + +def _local_optimizer_step(local_optim_rref, autograd_ctx_id): + local_optim = local_optim_rref.local_value() + local_optim.step(autograd_ctx_id) + + +# new/step functions combined with _ScriptLocalOptimizer to provide GIL-free optimizer +def _new_script_local_optimizer(optim_cls, local_params_rref, *args, **kwargs): + optim = _ScriptLocalOptimizer(optim_cls, local_params_rref, *args, **kwargs) + + with _ScriptLocalOptimizer.compile_lock: + script_optim = jit.script(optim) + return rpc.RRef(script_optim, _ScriptLocalOptimizerInterface) + + +@jit.script +def _script_local_optimizer_step( + local_optim_rref: RRef[_ScriptLocalOptimizerInterface], autograd_ctx_id: int +) -> None: + local_optim = local_optim_rref.local_value() + local_optim.step(autograd_ctx_id) + + +def _wait_for_all(rpc_futs): + # TODO: improve error propagation + exception = None + results = [] + for fut in rpc_futs: + try: + results.append(fut.wait()) + except Exception as e: + results.append(e) + exception = e + if exception is not None: + raise exception + return results + + +class DistributedOptimizer: + """ + DistributedOptimizer takes remote references to parameters scattered + across workers and applies the given optimizer locally for each parameter. + + This class uses :meth:`~torch.distributed.autograd.get_gradients` in order + to retrieve the gradients for specific parameters. + + Concurrent calls to + :meth:`~torch.distributed.optim.DistributedOptimizer.step`, + either from the same or different clients, will + be serialized on each worker -- as each worker's optimizer can only work + on one set of gradients at a time. However, there is no guarantee that + the full forward-backward-optimizer sequence will execute for one client + at a time. This means that the gradients being applied may not correspond + to the latest forward pass executed on a given worker. Also, there is no + guaranteed ordering across workers. + + `DistributedOptimizer` creates the local optimizer with TorchScript enabled + by default, so that optimizer updates are not blocked by the Python Global + Interpreter Lock (GIL) in the case of multithreaded training (e.g. Distributed + Model Parallel). This feature is currently enabled for most optimizers. You + can also follow `the recipe`__ in PyTorch tutorials to enable TorchScript support + for your own custom optimizers. + + Args: + optimizer_class (optim.Optimizer): the class of optimizer to + instantiate on each worker. + params_rref (list[RRef]): list of RRefs to local or remote parameters + to optimize. + args: arguments to pass to the optimizer constructor on each worker. + kwargs: arguments to pass to the optimizer constructor on each worker. + + Example:: + >>> # xdoctest: +SKIP("distributed") + >>> import torch.distributed.autograd as dist_autograd + >>> import torch.distributed.rpc as rpc + >>> from torch import optim + >>> from torch.distributed.optim import DistributedOptimizer + >>> + >>> with dist_autograd.context() as context_id: + >>> # Forward pass. + >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) + >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) + >>> loss = rref1.to_here() + rref2.to_here() + >>> + >>> # Backward pass. + >>> dist_autograd.backward(context_id, [loss.sum()]) + >>> + >>> # Optimizer. + >>> dist_optim = DistributedOptimizer( + >>> optim.SGD, + >>> [rref1, rref2], + >>> lr=0.05, + >>> ) + >>> dist_optim.step(context_id) + + __ https://github.com/pytorch/tutorials/pull/1465 + """ + + def __init__(self, optimizer_class, params_rref, *args, **kwargs): + torch._C._log_api_usage_once("torch.distributed.optim.DistributedOptimizer") + per_worker_params_rref = defaultdict(list) + for param in params_rref: + per_worker_params_rref[param.owner()].append(param) + + if optimizer_class in functional_optim_map and jit._state._enabled: + optim_ctor = functional_optim_map.get(optimizer_class) + else: + optim_ctor = optimizer_class + self.is_functional_optim = optim_ctor != optimizer_class + + if self.is_functional_optim: + optimizer_new_func = _new_script_local_optimizer + else: + logger.warning( + "Creating the optimizer %s without TorchScript support, " + "this might result in slow computation time in multithreading environment" + "(i.e. Distributed Model Parallel training on CPU) due to the Python's " + "Global Interpreter Lock (GIL). Please file an issue if you need this " + "optimizer in TorchScript. ", + optimizer_class, + ) + optimizer_new_func = _new_local_optimizer + + remote_optim_futs = [] + for worker, param_rrefs in per_worker_params_rref.items(): + remote_optim_rref_fut = rpc.rpc_async( + worker, + optimizer_new_func, + args=(optim_ctor, param_rrefs) + args, + kwargs=kwargs, + ) + remote_optim_futs.append(remote_optim_rref_fut) + + self.remote_optimizers = _wait_for_all(remote_optim_futs) + + def step(self, context_id): + """ + Performs a single optimization step. + + This will call :meth:`torch.optim.Optimizer.step` on each worker + containing parameters to be optimized, and will block until all workers + return. The provided ``context_id`` will be used to retrieve the + corresponding :class:`~torch.distributed.autograd.context` that + contains the gradients that should be applied to the parameters. + + Args: + context_id: the autograd context id for which we should run the + optimizer step. + """ + dist_autograd._is_valid_context(context_id) + + optimizer_step_func = ( + _script_local_optimizer_step + if self.is_functional_optim + else _local_optimizer_step + ) + + rpc_futs = [ + rpc.rpc_async( + optimizer.owner(), + optimizer_step_func, + args=(optimizer, context_id), + ) + for optimizer in self.remote_optimizers + ] + _wait_for_all(rpc_futs) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/post_localSGD_optimizer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/post_localSGD_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..c7b78510ed1a111998a4eda21546b003eedbcce7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/post_localSGD_optimizer.py @@ -0,0 +1,111 @@ +# mypy: allow-untyped-defs +import warnings + +import torch +import torch.distributed.algorithms.model_averaging.averagers as averagers + + +class PostLocalSGDOptimizer(torch.optim.Optimizer): + r""" + Wraps an arbitrary :class:`torch.optim.Optimizer` and runs `post-local SGD `_, + This optimizer runs local optimizer at every step. + After the warm-up stage, it averages parameters periodically after the local optimizer is applied. + + Args: + optim: The local optimizer. + averager: A model averager instance to run post-localSGD algorithm. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> import torch + >>> import torch.distributed as dist + >>> import torch.distributed.algorithms.model_averaging.averagers as averagers + >>> import torch.nn as nn + >>> from torch.distributed.optim import PostLocalSGDOptimizer + >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import ( + >>> PostLocalSGDState, + >>> post_localSGD_hook, + >>> ) + >>> + >>> model = nn.parallel.DistributedDataParallel( + >>> module, device_ids=[rank], output_device=rank + >>> ) + >>> + >>> # Register a post-localSGD communication hook. + >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100) + >>> model.register_comm_hook(state, post_localSGD_hook) + >>> + >>> # Create a post-localSGD optimizer that wraps a local optimizer. + >>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as + >>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``. + >>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01) + >>> opt = PostLocalSGDOptimizer( + >>> optim=local_optim, + >>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100) + >>> ) + >>> + >>> # In the first 100 steps, DDP runs global gradient averaging at every step. + >>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default), + >>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer. + >>> for step in range(0, 200): + >>> opt.zero_grad() + >>> loss = loss_fn(output, labels) + >>> loss.backward() + >>> opt.step() + """ + + def __init__(self, optim: torch.optim.Optimizer, averager: averagers.ModelAverager): + self.optim = optim + self.param_groups = self.optim.param_groups + self.averager = averager + + @property + def state(self): # type: ignore[override] + return self.optim.state + + def __repr__(self): + return self.optim.__repr__() + + def state_dict(self): + r""" + This is the same as :class:`torch.optim.Optimizer` :meth:`state_dict`, + but adds an extra entry to record model averager's step to the checkpoint + to ensure reload does not cause unnecessary warm up again. + """ + optim_state_dict = self.optim.state_dict() + optim_state_dict["step"] = self.averager.step + return optim_state_dict + + def load_state_dict(self, state_dict): + r""" + This is the same as :class:`torch.optim.Optimizer` :meth:`load_state_dict`, + but also restores model averager's step value to the one + saved in the provided ``state_dict``. + + If there is no ``"step"`` entry in ``state_dict``, + it will raise a warning and initialize the model averager's step to 0. + """ + self.optim.load_state_dict(state_dict) + if "step" in state_dict: + self.averager.step = state_dict["step"] + else: + warnings.warn( + "Loaded state dict does not contain a step counter for an averager. " + "Setting step counter to 0.", + stacklevel=2, + ) + self.averager.step = 0 + + def step(self): # type: ignore[override] + r""" + Performs a single optimization step (parameter update). + """ + self.optim.step() + self.averager.average_parameters(params=self.param_groups) + + def zero_grad(self, set_to_none: bool = True): # type: ignore[override] + self.optim.zero_grad(set_to_none=set_to_none) + + def add_param_group(self, param_group): + self.optim.add_param_group(param_group) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c7075edd2e5210f1dc3d50aaa09688a4a4e1d09c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/utils.py @@ -0,0 +1,65 @@ +# mypy: allow-untyped-defs + +from torch import optim + +from .functional_adadelta import _FunctionalAdadelta +from .functional_adagrad import _FunctionalAdagrad +from .functional_adam import _FunctionalAdam +from .functional_adamax import _FunctionalAdamax +from .functional_adamw import _FunctionalAdamW +from .functional_rmsprop import _FunctionalRMSprop +from .functional_rprop import _FunctionalRprop +from .functional_sgd import _FunctionalSGD + + +# dict to map a user passed in optimizer_class to a functional +# optimizer class if we have already defined inside the +# distributed.optim package, this is so that we hide the +# functional optimizer to user and still provide the same API. +functional_optim_map = { + optim.Adagrad: _FunctionalAdagrad, + optim.Adam: _FunctionalAdam, + optim.AdamW: _FunctionalAdamW, + optim.SGD: _FunctionalSGD, + optim.Adadelta: _FunctionalAdadelta, + optim.RMSprop: _FunctionalRMSprop, + optim.Rprop: _FunctionalRprop, + optim.Adamax: _FunctionalAdamax, +} + + +def register_functional_optim(key, optim): + """ + Interface to insert a new functional optimizer to functional_optim_map + ``fn_optim_key`` and ``fn_optimizer`` are user defined. The optimizer and key + need not be of :class:`torch.optim.Optimizer` (e.g. for custom optimizers) + Example:: + >>> # import the new functional optimizer + >>> # xdoctest: +SKIP + >>> from xyz import fn_optimizer + >>> from torch.distributed.optim.utils import register_functional_optim + >>> fn_optim_key = "XYZ_optim" + >>> register_functional_optim(fn_optim_key, fn_optimizer) + """ + if key not in functional_optim_map: + functional_optim_map[key] = optim + + +def as_functional_optim(optim_cls: type, *args, **kwargs): + try: + functional_cls = functional_optim_map[optim_cls] + except KeyError as e: + raise ValueError( + f"Optimizer {optim_cls} does not have a functional counterpart!" + ) from e + + return _create_functional_optim(functional_cls, *args, **kwargs) + + +def _create_functional_optim(functional_optim_cls: type, *args, **kwargs): + return functional_optim_cls( + [], + *args, + **kwargs, + _allow_empty_param_list=True, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/zero_redundancy_optimizer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/zero_redundancy_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..3183299a48347b4444cfe7b5105c1a1aadc8b4fd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/zero_redundancy_optimizer.py @@ -0,0 +1,1679 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +r"""Zero Redundancy Optimizer.""" + +import collections +import copy +import enum +import inspect +import io +import logging +from collections.abc import Callable +from itertools import chain +from typing import Any + +import torch +import torch.distributed as dist +from torch.distributed.algorithms.join import Join, Joinable, JoinHook +from torch.distributed.optim.utils import functional_optim_map +from torch.optim import Optimizer + + +__all__ = ["ZeroRedundancyOptimizer"] + + +logger = logging.getLogger(__name__) + + +# Credits: classy_vision/generic/distributed_util.py +def _recursive_copy_to_device( + value: Any, + non_blocking: bool, + device: torch.device, +) -> Any: + r""" + Recursively searches lists, tuples, dicts and copies tensors to device if possible. + + Non-tensor values are passed as-is in the result. + + .. note:: + These are all copies, so if there are two objects that reference + the same object, then after this call, there will be two different objects + referenced on the device. + """ + if isinstance(value, torch.Tensor): + return value.to(device, non_blocking=non_blocking) + + if isinstance(value, (list, tuple)): + values = [ + _recursive_copy_to_device(val, non_blocking=non_blocking, device=device) + for val in value + ] + return values if isinstance(value, list) else tuple(values) + + if isinstance(value, collections.abc.Mapping): + return { + key: _recursive_copy_to_device( + val, non_blocking=non_blocking, device=device + ) + for key, val in value.items() + } + + return value + + +def _is_trainable(param: torch.Tensor) -> bool: + r"""Return if a parameter is trainable, where trainability is equivalent to requiring a gradient.""" + return param.requires_grad + + +def _broadcast_object( + obj: Any, + src_rank: int, + group: object = dist.group.WORLD, + device: torch.device = torch.device("cpu"), +) -> Any: + r""" + Broadcasts an object to the given group. + + It will be sending the object if called from the source rank and receiving + the object otherwise. + + Arguments: + obj: object to broadcast; only used if called on the source rank. + src_rank (int): source rank. + group (``ProcessGroup``, optional): group used for the broadcast + (default: ``dist.group.WORLD``). + device (``torch.device``, optional): device to send from or receive + to (default: ``torch.device("cpu")``). + + Returns: + The broadcasted object. + """ + if dist.get_rank() == src_rank: + # Send the object + buffer = io.BytesIO() + torch.save(obj, buffer) + data = bytearray(buffer.getbuffer()) + length_tensor = torch.LongTensor([len(data)]).to(device) + data_send_tensor = torch.ByteTensor(data).to(device) + # pyrefly: ignore [bad-argument-type] + dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) + # pyrefly: ignore [bad-argument-type] + dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False) + else: + # Receive the object + length_tensor = torch.LongTensor([0]).to(device) + # pyrefly: ignore [bad-argument-type] + dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) + data_recv_tensor = torch.empty( + [int(length_tensor.item())], dtype=torch.uint8, device=device + ) + # pyrefly: ignore [bad-argument-type] + dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False) + buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) + obj = torch.load(buffer, map_location=device, weights_only=False) + return obj + + +class _ZeROJoinHook(JoinHook): + def __init__(self, zero): + assert isinstance(zero, ZeroRedundancyOptimizer), ( + "ZeRO join hook requires passing in a ZeroRedundancyOptimizer " + "instance as the state" + ) + self.zero = zero + super().__init__() + + def main_hook(self): + """ + Perform an optimizer step. + + This step updates the joined process's shard of + the parameters and broadcasts those parameters. + """ + self.zero.step() + + +class _DDPBucketAssignment: + r""" + Represent a :class:`DistributedDataParallel` bucket assignment. + + This means that a (possibly non-strict) subset of the parameters corresponding to + a DDP bucket assigned to a rank to update. + + Attributes: + bucket_index (int): index of the bucket determined by the DDP gradient + bucket all-reduce order. + parameters (List[torch.Tensor]): model parameters in the bucket + assigned to this rank. + offset (int): offset into the :class:`GradBucket` 's :meth:`parameters` + giving the index of the first element in the passed-in + ``parameters``; this equivalently indexes into the + :class:`GradBucket` 's :meth:`gradients`. + device (torch.device): device on which the parameters are stored. + tensor (torch.Tensor): flattened tensor giving the data of the + parameter subset assigned to the rank. + """ + + def __init__( + self, + bucket_index: int, + parameters: list[torch.Tensor], + offset: int, + ): + self.bucket_index = bucket_index + self.parameters = parameters + self.offset = offset + if len(self.parameters) == 0: + raise ValueError("Empty bucket assignment") + # DDP guarantees all parameters in the bucket have the same device + # pyrefly: ignore [read-only] + self.device: torch.device = self.parameters[0].device + self.tensor: torch.Tensor | None = None + + +class _OverlapStatus(enum.IntEnum): + r""" + Define possible statuses that :class:`ZeroRedundancyOptimizer` can be in when overlapping with :class:`DistributedDataParallel`. + + Attributes: + ``UNINITIALIZED``: The ZeRO instance is effectively uninitialized and + is waiting for DDP to finalize its bucketing. + ``DDP_HAS_REBUILT_BUCKETS``: DDP has rebuilt its buckets, meaning that + its bucketing is finalized. The ZeRO instance can now collect the + necessary information about the DDP bucketing. + ``INITIALIZED``: The ZeRO instance is fully initialized and can now + optimize parameters. + """ + + UNINITIALIZED = 0 + DDP_HAS_REBUILT_BUCKETS = 1 + INITIALIZED = 2 + + +class _OverlapInfo: + r""" + Information needed by :class:`ZeroRedundancyOptimizer` to overlap with :class:`DistributedDataParallel`. + + Arguments: + world_size (int): world size of the process group being used. + + Attributes: + shard_buckets (bool): if ``True``, then the assignment of each + :class:`DistributedDataParallel` bucket is partitioned across + possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e. + across possibly multiple ranks) to approximate uniformity following + a threshold given by the total parameter size divided by the world + size; if ``False``, then each bucket is wholly assigned to a single + :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank); + this should be set to the value passed into the hook constructor. + status (_OverlapStatus): current status; see :class:`_OverlapStatus` + for more information. + params_per_bucket (List[List[torch.Tensor]]): ``params_per_bucket[i]`` + gives the model parameters in the ``i``th bucket. + params_per_rank (List[List[torch.Tensor]]): ``params_per_rank[i]`` + gives the model parameters assigned to the ``i``th rank, where the + parameters are grouped by increasing bucket indices. + offsets (Dict[int, int]): maps from bucket index to the offset in + ``self.params_per_rank[rank]`` giving the index of the first + parameter in that bucket, where ``rank`` is this process's own + rank; the keys of this :class:`dict` are the bucket indices + assigned to this rank. + num_bucket_assignments (int): total number of bucket assignments across + all ranks; this is equal to the number of + :class:`DistributedDataParallel` gradient buckets if + ``shard_buckets=False`` and possibly greater otherwise. + total_size (int, optional): total size of all buckets (i.e. sum of + ``param.numel()`` for all ``param`` across all buckets) if + ``shard_buckets=True``; otherwise, ``None``. + broadcast_handles (List[Work]): :class:`list` of async work handles for + the parameter broadcasts. + bucket_index_to_future (Dict[int, torch.futures.Future]): + :class:`dict` mapping bucket index to the corresponding all-reduce + future. + bucket_index_to_bucket (Dict[int, dist.GradBucket]): :class:`dict` + mapping bucket index to the corresponding bucket. + bucket_indices_seen (List[int]): :class:`list` of the bucket indices + seen on this iteration. + """ + + def __init__(self, world_size) -> None: + self.status: _OverlapStatus = _OverlapStatus.UNINITIALIZED + self.shard_buckets: bool = False + + # Modified per bucket reconstruction + self.params_per_bucket: list[list[torch.Tensor]] = [] + self.params_per_rank: list[list[torch.Tensor]] = [[] for _ in range(world_size)] + self.offsets: dict[int, int] = {} + # Group Ranks + self.assigned_ranks_per_bucket: list[set[int]] = [] + self.num_bucket_assignments: int = 0 + self.total_size: int | None = None + + # Modified per iteration + self.broadcast_handles: list[Any] = [] + self.bucket_indices_seen: list[int] = [] + # Used by `hook_with_zero_step()` + self.bucket_index_to_future: dict[int, torch.futures.Future] = {} + self.bucket_index_to_bucket: dict[int, dist.GradBucket] = {} + + def wait_for_broadcasts(self) -> None: + r""" + Wait for all parameter broadcasts. + + This function should be called once all broadcasts have been scheduled, + meaning ``self.broadcast_handles`` is filled. This clears ``self.broadcast_handles`` + in preparation for the next iteration. + """ + assert len(self.broadcast_handles) == self.num_bucket_assignments, ( + f"Missing at least one broadcast handle on rank {dist.get_rank()}" + ) + _ = [x.wait() for x in self.broadcast_handles] + self.broadcast_handles.clear() + + def clear_per_iter_info(self) -> None: + r""" + Clear the data structures that are modified per-iteration. + + This function should be called at the end of an iteration. + """ + self.bucket_indices_seen.clear() + self.bucket_index_to_future.clear() + self.bucket_index_to_bucket.clear() + + +class ZeroRedundancyOptimizer(Optimizer, Joinable): + r""" + Wrap an arbitrary :class:`optim.Optimizer ` and shards its states across ranks in the group. + + The sharing is done as described by `ZeRO `_. + + The local optimizer instance in each rank is only + responsible for updating approximately ``1 / world_size`` parameters and + hence only needs to keep ``1 / world_size`` optimizer states. After + parameters are updated locally, each rank will broadcast its parameters to + all other peers to keep all model replicas in the same state. + ``ZeroRedundancyOptimizer`` can be used in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel` to reduce per-rank peak + memory consumption. + + ``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number + of parameters at each rank. Each parameter belongs to a single rank and is + not divided among ranks. The partition is arbitrary and might not match the + the parameter registration or usage order. + + Arguments: + params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s + or :class:`dict` s giving all parameters, which will be sharded + across ranks. + + Keyword Args: + optimizer_class (:class:`torch.nn.Optimizer`): the class of the local + optimizer. + process_group (``ProcessGroup``, optional): ``torch.distributed`` + ``ProcessGroup`` (default: ``dist.group.WORLD`` initialized by + :meth:`torch.distributed.init_process_group`). + parameters_as_bucket_view (bool, optional): if ``True``, parameters are + packed into buckets to speed up communication, and ``param.data`` + fields point to bucket views at different offsets; if ``False``, + each individual parameter is communicated separately, and each + ``params.data`` stays intact (default: ``False``). + overlap_with_ddp (bool, optional): if ``True``, :meth:`step` is + overlapped with :class:`DistributedDataParallel` 's gradient + synchronization; this requires (1) either a functional optimizer + for the ``optimizer_class`` argument or one with a functional + equivalent and (2) registering a DDP communication hook + constructed from one of the functions in ``ddp_zero_hook.py``; + parameters are packed into buckets matching those in + :class:`DistributedDataParallel`, meaning that the + ``parameters_as_bucket_view`` argument is ignored. + If ``False``, :meth:`step` runs disjointly after the backward pass + (per normal). + (default: ``False``) + **defaults: any trailing arguments, which are forwarded to the local + optimizer. + + Example:: + + >>> # xdoctest: +SKIP + >>> import torch.nn as nn + >>> from torch.distributed.optim import ZeroRedundancyOptimizer + >>> from torch.nn.parallel import DistributedDataParallel as DDP + >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)]) + >>> ddp = DDP(model, device_ids=[rank]) + >>> opt = ZeroRedundancyOptimizer( + >>> ddp.parameters(), + >>> optimizer_class=torch.optim.Adam, + >>> lr=0.01 + >>> ) + >>> ddp(inputs).sum().backward() + >>> opt.step() + + .. warning:: + Currently, ``ZeroRedundancyOptimizer`` requires that all of the + passed-in parameters are the same dense type. + + .. warning:: + If you pass ``overlap_with_ddp=True``, be wary of the following: Given + the way that overlapping :class:`DistributedDataParallel` with + :class:`ZeroRedundancyOptimizer` is currently implemented, the first + two or three training iterations do not perform parameter updates in + the optimizer step, depending on if ``static_graph=False`` or + ``static_graph=True``, respectively. This is because it needs + information about the gradient bucketing strategy used by + :class:`DistributedDataParallel`, which is not finalized until the + second forward pass if ``static_graph=False`` or until the third + forward pass if ``static_graph=True``. To adjust for this, one option + is to prepend dummy inputs. + + .. warning:: ZeroRedundancyOptimizer is experimental and subject to change. + """ + + def __init__( + self, + params, + optimizer_class: type[Optimizer], + process_group: Any | None = None, + parameters_as_bucket_view: bool = False, + overlap_with_ddp: bool = False, + **defaults: Any, + ): + r"""Init.""" + # Perform type and assumption checks on the input parameters + params = self._verify_and_init_params(params) + self._verify_same_dense_param_type() + + # NOTE: The parent constructor uses `add_param_group()` which is + # partially overloaded in ZeroRedundancyOptimizer, so we use the + # `initialized` flag to dissociate the behaviour of `add_param_group()` + # between the parent and child. + self.initialized = False + + Optimizer.__init__(self, params, defaults) + Joinable.__init__(self) + # Now, all parameters are held in both `self._all_params` and + # `self.param_groups` + + # Internal data structures (`_cache` indicates lazily evaluated) + self._param_to_rank_cache: dict[torch.Tensor, int] = {} + self._param_to_index_cache: dict[torch.Tensor, int] = {} + self._partition_parameters_cache: list[list[dict]] = [] + self._index_to_param_cache: list[torch.Tensor] = [] + self._device_to_params_per_rank_cache: dict[ + torch.device, list[list[torch.Tensor]] + ] = {} + self._bucket_assignments_per_rank_cache: list[ + dict[int, _DDPBucketAssignment] + ] = [] + self._is_trainable_mask = self._get_is_trainable_mask() + + # Default device for collective communication and buckets + self._default_device = self._all_params[0].device + + self.process_group = ( + process_group if process_group is not None else dist.group.WORLD + ) + self.world_size: int = dist.get_world_size(self.process_group) + self.rank: int = dist.get_rank(self.process_group) + self.global_rank: int = dist.distributed_c10d.get_global_rank( + # pyrefly: ignore [bad-argument-type] + self.process_group, + self.rank, + ) + + self._overlap_with_ddp: bool = overlap_with_ddp + self._optim_defaults = defaults + self._optim_constructor = self._get_optimizer_constructor(optimizer_class) + + # If `overlap_with_ddp=True`, local optimizer initialization is delayed + # to run time after the necessary information has been collected + if not overlap_with_ddp: + self._init_local_optimizer() + else: + self._overlap_info: _OverlapInfo = _OverlapInfo(self.world_size) + if parameters_as_bucket_view: + logger.warning( + "`parameters_as_bucket_view=True` will be ignored since " + "`overlap_with_ddp=True`; instead, a different bucketing " + "strategy will be used" + ) + + # `self._buckets` is used if `parameters_as_bucket_view=True`, in + # which case parameter data is flattened into contiguous bucket tensors + self.parameters_as_bucket_view = parameters_as_bucket_view + self._buckets: list[list[torch.Tensor]] = [] + self._build_param_buckets() + + # Optional consolidated optimizer state, only populated if this rank + # is the target in `consolidate_state_dict()` + self._all_state_dicts: list[dict[str, Any]] = [] + + self.initialized = True + + def _clear_cache(self) -> None: + r"""Clear the cached data structures giving partition information.""" + self._partition_parameters_cache.clear() + self._param_to_rank_cache.clear() + self._index_to_param_cache.clear() + self._param_to_index_cache.clear() + self._device_to_params_per_rank_cache.clear() + self._bucket_assignments_per_rank_cache.clear() + + def add_param_group(self, param_group: dict[str, Any]) -> None: + r""" + Add a parameter group to the :class:`Optimizer` 's ``param_groups``. + + This can be useful when fine tuning a pre-trained network, as frozen + layers can be made trainable and added to the :class:`Optimizer` as + training progresses. + + Arguments: + param_group (dict): specifies the parameters to be optimized and + group-specific optimization options. + + .. warning:: This method handles updating the shards on all partitions + but needs to be called on all ranks. Calling this on a subset of + the ranks will cause the training to hang because communication + primitives are called depending on the managed parameters and + expect all the ranks to participate on the same set of parameters. + """ + if self.initialized and self._overlap_with_ddp: + raise RuntimeError( + "ZeroRedundancyOptimizer with `overlap_with_ddp=True` only " + "supports a single parameter group" + ) + + super().add_param_group(param_group) + # NOTE: The rest of the method assumes that the call to the parent's + # `add_param_group()` appends the new parameter group and preserves + # the previous parameter-group ordering + + if self.initialized: + # Force a re-partitioning of the parameters + self._clear_cache() + param_groups = self._partition_parameters()[self.rank] + # NOTE: All parameters in the old parameter groups should be + # assigned to the same ranks so that the local optimizers do not + # need to be reinitialized + + # Add the parameters assigned to this rank from the new parameter + # group to the local optimizer, if any + if len(param_groups) == len(self.optim.param_groups) + 1: + self.optim.add_param_group(param_groups[-1]) + + # Update the bucketing strategy accordingly + if self.parameters_as_bucket_view: + self._build_param_buckets() + + def consolidate_state_dict(self, to: int = 0) -> None: + r""" + Consolidate a list of ``state_dict`` s (one per rank) on the target rank. + + Arguments: + to (int): the rank that receives the optimizer states (default: 0). + + Raises: + RuntimeError: if ``overlap_with_ddp=True`` and this method is + called before this :class:`ZeroRedundancyOptimizer` instance + has been fully initialized, which happens once + :class:`DistributedDataParallel` gradient buckets have been + rebuilt. + + .. warning:: This needs to be called on all ranks. + """ + self._check_overlap_initialized() + + # Sync the exposed `param_groups` attributes to the local optimizer in + # case they have been updated + self._sync_param_groups(self.param_groups, self.optim.param_groups) + + # Pull the sharded state from all ranks and store them in rank order + empty_messenger = torch.tensor( + [0], dtype=torch.uint8, device=self._default_device + ) + + # NOTE: We wastefully use `broadcast()` (e.g. instead of `gather()`) + # due to compatibility issues with NCCL backend; a possible follow-up + # is to move all sharded state management to RPC RRef + self._all_state_dicts = [] + for rank in range(self.world_size): + global_rank = dist.distributed_c10d.get_global_rank( + # pyrefly: ignore [bad-argument-type] + self.process_group, + rank, + ) + if self.rank == to: + # Consolidate all local `state_dict`s on this rank, storing on + # CPU to save GPU memory + if rank == self.rank: + # Directly append own optimizer state + self._all_state_dicts.append( + _recursive_copy_to_device( + self.optim.state_dict(), + non_blocking=True, + device=torch.device("cpu"), + ) + ) + else: + # Receive the optimizer state from the source rank + local_state_dict = _broadcast_object( + empty_messenger, + src_rank=global_rank, + group=self.process_group, + device=self._default_device, + ) + self._all_state_dicts.append( + _recursive_copy_to_device( + local_state_dict, + non_blocking=True, + device=torch.device("cpu"), + ) + ) + else: + if rank == self.rank: + # Send the optimizer state to the target rank + _ = _broadcast_object( + self.optim.state_dict(), + src_rank=self.global_rank, + group=self.process_group, + device=self._default_device, + ) + elif rank != to: + # Discard the received object; `broadcast()` is used for + # compatibility reasons + _ = _broadcast_object( + empty_messenger, + src_rank=global_rank, + group=self.process_group, + device=self._default_device, + ) + + def _verify_params_per_rank( + self, + params_per_rank: list[list[torch.Tensor]], + ) -> None: + r""" + Verify ``params_per_rank`` for :meth:`_partition_parameters`. + + The verification is done by checking that ``params_per_rank`` has length equal + to the world size and that it does not contain any parameters not passed into the + :class:`ZeroRedundancyOptimizer` constructor. + + The parameters in ``params_per_rank`` being a strict subset of those + passed into the constructor is valid since some parameters may be + frozen. + + Raises: + ValueError: if ``params_per_rank`` does not have length equal to + the world size or if it contains a parameter that was not + passed into the :class:`ZeroRedundancyOptimizer` constructor. + """ + if len(params_per_rank) != self.world_size: + raise ValueError( + "`params_per_rank` must have length equal to the world size" + ) + all_params_set = set(self._all_params) + for params in params_per_rank: + for param in params: + if param not in all_params_set: + raise ValueError( + "Passing a new parameter in `params_per_rank` that " + "was not passed into the ZeroRedundancyOptimizer " + "constructor" + ) + + def _partition_param_group( + self, param_group: dict[str, Any], params_per_rank: list[list[torch.Tensor]] + ) -> None: + r""" + Partition the parameter group ``param_group`` according to ``params_per_rank``. + + The partition will modify the ``self._partition_parameters_cache``. This method should + only be used as a subroutine for :meth:`_partition_parameters`. + + Arguments: + param_group (dict[str, Any]): a parameter group as normally defined + in an optimizer state. + params_per_rank (list[list[torch.Tensor]]): a :class:`list` of + length world size containing :class:`list` s of parameters to + assign to each rank. + """ + for rank, params in enumerate(params_per_rank): + rank_param_group = copy.copy(param_group) + rank_param_group["params"] = params + self._partition_parameters_cache[rank].append(rank_param_group) + + def _partition_parameters( + self, + params_per_rank: list[list[torch.Tensor]] | None = None, + ) -> list[list[dict]]: + r""" + Partitions parameters across distributed data parallel ranks. + + Arguments: + params_per_rank (list[list[torch.Tensor]], optional): a + :class:`list` of length world size containing :class:`list` s + of parameters to assign to each rank; this provides a way to + specify a partition manually. + If ``None``, the parameters are partitioned according to an + internal algorithm. + (default: ``None``) + + Returns: + A :class:`list` where each element of the list contains the + ``param_groups`` for a rank (which itself is a :class:`list` of + :class:`dict`); element 0 corresponds to rank 0, etc.; each rank + stores the ``param_groups`` for all ranks for the collective + communication in :meth:`step`. + + Raises: + ValueError: see :meth:`_validate_params_per_rank`. + RuntimeError: if ``params_per_rank`` is not ``None`` and this + :class:`ZeroRedundancyOptimizer` instance is using more than + one parameter group. + """ + if params_per_rank is None: + # Partition the parameters optimizing for uniformity + if len(self._partition_parameters_cache) == 0: + self._partition_parameters_cache = [[] for _ in range(self.world_size)] + sizes = [0] * self.world_size + for param_group in self.param_groups: + param_group_params_per_rank: list[list] = [ + [] for _ in range(self.world_size) + ] + # Sort the parameters by size (largest first) + params_sorted = sorted( + param_group["params"], key=lambda t: t.numel(), reverse=True + ) + for param in params_sorted: + # Greedily add the parameter to rank with smallest size so far + rank = self._get_min_index(sizes) + param_group_params_per_rank[rank].append(param) + sizes[rank] += param.numel() + # Apply the constructed partition of the parameter group + self._partition_param_group( + param_group, param_group_params_per_rank + ) + + return self._partition_parameters_cache + + # Partition the parameters according to `params_per_rank` + assert len(self._partition_parameters_cache) == 0, ( + "Specifying `params_per_rank` should only be done when the " + "parameters have not been partitioned yet" + ) + if len(self.param_groups) != 1: + raise RuntimeError( + "Specifying `params_per_rank` only supports a single parameter group" + ) + self._verify_params_per_rank(params_per_rank) + self._partition_parameters_cache = [[] for _ in range(self.world_size)] + + # Apply the passed-in partition of the parameter group + param_group = self.param_groups[0] + self._partition_param_group(param_group, params_per_rank) + + return self._partition_parameters_cache + + @property + def _param_to_rank(self) -> dict[torch.Tensor, int]: + r""":class:`dict` mapping parameters to their assigned data parallel rank in the partition.""" + if len(self._param_to_rank_cache) == 0: + for rank, param_groups in enumerate(self._partition_parameters()): + for param_group in param_groups: + for param in param_group["params"]: + self._param_to_rank_cache[param] = rank + return self._param_to_rank_cache + + @property + def _param_to_index(self) -> dict[torch.Tensor, int]: + r""" + :class:`dict` mapping parameters to their indices in the global optimizer state. + + NOTE: This assumes that the global optimizer state's indexing (in + ``state_dict``) follows a linear ordering over the parameter groups. + """ + if len(self._param_to_index_cache) == 0: + self._param_to_index_cache = { + p: i + for i, p in enumerate( + chain.from_iterable(g["params"] for g in self.param_groups) + ) + } + return self._param_to_index_cache + + @property + def _index_to_param(self) -> list[torch.Tensor]: + r"""List mapping parameter indices in the global optimizer scheme to the actual params.""" + if len(self._index_to_param_cache) == 0: + self._index_to_param_cache = list( + chain.from_iterable(g["params"] for g in self.param_groups) + ) + return self._index_to_param_cache + + def _broadcast_params_from_rank(self, rank: int): + r""" + Broadcast the shard of parameters from a given rank to all other ranks asynchronously. + + Arguments: + rank (int): the source rank. + + Returns: + A :class:`list` of async work handles for the ``broadcast()`` s + performed to synchronize the parameters. + """ + assert not self._overlap_with_ddp, ( + "`_broadcast_params_from_rank()` should not be used if " + "`overlap_with_ddp=True`; instead, the broadcasting should " + "happen in the DDP communication hook" + ) + handles = [] + if self.parameters_as_bucket_view: + for dev_i_buckets in self._buckets: + bucket = dev_i_buckets[rank] + global_rank = dist.distributed_c10d.get_global_rank( + # pyrefly: ignore [bad-argument-type] + self.process_group, + rank, + ) + handles.append( + dist.broadcast( + tensor=bucket, + src=global_rank, + group=self.process_group, + async_op=True, + ) + ) + else: + param_groups = self._partition_parameters()[rank] + global_rank = dist.distributed_c10d.get_global_rank( + # pyrefly: ignore [bad-argument-type] + self.process_group, + rank, + ) + for param_group in param_groups: + handles.extend( + dist.broadcast( + tensor=param.data, + src=global_rank, + group=self.process_group, + async_op=True, + ) + for param in param_group["params"] + ) + return handles + + def _sync_params(self): + r""" + Sync all parameter shards across the ranks. + + This rank sends its shard of the parameters to all other ranks and + receives a shard from each other rank. This is done using + ``broadcast()``. Parameters are sent bucket-by-bucket if + ``parameters_as_bucket_view=True``and sent parameter-by-parameter + otherwise. + """ + handles = [] + for rank in range(self.world_size): + handles.extend(self._broadcast_params_from_rank(rank)) + _ = [x.wait() for x in handles] + + @property + def _device_to_params_per_rank( + self, + ) -> dict[torch.device, list[list[torch.Tensor]]]: + r""" + Return device parameters assigned per rank. + + :class:`dict` mapping each device to a :class:`list` of the per-rank parameter + lists filtered to only include the parameters stored on that device. + Each per-rank parameter list gives the parameters assigned to that rank + to update. + + This is used for constructing the parameter buckets if + ``parameters_as_bucket_view=True``. + + Let ``dev_i`` denote the ``i``th device for this rank. Then: + ``dev_0`` maps to a list containing: + rank 0's assigned parameters stored on ``dev_0``, + rank 1's assigned parameters stored on ``dev_0``, + ... + ``dev_1`` maps to a list containing: + rank 0's assigned parameters stored on ``dev_1``, + rank 1's assigned parameters stored on ``dev_1``, + ... + ... + """ + assert self.parameters_as_bucket_view, ( + "`_device_to_params_per_rank` should only be used if " + "`parameters_as_bucket_view=True`" + ) + if len(self._device_to_params_per_rank_cache) == 0: + for rank, param_groups in enumerate(self._partition_parameters()): + for param_group in param_groups: + for param in param_group["params"]: + device = param.device + if device not in self._device_to_params_per_rank_cache: + self._device_to_params_per_rank_cache[device] = [ + [] for _ in range(self.world_size) + ] + self._device_to_params_per_rank_cache[device][rank].append( + param + ) + return self._device_to_params_per_rank_cache + + def _get_min_index( + self, + values: list[int], + disallowed_indices: set[int] | None = None, + ) -> int: + r""" + Return ``values.index(min(values))``, except only uses one pass. + + It also excludes any indices in ``disallowed_indices`` if provided. + + Arguments: + values: (List[int]): :class:`list` of values. + disallowed_indices (Optional[set[int]]): indices that are + disallowed from being the returned min index. + """ + min_index = -1 + min_value = float("inf") + for i, value in enumerate(values): + if disallowed_indices and i in disallowed_indices: + continue + if value < min_value: + min_value = value + min_index = i + assert min_index >= 0, "All indices are disallowed" + return min_index + + def _assign_bucket_subset_to_rank( + self, + bucket_index: int, + bucket_params: list[torch.Tensor], + bucket_offset: int, + assigned_rank: int, + assigned_ranks_per_bucket: list[set[int]], + ) -> None: + r""" + Assign ``bucket_params`` to the rank with the least size assigned so far and collects relevant information. + + The model parameters given by ``bucket_params`` represents a (possibly non-strict) + subset of the parameters corresponding to a :class:`DistributedDataParallel` bucket. + + Arguments: + bucket_index (int): index of the :class:`DistributedDataParallel` + gradient bucket. + bucket_params (List[torch.Tensor]): subset of the parameters + corresponding to the bucket to assign. + bucket_offset (int): offset giving the index of the first element + in ``bucket_params`` in the bucket's full parameter list. + assigned_rank (int): group rank to assign to. + assigned_ranks_per_bucket (list[set[int]]): :class:`set` of group ranks + assigned to each bucket. + """ + overlap_info = self._overlap_info + if len(bucket_params) == 0: + raise ValueError("Empty bucket assignment") + params_per_rank = overlap_info.params_per_rank + offsets = overlap_info.offsets + + self._bucket_assignments_per_rank_cache[assigned_rank][bucket_index] = ( + _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset) + ) + if self.global_rank == assigned_rank: + offsets[bucket_index] = len(params_per_rank[assigned_rank]) + params_per_rank[assigned_rank].extend(bucket_params) + assigned_ranks_per_bucket[bucket_index].add(assigned_rank) + self._overlap_info.num_bucket_assignments += 1 + + @property + def _bucket_assignments_per_rank(self) -> list[dict[int, _DDPBucketAssignment]]: + r""" + Return DDP bucket parameters assigned per rank. + + :class:`list` of length world size consisting of :class:`dict` s + mapping bucket indices to :class:`_DDPBucketAssignment` s for each + rank. + """ + assert self._overlap_with_ddp, ( + "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`" + ) + if len(self._bucket_assignments_per_rank_cache) > 0: + return self._bucket_assignments_per_rank_cache + + overlap_info = self._overlap_info + assert overlap_info.status == _OverlapStatus.INITIALIZED + + self._bucket_assignments_per_rank_cache = [{} for _ in range(self.world_size)] + params_per_bucket = overlap_info.params_per_bucket + + if overlap_info.shard_buckets: + # Define the assignment threshold to approximate uniformity + assert overlap_info.total_size is not None, "`total_size` was not computed" + threshold = overlap_info.total_size / self.world_size # type: ignore[operator] + size_per_rank = [0 for _ in range(self.world_size)] + + num_buckets = len(params_per_bucket) + overlap_info.assigned_ranks_per_bucket = [set() for _ in range(num_buckets)] + assigned_ranks_per_bucket = overlap_info.assigned_ranks_per_bucket + if not overlap_info.shard_buckets: + # Assign each DDP bucket entirely to a single rank + for bucket_index, bucket_params in enumerate(params_per_bucket): + assert len(bucket_params) > 0, "Empty bucket" + assigned_rank = self._get_assigned_rank(bucket_index) + self._assign_bucket_subset_to_rank( + bucket_index, + bucket_params, + 0, + assigned_rank, + assigned_ranks_per_bucket, + ) + else: + # Assign each DDP bucket to possibly multiple ranks + # Specifically, sort the DDP buckets by increasing size, and for + # each bucket, iteratively assign the maximal unassigned subset + # with size less than `threshold` to the rank with the least total + # size so far -- each such assignment is represented by a + # `_DDPBucketAssignment` instance and only contains parameters from + # a single DDP bucket + params_per_bucket_enum = sorted( + enumerate(params_per_bucket), key=lambda x: sum(p.numel() for p in x[1]) + ) + for bucket_index, bucket_params in params_per_bucket_enum: + assert len(bucket_params) > 0, "Empty bucket" + bucket_offset = 0 + assignment_size = 0 + for param_index, param in enumerate(bucket_params): + param_numel = param.numel() + if ( + # pyrefly: ignore [unbound-name] + assignment_size + param_numel >= threshold + and param_index > bucket_offset + ): + assigned_rank = self._get_min_index( + # pyrefly: ignore [unbound-name] + size_per_rank, + assigned_ranks_per_bucket[bucket_index], + ) + # Include up to but not including the parameter that + # exceeded the threshold + self._assign_bucket_subset_to_rank( + bucket_index, + bucket_params[bucket_offset:param_index], + bucket_offset, + assigned_rank, + assigned_ranks_per_bucket, + ) + # pyrefly: ignore [unbound-name] + size_per_rank[assigned_rank] += assignment_size + bucket_offset = param_index + assignment_size = 0 + assignment_size += param_numel + # Assign the remainder of the bucket so that no assignment + # spans across two buckets + assigned_rank = self._get_min_index( + # pyrefly: ignore [unbound-name] + size_per_rank, + assigned_ranks_per_bucket[bucket_index], + ) + self._assign_bucket_subset_to_rank( + bucket_index, + bucket_params[bucket_offset:], + bucket_offset, + assigned_rank, + assigned_ranks_per_bucket, + ) + # pyrefly: ignore [unbound-name] + size_per_rank[assigned_rank] += assignment_size + + return self._bucket_assignments_per_rank_cache + + def _local_step( + self, + gradients: list[torch.Tensor | None] | None = None, + closure: Callable[[], float] | None = None, + **kwargs: Any, + ) -> float | None: + r""" + Perform a single optimizer step without syncing parameters across ranks. + + Arguments: + gradients (list[Optional[torch.Tensor]], optional): a :class:`list` + of length equal to the number of parameters assigned to this + rank containing gradient tensors or ``None`` as its elements; + a ``None`` in the :class:`list` indicates that the + corresponding parameter should not be updated. + If the argument itself is ``None``, then all parameters are + updated, and the gradients are assumed to be already populated. + (default: ``None``) + closure (Callable): a closure that re-evaluates the model and + returns the loss; optional for most optimizers and should be + ``None`` if ``gradients`` is not ``None``; (default: ``None``) + Returns: + Optional loss depending on the underlying local optimizer. + + .. warning:: + The argument ``gradients`` should only be specified (i.e. not + ``None``) if ``overlap_with_ddp=True``, in which case + :class:`ZeroRedundancyOptimizer` wraps a functional optimizer. + """ + Join.notify_join_context(self) + # Check if the model trainability has changed + is_trainable_mask = self._get_is_trainable_mask() + if is_trainable_mask != self._is_trainable_mask: + if self._overlap_with_ddp: + raise RuntimeError( + "ZeroRedundancyOptimizer with `overlap_with_ddp=True` " + "does not support changing parameter trainability at run " + "time" + ) + logger.warning( + "ZeroRedundancyOptimizer detected that the trainable " + "parameters changed; rebuilding the parameter buckets if " + "enabled" + ) + self._build_param_buckets() + self._is_trainable_mask = is_trainable_mask + + # Sync the exposed `param_groups` attributes to the local optimizer in + # case they have been updated + self._sync_param_groups(self.param_groups, self.optim.param_groups) + + # Run the optimizer step on this shard only + if gradients is None: + loss = ( + self.optim.step(**kwargs) + if closure is None + else self.optim.step(closure=closure, **kwargs) + ) + else: + assert self._overlap_with_ddp, ( + "Specifying `gradients` should not " + "be used when `overlap_with_ddp=False`" + ) + assert closure is None, ( + "`closure` is not supported when using a local functional optimizer" + ) + loss = self.optim.step(gradients=gradients) + + # Sync any updated attributes in the local optimizer to the exposed + # `param_groups` + self._sync_param_groups(self.optim.param_groups, self.param_groups) + + return loss + + # pyrefly: ignore [bad-override] + def step( + self, + closure: Callable[[], float] | None = None, + **kwargs: Any, + ) -> float | None: + r""" + Perform a single optimizer step and syncs parameters across all ranks. + + Arguments: + closure (Callable): a closure that re-evaluates the model and + returns the loss; optional for most optimizers. + Returns: + Optional loss depending on the underlying local optimizer. + + .. note:: Any extra parameters are passed to the base optimizer as-is. + """ + if self._overlap_with_ddp: + logger.warning( + "`step()` should not be included in the training loop when " + "`overlap_with_ddp=True`" + ) + return None + + # Perform the local optimizer step + loss = self._local_step(closure=closure, **kwargs) + + # Sync all of the updated parameter shards across the ranks + self._sync_params() + + return loss + + def join_hook(self, **kwargs): + r""" + Return the ZeRO join hook. + + It enables training on uneven inputs by + shadowing the collective communications in the optimizer step. + + Gradients must be properly set before this hook is called. + + Arguments: + kwargs (dict): a :class:`dict` containing any keyword arguments + to modify the behavior of the join hook at run time; all + :class:`Joinable` instances sharing the same join context + manager are forwarded the same value for ``kwargs``. + + This hook does not support any keyword arguments; i.e. ``kwargs`` is + unused. + """ + return _ZeROJoinHook(self) + + @property + def join_device(self) -> torch.device: + r"""Return default device.""" + return self._default_device + + @property + def join_process_group(self) -> Any: + r"""Return process group.""" + return self.process_group + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + r""" + Load the state pertaining to the given rank from the input ``state_dict``, updating the local optimizer as needed. + + Arguments: + state_dict (dict): optimizer state; should be an object returned + from a call to :meth:`state_dict`. + + Raises: + RuntimeError: if ``overlap_with_ddp=True`` and this method is + called before this :class:`ZeroRedundancyOptimizer` instance + has been fully initialized, which happens once + :class:`DistributedDataParallel` gradient buckets have been + rebuilt. + """ + self._check_overlap_initialized() + + for index, value in state_dict["state"].items(): + param = self._index_to_param[index] + if self._param_to_rank[param] != self.rank: + # Clear any state irrelevant to this rank + state_dict["state"][index] = None + else: + # Load the parameter state to the local optimizer + self.optim.state[param] = _recursive_copy_to_device( + value, non_blocking=True, device=param.device + ) + # Force zero-dimensional tensors (like Adam "step") on CPU + for state_name, state_value in self.optim.state[param].items(): + if torch.is_tensor(state_value) and state_value.dim() == 0: + self.optim.state[param][state_name] = state_value.cpu() + + super().load_state_dict(state_dict) + + # Sync the input state with the exposed and local optimizer states + self._sync_param_groups(state_dict["param_groups"], self.param_groups) + self._sync_param_groups(self.param_groups, self.optim.param_groups) + + def state_dict(self) -> dict[str, Any]: + r""" + Return the last global optimizer state known to this rank. + + .. warning: + If the state has not been consolidated to this rank, this raises a + runtime error, and even if it has, the state may not be up-to-date, + depending on when :meth:`consolidate_state_dict` was last called. + + Raises: + RuntimeError: if ``overlap_with_ddp=True`` and this method is + called before this :class:`ZeroRedundancyOptimizer` instance + has been fully initialized, which happens once + :class:`DistributedDataParallel` gradient buckets have been + rebuilt; or if this method is called without a preceding call + to :meth:`consolidate_state_dict`. + """ + self._check_overlap_initialized() + + if len(self._all_state_dicts) == 0: + raise RuntimeError( + "Optimizer state has not been consolidated on this rank. " + f"Please call `consolidate_state_dict(to={self.rank})` on " + "all ranks beforehand if you meant to save the global state." + ) + + # Get the possibly-stale global optimizer state that uses global + # parameter indexing + state_dict = super().state_dict() + + # Update the global optimizer state with local state information, + # factoring in the translation from local to global indexing + for rank, local_state_dict in enumerate(self._all_state_dicts): + local_param_groups = local_state_dict["param_groups"] + global_param_groups = self._partition_parameters()[rank] + assert len(local_param_groups) == len(global_param_groups), ( + "Mismatch between number of local and global parameter groups" + ) + + for local_param_group, global_param_group in zip( + local_param_groups, global_param_groups + ): + # `local_param_group` stores local indices, while + # `global_param_group` stores the tensors directly + local_param_indices = local_param_group["params"] + global_params = global_param_group["params"] + + assert len(local_param_indices) == len(global_params), ( + "Mismatch between number of local and global parameters in parameter group" + ) + for local_param_index, global_param in zip( + local_param_indices, global_params + ): + # Update the global parameter state, if any + if local_param_index in local_state_dict["state"]: + global_param_index = self._param_to_index[global_param] + state_dict["state"][global_param_index] = local_state_dict[ + "state" + ][local_param_index] + + # Sort the parameters in the state + state_dict["state"] = dict(sorted(state_dict["state"].items())) + return state_dict + + @staticmethod + def _sync_param_groups( + src_param_groups: list[dict[Any, Any]], + dst_param_groups: list[dict[Any, Any]], + ) -> None: + r""" + Sync the attributes from the source parameter groups to the destination parameter groups. + + Example attributes include learning rate or scheduler attributes. The + two parameter groups should have the same length (i.e. same number of + parameter groups). + + Arguments: + src_param_groups (list[dict]): parameter groups giving the + attribute settings to copy. + dst_param_groups (list[dict]): parameter groups giving the + attribute settings to set. + """ + assert len(src_param_groups) == len(dst_param_groups), ( + "Mismatch between number of source and destination parameter groups" + ) + for src_param_group, dst_param_group in zip(src_param_groups, dst_param_groups): + # Sync all attributes except the parameters + for attr in filter(lambda x: x != "params", src_param_group.keys()): + dst_param_group[attr] = src_param_group[attr] + + def _build_param_buckets(self) -> None: + r""" + Build parameter buckets if ``parameters_as_bucket_view=True``. + + For each device that stores this rank's parameters, there is a + bucket (represented as a tensor) containing all of the parameters on + that device that are assigned to a given rank in the parameter update + partition. + + This method is called in the constructor and any time parameter + trainability is changed. + + .. warning:: + The current implementation assumes that all of the parameters in a + bucket are of the same dense type when allocating the bucket's + tensor. + + .. warning:: + If the model parameters are stored across more than one device, + then the storage partitioning must be the same across all + processes in order for parameter synchronization to work. + """ + if not self.parameters_as_bucket_view or self._overlap_with_ddp: + return + + # `self._buckets[i][j]` are the parameters stored on device i and + # assigned to rank j + num_devices = len(self._device_to_params_per_rank) + self._buckets = [[] for _ in range(num_devices)] # type: ignore[assignment] + + for dev_i, (device, params_per_rank) in enumerate( + self._device_to_params_per_rank.items() + ): + for params in params_per_rank: + bucket_size = 0 + dtype = None + trainable_params = [] + for param in params: + if not _is_trainable(param): + # Clone in case the parameter was previously part of + # a bucket to avoid the data from being destroyed + param.data = param.data.detach().clone() + else: + bucket_size += param.numel() + trainable_params.append(param) + dtype = param.dtype # assumes all same dtype + + if bucket_size == 0: + # Create a dummy bucket if there are no parameters + bucket = torch.zeros(1, device=device) + else: + # Construct the bucket (assuming all dense and same dtype) + bucket = torch.empty(bucket_size, dtype=dtype, device=device) + offset = 0 + for param in trainable_params: + offset_next = offset + param.numel() + bucket[offset:offset_next].copy_(param.data.flatten()) + param.data = bucket[offset:offset_next].view_as(param.data) + offset = offset_next + self._buckets[dev_i].append(bucket) # type: ignore[arg-type] + + def _build_ddp_param_buckets(self) -> None: + r""" + Build the DDP bucket with parameters assigned to this rank. + + For each DDP bucket with parameters assigned to this rank, flattens the + data of those parameters into a single tensor and saves the tensor to + the ``tensor`` attribute in the corresponding + :class:`_DDPBucketAssignment` instance stored in + ``self._bucket_assignments_per_rank``. + + :class:`DistributedDataParallel` guarantees that the parameters + corresponding to a gradient bucket have the same device and the same + dtype. + """ + for bucket_assignments in self._bucket_assignments_per_rank: + for bucket_assignment in bucket_assignments.values(): + params = bucket_assignment.parameters + bucket_size = 0 + dtype = None + for param in params: + assert _is_trainable(param), ( + "Model parameter " + "corresponding to a gradient in a DDP bucket should " + "require a gradient" + ) + bucket_size += param.numel() + dtype = param.dtype # assumes all same dtype + assert bucket_size > 0, "Empty bucket" + + # Construct the bucket tensor (assuming all dense and same dtype) + tensor = torch.empty( + bucket_size, dtype=dtype, device=bucket_assignment.device + ) + offset = 0 + for param in params: + offset_next = offset + param.numel() + tensor[offset:offset_next].copy_(param.data.flatten()) + param.data = tensor[offset:offset_next].view_as(param.data) + offset = offset_next + bucket_assignment.tensor = tensor + + def _verify_and_init_params( + self, + params: Any, + ) -> list[torch.Tensor] | list[dict]: + r""" + Verify the type of ``params`` and initializes ``self._all_params`` as a :class:`list` of all parameters. + + The initializagtion will first make sure that provided ``params`` is valid. + + Arguments: + params (Any): Candidate parameter list or parameter groups to verify. + + Raises: + TypeError: ``params`` has an invalid type. + ValueError: ``params`` is empty. + + Returns: + The persistent form of ``params`` to be passed into the parent + :class:`Optimizer` constructor -- i.e. returns ``params`` as a + :class:`list` to ensure that it can be iterated over again. + """ + if isinstance(params, torch.Tensor): + raise TypeError( + "`params` argument should be an iterable of " + f"Tensors, but got {torch.typename(params)}" + ) + try: + all_params = list(params) + except TypeError as e: + raise TypeError( + "`params` argument should be an iterable of Tensors" + f" or dicts, but got {torch.typename(params)}" + ) from e + if len(all_params) == 0: + raise ValueError("ZeroRedundancyOptimizer got an empty parameter list") + all_tensors = True + all_dicts = True + for param in all_params: + all_tensors &= isinstance(param, torch.Tensor) + all_dicts &= isinstance(param, dict) + if not all_tensors and not all_dicts: + raise TypeError( + "`params` argument should be an iterable of Tensors or dicts" + ) + # Ensure that `self._all_params` contains a list of all parameters + if all_tensors: + self._all_params = all_params + elif all_dicts: + self._all_params = [] + # `all_params` contains parameter groups (not parameters) + for param_group in all_params: + if "params" not in param_group: + raise ValueError( + "Each parameter group passed-in via `params` must " + "have a 'params' key mapping to the parameters in " + "the group" + ) + self._all_params.extend(param_group["params"]) + return all_params + + def _verify_same_dense_param_type(self) -> None: + r""" + Verify that all parameters are of the same dense type. + + The method assumes that ``self._all_params`` has been initialized + and is non-empty. + + Raises: + ValueError: ``params`` contains sparse parameters or parameters + of varying dense types. + + NOTE: This method can be removed once support for sparse parameters + and varying parameter types is added. + """ + typename = torch.typename(self._all_params[0]) + if self._all_params[0].is_sparse: + raise ValueError( + "ZeroRedundancyOptimizer only supports using " + "the same dense type for all parameters but got " + f"{typename}" + ) + for param in self._all_params[1:]: + other_typename = torch.typename(param) + if other_typename != typename: + raise ValueError( + "ZeroRedundancyOptimizer only supports " + "using the same dense type for all " + f"parameters but got both {typename} and " + f"{other_typename}" + ) + + def _get_is_trainable_mask(self) -> list[bool]: + r"""Return a boolean mask indicating if each parameter is trainable (``requires_grad``) or not.""" + return list(map(_is_trainable, self._all_params)) + + def _init_local_optimizer(self) -> None: + r""" + Initialize this rank's local optimizer, responsible for its subset of the parameters. + + The local optimizer is saved in ``self.optim``. + """ + assert self._optim_constructor is not None, ( + "The local optimizer class has not been set" + ) + + param_groups = self._partition_parameters()[self.rank] + # `overlap_with_ddp=True` requires a local functional optimizer + if self._overlap_with_ddp: + # Functional optimizers only support a single parameter group and + # require passing in the parameters as a list + assert len(param_groups) == 1, ( + "Initializing the local " + "functional optimizer with more than one parameter group" + ) + params = param_groups[0]["params"] + # Try to pass `_allow_empty_param_list=True` to avoid erroring + if ( + "_allow_empty_param_list" + in inspect.signature(self._optim_constructor).parameters + ): + self.optim: Any = self._optim_constructor( + params, **self._optim_defaults, _allow_empty_param_list=True + ) + else: + logger.warning( + "%s does not support the argument " + "`_allow_empty_param_list`; ZeroRedundancyOptimizer may " + "error due to an empty parameter list", + self._optim_constructor, + ) + self.optim: Any = self._optim_constructor( + params, **self._optim_defaults + ) # type: ignore[no-redef] + + # Log information about the DDP and ZeRO bucketing + if dist.get_debug_level() != dist.DebugLevel.OFF: + local_numel = sum(p.numel() for p in params) + num_assigned_buckets = len( + self._bucket_assignments_per_rank[self.global_rank] + ) + logger.info( + "rank %s with %s parameters across %s buckets", + self.global_rank, + local_numel, + num_assigned_buckets, + ) + if self.global_rank == 0: + logger.info( + "%s DDP buckets and %s bucket assignments", + len(self._overlap_info.params_per_bucket), + self._overlap_info.num_bucket_assignments, + ) + else: + # NOTE: Passing `param_groups` into the local optimizer constructor + # bypasses the empty parameter list check + self.optim: Optimizer = self._optim_constructor( + param_groups, **self._optim_defaults + ) # type: ignore[no-redef] + + # TODO: Manually add `self.param_groups` if using a functional + # optimizer; remove this if/when the functional optimizers support + # multiple parameter groups + if self._overlap_with_ddp and not hasattr(self.optim, "param_groups"): + assert hasattr(self.optim, "param_group"), ( + "The functional optimizer should set at least one of the " + "attributes `param_group` or `param_groups`" + ) + self.optim.param_groups = [self.optim.param_group] # type: ignore[attr-defined] + + self._sync_param_groups(self.optim.param_groups, self.param_groups) + + def _init_zero_for_overlap(self) -> None: + r"""Perform a delayed initialization of the local optimizer and the supporting data structures.""" + assert self._overlap_with_ddp, ( + "`_init_zero_for_overlap()` should only be called when " + "`overlap_with_ddp=True`" + ) + self._overlap_info.status = _OverlapStatus.INITIALIZED + self._clear_cache() + self._partition_parameters(self._overlap_info.params_per_rank) + self._build_ddp_param_buckets() + self._init_local_optimizer() + + def _get_assigned_rank(self, bucket_index: int) -> int: + r""" + Return the single rank assigned to a :class:`DistributedDataParallel` gradient bucket. + + Arguments: + bucket_index (int): index of the :class:`DistributedDataParallel` + bucket for which to get the assigned rank. + """ + assert not self._overlap_info.shard_buckets, ( + "The bucket assignment requires global bucket information and " + "will be computed later; there should be no need to use this " + "method" + ) + return bucket_index % self.world_size + + def _check_overlap_initialized(self): + r""" + Check the delayed initialization depending on the value of ``overlap_with_ddp``. + + The delayed initialization has occurred (see + :meth:`_init_zero_for_overlap`) if ``overlap_with_ddp=True``, and + raises a ``RuntimeError`` if not. This should preface methods that + should not be run before that delayed initialization. + + Raises: + RuntimeError: if ``overlap_with_ddp=True`` and + :meth:`_init_zero_for_overlap` has not been called. + """ + if ( + self._overlap_with_ddp + and self._overlap_info.status != _OverlapStatus.INITIALIZED + ): + raise RuntimeError( + "This method should not be called until this " + "ZeroRedundancyOptimizer instance has been fully " + "initialized" + ) + + def _get_optimizer_constructor(self, optimizer_class: Any) -> Any: + r""" + Return the optimizer constructor using validation and transformation depending on ``overlap_with_ddp``. + + Returns: + - ``optimizer_class`` if ``overlap_with_ddp=False`` and + ``optimizer_class`` is not a functional optimizer. + - ``optimizer_class`` if ``overlap_with_ddp=True`` and + ``optimizer_class`` is already a functional optimizer. + - The functional equivalent of ``optimizer_class`` if + ``overlap_with_ddp=True`` and ``optimizer_class`` is not + already a functional optimizer (assuming the equivalent + exists). + + Raises: + ValueError: + + - if ``overlap_with_ddp=True`` but ``optimizer_class`` is + neither a functional optimizer nor translatable to a + functional optimizer. + - if ``overlap_with_ddp=False`` and ``optimizer_class`` is a + functional optimizer. + """ + functional_optims = functional_optim_map.values() + if not self._overlap_with_ddp: + if optimizer_class in functional_optims: + # Using a functional optimizer is only supported when + # `overlap_with_ddp=True` + raise ValueError( + f"Passing in a functional optimizer {optimizer_class} " + "when `overlap_with_ddp=False`" + ) + else: + return optimizer_class + else: + if optimizer_class in functional_optims: + # Already a functional optimizer + return optimizer_class + elif optimizer_class in functional_optim_map: + # Translate the passed-in optimizer class to its functional + # equivalent if `overlap_with_ddp=True` + optim_constructor = functional_optim_map[optimizer_class] + logger.info( + "Using the functional optimizer %s " + "instead of %s since " + "`overlap_with_ddp=True`", + optim_constructor, + optimizer_class, + ) + return optim_constructor + else: + raise ValueError( + "Using `ddp_with_overlap=True` requires using a " + "functional optimizer, but there is no supported functional " + f"optimizer equivalent for {optimizer_class}" + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/zero_redundancy_optimizer.pyi b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/zero_redundancy_optimizer.pyi new file mode 100644 index 0000000000000000000000000000000000000000..8ffbb04f13ffcfdba07589eac0594c80cc28968d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/optim/zero_redundancy_optimizer.pyi @@ -0,0 +1,85 @@ +# mypy: allow-untyped-defs +import enum +from collections.abc import Callable +from typing import Any, overload + +import torch +from torch.distributed.algorithms.join import Joinable, JoinHook +from torch.optim import Optimizer + +class _ZeROJoinHook(JoinHook): + zero: Any = ... + def __init__(self, zero: Any) -> None: ... + def main_hook(self) -> None: ... + +class _DDPBucketAssignment: + bucket_index: int + parameters: list[torch.Tensor] + offset: int + device: torch.device + tensor: torch.Tensor | None + +class _OverlapStatus(enum.IntEnum): + UNINITIALIZED = ... + DDP_HAS_REBUILT_BUCKETS = ... + INITIALIZED = ... + +class _OverlapInfo: + status: Any = ... + params_per_bucket: Any = ... + params_per_rank: Any = ... + offsets: Any = ... + broadcast_handles: Any = ... + bucket_index_to_future: Any = ... + bucket_index_to_bucket: Any = ... + bucket_indices_seen: Any = ... + assigned_ranks_per_bucket: list[set[int]] = ... + total_size: int = ... + shard_buckets: bool = ... + def __init__(self) -> None: ... + def wait_for_broadcasts(self) -> None: ... + def clear_per_iter_info(self) -> None: ... + +class ZeroRedundancyOptimizer(Optimizer, Joinable): + functional_optim_map: Any = ... + initialized: bool = ... + process_group: Any = ... + world_size: int = ... + rank: int = ... + global_rank: int = ... + parameters_as_bucket_view: bool = ... + optim: Any = ... + _device_to_device_index: dict[torch.device, int] = ... + _overlap_with_ddp: bool = ... + _overlap_info: _OverlapInfo = ... + _buckets: list[list[torch.Tensor]] = ... + _bucket_assignments_per_rank: list[dict[int, _DDPBucketAssignment]] = ... + def __init__( + self, + params: Any, + optimizer_class: type[Optimizer], + process_group: Any | None = ..., + parameters_as_bucket_view: bool = ..., + overlap_with_ddp: bool = ..., + **defaults: Any, + ) -> None: ... + def add_param_group(self, param_group: dict[str, Any]) -> None: ... + def consolidate_state_dict(self, to: int = ...) -> None: ... + @overload + def step(self, closure: None = None, **kwargs: Any) -> None: ... + @overload + def step(self, closure: Callable[[], float], **kwargs: Any) -> float: ... + def load_state_dict(self, state_dict: dict[str, Any]) -> None: ... + def state_dict(self) -> dict[str, Any]: ... + def _local_step( + self, + gradients: list[torch.Tensor | None] | None = None, + closure: Callable[[], float] | None = None, + **kwargs: Any, + ) -> float | None: ... + def _get_assigned_rank(self, bucket_index: int) -> int: ... + def _init_zero_for_overlap(self) -> None: ... + def join_hook(self, **kwargs): ... + @property + def join_device(self) -> torch.device: ... + def join_process_group(self) -> Any: ... diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py new file mode 100644 index 0000000000000000000000000000000000000000..eae7def75d5faf1b9427e7a477b866b8229b1651 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py @@ -0,0 +1,1257 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import copy +import logging +import operator +from collections import defaultdict +from collections.abc import Callable +from enum import Enum +from inspect import Parameter, Signature, signature +from types import MethodType +from typing import Any, Union + +import torch +import torch.fx as fx +from torch.distributed import ProcessGroup +from torch.export import ExportedProgram +from torch.export.unflatten import ( + _assign_attr, + _AttrKind, + _sink_params, + InterpreterModule, +) +from torch.fx.node import map_aggregate +from torch.fx.passes.split_module import split_module + +from ._backward import _null_coalesce_accumulate, stage_backward +from ._unflatten import _outline_submodules +from ._utils import PipeInfo +from .stage import _PipelineStage + + +logger = logging.getLogger(__name__) + +# TODO: +# 1. investigate gradient sync for shared parameters. how does DDP do it? +# 2. Add parameter movement to split_module + + +PP_SUBMOD_PREFIX = "submod_pp" + + +def get_submod_name(stage_idx: int): + """Returns the name of the submod for a given stage index. + For example, "submod_pp_0", "submod_pp_1", etc. + """ + return "_".join([PP_SUBMOD_PREFIX, str(stage_idx)]) + + +def _find_loss_from_output_and_spec(output_val, spec_val): + if spec_val is False: + return None + if spec_val is True: + if not isinstance(output_val, fx.Node): + raise RuntimeError( + f"Loss spec must specify a dynamic value but got {output_val}" + ) + return output_val + + if isinstance(spec_val, (tuple, list)): + if not isinstance(output_val, (tuple, list)): + raise RuntimeError( + f"Output value {output_val} must match type of loss specification " + f"{spec_val}" + ) + if len(output_val) != len(spec_val): + raise RuntimeError( + f"Output value {output_val} must match length of loss specification " + f"{spec_val}" + ) + for out, spec in zip(output_val, spec_val): + loss_val = _find_loss_from_output_and_spec(out, spec) + if loss_val is not None: + return loss_val + raise RuntimeError(f"Did not find loss value in specification {spec_val}") + + if isinstance(spec_val, dict): + if not isinstance(output_val, dict): + raise RuntimeError( + f"Output value {output_val} must match type of loss specification " + f"{spec_val}" + ) + if set(output_val.keys()) != set(spec_val.keys()): + raise RuntimeError( + f"Output value {output_val} must match keys of loss specification " + f"{spec_val}" + ) + for k in spec_val: + loss_val = _find_loss_from_output_and_spec(output_val[k], spec_val[k]) + if loss_val is not None: + return loss_val + raise RuntimeError(f"Did not find loss value in specification {spec_val}") + + raise RuntimeError(f"Unsupported type {type(spec_val)} in loss specification") + + +def _find_loss_output(mod: torch.nn.Module, g: fx.Graph, output_loss_value_spec): + output_nodes = [n for n in g.nodes if n.op == "output"] + assert len(output_nodes) == 1 + output_node = output_nodes[0] + output_val = output_node.args[0] + generated_spec: Any = None + + if isinstance(mod, TrivialLossWrapper): + # TrivialLossWrapper is pre-defined by PiPPy. + # It has loss as the only output so we can safely assume the first output arg is the loss. + assert len(output_node.args) == 1 + loss_node = output_val + generated_spec = TrivialLossWrapper.loss_spec + elif output_loss_value_spec is None: + # Use default spec, i.e. search for "loss" in output values + if isinstance(output_val, dict) and "loss" in output_val: + loss_node = output_val["loss"] + generated_spec = {k: k == "loss" for k in output_val} + else: + loss_node = None + generated_spec = None + else: + loss_node = _find_loss_from_output_and_spec(output_val, output_loss_value_spec) + generated_spec = output_loss_value_spec + + return loss_node, output_node, generated_spec + + +def _insert_stage_symbolic_backward( + g: fx.Graph, + loss_node: fx.Node, + output_node: fx.Node, +): + # Collect metadata about tuple output values. TODO: move this to split_module or FX IR + tuples: dict[fx.Node, tuple] = {} + for node in reversed(g.nodes): + if node.op == "call_function": + # In the forward pass, only emit placeholder, module calls, and + # getitem calls. If we have a target other than getitem in this + # (forward-only) code, there is a bug. + assert node.target is operator.getitem, ( + "Found non-getitem call in forward pass. Please report a bug to PiPPy" + ) + assert len(node.args) == 2, ( + "Found malformed getitem call. Please report a bug to PiPPy" + ) + indexed_value, node_idx = tuple(node.args) + + # indexed_value is a collection that we are indexing into. It could + # exist in the tuples map if we've processed another `getitem` + # already. + existing_list_size = ( + len(tuples[indexed_value]) if indexed_value in tuples else -1 + ) + new_list_size = max(node_idx + 1, existing_list_size) + + reconstructed_list = [None for _ in range(new_list_size)] + + # Copy over existing elements if present + if indexed_value in tuples: + for i, val in enumerate(tuples[indexed_value]): + reconstructed_list[i] = val + + # Populate value represented by this node + reconstructed_list[node_idx] = node + + tuples[indexed_value] = tuple(reconstructed_list) + + # Keep track of nodes that dominate the loss node. + # We will only emit backward operations for nodes that can contribute + # to the specified loss value. + live_nodes = {loss_node: None} + val_to_grad: dict[fx.Node, fx.Node | None] = {loss_node: None} + + def assign_or_accumulate_grad(forward_node, grad_value): + if forward_node in val_to_grad and forward_node.op != "placeholder": + grad_value = g.call_function( + _null_coalesce_accumulate, + (val_to_grad[forward_node], grad_value), + ) + val_to_grad[forward_node] = grad_value + + with g.inserting_before(output_node): + for node in reversed(g.nodes): + if node not in live_nodes: + continue + + def add_to_live_nodes(n): + live_nodes.setdefault(n, None) + + fx.node.map_arg(node.args, add_to_live_nodes) + fx.node.map_arg(node.kwargs, add_to_live_nodes) + if node.op == "call_module": + output_grads: tuple[fx.Node | None, ...] | fx.Node | None + if node in tuples: + stage_output = tuples[node] + output_grads = tuple(val_to_grad.get(n) for n in tuples[node]) + outputs_with_grads_idxs = [ + i for i, n in enumerate(tuples[node]) if n in live_nodes + ] + else: + stage_output = (node,) + output_grads = val_to_grad[node] + outputs_with_grads_idxs = [0] + + output_grads = ( + (output_grads,) + if not isinstance(output_grads, tuple) + else output_grads + ) + + grad_call = g.call_function( + stage_backward, + kwargs={ + "stage_output": stage_output, + "output_grads": output_grads, + "input_values": list(node.all_input_nodes), + "outputs_with_grads_idxs": outputs_with_grads_idxs, + }, + ) + # Insert backward stage debug info + kwargs_copy = dict(grad_call.kwargs) + grad_call.kwargs = kwargs_copy + + grad_call_proxy = fx.Proxy(grad_call) + grads = grad_call_proxy.node + + input_nodes = list(node.all_input_nodes) + grads_proxy = fx.Proxy(grads) + for i, input_node in enumerate(input_nodes): + assign_or_accumulate_grad(input_node, grads_proxy[i].node) # type: ignore[index] + + return g + + +class PipeSequential(torch.nn.Sequential): + @staticmethod + def from_sequential(sequential_instance: torch.nn.Sequential): + return PipeSequential(*[copy.copy(m) for m in sequential_instance]) + + def forward(self, input): + for i, module in enumerate(self): + input = module(input) + if i != len(self) - 1: + pipe_split() + return input + + +class LossWrapper(torch.nn.Module): + """ + LossWrapper is a convenient abstract class that allows you to wrap up both + your model as well as its loss function and specify the connectivity between + the inputs, model, loss function, and output value. Example:: + + class MyModelWrapper(LossWrapper): + def forward(self, x, targets): + model_out = self.module(x) + loss_value = self.loss_fn(model_out, targets) + return loss_value + + The above example defines a connectivity where we expect the forward/loss/backward + training procedure to take two arguments (x and targets), pass x into the module + to get the output of the feedforward computation, pass the model output and the + targets value into the loss function, and get and return the loss value, which will + be backpropagated by PiPPy. The above class would then be instantiated like:: + + model = ... # instantiate the model + loss_fn = torch.nn.MSELoss() # for the sake of demonstration + + wrapper = MyModelWrapper(model, loss_fn) + pipe = Pipe.from_tracing(wrapper, ...) + + """ + + def __init__(self, module, loss_fn): + super().__init__() + self.module = module + self.loss_fn = loss_fn + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "This instance of LossWrapper does not have an overridden" + "forward(). Please implement forward() to specify the arguments, " + "connection between the module and loss, and loss output " + "value." + ) + + +class TrivialLossWrapper(LossWrapper): + # pyrefly: ignore [bad-override] + def forward(self, x, targets): + model_out = self.module(x) + return self.loss_fn(model_out, targets) + + loss_spec = True + + +# Pipe model representation +# +# Pipe can be thought of as an `nn.Sequential++`. That is to say: it specifies +# a single topological ordering of pipeline "stages" that, when run in series, +# constitutes all of the operations of the program. However, unlike `nn.Sequential`, +# Pipe allows non-local usages of values, so long as those uses still respect +# topological ordering. In particular: +# +# 1. Non-local activations. This type of usage can appear in, for example, skip +# connections. These values will be directly transmitted from the "def" stage +# to all stages that use them skipping intermediate stages. During autograd, +# gradients will be propagated back through this skip connection reverse +# to how activations propagated in the forward pass. +# 2. Non-local parameter/module invocations. This occurs when a parameter is used +# in a stage downstream of where it is resident. These values can be carried +# forward similarly to (1), but in addition one might want to replicate the +# value on multiple stages. Gradients for these shared parameters will be +# accumulated separately on each stage, but there will be an additional +# gradient accumulation before the optimizer step. + + +# Register `_pipe_split()` as an ATen operator. This is required for Export to +# preserve this marker in the graph. +torch.library.define("pippy::_pipe_split", "() -> ()") + + +@torch.library.impl("pippy::_pipe_split", "BackendSelect") +def _pipe_split(): + return None + + +@torch.library.register_fake("pippy::_pipe_split") # type: ignore[no-redef] +def _pipe_split(): # noqa: F811 + return None + + +# Add an alias for convenience +aten_pipe_split_alias = torch.ops.pippy._pipe_split.default + +# Ask Export to preserve the `_pipe_split` op. +# See examples in pytorch/torch/fx/node.py +fx.node._side_effectful_functions.add(aten_pipe_split_alias) + + +# User facing API +def pipe_split(): + """ + pipe_split is a special operator that is used to mark the boundary between + stages in a module. It is used to split the module into stages. It is a + no-op if your annotated module is run eagerly. + + Example: + >>> # xdoctest: +SKIP + >>> def forward(self, x): + >>> x = torch.mm(x, self.mm_param) + >>> x = torch.relu(x) + >>> pipe_split() + >>> x = self.lin(x) + >>> return x + + The above example will be split into two stages. + """ + return torch.ops.pippy._pipe_split() + + +class MultiUseParameterConfig(Enum): + TRANSMIT = 1 + REPLICATE = 2 + + +MultiUseParamSpec = Union[MultiUseParameterConfig, dict[str, MultiUseParameterConfig]] + + +class DetachExecutor(fx.Interpreter): + """ + Special interpreter to run the split_gm in testing that detaches all inputs to + a module invocation. This is needed so that the values at the boundary are + leaf modules in autograd execution. + """ + + def __init__(self, module, garbage_collect_values=True): + garbage_collect_values = False + super().__init__(module, garbage_collect_values) + self.value_remap = {} + + def run(self, *args, initial_env=None): # type: ignore[override] + self.value_remap = {} + return super().run(*args, initial_env=initial_env) + + def call_module(self, target, args, kwargs): + def detach_tensors(a): + if isinstance(a, torch.Tensor) and a.requires_grad: + if a not in self.value_remap: + new_val = a.detach().requires_grad_(True) + self.value_remap[a] = new_val + return self.value_remap[a] + else: + return a + + """ + def dont_traverse_size(a): + return type(a) is not torch.Size + """ + + args = map_aggregate( + args, + detach_tensors, # dont_traverse_size + ) + kwargs = map_aggregate( + kwargs, + detach_tensors, # dont_traverse_size + ) + + return super().call_module(target, args, kwargs) + + def call_function(self, target, args, kwargs): + # HACK to reroute saved input tensors to point to the detach()ed version + if target is stage_backward: + kwargs = dict(kwargs) + kwargs["input_values"] = [ + self.value_remap.get(v, v) for v in kwargs["input_values"] + ] + return super().call_function(target, args, kwargs) + + +class _NodeReference: + def __init__(self, name): + self.name = name + + name: str + + +class _LinearNodeList: + def __init__(self, node_list): + self.serialize_node_list = [] + for node in node_list: + node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value] + node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value] + serialize_node = fx.Node( + graph=None, # type: ignore[arg-type] + name=node.name, + op=node.op, + target=node.target, + args=node_args, # type: ignore[arg-type] + kwargs=node_kwargs, # type: ignore[arg-type] + return_type=node.type, + ) + serialize_node.meta = copy.copy(node.meta) + self.serialize_node_list.append(serialize_node) + + def to_graph(self): + graph = fx.Graph() + + ref_str_to_node: dict[str, fx.Node] = {} + + def ref_to_node(arg): + if isinstance(arg, _NodeReference): + return ref_str_to_node[arg.name] + else: + return arg + + for node in self.serialize_node_list: + node_args = map_aggregate(node.args, ref_to_node) + node_kwargs = map_aggregate(node.kwargs, ref_to_node) + deser_node = graph.create_node( + op=node.op, + target=node.target, + args=node_args, # type: ignore[arg-type] + kwargs=node_kwargs, # type: ignore[arg-type] + name=node.name, + type_expr=node.type, + ) + ref_str_to_node[node.name] = deser_node + + return graph + + +def _direct_serialization_deserialize(body, nodes): + """ + Custom `__reduce__` method for serialization. + DO AS I SAY -- NOT AS I DO. This violates the principle that + GraphModules serialize via code export & re-tracing. We allow + for this here because **PIPE STAGES SHOULD NOT BE PERSISTED + TO DISK -- THIS IS ONLY FOR TRANSMISSION VIA RPC**. Persisting + these instances to disk will expose internal implementation + details of `fx.Graph` and related data structures and is + NOT advised. + """ + + class DummyModule(torch.nn.Module): + def __init__(self, body): + super().__init__() + self.__dict__.update(body) + + dummy = DummyModule(body) + + return fx.GraphModule(dummy, nodes.to_graph()) + + +def _direct_serialization_reduce(self): + serialization_dict = dict(self.__dict__) + serialization_dict.pop("_graph") + return ( + _direct_serialization_deserialize, + (serialization_dict, _LinearNodeList(self.graph.nodes)), + ) + + +def _modify_graph_op_device( + gm: torch.fx.GraphModule, + new_device: torch.device, +): + """ + Modify the device argument of all "call_function" nodes in the graph. This + is useful for moving the graph to a different device. In particular for + generator ops, like torch.ones. + """ + modified = False + for node in gm.graph.nodes: + if node.op == "call_function": + if "device" in node.kwargs and node.kwargs["device"] != new_device: + logger.debug( + f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" # noqa: G004 + ) + node.update_kwarg("device", new_device) + modified = True + elif node.op == "call_module": + # Recursively modify "device" in submodules + submod = gm.get_submodule(node.target) + if isinstance(submod, torch.fx.GraphModule): + _modify_graph_op_device(submod, new_device) + elif isinstance(submod, InterpreterModule): + # If unflattening has been performed, we need to access its graph module by `.graph_module` + _modify_graph_op_device(submod.graph_module, new_device) # type: ignore[arg-type] + else: + logger.warning( + f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" # noqa: G004 + ) + + if modified: + gm.recompile() + + +class Pipe(torch.nn.Module): + def __init__( + self, + split_gm: fx.GraphModule, + num_stages: int, + has_loss_and_backward: bool, + loss_spec, + ): + # TODO: is there a way not to hard wire init? + torch.nn.Module.__init__(self) + self.split_gm: fx.GraphModule = split_gm + self.executor: DetachExecutor = DetachExecutor(self.split_gm) + self.num_stages: int = num_stages + self.has_loss_and_backward = has_loss_and_backward + self.loss_spec = loss_spec + + for node in split_gm.graph.nodes: + assert ( + node.op in {"call_module", "placeholder", "output"} + or (node.op, node.target) == ("call_function", operator.getitem) + or (node.op, node.target) == ("call_method", "backward") + or (node.op, node.target) == ("call_function", stage_backward) + or (node.op, node.target) + == ("call_function", _null_coalesce_accumulate) + ), node + + # Detect replicated parameters so we know that we have to do an additional allreduce + # before applying the optimizer + # + # Note that this also handles the case where there were multiple calls to a single + # module from different stages, regardless of whether that module invocation + # was handled by the logic above. + + # Map parameter value to a dictionary that maps the user pipeline module + # to the local qualname within that module + params_to_users: dict[torch.nn.Parameter, dict[str, str]] = {} + + for m_qualname, mod in self.split_gm.named_children(): + for p_qualname, param in mod.named_parameters(): + params_to_users.setdefault(param, {}) + params_to_users[param][m_qualname] = p_qualname + + self.replicated_params: list[dict[str, str]] = [ + use_mapping + for _, use_mapping in params_to_users.items() + if len(use_mapping) > 1 + ] + + # We must break the aliasing relationship between the replicated parameters for correct + # numerics in reference runs. If we do not do this, the autograd tape in separate stages + # will have a reference to the same tensor value and will erroneously apply gradient + # updates multiple times. Therefore, for each replicated parameter set, we deepcopy the + # values so that we have separate instances. + for param_mapping in self.replicated_params: + for submod_name, param_qualname in param_mapping.items(): + submod = getattr(self.split_gm, submod_name) + atoms = param_qualname.split(".") + for atom in atoms[:-1]: + submod = getattr(submod, atom) + setattr(submod, atoms[-1], copy.deepcopy(getattr(submod, atoms[-1]))) + + def throw(self, *args, **kwargs): + raise RuntimeError( + "To run pipeline locally, invoke the Pipe object directly, not `split_gm`" + ) + + self.split_gm.forward = throw + + # Make submodules use custom direct-serialized GraphModule + i = 0 + while True: + try: + name = get_submod_name(i) + submod = getattr(self.split_gm, name) + submod.__class__.__reduce__ = _direct_serialization_reduce + i += 1 + except AttributeError: + break + + def forward(self, *args, **kwargs): + executor_args = args + if len(kwargs) > 0: + parameters = [] + for node in self.split_gm.graph.nodes: + if node.op == "placeholder": + if node.args and len(node.args) > 0: + parameters.append( + Parameter( + node.target, + Parameter.POSITIONAL_OR_KEYWORD, + default=node.args[0], + ) + ) + else: + parameter_kind = Parameter.POSITIONAL_OR_KEYWORD + param_name = node.target + if node.target.startswith("**"): + parameter_kind = Parameter.VAR_KEYWORD # type: ignore[assignment] + param_name = param_name[2:] + elif node.target.startswith("*"): + parameter_kind = Parameter.VAR_POSITIONAL # type: ignore[assignment] + param_name = param_name[1:] + parameters.append(Parameter(param_name, parameter_kind)) + signature = Signature(parameters) + ba = signature.bind(*args, **kwargs) + ba.apply_defaults() + executor_args = ba.arguments.values() # type: ignore[assignment] + + res = self.executor.run(*executor_args) + + return res + + def get_stage_module(self, stage_idx: int) -> torch.nn.Module: + """ + Return a stage module corresponding to `stage_idx` of the `pipe`. + """ + if stage_idx < 0 or stage_idx >= self.num_stages: + raise ValueError(f"Invalid stage index {stage_idx}!") + + submod_name = get_submod_name(stage_idx) + return getattr(self.split_gm, submod_name) + + @staticmethod + def _number_and_count_forward_stages(gm: fx.GraphModule): + num_stages = 0 + found_idxs: dict[int, None] = {} + for node in gm.graph.nodes: + if node.op == "call_module" and node.target.startswith(PP_SUBMOD_PREFIX): + node.meta["stage_idx"] = int(node.target[len(PP_SUBMOD_PREFIX) + 1 :]) + found_idxs.setdefault(node.meta["stage_idx"]) + num_stages += 1 + + # this assert will fail if a split point is inserted before the first layer, which creates empty first submodule + # Update: the following assert may fail against some torch versions >= + # 2.2.0, as: + # submod_0, submod_1, submod_2, ... + # may be named as + # submod_0, submod_2, submod_4, ... + # TODO: investigate + # assert all(i in found_idxs for i in range(num_stages)) + + return num_stages + + @staticmethod + def _from_traced( + mod: torch.nn.Module, + exported_program: ExportedProgram, + multi_use_param_spec: MultiUseParamSpec | None = None, + output_loss_value_spec=None, + split_policy: Callable[[torch.fx.GraphModule], torch.fx.GraphModule] + | None = None, + ): + """ + Additionally, the ``output_loss_value_spec`` value can be specified to disambiguate + which value in the output of `forward` is the loss value on which PiPPy should apply + backpropagation. For example, if your ``forward`` returns a tuple ``(loss, model_out)``, + you can specify ``output_loss_value_spec=(True, False)``. Or, if your ``forward`` returns + a dict ``{'loss': loss_value, 'model_out': model_out}``, you can specify + ``output_loss_value_spec={'loss': True, 'model_out': False}`` + """ + + traced = exported_program.module(check_guards=False) + + if split_policy is not None: + logger.info("Auto-splitting model") + traced = split_policy(traced) # type: ignore[arg-type] + + logger.debug(traced.print_readable(print_output=False)) # type: ignore[operator] + + # Deduplicate `get_attr` nodes that refer to the same parameter . Downstream code for moving + # parameters relies on the invariant that parameter accesses happen once. This is not necessarily + # the case (especially with custom tracers), so fix that up here. + get_attr_nodes: dict[str, fx.Node] = {} + for node in traced.graph.nodes: # type: ignore[union-attr] + if node.op == "get_attr": + get_attr_nodes.setdefault(node.target, node) + + if get_attr_nodes[node.target] != node: + node.replace_all_uses_with(get_attr_nodes[node.target]) + traced.graph.erase_node(node) # type: ignore[operator, union-attr] + + # avoid looking at next node by keeping track of previous pipe_split + prev_pipe_split_idx = -1 + pipe_split_nodes_to_erase = set() + for i, node in enumerate(traced.graph.nodes): # type: ignore[arg-type, union-attr] + if (node.op, node.target) == ("call_function", pipe_split): + if prev_pipe_split_idx == i - 1: + pipe_split_nodes_to_erase.add(node) + prev_pipe_split_idx = i + + for node in pipe_split_nodes_to_erase: + traced.graph.erase_node(node) # type: ignore[operator, union-attr] + + traced.recompile() # type: ignore[operator] + + part_idx = 0 + + def split_callback(n: fx.Node): + nonlocal part_idx + if (n.op, n.target) == ( + "call_function", + aten_pipe_split_alias, + ): + logger.debug(f"Found pipe_split {part_idx}") # noqa: G004 + part_idx += 1 + return part_idx + + # TODO: what does split do with module invocations? does it move the modules + # into the submodules? + split = split_module(traced, mod, split_callback, partition_affix="pp") # type: ignore[arg-type] + # a (custom) tracer can produce dead code like orphan get_attr nodes + split.graph.eliminate_dead_code() + + # peephole to remove pipe_split + for submodule in split.modules(): + if isinstance(submodule, fx.GraphModule): + for node in submodule.graph.nodes: + if (node.op, node.target) == ( + "call_function", + aten_pipe_split_alias, + ): + submodule.graph.erase_node(node) + submodule.recompile() + + for name, submodule in split.named_children(): + if isinstance(submodule, fx.GraphModule): + new_submod = _outline_submodules(submodule.graph) + # Replace old submod + split.register_module(name, new_submod) + + # TODO: backport this into split_module + def delete_user_reference(node, user): + """ + Delete reference of `node` from `user`'s arg list. + Args: + - node: a `get_attr` node at root. + - user: a submodule node that uses `node`. + """ + assert len(user.kwargs) == 0 + use_idxs = [i for i, arg in enumerate(user.args) if arg == node] + assert len(use_idxs) == 1 + args_copy = list(user.args) + args_copy.pop(use_idxs[0]) + user.args = tuple(args_copy) + logger.debug( + f"Deleted {node} from user {user}, arg index = {use_idxs[0]}" # noqa: G004 + ) + + # A list of param referrals for deferred deletion. + # To be accumulated in `move_param_to_callee`. + to_delete = [] + + def _recursive_getattr_with_parent(mod, fqn): + # Returns getattr call given a nested FQN, and the last parent + atoms = fqn.split(".") + for atom in atoms[:-1]: + if not hasattr(mod, atom): + return None, None + mod = getattr(mod, atom) + if not hasattr(mod, atoms[-1]): + return mod, None + attr = getattr(mod, atoms[-1]) + return mod, attr + + def move_param_to_callee( + root, + callee_name, + param_fqn, + ): + """ + Move a parameter from the root module to a submodule. + Args: + root: The root module. + callee_name: The name of the submodule to move the parameter to. + param_fqn: The fully qualified name of the parameter to move. + """ + # `atoms` is a list of strings representing the path to the + # parameter in the original model + atoms = param_fqn.split(".") + mod_itr, param_val = _recursive_getattr_with_parent(split, param_fqn) + # Check whether the parameter is a buffer or a parameter + is_buffer = atoms[-1] in mod_itr._buffers + + # Check whether the parameter is a tensor + assert isinstance(param_val, torch.Tensor), ( + f"Expected '{param_fqn}' to be {torch.Tensor} but got {type(param_val)}." + + ( + f" It might happen if module '{param_fqn}' was passed to some 'leaf function'" + f"(see https://pytorch.org/docs/stable/fx.html#fx.wrap). Please inspect " + f"usages of '{param_fqn}' in the traced graph." + if isinstance(param_val, torch.nn.Module) + else "" + ) + ) + + # Get submodule + callee = root.get_submodule(callee_name) + assert not hasattr(callee, param_fqn), ( + f"Module {callee_name} already has a parameter named {param_fqn}" + ) + + # Assign the parameter to the submodule + if is_buffer: + _assign_attr( + param_val, + callee, + param_fqn, + attr_kind=_AttrKind.BUFFER, + persistent=True, # TODO: handle non-persistent buffer + ) + else: + _assign_attr( + param_val, + callee, + param_fqn, + attr_kind=_AttrKind.PARAMETER, + ) + logger.debug(f"Moved parameter {param_fqn} to {callee_name}") # noqa: G004 + + # Next step is to replace placeholder of submodule with a get_attr. + # Those placeholders are created by `split_module` inside each + # submodule. + # Update: this step is now moved to `_sink_params` because + # `_sink_params` can do it recursively (i.e. for modules inside + # submodule) + + to_delete.append((mod_itr, atoms[-1])) + + # Get the list of all parameters in the root module + attr_nodes = list(filter(lambda n: n.op == "get_attr", split.graph.nodes)) + for node in attr_nodes: + # Check whether the parameter is used in only one submodule + if len(node.users) > 1: + logger.info( + f"Parameter {node.target} used in multiple stages: {node.users}." # noqa: G004 + ) + for user in node.users: + assert user.op == "call_module" + # Move parameter into submodule + move_param_to_callee( + split, + user.target, + node.target, + ) + + # [aliasing] store tensor id -> list of FQNs, built from state dict + # Also assign non-persistent buffers + id_to_fqns: dict[int, set[str]] = defaultdict(set) + for fqn, tensor in mod.state_dict(keep_vars=True).items(): + id_to_fqns[id(tensor)].add(fqn) + for fqn, tensor in mod.named_buffers(): + id_to_fqns[id(tensor)].add(fqn) + + # After moving the params to their corresponding hierarchies, we also + # need to move the `get_attr` nodes from the root of the graph to those + # hierarchies. + # [aliasing] use id -> fqn mapping to list out all valid FQNs + inputs_to_state: dict[str, list[str]] = {} + for attr in attr_nodes: + _, tensor = _recursive_getattr_with_parent(mod, attr.target) + fqns = list(id_to_fqns[id(tensor)]) + if fqns: + inputs_to_state[attr.name] = fqns + elif attr.target in exported_program.constants: # lifted constants + inputs_to_state[attr.name] = [attr.target] + + # [aliasing] for each submodule split, assign attributes on FQNs that may be used. + # We determine this based on whether or not the FQN attribute parent exists. + # i.e. if the last submodule exists, assign the attribute. + added_attributes: dict[str, list[str]] = defaultdict(list) + for fqn, tensor in mod.state_dict(keep_vars=True).items(): + for name, submod in split.named_children(): + if isinstance(submod, fx.GraphModule): + parent, child = _recursive_getattr_with_parent(submod, fqn) + if ( + parent and child is None + ): # parent exists, attribute doesn't -> assign + added_attributes[name].append(fqn) + setattr(parent, fqn.split(".")[-1], tensor) + + # Deferral deletion: Remove the original attributes (to params) from the + # root GraphModule + for mod_itr, last_atom in to_delete: + try: + delattr(mod_itr, last_atom) + except AttributeError: + # This is expected if the parameter is used in multiple stages + pass + + # This is done by (1) `_sink_params` at each submodule; + for submod in split.children(): + if isinstance(submod, fx.GraphModule): + _sink_params(submod, inputs_to_state, []) + submod.graph.lint() + submod.recompile() + + # [aliasing] This step is not super necessary, but helps reduce parameter usage/memory. + # After _sink_params() routine has run, clean up unused attributes that we previously added. + # Determine this based on the get_attr nodes - if not used, remove it. + for name, attributes in added_attributes.items(): + submod = getattr(split, name) + unused_attributes = set(attributes) + # track used attributes in the submodule, running DFS on subgraph hierarchy + stack = [("", submod)] # (scope, submodule) + while stack: + scope, _mod = stack.pop() + if isinstance(_mod, (fx.GraphModule, InterpreterModule)): + for node in _mod.graph.nodes: + if node.op == "get_attr": + # get_attr might get access deeper level attribute + fqn = scope + "." + node.target if scope else node.target + unused_attributes.discard(fqn) + for _name, _submod in _mod.named_children(): + stack.append((scope + "." + _name if scope else _name, _submod)) + # delete unused attributes + for attr in unused_attributes: + mod_itr, atoms = submod, attr.split(".") + for atom in atoms[:-1]: + mod_itr = getattr(mod_itr, atom) + delattr(mod_itr, atoms[-1]) + + for node in attr_nodes: + # And (2): remove `get_attr` node from submod's arg list + for user in copy.copy(node.users): + assert user.op == "call_module" + delete_user_reference(node, user) + # And (3): remove the `get_attr` node from the root graph. + split.graph.erase_node(node) + + split.delete_all_unused_submodules() + split.graph.lint() + split.recompile() + + num_stages = Pipe._number_and_count_forward_stages(split) + + has_loss_and_backward = False + generated_loss_spec = output_loss_value_spec + + if output_loss_value_spec is not None: + loss_node, output_node, generated_loss_spec = _find_loss_output( + mod, split.graph, output_loss_value_spec + ) + if loss_node is not None: + _insert_stage_symbolic_backward( + split.graph, + loss_node, + output_node, + ) + split.recompile() + has_loss_and_backward = True + logger.debug("Pipeline is in training mode, backward pass generated") + else: + raise RuntimeError( + f"Did not find any loss value according to {output_loss_value_spec=}" + ) + else: + logger.debug("Pipeline is in inference mode, backward pass not generated") + + logger.debug(f"Full pipe model:\n{split}") # noqa: G004 + + return Pipe( + split, + num_stages, + has_loss_and_backward, + generated_loss_spec, + ) + + def print_readable(self): + """ + Print the pipe in a human-readable format. + This will print both the root pipe and each stage module. + """ + self.split_gm.print_readable() + + @staticmethod + def _trace_with_export( + mod: torch.nn.Module, + example_args: tuple[Any, ...], + example_kwargs: dict[str, Any] | None = None, + ) -> ExportedProgram: + logger.info("Tracing model ...") + try: + ep = torch.export.export(mod, example_args, example_kwargs) + except Exception as e: + raise RuntimeError( + "It seems that we cannot capture your model as a full graph. " + "Typical reasons include graph breaks, data/shape-dependent " + "control flow, or missing meta kernels for custom operators. " + "You can use our manual pipeline interfaces, or try to fix the " + "graph breaks, see https://pytorch.org/docs/stable/export.html" + ) from e + + return ep + + @staticmethod + def from_tracing( + mod: torch.nn.Module, + example_args: tuple[Any, ...], + example_kwargs: dict[str, Any] | None = None, + split_policy: Callable[[fx.GraphModule], fx.GraphModule] | None = None, + ): + # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across + # stages instead of TRANSMIT'ting it + multi_use_param_spec = MultiUseParameterConfig.REPLICATE + + # Figure out which output is loss from output_chunk_spec + output_loss_value_spec: Any = None + # Deprecated + """ + if output_chunk_spec is not None: + output_loss_value_spec = map_aggregate( + output_chunk_spec, lambda v: isinstance(v, _LossReducer) + ) + """ + + # Trace with export + exported_program = Pipe._trace_with_export( + mod, + example_args, + example_kwargs, + ) + + pipe = Pipe._from_traced( + mod, + exported_program, + multi_use_param_spec, + output_loss_value_spec=output_loss_value_spec, + split_policy=split_policy, + ) + + # Users want the first pipeline stage to accept kwargs if the original + # program does. This is controlled by the `_codegen` field of the graph, + # so we make a copy here. Note: we only want the input spec and not the + # output spec, because the output spec is for the last stage. Maybe a + # TODO? Not sure yet. + split = pipe.split_gm + traced = exported_program.module() + submod0 = next(iter(split.children())) + submod0_sign = signature(submod0.forward) + model_sign = signature(traced.forward) + if len(model_sign.parameters) != len(submod0_sign.parameters): + # We don't change the signature of the first stage if it takes + # different number of args than original model + logger.info( + f"Original model takes {len(model_sign.parameters)} args but the " # noqa: G004 + f"first pipeline stage takes {len(submod0_sign.parameters)}. " + "Please provide args to respective pipeline stages." + ) + else: + # Support kwargs for the first stage + submod0.graph._codegen = copy.deepcopy(traced.graph._codegen) # type: ignore[union-attr] + # `_replace` is actually not "private" or internal. based on this doc: + # To prevent conflicts with field names, the method and attribute names + # start with an underscore + submod0.graph._codegen.pytree_info = ( # type: ignore[union-attr] + submod0.graph._codegen.pytree_info._replace(out_spec=None) # type: ignore[operator, union-attr] + ) + submod0.recompile() + + return pipe + + def __str__(self): + return self.split_gm.__str__() + + def __repr__(self): + return self.split_gm.__repr__() + + def info(self) -> PipeInfo: + """ + Get information about the pipe. + + Returns + ------- + PipeInfo + A dataclass containing information about the pipe. + """ + return PipeInfo( + graph=self.split_gm.graph, + num_stages=self.num_stages, + has_loss_and_backward=self.has_loss_and_backward, + ) + + def build_stage( + self, + stage_index: int, + device: torch.device, + group: ProcessGroup | None = None, + ) -> _PipelineStage: + """ + Create a `PipelineStage` given a stage index and distributed group. + The `PipelineStage` can run with `PipelineSchedule`s. + """ + # Find stage module + stage_module = self.get_stage_module(stage_index) + + # Move ops argument to device + # Today PT2 tracer does not treat `x.device` as a symbolic device; + # instead, the device of tracing time got burned into the generated + # code. Here we provide a workaround for users to manually modify the + # "device" kwarg of operations. Such operation may include: + # `torch.ones`, `torch.zeros`, `torch.rand`, etc. + if isinstance(stage_module, torch.fx.GraphModule): + _modify_graph_op_device(stage_module, device) + else: + logger.warning( + f"Expected a `torch.fx.GraphModule` but got {type(stage_module)}" # noqa: G004 + ) + + # Detach pipe info + # Note: be careful what's included in `pipe_info`. We don't want to keep + # a reference to `Pipe` or `Pipe.split_gm` which stops python from + # recycling them. When python recycles them, other stage modules (which + # are irrelevant to current rank) can be automatically freed. + pipe_info = self.info() + return _PipelineStage(stage_module, stage_index, pipe_info, device, group) + + +class SplitPoint(Enum): + """ + Enum representing the points at which a split can occur in the execution of a submodule. + Attributes: + BEGINNING: Represents adding a split point *before* the execution of a certain submodule in the `forward` function. + END: Represents adding a split point *after* the execution of a certain submodule in the `forward` function. + """ + + BEGINNING = 1 + END = 2 + + +# For backward compatibility, we kept the PipeSplitWrapper class because `class +# SplitPoint` used to be defined in this class. +class PipeSplitWrapper: + # Create a class alias for BC + SplitPoint = SplitPoint + + +def _split_before_forward(self, *args, **kwargs): + pipe_split() + return self._orig_forward(*args, **kwargs) + + +def _split_after_forward(self, *args, **kwargs): + try: + return self._orig_forward(*args, **kwargs) + finally: + pipe_split() + + +def annotate_split_points(mod: torch.nn.Module, spec: dict[str, SplitPoint]): + # TODO: make this implementation out-of-place? + for qualname, split_type in spec.items(): + atoms = qualname.split(".") + predecessor_module = mod + for i, atom in enumerate(atoms[:-1]): + try: + predecessor_module = getattr(predecessor_module, atom) + except AttributeError as e: + raise AttributeError( + f"Specified target {qualname} referenced " + f"nonexistent module {'.'.join(atoms[: i + 1])}" + ) from e + + mod_to_wrap = getattr(predecessor_module, atoms[-1]) + mod_to_wrap._orig_forward = mod_to_wrap.forward + if split_type == SplitPoint.BEGINNING: + mod_to_wrap.forward = MethodType(_split_before_forward, mod_to_wrap) + elif split_type == SplitPoint.END: + mod_to_wrap.forward = MethodType(_split_after_forward, mod_to_wrap) + else: + raise ValueError("Unknown split point type.") + + +def pipeline( + module: torch.nn.Module, + mb_args: tuple[Any, ...], + mb_kwargs: dict[str, Any] | None = None, + split_spec: dict[str, SplitPoint] | None = None, + split_policy: Callable[[fx.GraphModule], fx.GraphModule] | None = None, +) -> Pipe: + """ + Split a module based on a specification. + + See `Pipe` for more details. + + Arguments + --------- + module: + The module to be split. + mb_args: + Example positional inputs, in micro-batch form. + mb_kwargs: + Example keyword inputs, in micro-batch form. (default: `None`) + split_spec: + A dictionary using submodule names as split marker. (default: `None`) + split_policy: + The policy to use for splitting the module. (default: `None`) + + Returns + ------- + A pipeline representation of class `Pipe`. + """ + if split_spec is not None and split_policy is not None: + raise ValueError( + "Cannot specify both `split_spec` and `split_policy`. Please use only one of them." + ) + + if split_spec is not None: + # Annotate split points in the module based on user spec + annotate_split_points(module, split_spec) + return Pipe.from_tracing( + mod=module, + example_args=mb_args, + example_kwargs=mb_kwargs, + ) + else: + # Use split policy + return Pipe.from_tracing( + mod=module, + example_args=mb_args, + example_kwargs=mb_kwargs, + split_policy=split_policy, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aacaf0b7f5e4ae7f5d221906ebb5b1b6ff93dea9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from ._IR import Pipe, pipe_split, pipeline, SplitPoint +from .schedules import ( + _ScheduleForwardOnly, + Schedule1F1B, + ScheduleDualPipeV, + ScheduleGPipe, + ScheduleInterleaved1F1B, + ScheduleInterleavedZeroBubble, + ScheduleLoopedBFS, + ScheduleZBVZeroBubble, +) +from .stage import build_stage, PipelineStage + + +__all__ = [ + "Pipe", + "pipe_split", + "SplitPoint", + "pipeline", + "PipelineStage", + "build_stage", + "Schedule1F1B", + "ScheduleGPipe", + "ScheduleInterleaved1F1B", + "ScheduleLoopedBFS", + "ScheduleInterleavedZeroBubble", + "ScheduleZBVZeroBubble", + "ScheduleDualPipeV", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_IR.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_IR.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cb68ca05e741d0e80c26e6f11dc0e8c437aa8bc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_IR.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f48bf751b0b78f3e4465a25352a4a36875a6e556 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_backward.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_backward.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77bf7853c9c808b7092ee3641ae85ae6b1f16210 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_backward.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_debug.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_debug.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abd506ea608d3b24630ea134ded41ed29688764d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_debug.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_schedule_visualizer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_schedule_visualizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e86f0641709a5cf5f4cef0fc0466466c2522088 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_schedule_visualizer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_unflatten.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_unflatten.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e87f3909794b1829af1e262438c37f04099e23a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_unflatten.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af7b403b0ee5e81ed3059da1228b37a76dcd0806 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/microbatch.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/microbatch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93cbcee6b1f6c19609b1a63e8d47f7d11d0a29ec Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/microbatch.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/stage.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/stage.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce7cf2976fb996a66ee8b905871a8b0745a3f856 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/stage.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/_backward.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..bfcf294c2946c5107c7c506b9846cd320155b27c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/_backward.py @@ -0,0 +1,418 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import collections +import logging +from collections.abc import Iterator +from typing import Any + +import torch +from torch.autograd.graph import GradientEdge, Node +from torch.nn import Parameter + +from ._debug import map_debug_info + + +logger = logging.getLogger(__name__) + + +def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Node | None: + """ + Get the grad function or grad accumulator for a tensor. + + Accumulate grad nodes are lazily created, so we need to a + dummy view in order to trigger its creation. + """ + if t.requires_grad and t.grad_fn is None: + # if no grad function (leaf tensors) we use view + viewed_t = t.view_as(t) + grad_fn = viewed_t.grad_fn + if grad_fn is not None: + return grad_fn.next_functions[0][0] + else: + raise RuntimeError( + "Attempted to get grad_fn, but got None." + "Is this being created in a no-grad context?" + ) + else: + return t.grad_fn + + +def reverse_closure( + roots: list[Node], target_nodes: set[Node], reverse_edges_dict +) -> tuple[set[Node], set[Node]]: + """ + This function returns the reverse closure of the given roots, + i.e. the set of nodes that can be reached from the roots by following the + reverse edges of the graph. The target_nodes are the nodes that we want to + include in the closure. + """ + # Recurse until we reach a target node + closure: set[Node] = set() + visited_target_nodes = set() + q: collections.deque[Node] = collections.deque() + for node in roots: + if node is not None and node not in closure: + closure.add(node) + q.append(node) + while q: + node = q.popleft() + reverse_edges = reverse_edges_dict[node] + for fn in reverse_edges: + if fn in closure or fn is None: + continue + if fn in target_nodes: + visited_target_nodes.add(fn) + continue + closure.add(fn) + q.append(fn) + return closure, visited_target_nodes + + +def construct_reverse_graph(roots: list[Node]) -> dict[Node, list[Node]]: + q: collections.deque[Node] = collections.deque() + root_seen: set[Node] = set() + reverse_edges_dict: dict[Node, list[Node]] = collections.defaultdict(list) + for node in roots: + if node is not None and node not in root_seen: + q.append(node) + root_seen.add(node) + while q: + node = q.popleft() + for fn, _ in node.next_functions: + if fn is not None: + if len(reverse_edges_dict[fn]) == 0: + q.append(fn) + reverse_edges_dict[fn].append(node) + return reverse_edges_dict + + +def get_param_groups( + inputs: list[Node], params: list[Node], reverse_edges_dict +) -> list[dict[str, Any]]: + """ + Given a list of inputs and a list of parameters, return a list of parameter + groups, where each group contains the parameters and the intermediates that + are connected to the parameters. + + The returned list of parameter groups is a list of dictionaries, where each + dictionary contains the following keys: + - "params": a set of parameters + - "intermediates": a set of intermediates + + The returned list of parameter groups is a list of dictionaries, + """ + # reverse graph that starts with inputs, and goes up to the dOutput or the loss, + # but omits weights and any subgraphs connecting weights to this closure + inputs_closure, _ = reverse_closure(inputs, set(), reverse_edges_dict) + param_groups: dict[Node, dict[str, set]] = dict() # keyed on intermediates + for param in params: + closure, intersected = reverse_closure( + [param], inputs_closure, reverse_edges_dict + ) + param_group: dict[str, set] = { + "params": {param}, + "intermediates": intersected, + } + for input_node in intersected: + existing = param_groups.get(input_node) + if existing is not None: + existing["params"] = existing["params"].union(param_group["params"]) + existing["intermediates"] = existing["intermediates"].union( + param_group["intermediates"] + ) + param_group = existing + else: + param_groups[input_node] = param_group + + # Sanity check: union of all param_groups params should be equal to all params + union_params: set[Node] = set() + seen_ids: set[int] = set() + unique_param_groups = [] + for param_group in param_groups.values(): + if id(param_group) not in seen_ids: + seen_ids.add(id(param_group)) + unique_param_groups.append(param_group) + union_params = union_params.union(param_group["params"]) + + # The assert will only be true if the input tensor requires gradients, + # otherwise the autograd graph will miss the first layer of inputs + # assert union_params == set(params) + return unique_param_groups + + +def stage_backward_input( + stage_outputs_or_loss: list[torch.Tensor], + output_grads: list[torch.Tensor] | None, + input_values: list[torch.Tensor], + weights: Iterator[Parameter], +) -> tuple[tuple[torch.Tensor | None, ...], list[dict[str, Any]]]: + """ + Compute the gradients for only the stage inputs with + respect to the stage outputs (if non-last stage) or loss (if last stage) + + After computing input gradients, we save the intermediate nodes in `param_groups` + for later use in stage_backward_weight. We don't need to save any other intermediate nodes + that aren't needed for dW because when we do dW calculation, we start from saved intermediates. + Detaching the stage_outputs_or_loss at the end of this function is important as + it frees up the memory that the autograd graph is anticipating to be used later (but doesn't actually need). + """ + stage_output_grad_fns: list[Node] = list( + filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs_or_loss)) + ) + stage_input_grad_fns: list[Node] = list( + filter(None, map(_get_grad_fn_or_grad_acc, input_values)) + ) + weight_grad_fns: list[Node] = list( + filter(None, map(_get_grad_fn_or_grad_acc, weights)) + ) + + reverse_edges_dict = construct_reverse_graph(stage_output_grad_fns) + param_groups = get_param_groups( + stage_input_grad_fns, weight_grad_fns, reverse_edges_dict + ) + + handles = [] + for param_group in param_groups: + for i, intermediate in enumerate(param_group["intermediates"]): + + def get_hook(param_group, i): + def hook(grad_inputs): + if param_group.get("grads", None) is None: + param_group["grads"] = [None] * len( + param_group["intermediates"] + ) + param_group["grads"][i] = grad_inputs + + return hook + + # These are always "split" nodes that we need to recompute, so + # save their inputs. + handle = intermediate.register_prehook(get_hook(param_group, i)) + handles.append(handle) + + if output_grads is None: + # In case this is the loss and there are no output_grads, then we just use 1s + output_grads = [ + torch.ones_like(stage_output) for stage_output in stage_outputs_or_loss + ] + + # Some inputs may not be used or may not require gradients, so we filter them out + input_values = [inp for inp in input_values if inp.requires_grad] + dinputs = torch.autograd.grad( + stage_outputs_or_loss, + inputs=input_values, + grad_outputs=output_grads, + retain_graph=True, + ) + # Update the gradients for inputs + for inp, dinput in zip(input_values, dinputs): + if inp.grad is None: + inp.grad = dinput + else: + inp.grad += dinput + + # stage_outputs_or_loss are not used in backwards after this point, so we can safely remove it from the autograd graph + # this allows autograd to clear up the graph dedicated for this tensor and free up significant memory + for t in stage_outputs_or_loss: + t.detach_() + + # hooks are no longer necessary, clean up for consistency + for handle in handles: + handle.remove() + + return dinputs, param_groups + + +def stage_backward_weight( + weights: Iterator[Parameter], param_groups: list[dict[str, Any]], retain_graph=False +) -> tuple[torch.Tensor | None, ...]: + # map weights to param_group_weights + grad_acc_to_weight = {} + weight_grads: list[torch.Tensor | None] = [] + for index, weight in enumerate(weights): + grad_acc = _get_grad_fn_or_grad_acc(weight) + grad_acc_to_weight[grad_acc] = weight, index + weight_grads.append(weight.grad) + + for param_group in param_groups: + valid_edges = [] + valid_grad_outputs: list[torch.Tensor] = [] + + for grads_tuple, intermediate in zip( + param_group["grads"], param_group["intermediates"] + ): + non_none_grads = [g for g in grads_tuple if g is not None] + if non_none_grads: + summed_grad = sum(non_none_grads) + valid_edges.append(GradientEdge(intermediate, 0)) + # pyrefly: ignore [bad-argument-type] + valid_grad_outputs.append(summed_grad) + + # Break a reference cycle caused inside stage_backward_input->get_hook->hook + # The summarized cycle is: + # `hook` -> cell -> param_group -> intermediates -> `hook` + # because we install the hook function onto each of the intermediate autograd nodes. + # We need to keep intermediates alive up until backward_weight, but we can free it now. + del param_group["intermediates"] + + if valid_edges: # Only call autograd.grad if we have valid gradients + # [NEW!] Able to pass a GradientEdge to autograd.grad as output + weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"]) + dweights = torch.autograd.grad( + valid_edges, + weights_edges, + grad_outputs=valid_grad_outputs, + retain_graph=retain_graph, + ) + + # release grad memory early after use + del param_group["grads"] + + for grad_acc, dw in zip(param_group["params"], dweights): + weight, index = grad_acc_to_weight[grad_acc] + if weight.grad is None: + weight.grad = dw + else: + weight.grad += dw + # return grads in the original order weights were provided in + return tuple(weight_grads) + + +def stage_backward( + stage_output, + output_grads, + input_values, + outputs_with_grads_idxs: list[int] | None = None, # deprecated, not used +) -> tuple[torch.Tensor | None, ...]: + """ + This is a helper function to: + 1. compute the gradients for the stage inputs, and + 2. accumulate gradients for the stage module's parameters. + + Given the input value(s) and the corresponding gradient for the output + value(s), compute and accumulate gradients for all parameter values (leaves + in the autograd trace) as well as return a list of the gradients for the + input values + """ + if outputs_with_grads_idxs is not None: + # Deprecated, not used in runtime calls, only exists in compiler + stage_output = [stage_output[i] for i in outputs_with_grads_idxs] + output_grads = [output_grads[i] for i in outputs_with_grads_idxs] + + try: + # stage_output may be a composite datatype like dict. Extract all individual + # tensor values here + stage_output_tensors: list[torch.Tensor] = [] + output_grad_tensors: list[torch.Tensor | None] = [] + + def extract_tensors_with_grads( + output_val, + grad_val, + # Don't delete me- see [Note: ref cycle] + extract_tensors_with_grads, + ): + if isinstance(output_val, torch.Tensor): + if not output_val.requires_grad and output_val.grad_fn is None: + return + assert isinstance(grad_val, (torch.Tensor, type(None))), ( + f"Expected Tensor or None gradient but got {type(grad_val)}" + ) + stage_output_tensors.append(output_val) + output_grad_tensors.append(grad_val) + elif isinstance(output_val, (tuple, list)): + if grad_val is None: + return + assert isinstance(grad_val, (tuple, list)), ( + f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}" + ) + assert len(output_val) == len(grad_val) + for ov, gv in zip(output_val, grad_val): + extract_tensors_with_grads( + ov, + gv, + extract_tensors_with_grads, + ) + elif isinstance(output_val, dict): + if grad_val is None: + return + assert isinstance(grad_val, dict) + assert set(output_val.keys()) == set(grad_val.keys()) + for k in output_val: + extract_tensors_with_grads( + output_val[k], grad_val[k], extract_tensors_with_grads + ) + else: + # Output is a non-tensor type; just ignore it + pass + + # Note: ref cycle + # break a ref cycle that would keep tensors alive until GC runs + # 1. extract_tensors_with_grads refers to a cell that holds refs to any vars defined in stage_backward + # and used in extract_tensors_with_grads + # 2. extract_tensors_with_grads referred to both stage_output_tensors, output_grad_tensors, + # and to itself (extract_tensors_with_grads) since it makes a recursive call + # 3. stage_output_tensors was kept alive by the above refcycle, and it holds activation tensors, which is bad + # fix -> explicitly pass in the ref to the fn, so there is no gc cycle anymore + extract_tensors_with_grads( + stage_output, output_grads, extract_tensors_with_grads + ) + + torch.autograd.backward( + stage_output_tensors, + grad_tensors=output_grad_tensors, # type: ignore[arg-type] + ) + + # Extract gradients wrt the input values + grad_inputs: list[torch.Tensor | None] = [] + for val in input_values: + if isinstance(val, torch.Tensor): + grad_inputs.append(val.grad) + # Since gradients that will pass back to previous stages do not require gradient accumulation, + # by decrementing the gradients' reference count at this point, the memory of gradients will be + # returned to the allocator as soon as the next micro batch's get_bwd_send_ops comes and current + # asynchronous send completes. + # This prevents the gradients from persisting in GPU memory for the entire duration of step_microbatches + # until clear_runtime_states() is called. + val.grad = None + else: + grad_inputs.append(None) + + # Alternative impl: `torch.autograd.grad`. + # Note that `torch.autograd.grad` will not accumulate gradients into the + # model's parameters. + """ + inputs_with_grad = [] + for val in input_values: + if isinstance(val, torch.Tensor) and val.requires_grad: + inputs_with_grad.append(val) + + grad_inputs = torch.autograd.grad( + stage_output_tensors, inputs_with_grad, output_grad_tensors, # type: ignore[arg-type] + ) + """ + + except Exception as e: + exc_msg = f""" + Failed to run stage backward: + Stage output: {map_debug_info(stage_output)} + Output gradient: {map_debug_info(output_grads)} + Input: {map_debug_info(input_values)} + """ + raise RuntimeError(exc_msg) from e + + return tuple(grad_inputs) + + +# TODO: handling requires_grad=False dynamically. Can we analyze this during initial +# IR emission? +def _null_coalesce_accumulate(lhs, rhs): + """ + Coalesce two values, even if one of them is null, returning the non-null + value. + """ + if lhs is None: + return rhs + elif rhs is None: + return lhs + else: + return torch.add(lhs, rhs) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/_debug.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/_debug.py new file mode 100644 index 0000000000000000000000000000000000000000..a3201d2d3adf1d05921e070d14b4e544844df88f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/_debug.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import torch +from torch.fx.node import Argument + + +def friendly_debug_info(v: object) -> Argument: + """ + Helper function to print out debug info in a friendly way. + """ + if isinstance(v, torch.Tensor): + return f"Tensor({v.shape}, grad={v.requires_grad}, dtype={v.dtype})" + else: + return str(v) + + +def map_debug_info(a: Argument) -> Argument: + """ + Helper function to apply `friendly_debug_info` to items in `a`. + `a` may be a list, tuple, or dict. + """ + return torch.fx.node.map_aggregate(a, friendly_debug_info) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/_schedule_visualizer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/_schedule_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5ecc5bf19ab17d83e0f3128290c0d5cc4b862b4d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/_schedule_visualizer.py @@ -0,0 +1,437 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +""" +This visualizer requires matplotlib to be installed. + +Example usage: + +ops = get_schedule_ops("InterleavedZeroBubble", 4, 8) +visualize_schedule(ops, "test.png") +""" + +import collections +from typing import NamedTuple +from unittest import mock + +from torch.distributed.pipelining.schedules import ( + _Action, + _ComputationType, + _PipelineSchedule, + _PipelineScheduleRuntime, + get_schedule_class, + PipelineScheduleMulti, + PipelineScheduleSingle, +) +from torch.distributed.pipelining.stage import PipelineStage + + +class OpKey(NamedTuple): + stage_index: int + computation_type: _ComputationType + microbatch_index: int + + +def get_schedule_ops( + schedule: str | type[_PipelineSchedule], + pp_degree: int, + num_microbatches: int, + num_stages_per_rank: int | None = None, + add_spacing: bool = False, + with_comms: bool = False, +) -> list[list[_Action | None]]: + """ + Get all actions for a given schedule, pp_degree, and num_microbatches. The actions are returned in a list of lists + where each inner list represents a rank and each element in the inner list represents an action. + + The schedule can be specified as a string which is passed into get_schedule_class() or a _PipelineSchedule instance. + """ + if add_spacing and with_comms: + raise ValueError("Cannot add spacing and view comms at the same time") + + if isinstance(schedule, str): + schedule_class = get_schedule_class(schedule) + elif issubclass(schedule, _PipelineSchedule): + schedule_class = schedule + else: + raise ValueError(f"Invalid schedule: {schedule}") + + # Create a mock of the PipelineStage class + mock_pipeline_stage = mock.create_autospec(PipelineStage, instance=True) + # Set the return values for group_rank and group_size methods + mock_pipeline_stage.group_rank = 0 + mock_pipeline_stage.group_size = pp_degree + mock_pipeline_stage.submod = None + + # Check num_stages_per_rank is valid + if issubclass(schedule_class, PipelineScheduleSingle): + if num_stages_per_rank is None: + num_stages_per_rank = 1 + assert num_stages_per_rank == 1 + stages = mock_pipeline_stage + stages.num_stages = num_stages_per_rank * pp_degree + elif issubclass(schedule_class, PipelineScheduleMulti): + if num_stages_per_rank is None: + num_stages_per_rank = 2 + assert num_stages_per_rank >= 2 + stages = [mock_pipeline_stage for _ in range(num_stages_per_rank)] + for stage in stages: + stage.num_stages = num_stages_per_rank * pp_degree + + else: + raise ValueError(f"Invalid schedule: {schedule_class}") + + # Instantiate the schedule class + # pyrefly: ignore [bad-instantiation, bad-argument-type] + schedule_instance = schedule_class(stages, num_microbatches) + assert schedule_instance.pipeline_order is not None + + # Convert to List[List[_Action]] + all_actions: list[list[_Action | None]] = [] + if with_comms: + runtime = _PipelineScheduleRuntime(stages, num_microbatches) + runtime._prepare_schedule_with_comms(schedule_instance.pipeline_order) + for rank in range(pp_degree): + all_actions.append(list(runtime.pipeline_order_with_comms[rank])) + else: + for rank in range(pp_degree): + all_actions.append(schedule_instance.pipeline_order[rank]) + + # Add spacing + if add_spacing: + # remove all Nones, then respace + # TODO: later we can change this at the schedule creation level to not use Nones + all_actions = [ + [action for action in rank if action is not None] for rank in all_actions + ] + all_actions = add_schedule_op_spacing(all_actions) + + # Return the pipeline order + return all_actions + + +class _ComputationTypeVisual: + def __init__( + self, + color: str, + text: str = "", + width: int = 1, + ): + self.color = color + self.width = width + self.text = text + + +# Update the mapping to use _ComputationTypeVisual instances +action_type_to_color_mapping = { + _ComputationType.FORWARD: _ComputationTypeVisual("blue", "Forward"), + _ComputationType.BACKWARD_INPUT: _ComputationTypeVisual("teal", "Backward Input"), + _ComputationType.BACKWARD_WEIGHT: _ComputationTypeVisual( + "green", "Backward Weight" + ), + _ComputationType.FULL_BACKWARD: _ComputationTypeVisual( + "orange", "Full Backward", 2 + ), + _ComputationType.OVERLAP_F_B: _ComputationTypeVisual("purple", "Overlap F+B", 3), +} + + +def add_schedule_op_spacing( + schedule: list[list[_Action | None]], +) -> list[list[_Action | None]]: + """ + Add spacing to the schedule based on dependencies between ranks. + + Before adding an operation to the list, this function checks if there are + dependencies from other ranks. If there are dependencies (other ranks have + not finished processing the required microbatch), it adds None instead. + + For example, Forward microbatch 0 on rank 1 depends on rank 0 processing + Forward microbatch 0 first. + + Args: + schedule: The original schedule as a list of lists where each inner list + represents a rank and each element represents an action. + + Returns: + A new schedule with proper spacing based on dependencies. + """ + if not schedule: + return schedule + + num_stages = ( + max( + action.stage_index + for rank_actions in schedule + for action in rank_actions + if action is not None + ) + + 1 + ) + + num_ranks = len(schedule) + spaced_schedule: list[list[_Action | None]] = [[] for _ in range(num_ranks)] + rank_ops = [collections.deque(ops) for ops in schedule] + + # Track completion times: (stage_index, action_type, microbatch_index) -> completion_time + scheduled_ops: dict[OpKey, int] = {} + + def is_dependency_ready(dependency_key: OpKey, timestep: int) -> bool: + """Check if a dependency operation has completed by the given timestep.""" + return ( + dependency_key in scheduled_ops + and timestep >= scheduled_ops[dependency_key] + ) + + def get_dependencies(action: _Action) -> list[OpKey]: + """Get the list of dependencies for an action.""" + stage_idx = action.stage_index + comp_type = action.computation_type + mb_idx = action.microbatch_index + + # Ensure mb_idx is not None for dependency tracking + assert mb_idx is not None, f"Action {action} has None microbatch_index" + + # First stage forward has no dependencies + if stage_idx == 0 and comp_type == _ComputationType.FORWARD: + return [] + + # Last stage backward depends on forward from previous stage + if stage_idx == num_stages - 1 and comp_type in ( + _ComputationType.FULL_BACKWARD, + _ComputationType.BACKWARD_INPUT, + ): + return [OpKey(stage_idx - 1, _ComputationType.FORWARD, mb_idx)] + + # Forward depends on previous stage forward + if comp_type == _ComputationType.FORWARD: + return [OpKey(stage_idx - 1, _ComputationType.FORWARD, mb_idx)] + + # Backward depends on next stage backward + if comp_type in ( + _ComputationType.FULL_BACKWARD, + _ComputationType.BACKWARD_INPUT, + ): + return [ + OpKey(stage_idx + 1, _ComputationType.FULL_BACKWARD, mb_idx), + OpKey(stage_idx + 1, _ComputationType.BACKWARD_INPUT, mb_idx), + ] + + # Weight backward depends on input backward + if comp_type == _ComputationType.BACKWARD_WEIGHT: + return [OpKey(stage_idx, _ComputationType.BACKWARD_INPUT, mb_idx)] + + raise RuntimeError(f"Unknown computation type: {comp_type}") + + def is_action_ready(action: _Action, timestep: int) -> bool: + """Check if an action is ready to be scheduled at the given timestep.""" + # For OR dependencies (like backward), check if any dependency is satisfied + if action.computation_type in ( + _ComputationType.FULL_BACKWARD, + _ComputationType.BACKWARD_INPUT, + _ComputationType.BACKWARD_WEIGHT, + ): + dependencies = get_dependencies(action) + return any(is_dependency_ready(dep, timestep) for dep in dependencies) + # For AND dependencies, all must be satisfied + elif action.computation_type == _ComputationType.FORWARD: + dependencies = get_dependencies(action) + return all(is_dependency_ready(dep, timestep) for dep in dependencies) + elif action.computation_type == _ComputationType.OVERLAP_F_B: + assert action.sub_actions is not None, ( + f"OVERLAP_F_B action {action} has None sub_actions" + ) + dep_list: list[bool] = [] + for sub_action in action.sub_actions: + dep_list.append(is_action_ready(sub_action, timestep)) + return all(dep_list) + else: + raise RuntimeError(f"Unknown computation type: {action.computation_type}") + + def schedule_action(action: _Action, rank: int, timestep: int) -> int: + """Schedule an action and return completion time.""" + spaced_schedule[rank].append(action) + comp_type = action.computation_type + comp_time = action_type_to_color_mapping[comp_type].width + completion_time = timestep + comp_time + + if comp_type == _ComputationType.OVERLAP_F_B: + # For overlap actions, schedule each sub-action with cumulative timing + assert action.sub_actions is not None, ( + f"OVERLAP_F_B action {action} has None sub_actions" + ) + cumulative_time = 0 + for sub_action in action.sub_actions: + assert sub_action.microbatch_index is not None, ( + f"Sub-action {sub_action} has None microbatch_index" + ) + sub_comp_time = action_type_to_color_mapping[ + sub_action.computation_type + ].width + cumulative_time += sub_comp_time + scheduled_ops[ + OpKey( + sub_action.stage_index, + sub_action.computation_type, + sub_action.microbatch_index, + ) + ] = timestep + cumulative_time + else: + assert action.microbatch_index is not None, ( + f"Action {action} has None microbatch_index" + ) + scheduled_ops[ + OpKey(action.stage_index, comp_type, action.microbatch_index) + ] = completion_time + + return completion_time + + # Main scheduling loop + current_timestep = 0 + timesteps_without_progress = 0 + rank_completion_times = dict.fromkeys(range(num_ranks), 0) + while rank_ops: + print(f"Current timestep: {current_timestep}") + # Process all operations during timestep until we run out of ready operations + for rank, op_queue in enumerate(rank_ops): + if not op_queue: + continue + + op_queue = rank_ops[rank] + action = op_queue[0] + print(f"Rank: {rank}, {action=}") + if action is None: + spaced_schedule[rank].append(None) + op_queue.popleft() + timesteps_without_progress = 0 + elif current_timestep >= rank_completion_times[rank] and is_action_ready( + action, current_timestep + ): + rank_completion_times[rank] = schedule_action( + action, rank, current_timestep + ) + op_queue.popleft() + timesteps_without_progress = 0 + + # Add None for ranks that are waiting + for rank in range(num_ranks): + if current_timestep >= rank_completion_times[rank]: + spaced_schedule[rank].append(None) + + # Remove empty queues and advance timestep + rank_ops = [op_queue for op_queue in rank_ops if op_queue] + current_timestep += 1 + timesteps_without_progress += 1 + + if timesteps_without_progress > max( + visual.width for visual in action_type_to_color_mapping.values() + ): + raise RuntimeError("No progress made in scheduling - possible deadlock") + + return spaced_schedule + + +def visualize_schedule( + schedule: list[list[_Action | None]], + filename: str | None = None, +) -> None: + """ + Visualize the schedule using matplotlib. + The schedule is a list of lists where each inner list represents a rank and each element in the inner list represents an action. + The actions are represented as rectangles with different colors based on their computation type. + The filename is optional and if provided, the plot will be saved to that file. + + Args: + schedule: The schedule to visualize. + filename: The filename to save the plot to. If not provided, the plot will be displayed. + add_schedule_spacing: If True, add spacing to the schedule based on dependencies between ranks. + + """ + + import matplotlib.pyplot as plt + from matplotlib.patches import Rectangle + + plt.rcParams["font.family"] = ( + "DejaVu Sans" # or any other font available on your system + ) + num_ranks = len(schedule) + max_actions = max(len(rank) for rank in schedule) + + # Increase the figure size to provide more space for the legend + fig, ax = plt.subplots(figsize=(max_actions + 2, num_ranks + 2)) + max_draw_position = -1 + # Calculate dynamic font size based on figure size + font_size = min(max_actions, num_ranks) + 4 + used_computation = set() + for rank_idx, actions in enumerate(schedule): + draw_position = 0 # Initialize drawing position for each rank + for action in actions: + if action is not None: + comp_type_color = action_type_to_color_mapping.get( + action.computation_type, _ComputationTypeVisual("black") + ) + used_computation.add(action.computation_type) + color = comp_type_color.color + width = comp_type_color.width + + # Check if action has sub_actions to determine styling + if action.sub_actions is not None: + linewidth = 2 # Thicker border for compound actions + text_weight = "normal" # Bold text for compound actions + else: + linewidth = 1 # Default linewidth for regular actions + text_weight = "normal" # Default text weight + + # Draw the rectangle to represent the action duration + rect = Rectangle( + (draw_position, num_ranks - rank_idx - 1), + width, + 1, + facecolor=color, + edgecolor="black", + linewidth=linewidth, + ) + ax.add_patch(rect) + + # Draw the text centered within the rectangle + ax.text( + draw_position + width / 2, + num_ranks - rank_idx - 1 + 0.5, + str(action), + ha="center", + va="center", + fontsize=font_size, + color="white", + weight=text_weight, + ) + + draw_position += width + else: + draw_position += 1 # Move to the next + max_draw_position = max(max_draw_position, draw_position) + ax.set_xlim(-0.5, max_draw_position + 1) + ax.set_ylim(-0.5, num_ranks + 0.5) # Add extra space at the top + # Set y-ticks to be in the middle of each rank's row + ax.set_yticks([num_ranks - rank_idx - 0.5 for rank_idx in range(num_ranks)]) + ax.set_yticklabels([f"Rank {i}" for i in range(num_ranks)], fontsize=font_size) + ax.set_xticklabels([]) + + # Remove grid lines and ticks + ax.grid(False) + # Add legend with larger font size + legend_elements = [ + Rectangle( + (0, 0), + 1, + 1, + facecolor=action_type_to_color_mapping[comp_type].color, + edgecolor="black", + label=action_type_to_color_mapping[comp_type].text, + ) + for comp_type in used_computation + ] + ax.legend(handles=legend_elements, loc="upper right", fontsize=font_size) + # Save to file if filename is provided, otherwise display the plot + if filename: + plt.savefig(filename, bbox_inches="tight") + else: + plt.show() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/_unflatten.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/_unflatten.py new file mode 100644 index 0000000000000000000000000000000000000000..0ed592f2f8d832de0703fbfa296225f17698afbf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/_unflatten.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from collections import defaultdict + +import torch +from torch.export.unflatten import _ModuleFrame, _SubmoduleEntry + + +def _outline_submodules(orig_graph: torch.fx.Graph) -> torch.fx.GraphModule: + # Create an empty GraphModule to hold the outlined modules + new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + seen_nodes: dict[str, torch.fx.Node] = {} + seen_modules: dict[int, list[_SubmoduleEntry]] = defaultdict(list) + seen_attrs: dict[str, set[str]] = defaultdict(set) + created_modules: dict[str, torch.nn.Module] = {} + _ModuleFrame( + orig_graph, + tuple(orig_graph.nodes), + seen_nodes, + seen_modules, + seen_attrs, + created_modules, + None, + [("", None, 0)], + "", + {}, + module=new_module, + ).run_outer() + new_module.graph.lint() + new_module.recompile() + return new_module diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..79b74be40681425ab4a5c97198bf0a2020d1d10e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/_utils.py @@ -0,0 +1,159 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +import logging +from dataclasses import dataclass + +import torch +from torch import fx + + +logger = logging.getLogger(__name__) + + +def flatten_args_detach(args): + """ + Flatten the args into a list form and detach the tensors from computational graph. + """ + flat_detached_args = [] + + def extract_tensor_args(a): + nonlocal flat_detached_args + if isinstance(a, torch.Tensor): + val = a.detach().requires_grad_(a.requires_grad) + flat_detached_args.append(val) + return val + else: + flat_detached_args.append(a) + return a + + new_args = fx.node.map_aggregate( + args, + extract_tensor_args, + ) + + return new_args, flat_detached_args + + +def flatten_args(args): + """ + Flatten the args into a list form. + """ + flat_args = [] + + def extract_tensor_args(a): + nonlocal flat_args + flat_args.append(a) + return a + + fx.node.map_aggregate( + args, + extract_tensor_args, + ) + + return flat_args + + +class PipeliningShapeError(RuntimeError): + """Shape mismatch between configured and runtime values.""" + + +def validate_tensor_metadata(desc, expected, given): + if not expected.shape == given.shape: + raise PipeliningShapeError( + f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}" + ) + if not expected.dtype == given.dtype: + raise PipeliningShapeError( + f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}" + ) + if not expected.stride() == given.stride(): + raise PipeliningShapeError( + f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}" + ) + + +def validate_tensors_metadata( + desc, + expected_tensors: list[torch.Tensor] | tuple[torch.Tensor, ...], + actual_tensors: list[torch.Tensor] | tuple[torch.Tensor, ...], +): + if len(expected_tensors) != len(actual_tensors): + raise PipeliningShapeError( + f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})" + ) + for i in range(len(expected_tensors)): + validate_tensor_metadata( + f"{desc}: value {i}", expected_tensors[i], actual_tensors[i] + ) + + +def generate_stage_to_rank_mapping( + pp_size: int, num_stages: int, style: str = "loop" +) -> dict[int, int]: + """ + Compute the stage id to rank mapping for either a looped or V-style schedule. + + Most commonly num_stages == pp_size * 2, but this function can be used to + compute the mapping for any number of stages per rank. + """ + mapping = {} + if style == "loop": + for stage_index in range(num_stages): + mapping[stage_index] = stage_index % pp_size + elif style == "v": + if num_stages % pp_size != 0: + raise ValueError( + f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size} for V schedules" + ) + + rank_index = 0 + for stage_index in range(num_stages): + mapping[stage_index] = rank_index + # dont change rank if we are on the border (to keep v shape) + if (stage_index + 1) % pp_size == 0: + continue + if (stage_index // pp_size) % 2 == 0: + rank_index += 1 + else: + rank_index -= 1 + else: + raise ValueError(f"Style {style} is not supported.") + return mapping + + +def generate_rank_to_stage_mapping( + pp_size: int, num_stages: int, style: str = "loop" +) -> dict[int, list[int]]: + """ + Compute the rank to stage id mapping for either a looped or V-style schedule. + + This function inverts the stage_to_rank_mapping to get which stages are assigned to each rank. + + Returns a dictionary mapping rank -> list of stage indices assigned to that rank. + """ + stage_to_rank = generate_stage_to_rank_mapping(pp_size, num_stages, style) + + # Invert the mapping: rank -> list of stages + rank_to_stages: dict[int, list[int]] = {} + for stage_id, rank in stage_to_rank.items(): + if rank not in rank_to_stages: + rank_to_stages[rank] = [] + rank_to_stages[rank].append(stage_id) + + # Sort the stage lists for each rank to ensure consistent ordering + for stages in rank_to_stages.values(): + stages.sort() + + return rank_to_stages + + +@dataclass +class PipeInfo: + """ + Captures information for a pipeline (`Pipe` object). + """ + + graph: fx.Graph + num_stages: int + has_loss_and_backward: bool diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/microbatch.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/microbatch.py new file mode 100644 index 0000000000000000000000000000000000000000..a82f83072fa1897c1738e8bf879911921720cfe5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/microbatch.py @@ -0,0 +1,544 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +import operator +from collections.abc import Sequence +from typing import Any + +import torch +from torch.fx.node import map_aggregate +from torch.nn.attention.flex_attention import BlockMask +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + + +__all__ = [ + "TensorChunkSpec", + "split_args_kwargs_into_chunks", + "merge_chunks", +] + +logger = logging.getLogger(__name__) + +""" +_debug_mask_minibatches specifies to send masked versions of the mini-batch +through instead of micro-batch slices--this can be used for more stable +numerical testing (see [A Note About Correctness Testing]) +""" +_debug_mask_minibatches = False + + +class _CustomReducer: + """ + Custom reducer class that can be used to specify a custom operation that + reduces losses of multiple microbatches into one value. + + Example: + >>> # xdoctest: +SKIP + >>> sum_reducer = _CustomReducer( + >>> torch.tensor(0.0), + >>> lambda a, b: a + b + >>> ) + """ + + def __init__(self, init_value, reduce_fn): + self.init_value = init_value + self.reduce_fn = reduce_fn + + +class _LossReducer(_CustomReducer): + pass + + +sum_reducer = _LossReducer(torch.tensor(0.0), operator.add) + +# Default chunking dimension is 0. This is used for the case where the user did +# not specify a chunking dimension. +DEFAULT_CHUNK_DIM = 0 + + +class TensorChunkSpec: + """ + Class used to specify chunking of inputs + """ + + def __init__(self, split_dim): + self.split_dim = split_dim + + split_dim: int + + def __repr__(self): + return ( + f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})" + ) + + def __str__(self): + return f"TensorChunkSpec({self.split_dim})" + + @staticmethod + def from_tuple( + chunk_dims: tuple[int, ...], + ): + """ + A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk + dimensions (int's). + Example: + >>> # xdoctest: +SKIP + >>> # There are three positional arguments to the model, and + >>> # we are chunking them along dimension 0, 0 and 1, respectively + >>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1)) + """ + args_chunk_spec = map_aggregate( + chunk_dims, + lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value] + ) + return args_chunk_spec + + @staticmethod + def from_dict( + chunk_dims: dict[str, int], + ): + """ + A helper for creating a dictionary of `TensorChunkSpec` from a + dictionary of chunk dimensions (int's). + Example: + >>> # xdoctest: +SKIP + >>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument + >>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1}) + """ + kwargs_chunk_spec = map_aggregate( + chunk_dims, + lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value] + ) + return kwargs_chunk_spec + + +# Class used to specify replication of inputs +class _Replicate: + pass + + +def _split_block_mask( + block_mask: BlockMask, + num_chunks: int, +) -> list[BlockMask]: + """Given a block mask, split the block mask along the batch dimension (dim0). + + Args: + block_mask: Block mask to split + num_chunks: Number of chunks to split the block mask into + + Returns: + chunk_block_masks: List of chunked block masks + """ + + # BlockMask will broadcast if B is 1. + if block_mask.kv_num_blocks.size(0) == 1: + return [block_mask] * num_chunks + + assert block_mask.kv_num_blocks.size(0) >= num_chunks, ( + "Block mask has fewer batch size than the number of chunks. " + ) + + batch_dim = 0 + kv_num_blocks_chunks = torch.tensor_split( + block_mask.kv_num_blocks, num_chunks, batch_dim + ) + kv_indices_chunks = torch.tensor_split(block_mask.kv_indices, num_chunks, batch_dim) + full_kv_num_blocks_chunks = ( + torch.tensor_split(block_mask.full_kv_num_blocks, num_chunks, batch_dim) + if block_mask.full_kv_num_blocks is not None + else [None] * num_chunks + ) + full_kv_indices_chunks = ( + torch.tensor_split(block_mask.full_kv_indices, num_chunks, batch_dim) + if block_mask.full_kv_indices is not None + else [None] * num_chunks + ) + + chunk_block_masks = [] + batch_offset = 0 + for chunk_idx in range(num_chunks): + + def create_mask_mod(idx): + def batch_offset_mask_mod(b, h, q_idx, kv_idx): + b_offset = torch.full_like(b, idx) + return block_mask.mask_mod(b + b_offset, h, q_idx, kv_idx) + + return batch_offset_mask_mod + + chunk_block_masks.append( + BlockMask.from_kv_blocks( + kv_num_blocks=kv_num_blocks_chunks[chunk_idx], + kv_indices=kv_indices_chunks[chunk_idx], + full_kv_num_blocks=full_kv_num_blocks_chunks[chunk_idx], + full_kv_indices=full_kv_indices_chunks[chunk_idx], + BLOCK_SIZE=block_mask.BLOCK_SIZE, + mask_mod=create_mask_mod(batch_offset), + seq_lengths=block_mask.seq_lengths, + ) + ) + batch_offset += kv_num_blocks_chunks[chunk_idx].size(0) + return chunk_block_masks + + +def _split_tensor( + tensor: torch.Tensor, + spec: TensorChunkSpec, + num_chunks: int, +) -> Sequence[torch.Tensor]: + """Given a tensor, and a chunking spec, split the tensor. + Args: + + tensor: Tensor to split + spec: Chunking spec + num_chunks: Number of chunks to split the tensor into + + Returns: + chunk_tensors: List of chunked tensors + """ + + assert tensor.size(spec.split_dim) >= num_chunks, ( + f"Tensor size {tensor.size(spec.split_dim)} is smaller than num_chunks" + ) + chunk_tensors = torch.tensor_split(tensor, num_chunks, spec.split_dim) + + if not _debug_mask_minibatches: + return chunk_tensors + + expanded_chunks = [] + split_dim_idx = 0 + for chunk_tensor in chunk_tensors: + new_val = torch.zeros_like(tensor) + upper_idx = split_dim_idx + chunk_tensor.size(spec.split_dim) + + slice_indices = [slice(None, None, None)] * new_val.ndim + slice_indices[spec.split_dim] = slice(split_dim_idx, upper_idx) + new_val[slice_indices] = chunk_tensor + + expanded_chunks.append(new_val) + + split_dim_idx += chunk_tensor.size(spec.split_dim) + + return expanded_chunks + + +def _shard_dict_of_args( + args_dict, + args_chunk_spec, + num_chunks, +): + """ + Given a dictionary of args, and a dictionary of chunking specs, shard the + args according to the chunking specs. + + Args: + args_dict: Dictionary of args + args_chunk_spec: Dictionary of chunking specs + num_chunks: Number of chunks to shard the args into + + Returns: + args_split: List of sharded args + """ + + if not args_dict: + return [{} for _ in range(num_chunks)] + + assert len(args_dict) == len(args_chunk_spec), ( + f"args_dict.keys() = {list(args_dict.keys())} " + f"args_chunk_spec.keys() = {list(args_chunk_spec.keys())}" + ) + assert args_chunk_spec is not None # Should have been set by caller + + values, tree_spec = tree_flatten( + args_dict, is_leaf=lambda x: isinstance(x, BlockMask) + ) + chunk_specs, _ = tree_flatten( + args_chunk_spec, is_leaf=lambda x: isinstance(x, BlockMask) + ) + + # First check and find the actual number of chunks + split_sizes = [] + for v, spec in zip(values, chunk_specs, strict=True): + # The original logic is "spec is _Replicate". This doesn't seem to be + # correct. But we keep it for backward compatibility. + if spec is _Replicate or isinstance(spec, _Replicate): + split_sizes.append(num_chunks) + elif isinstance(v, torch.Tensor): + assert isinstance(spec, TensorChunkSpec) + split_sizes.append(v.size(spec.split_dim)) + elif isinstance(v, BlockMask): + assert isinstance(spec, TensorChunkSpec) + assert spec.split_dim == 0, "BlockMask only supports split_dim=0" + # BlockMask will broadcast if B is 1. + if v.kv_num_blocks.size(0) == 1: + split_sizes.append(num_chunks) + else: + split_sizes.append(v.kv_num_blocks.size(0)) + else: + raise ValueError( + f"Unsupported chunk spec: {spec} and value: {v} combination." + ) + result_num_chunks = min(*split_sizes, num_chunks) + + flat_split_results: list[Any] = [[] for _ in range(result_num_chunks)] + for v, spec in zip(values, chunk_specs, strict=True): + v_splits: Sequence[Any] = [] + if spec is _Replicate or isinstance(spec, _Replicate): + v_splits = [v] * result_num_chunks + elif isinstance(v, torch.Tensor): + v_splits = _split_tensor(v, spec, result_num_chunks) + elif isinstance(v, BlockMask): + v_splits = _split_block_mask(v, result_num_chunks) + else: + raise ValueError( + f"Unsupported chunk spec: {spec} and value: {v} combination." + ) + + for _flat_split_result, _v_split in zip( + flat_split_results, v_splits, strict=True + ): + _flat_split_result.append(_v_split) + + return [ + tree_unflatten(_flat_split_result, tree_spec) + for _flat_split_result in flat_split_results + ] + + +def split_args_kwargs_into_chunks( + args: tuple[Any, ...], + kwargs: dict[str, Any] | None, + chunks: int, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, +) -> tuple[list[tuple], list[dict]]: + """ + Given a sequence of args and kwargs, split them into a number of chunks + according to their respective chunking specs. + + Args: + args: Tuple of args + kwargs: Dict of kwargs + chunks: Number of chunks to split the args and kwargs into + args_chunk_spec: chunking specs for args, in same shape as args + kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs + + Returns: + args_split: List of sharded args + kwargs_split: List of sharded kwargs + """ + # Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that + # the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec` + # and `kwargs_chunk_spec` specifications. The steps are as follows: + # + # 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values. + # To use a running example: suppose our inputs look like + # + # args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None) + # (kwargs not shown but it's a similar process) + # + # Then for this step we would end up with + # + # args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None) + # + # 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2 + # + # args = ([[A, A], [B, B], [C_1, C_2]], [D, D]) + # + # 3. Rotate the nesting order such that chunks are the outer dimension + # + # args_chunks = [ + # ([A, B, C_1], D), + # ([A, B, C_2], D), + # ] + # + # 4. Unflatten each chunk according to the spec + # + # args_chunks = [ + # ([A, [B, C_1]], D), + # ([A, [B, C_2]], D), + # ] + + # TODO: _debug_mask_minibatches + # Handle the case where kwargs is None + if kwargs is None: + kwargs = {} + + # If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend + # their format and use default chunking along dim 0 + def default_spec(v): + if isinstance(v, torch.Tensor | BlockMask): + return TensorChunkSpec(DEFAULT_CHUNK_DIM) + else: + return _Replicate() + + if args_chunk_spec is None: + args_chunk_spec = tree_map( + default_spec, args, is_leaf=lambda v: isinstance(v, BlockMask) + ) + + if kwargs_chunk_spec is None: + kwargs_chunk_spec = tree_map( + default_spec, kwargs, is_leaf=lambda v: isinstance(v, BlockMask) + ) + + args_split_dict = _shard_dict_of_args( + dict(enumerate(args)), + dict(enumerate(args_chunk_spec)), + chunks, + ) + real_num_chunks = len(args_split_dict) + + kwargs_split = _shard_dict_of_args( + kwargs, + kwargs_chunk_spec, + real_num_chunks, + ) + + if len(kwargs_split) < real_num_chunks: + # In case kwargs are sharded into less chunks + # e.g. when `args` has no tensor, just values + real_num_chunks = len(kwargs_split) + # Re-shard args + args_split_dict = _shard_dict_of_args( + dict(enumerate(args)), + dict(enumerate(args_chunk_spec)), + real_num_chunks, + ) + + if len(args_split_dict) != len(kwargs_split): + raise RuntimeError( + "args and kwargs are split into different number of chunks: " + f"{len(args_split_dict)}, {len(kwargs_split)}" + ) + + args_split = [ + tuple(chunk_args[i] for i in range(len(chunk_args))) + for chunk_args in args_split_dict + ] + + return args_split, kwargs_split + + +def merge_chunks( + chunks: list[Any], + chunk_spec, +): + """ + Given a list of chunks, merge them into a single value according to + the chunk spec. + + Args: + chunks: list of chunks + chunk_spec: Chunking spec for the chunks + + Returns: + value: Merged value + """ + # This is essentially the inverse of `split_args_kwargs_into_chunks`, so the + # steps are similar to the steps in that function but in reverse. Given the + # input values: + # + # chunks = [ + # ([A, [B, C_1]], D), + # ([A, [B, C_2]], D), + # ] + # args_spec = ([None, [None, TensorChunkSpec]], None) + # + # 1. Flatten the chunks according to the chunk_spec + # + # chunks_flat = [ + # ([A, B, C_1], D), + # ([A, B, C_2], D), + # ] + # + # 2. Rotate the nesting order such that chunks are the inner dimension + # + # value_inner = ([A, B, [C_1, C_2]], D) + # + # 3. Concatenate sharded arguments + # + # value_combined = ([A, B, C], D) + # + # 4. Unflatten the combined args given the spec + # + # value = ([A, [B, C]], D) + + # Preliminary: flatten the chunk spec + if chunk_spec is not None: + spec_flattened, flatten_spec = tree_flatten(chunk_spec) + else: + # If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields + # We obtain the output structure by flattening chunk 0 and generate the chunk_spec + chunk0_flat, flatten_spec = tree_flatten(chunks[0]) + spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat) + + # Stage 1: flatten chunks + # chunks_flattened : [num chunks, num args] + chunks_flattened = [] + + for chunk in chunks: + chunk_flattened, _ = tree_flatten(chunk) + if len(chunk_flattened) != len(spec_flattened): + raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}") + + chunks_flattened.append(chunk_flattened) + + # Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and + # concatenate sharded operands + # args_flattened : [num args] + args_flattened = [] + for arg_idx, arg in enumerate(spec_flattened): + if isinstance(arg, TensorChunkSpec): + partial_values = [ + chunks_flattened[chunk_idx][arg_idx] + for chunk_idx in range(len(chunks_flattened)) + ] + + if _debug_mask_minibatches: + # Infer size of individual chunks by running `tensor_split` again + overall_shape = partial_values[0].shape + for val in partial_values[1:]: + assert val.shape == overall_shape + meta_chunks = torch.tensor_split( + torch.empty(*overall_shape, device="meta"), + sections=len(partial_values), + dim=arg.split_dim, + ) + + values_to_cat = [] + chunk_start_idx = 0 + assert len(partial_values) == len(meta_chunks) + for partial_value, meta_chunk in zip( + partial_values, meta_chunks, strict=True + ): + chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim) + + slice_indices = [slice(None, None, None)] * partial_value.ndim + slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx) + sliced = partial_value[slice_indices] + values_to_cat.append(sliced) + + chunk_start_idx = chunk_end_idx + + else: + values_to_cat = partial_values + + args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim)) + elif isinstance(arg, _CustomReducer): + reduced_val = arg.init_value + + for chunk_idx in range(len(chunks_flattened)): + reduced_val = arg.reduce_fn( + reduced_val, chunks_flattened[chunk_idx][arg_idx] + ) + + args_flattened.append(reduced_val) + else: + value = chunks_flattened[0][arg_idx] + for chunk_idx in range(1, len(chunks_flattened)): + assert chunks_flattened[chunk_idx][arg_idx] == value + args_flattened.append(value) + + # Stage 4: Unflatten combined args + return tree_unflatten(args_flattened, flatten_spec) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/schedules.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/schedules.py new file mode 100644 index 0000000000000000000000000000000000000000..5657068f0bcd7008a0dc9b4a2a56e364bcc92428 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/schedules.py @@ -0,0 +1,3438 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +import copy +import csv +import itertools +import logging +import re +from abc import ABC, abstractmethod +from collections import Counter, defaultdict +from collections.abc import Callable +from enum import Enum +from functools import lru_cache +from typing import Any, cast, NamedTuple, Protocol + +import torch +import torch.distributed as dist +from torch._dynamo import OptimizedModule +from torch.distributed.fsdp import FSDPModule, UnshardHandle +from torch.nn.modules.loss import _Loss +from torch.profiler import record_function + +from ._utils import generate_rank_to_stage_mapping, generate_stage_to_rank_mapping +from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec +from .stage import _PipelineStageBase + + +__all__ = [ + "get_schedule_class", + "PipelineScheduleSingle", + "PipelineScheduleMulti", + "Schedule1F1B", + "ScheduleGPipe", + "ScheduleInterleaved1F1B", + "ScheduleLoopedBFS", + "ScheduleInterleavedZeroBubble", + "ScheduleZBVZeroBubble", + "ScheduleDualPipeV", +] + +logger = logging.getLogger(__name__) + + +class _ComputationType(Enum): + # TODO(whc) rename to _ActType? + FORWARD = 1 + BACKWARD_INPUT = 2 + BACKWARD_WEIGHT = 3 + UNSHARD = 4 + RESHARD = 5 + SEND_F = 6 + RECV_F = 7 + SEND_B = 8 + RECV_B = 9 + FULL_BACKWARD = 10 + OVERLAP_F_B = 11 + REDUCE_GRAD = 12 + + def __str__(self): + str_map = { + _ComputationType.FORWARD: "F", + _ComputationType.BACKWARD_INPUT: "I", + _ComputationType.BACKWARD_WEIGHT: "W", + _ComputationType.UNSHARD: "UNSHARD", + _ComputationType.RESHARD: "RESHARD", + _ComputationType.SEND_F: "SEND_F", + _ComputationType.RECV_F: "RECV_F", + _ComputationType.SEND_B: "SEND_B", + _ComputationType.RECV_B: "RECV_B", + _ComputationType.FULL_BACKWARD: "B", + _ComputationType.OVERLAP_F_B: "OVERLAP_F_B", + _ComputationType.REDUCE_GRAD: "REDUCE_GRAD", + } + return str_map[self] + + @staticmethod + def from_str(action): + if action == "F": + return _ComputationType.FORWARD + elif action == "I": + return _ComputationType.BACKWARD_INPUT + elif action == "W": + return _ComputationType.BACKWARD_WEIGHT + elif action == "UNSHARD": + return _ComputationType.UNSHARD + elif action == "RESHARD": + return _ComputationType.RESHARD + elif action == "SEND_F": + return _ComputationType.SEND_F + elif action == "RECV_F": + return _ComputationType.RECV_F + elif action == "SEND_B": + return _ComputationType.SEND_B + elif action == "RECV_B": + return _ComputationType.RECV_B + elif action == "B": + return _ComputationType.FULL_BACKWARD + elif action == "OVERLAP_F_B": + return _ComputationType.OVERLAP_F_B + elif action == "REDUCE_GRAD": + return _ComputationType.REDUCE_GRAD + else: + raise RuntimeError(f"Invalid computation type {action}") + + +FORWARD = _ComputationType.FORWARD +BACKWARD_INPUT = _ComputationType.BACKWARD_INPUT +BACKWARD_WEIGHT = _ComputationType.BACKWARD_WEIGHT +UNSHARD = _ComputationType.UNSHARD +RESHARD = _ComputationType.RESHARD +SEND_F = _ComputationType.SEND_F +RECV_F = _ComputationType.RECV_F +SEND_B = _ComputationType.SEND_B +RECV_B = _ComputationType.RECV_B +FULL_BACKWARD = _ComputationType.FULL_BACKWARD +OVERLAP_F_B = _ComputationType.OVERLAP_F_B +REDUCE_GRAD = _ComputationType.REDUCE_GRAD + +# Convenience shorthand for compute actions only since they are used in 'simple schedule format' +F = FORWARD +I = BACKWARD_INPUT +W = BACKWARD_WEIGHT +B = FULL_BACKWARD + +# Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index) +_action_regex = re.compile( + r"(\d+)(F|I|B|W|UNSHARD|RESHARD|REDUCE_GRAD|SEND_F|RECV_F|SEND_B|RECV_B)(\d*)" +) + + +class _Action(NamedTuple): + stage_index: int + computation_type: _ComputationType + microbatch_index: int | None = None + sub_actions: tuple["_Action", ...] | None = None + + def __str__(self): + return self.__repr__() + + def __repr__(self): + if self.sub_actions is not None: + # Use recursive repr for sub_actions + sub_action_reprs = [repr(sub_action) for sub_action in self.sub_actions] + return f"({';'.join(sub_action_reprs)}){self.computation_type}" + else: + repr_str = str(self.stage_index) + repr_str += str(self.computation_type) + if self.microbatch_index is not None: + repr_str += str(self.microbatch_index) + return repr_str + + @property + def is_compute_op(self) -> bool: + return self.computation_type in ( + FORWARD, + FULL_BACKWARD, + BACKWARD_INPUT, + BACKWARD_WEIGHT, + OVERLAP_F_B, + ) + + @staticmethod + def from_str(action_string: str): + """ + Reverse of __repr__ + + String should be formatted as [stage][action type][(microbatch)] + e.g. `2F0`, `1UNSHARD`, `3SEND_F1` + """ + action_string = action_string.strip() + if action_string == "": + return None + + # Check for sub_actions format: [sub_action1;sub_action2;...]ComputationType + if action_string.startswith("(") and ")" in action_string: + # Find the closing bracket to separate sub_actions from computation type + bracket_end = action_string.find(")") + sub_part = action_string[ + 1:bracket_end + ] # Remove '[' and get content before ']' + computation_type_part = action_string[ + bracket_end + 1 : + ] # Get part after ']' + + # Parse sub_actions + sub_actions = [] + if sub_part.strip(): + for sub_str in sub_part.split(";"): + sub_action = _Action.from_str(sub_str.strip()) + if sub_action is not None: + sub_actions.append(sub_action) + + # For sub_actions format, we create an action with just the computation type + # The stage_index and microbatch_index are not meaningful for the container action + return _Action( + stage_index=-1, # Placeholder, not meaningful for sub_actions container + computation_type=_ComputationType.from_str(computation_type_part), + microbatch_index=None, + sub_actions=tuple(sub_actions) if sub_actions else None, + ) + + # Handle regular single action format + if match := _action_regex.match(action_string): + stage_index, computation_type, microbatch_index = match.groups() + return _Action( + int(stage_index), + _ComputationType.from_str(computation_type), + int(microbatch_index) if len(microbatch_index) else None, + ) + elif action_string == "": + return None + raise RuntimeError( + f"Invalid action string: {action_string}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0" + ) + + +@lru_cache +def _get_profiler_function_name(action: _Action) -> str: + return f"PP:{str(action)}" + + +def _format_pipeline_order( + pipeline_order: dict[int, list[_Action | None]], + error_step_number: int | None = None, +) -> str: + """ + Formats the pipeline order in a timestep (row) x rank (column) grid of actions + and returns the formatted string. + + If `error_step_number` is passed in, an additional label will be added to signify which step + that it is erroring on. + """ + + # don't mutate the original + pipeline_order = copy.deepcopy(pipeline_order) + + # Replace None with "" + for rank in pipeline_order: + for i in range(len(pipeline_order[rank])): + if pipeline_order[rank][i] is None: + # TODO make a real 'None action' that prints as empty string and make mypy happy + pipeline_order[rank][i] = "" # type: ignore[call-overload] + + # Calculate the maximum number of steps across all ranks + num_steps = max(len(actions) for actions in pipeline_order.values()) + step_labels = [ + "Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps) + ] + # Sorting the dictionary by keys and retrieving values in that order + rank_actions = [ + pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order) + ] + # Transpose the list of lists (rows to columns) + # pyrefly: ignore [no-matching-overload] + transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue="")) + # Generate column labels for ranks + num_ranks = len(pipeline_order) + rank_labels = ["Rank " + str(i) for i in range(num_ranks)] + # Calculate the maximum length of each column, considering labels + max_lengths = [ + max(len(str(item)) if item is not None else 0 for item in col) + for col in zip(step_labels, *transposed_actions) + ] + # Format the header row with rank labels + header_row = " " * (len(step_labels[0]) + 2) + " ".join( + f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels) + ) + # Format each row with its corresponding label + formatted_rows = [ + f"{label}: " + + " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row)) + + ( + " <-- ERROR HERE" + if error_step_number is not None + and int(label.split()[1]) == error_step_number + else "" + ) + for label, row in zip(step_labels, transposed_actions) + ] + # Join the rows into a single string + formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n" + return formatted_table + + +class _PipelineSchedule(ABC): + def __init__( + self, + n_microbatches: int, + loss_fn: Callable[..., torch.Tensor] | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + scale_grads: bool = True, + ): + # From arguments + self._n_microbatches = n_microbatches + self._loss_fn = loss_fn + + # See documentation in `PipelineScheduleSingle` / `PipelineScheduleMulti` + self.scale_grads = scale_grads + + # Chunking specification for positional inputs. (default: `None`) + self._args_chunk_spec = args_chunk_spec + # Chunking specification for keyword inputs. (default: `None`) + self._kwargs_chunk_spec = kwargs_chunk_spec + self._output_merge_spec = output_merge_spec + """ + # args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs. + # They are used to convert batch to microbatches in `step(x)`. See + # `TensorChunkSpec` for helper methods for creating them. + """ + + # Derived + self._has_backward = self._loss_fn is not None + + # Holds the losses for each microbatch. + self._internal_losses: list[torch.Tensor] = [] + logger.info("Using %s", self.__class__.__name__) + + def _maybe_compute_loss(self, stage, output, target_mbs, mb_index): + if stage.is_last and self._loss_fn is not None: + loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index] + self._internal_losses.append(loss) + + def _maybe_get_loss(self, stage, mb_index): + valid_index = 0 <= mb_index < len(self._internal_losses) + if stage.is_last and self._loss_fn is not None and valid_index: + return self._internal_losses[mb_index] + elif len(self._internal_losses) != 0 and not valid_index: + raise RuntimeError( + f"Loss for microbatch {mb_index} is not available. " + f"Available losses for microbatches: {self._internal_losses}" + ) + else: + return None + + def _update_losses(self, stages, losses): + """ + Update the losses to those in the internal state + """ + # if stages not a list turn into a list + if not isinstance(stages, list): + stages = [stages] + contains_last_stage = any(stage.is_last for stage in stages) + + # Return losses if there is a container passed in + if contains_last_stage and losses is not None: + if len(self._internal_losses) != self._n_microbatches: + raise RuntimeError( + f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}" + ) + + # Clean external container first + losses.clear() + # Copy internal losses to external container + losses.extend(self._internal_losses) + + self._internal_losses.clear() + + @abstractmethod + def _step_microbatches( + self, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, + return_outputs: bool = True, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the schedule + implementation. + + Args: + microbatches: list of microbatch args. + return_outputs: whether to return the outputs from the last stage. + """ + raise NotImplementedError + + @abstractmethod + def step( + self, + *args, + target=None, + losses: list | None = None, + return_outputs=True, + **kwargs, + ): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + return_outputs: whether to return the outputs from the last stage. + """ + raise NotImplementedError + + def eval(self, *args, target=None, losses: list | None = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches, calling forward only. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target values for the loss function. + losses: a list to store the losses for each microbatch. + """ + # Save the original has_backward state + original_has_backward = self._has_backward + try: + self._has_backward = False + return self.step(*args, target=target, losses=losses, **kwargs) + finally: + # Restore the original state + self._has_backward = original_has_backward + + def _check_inputs( + self, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, + ) -> tuple[list, list]: + """ + Pre-process/check inputs + """ + + def check_type_and_len(mbs, name: str): + if not isinstance(mbs, list): + raise TypeError(f"{name} must be a list but got a {type(mbs)}") + if len(mbs) != self._n_microbatches: + raise ValueError( + f"Expecting {self._n_microbatches} {name} but got {len(mbs)}" + ) + + if arg_mbs is not None: + check_type_and_len(arg_mbs, "arg_mbs") + else: + arg_mbs = [()] * self._n_microbatches + + if kwarg_mbs is not None: + check_type_and_len(kwarg_mbs, "kwarg_mbs") + else: + kwarg_mbs = [{}] * self._n_microbatches + + if target_mbs is not None: + check_type_and_len(target_mbs, "target_mbs") + + if losses is not None: + if not isinstance(losses, list): + raise TypeError(f"losses must be a list but got a {type(losses)}") + + return arg_mbs, kwarg_mbs + + def _compute_loss(self, output, target): + return self._loss_fn(output, target) # type: ignore[misc] + + def _split_inputs( + self, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + ): + """ + Splits a full-batch input into chunks (i.e. microbatches) and returns + the chunks + """ + if args or kwargs: + args_split, kwargs_split = split_args_kwargs_into_chunks( + args, + kwargs, + self._n_microbatches, + self._args_chunk_spec, + self._kwargs_chunk_spec, + ) + return args_split, kwargs_split + else: + # Empty inputs (e.g. when called on middle stages) + # Return a list of empty tuples/dicts with matching length as chunks + return [()] * self._n_microbatches, [{}] * self._n_microbatches + + def _merge_outputs(self, output_chunks: list[Any]) -> Any: + """ + Merge output chunks back to a batch state. + If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim). + """ + return merge_chunks( + output_chunks, + self._output_merge_spec, + ) + + +def _batch_p2p(p2p_ops: list[dist.P2POp], desc: str | None = None) -> list[dist.Work]: + """ + Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top. + """ + if len(p2p_ops) == 0: + return [] + desc_str = f"{desc}, " if desc else "" + logger.debug("batch_p2p %s%s", desc_str, p2p_ops) + return dist.batch_isend_irecv(p2p_ops) + + +def _sorted_batch_p2p( + p2p_ops: list[dist.P2POp], desc: str | None = None +) -> dict[int, list[dist.Work]]: + """ + Sorts the list of P2P ops by the peer rank, and then calls + batch_isend_irecv. Return a dictionary of works by peer rank. This function + helps us avoid hangs in case of skip connections. + """ + # Arrange p2p_ops by peer rank: + # int is the peer rank; + # List is the list of ops towards the peer + ops_by_peer: dict[int, list[dist.P2POp]] = defaultdict(list) + work_by_peer: dict[int, list[dist.Work]] = {} + if len(p2p_ops) == 0: + return work_by_peer + + # Classify the ops by peer rank + for op in p2p_ops: + ops_by_peer[op.peer].append(op) + + # Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs) + for peer, ops in sorted(ops_by_peer.items()): + work_by_peer[peer] = _batch_p2p(ops, desc=desc) + + return work_by_peer + + +def _wait_batch_p2p(work: list[dist.Work]): + """ + Waits for a list of dist.Work (typically from _batch_p2p / _sorted_batch_p2p). + """ + for w in work: + w.wait() + + +class PipelineScheduleSingle(_PipelineSchedule): + """ + Base class for single-stage schedules. + Implements the `step` method. + Derived classes should implement `_step_microbatches`. + + Gradients are scaled by num_microbatches depending on the `scale_grads` argument, defaulting to True. This setting + should match the configuration of your loss_fn, which may either average losses (scale_grads=True) + or sum losses (scale_grads=False). + """ + + def __init__( + self, + stage: _PipelineStageBase, + n_microbatches: int, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + scale_grads: bool = True, + ): + # Init parent + super().__init__( + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + ) + # Self attributes + self._stage = stage + self._num_stages = stage.num_stages + self._stage_forward_initialized = False + self._stage_backward_initialized = False + + if n_microbatches < self._num_stages: + raise ValueError( + f"Number of microbatches ({n_microbatches}) must be greater than \ +or equal to the number of stages ({self._num_stages})." + ) + + self.pipeline_order: dict[int, list[_Action | None]] | None = ( + self._get_pipeline_order() + ) + + def _initialize_stage(self, args, kwargs): + if not self._stage_forward_initialized: + # Prepare the communication needed for the pipeline schedule execution + # This is needed because during execution we always perform a series of batch P2P ops + # The first call of the batched P2P needs to involve the global group + all_ops: list[dist.P2POp] = [] + all_ops.extend(self._stage._get_init_p2p_neighbors_ops()) + _wait_batch_p2p(_batch_p2p(all_ops)) + + self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs) + self._stage_forward_initialized = True + + if self._has_backward and not self._stage_backward_initialized: + self._stage._prepare_backward_infra(self._n_microbatches) + self._stage_backward_initialized = True + + def step( + self, + *args, + target=None, + losses: list | None = None, + return_outputs: bool = True, + **kwargs, + ): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + return_outputs: whether to return the outputs from the last stage. + """ + if self._has_backward and not torch.is_grad_enabled(): + raise RuntimeError( + "step() requires gradients to be enabled for backward computation; " + "it should not be used under torch.no_grad() context. " + "Please call eval() instead." + ) + + # Set the same has_backward flag for stage object + self._stage.has_backward = self._has_backward + + # Clean per iteration + self._stage.clear_runtime_states() + + # Split inputs into microbatches + args_split, kwargs_split = self._split_inputs(args, kwargs) + + # Split target into microbatches + if target is not None: + targets_split = list(torch.tensor_split(target, self._n_microbatches)) + else: + targets_split = None + + # Run microbatches + self._step_microbatches( + args_split, kwargs_split, targets_split, losses, return_outputs + ) + + # Return merged results per original format + if self._stage.is_last and return_outputs: + return self._merge_outputs(self._stage.output_chunks) + else: + return None + + def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None: + """ + Returns the pipeline execution order as a schedule IR. + + The returned IR is a dictionary mapping rank IDs to lists of actions. + Each action is either an _Action object representing computation to perform, + or None representing a deliberate idle step. + + The None values are used to represent pipeline bubbles where a rank + must wait for dependencies from other ranks before proceeding. However + during execution, with the _PipelineScheduleRuntime, these Nones are + skipped since the relevant communication (send/recv) will be scheduled and waited on. + + Returns: + A dictionary mapping rank -> list of actions + """ + return None + + +class _ScheduleForwardOnly(PipelineScheduleSingle): + """ + The forward-only schedule. + Will go through all the microbatches and perform only the forward pass + """ + + def _step_microbatches( + self, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, + return_outputs: bool = True, + ): + """ + Run one iteration of the pipeline schedule + """ + if target_mbs is not None or losses is not None: + raise RuntimeError( + "Forward-only schedule does not support loss computation" + ) + + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) + + # Delay send waits + fwd_sends_to_wait: list[list[dist.Work]] = [] + + # Run microbatches + for i in range(self._n_microbatches): + with record_function(f"Forward {i}"): + ops = self._stage.get_fwd_recv_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_recv") + for work in works.values(): + _wait_batch_p2p(work) + + self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] + + ops = self._stage.get_fwd_send_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_send") + fwd_sends_to_wait.extend(works.values()) + + logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i) + + # Wait for all forward sends to finish + # This should not have performance impact because by the time the first + # backward arrives all the forward sends should have been finished. + for work in fwd_sends_to_wait: + _wait_batch_p2p(work) + + +class ScheduleGPipe(PipelineScheduleSingle): + """ + The GPipe schedule. + Will go through all the microbatches in a fill-drain manner. + """ + + def _step_microbatches( + self, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, + return_outputs: bool = True, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the GPipe schedule. + + Args: + microbatches: list of microbatch args. + return_outputs: whether to return the outputs from the last stage. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) + + # Delay send waits + fwd_sends_to_wait: list[list[dist.Work]] = [] + + # Run microbatches + for i in range(self._n_microbatches): + with record_function(f"Forward {i}"): + ops = self._stage.get_fwd_recv_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_recv") + for work in works.values(): + _wait_batch_p2p(work) + + output = self._stage.forward_one_chunk( + i, arg_mbs[i], kwarg_mbs[i], save_forward_output=return_outputs + ) # type: ignore[index] + + ops = self._stage.get_fwd_send_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_send") + fwd_sends_to_wait.extend(works.values()) + + logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i) + + self._maybe_compute_loss(self._stage, output, target_mbs, i) + + # Wait for all forward sends to finish + # This should not have performance impact because by the time the first + # backward arrives all the forward sends should have been finished. + for work in fwd_sends_to_wait: + _wait_batch_p2p(work) + + # Run backward + # Delay send waits + bwd_sends_to_wait: list[list[dist.Work]] = [] + for i in range(self._n_microbatches): + with record_function(f"Backward {i}"): + ops = self._stage.get_bwd_recv_ops(i) + works = _sorted_batch_p2p(ops, desc="bwd_recv") + for work in works.values(): + _wait_batch_p2p(work) + + loss = self._maybe_get_loss(self._stage, i) + self._stage.backward_one_chunk( + i, + loss=loss, + last_backward=i == self._n_microbatches - 1, + ) + + ops = self._stage.get_bwd_send_ops(i) + works = _sorted_batch_p2p(ops, desc="bwd_send") + bwd_sends_to_wait.extend(works.values()) + + logger.debug("[%s] Backwarded microbatch %s", self._stage.stage_index, i) + + # Wait for all backward sends to finish + for work in bwd_sends_to_wait: + _wait_batch_p2p(work) + + # Update losses if there is a container passed in + self._update_losses(self._stage, losses) + + self._stage.perform_reduce_grad(self._n_microbatches if self.scale_grads else 1) + + def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None: + """ + Returns the pipeline order for GPipe schedule. + + See base method in PipelineScheduleSingle for details on the schedule IR format. + """ + pipeline_order = {} + pp_group_size = self._num_stages + + for rank in range(pp_group_size): + actions: list[_Action | None] = [] + + # 1. Initial delay based on rank position + warmup_delay = rank + actions.extend([None] * warmup_delay) + + # 2. Forward passes for all microbatches + for mb_idx in range(self._n_microbatches): + actions.append(_Action(rank, _ComputationType.FORWARD, mb_idx)) + + # 3. Wait period before backward passes can begin + backward_delay = 3 * (pp_group_size - 1 - rank) + actions.extend([None] * backward_delay) + + # 4. Backward passes for all microbatches + for mb_idx in range(self._n_microbatches): + actions.append(_Action(rank, _ComputationType.FULL_BACKWARD, mb_idx)) + + pipeline_order[rank] = _add_reduce_grad(actions, self._n_microbatches) + + return pipeline_order # type: ignore[return-value] + + +class Schedule1F1B(PipelineScheduleSingle): + """ + The 1F1B schedule. + Will perform one forward and one backward on the microbatches in steady state. + """ + + def _step_microbatches( + self, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, + return_outputs: bool = True, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the 1F1B schedule. + + Args: + microbatches: list of microbatch args. + return_outputs: whether to return the outputs from the last stage. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) + + # Last stage has 1 warmup, second-to-last 2 warmups, ... + # first stage `num_stages` warmups + warmup_chunks = min( + self._n_microbatches, + self._num_stages - self._stage.stage_index, + ) + + # Chunk counters + fwd_mb_index = 0 + bwd_mb_index = 0 + + # Warmup phase + send_work: list[dist.Work] = [] + fwd_sends = [] + for _ in range(warmup_chunks): + # Receive activations + fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index) + _wait_batch_p2p(_batch_p2p(fwd_recvs, desc="fwd_recv")) + + # Compute + output = self._stage.forward_one_chunk( + fwd_mb_index, + arg_mbs[fwd_mb_index], + kwarg_mbs[fwd_mb_index], + save_forward_output=return_outputs, + ) # type: ignore[index] + + # Clear previous chunk's forward sends (hopefully they have well + # finished, otherwise, we are heavily communication bound, in which + # case it doesn't create a lot of benefit to compute next chunk + # eagerly either) + _wait_batch_p2p(send_work) + + # Send activations + fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index) + if fwd_mb_index != warmup_chunks - 1: + # Safe to fire + send_work = _batch_p2p(fwd_sends, desc="fwd_send") + # otherwise: + # The last forward send is left for fuse with first 1B in 1B1F below + + # Compute loss + self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) + fwd_mb_index += 1 + + # Now we should have send ops left over, to be fused with first 1B of 1B1F phase below. + + # 1B1F phase + while True: # Don't worry, we have a break inside + # We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops + bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index) + + # Now, we need to fire the fwd_sends and bwd_recvs together + _wait_batch_p2p(_batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv")) + + # Backward one chunk + loss = self._maybe_get_loss(self._stage, bwd_mb_index) + self._stage.backward_one_chunk( + bwd_mb_index, + loss=loss, + last_backward=bwd_mb_index == self._n_microbatches - 1, + ) + + # Get the bwd send ops, but don't fire, to be fused with the 1F below + bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index) + bwd_mb_index += 1 + + if fwd_mb_index == self._n_microbatches: + # We are done with 1B1F, so break with some left-over bwd_sends + break + + # We prepare 1F of the `1B1F` + fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index) + + # Fuse it with bwd_sends above + _wait_batch_p2p(_batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv")) + + # Now do the fwd + output = self._stage.forward_one_chunk( + fwd_mb_index, + arg_mbs[fwd_mb_index], + kwarg_mbs[fwd_mb_index], + save_forward_output=return_outputs, + ) # type: ignore[index] + + # Compute loss + self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) + + # Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around) + fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index) + fwd_mb_index += 1 + + # Remember we still have some bwd_sends left over after the break? Now it is time to fire it + send_work = _batch_p2p(bwd_sends, desc="bwd_send") + + # Cooldown + while bwd_mb_index < self._n_microbatches: + # prepare bwd recv ops + bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index) + _wait_batch_p2p(_batch_p2p(bwd_recvs, desc="bwd_recv")) + + # Backward one chunk + loss = self._maybe_get_loss(self._stage, bwd_mb_index) + self._stage.backward_one_chunk( + bwd_mb_index, + loss=loss, + last_backward=bwd_mb_index == self._n_microbatches - 1, + ) + + # Clear previous chunk's backward sends (hopefully they have well finished) + _wait_batch_p2p(send_work) + + # Get the bwd send ops, fire it + bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index) + send_work = _batch_p2p(bwd_sends, desc="bwd_send") + bwd_mb_index += 1 + + # Wait for the last backward send to finish + _wait_batch_p2p(send_work) + + # Return losses if there is a container passed in + self._update_losses(self._stage, losses) + + self._stage.perform_reduce_grad(self._n_microbatches if self.scale_grads else 1) + + def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None: + """ + Returns the pipeline order for 1F1B schedule. + + See base method in PipelineScheduleSingle for details on the schedule IR format. + """ + pipeline_order = {} + pp_group_size = self._num_stages + + for rank in range(pp_group_size): + actions: list[_Action | None] = [] + + # 1. Warmup phase: initial delay based on rank + actions.extend([None] * rank) + + # 2. Initial forward passes before 1F1B phase + num_forward = (pp_group_size - 1) - rank + forward_mb = 0 + for i in range(num_forward): + actions.append(_Action(rank, _ComputationType.FORWARD, i)) + forward_mb = i + + # 3. Wait for backward to be ready + wait_for_1f1b = max(0, 2 * (pp_group_size - 1 - rank)) + actions.extend([None] * wait_for_1f1b) + + # 4. 1F1B steady state phase + backward_mb = 0 + remaining_forward = self._n_microbatches - num_forward + + while remaining_forward > 0: + # One forward + forward_mb += 1 + actions.append(_Action(rank, _ComputationType.FORWARD, forward_mb)) + remaining_forward -= 1 + + # One backward + actions.append( + _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb) + ) + backward_mb += 1 + + # 5. Cooldown phase: remaining backward passes + remaining_backward = self._n_microbatches - backward_mb + + while remaining_backward > 0: + # Add None and backward actions in alternating pattern + # based on distance from the last stage + if (pp_group_size - rank) > 0: + actions.append(None) + # Decrement the wait counter only if we still have backward passes to do + if remaining_backward > 0: + actions.append( + _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb) + ) + backward_mb += 1 + remaining_backward -= 1 + else: + # If we're at the last stage, just add backward actions without None + actions.append( + _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb) + ) + backward_mb += 1 + remaining_backward -= 1 + + pipeline_order[rank] = _add_reduce_grad(actions, self._n_microbatches) + return pipeline_order + + +def _requires_reduce_grad(action_type: _ComputationType) -> bool: + return action_type in (W, B) + + +def _add_reduce_grad( + actions: list[_Action | None], n_microbatches: int +) -> list[_Action | None]: + """ + REDUCE_GRAD refers to joint across minibatches grad reduction. + reduce_grad frees memory and we want to schedule it just after the last "backward"-like stage. + """ + actions_with_reduce_grad: list[_Action | None] = [] + cnt: dict[int, int] = defaultdict(int) + + def _leaf_action(a, to_schedule): + if _requires_reduce_grad(a.computation_type): + stage_index = a.stage_index + cnt[stage_index] += 1 + if cnt[stage_index] == n_microbatches: + to_schedule.append(stage_index) + + for a in actions: + if a is None: + continue + actions_with_reduce_grad.append(a) + schedule_reduce_grad_stage_idxs: list[int] = [] + if a.computation_type == OVERLAP_F_B and a.sub_actions is not None: + for sub_action in a.sub_actions: + _leaf_action(sub_action, schedule_reduce_grad_stage_idxs) + else: + _leaf_action(a, schedule_reduce_grad_stage_idxs) + + for stage_idx in schedule_reduce_grad_stage_idxs: + actions_with_reduce_grad.append(_Action(stage_idx, REDUCE_GRAD, None)) + return actions_with_reduce_grad + + +def _add_unshard_reshard( + compute_actions: list[_Action | None], + max_active_stages: int = 3, +) -> list[_Action]: + """Given a basic schedule involving only compute actions (F,B,W,OVERLAP_F_B), add UNSHARD/RESHARD actions for FSDP. + + UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation. + RESHARD does the opposite, releasing memory (but doing no communication) + + We abandon the "timestep lock" during lowering + + max_active_stages controls how many prefetches we allow. It should be measured in mb and tuneable but in practice + 3 stages is probably the thing we want? + (to account for having one f and one b active, and something else prefetching?) + """ + + def next_stage_indices(count: int, next_actions: list[_Action | None]) -> list[int]: + """Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute.""" + seen: set[int] = set() + ret: list[int] = [] + + for a in next_actions: + if a is not None: + # Handle OVERLAP_F_B actions by checking their sub_actions + if a.computation_type == OVERLAP_F_B and a.sub_actions is not None: + for sub_action in a.sub_actions: + if sub_action.stage_index not in seen: + seen.add(sub_action.stage_index) + ret.append(sub_action.stage_index) + if len(ret) >= count: + break + else: + # Regular action + if a.stage_index not in seen: + seen.add(a.stage_index) + ret.append(a.stage_index) + if len(ret) == count: + break + return ret + + active_stages: set[int] = set() + fsdp_aware_actions: list[_Action] = [] + + def _unshard(stage_index: int): + active_stages.add(stage_index) + fsdp_aware_actions.append(_Action(stage_index, UNSHARD, None)) + + def _reshard(stage_index: int): + active_stages.remove(stage_index) + fsdp_aware_actions.append(_Action(stage_index, RESHARD, None)) + + for i, action in enumerate(compute_actions): + if action is None: + continue + + # We prefetch the next N stages we'll see, dropping existing stages to make room + next_n = next_stage_indices(max_active_stages, compute_actions[i:]) + # Fetch needs to be ordered correctly, so don't use a set + fetch = list(filter(lambda s: s not in active_stages, next_n)) + # Unclear what the best policy is for eviction, but we can maintain order so we do + evict = list(filter(lambda s: s not in next_n, active_stages)) + + # logger.debug( + # "_add_unshard_reshard Step %d active: %s fetch %s, evict %s", + # i, + # active_stages, + # fetch, + # evict, + # ) + + for stage in evict: + _reshard(stage) + for stage in fetch: + _unshard(stage) + fsdp_aware_actions.append(action) + + # Reshard all remaining active stages after processing all operations + for stage in list(active_stages): + _reshard(stage) + + return fsdp_aware_actions + + +def _merge_bw( + compute_actions: list[_Action | None], +) -> list[_Action]: + """Given a basic schedule involving only compute actions (F,I,W), merge adjacent I and W ops into B ops. + (note: I = BACKWARD_INPUT, W = BACKWARD_WEIGHT, B = FULL_BACKWARD) + + B refers to running the whole backward (not separating grad_input and grad_weight), which can be more efficient + in some cases. + """ + merged_actions = [] + while compute_actions: + action = compute_actions.pop(0) + if action is None: + continue + + # Remove any None actions and find the next non-None action + while len(compute_actions) and compute_actions[0] is None: + compute_actions.pop(0) + + # Get the next action if it exists + next_action = compute_actions[0] if len(compute_actions) > 0 else None + + if ( + action.computation_type == BACKWARD_INPUT + and next_action is not None + and next_action.computation_type == BACKWARD_WEIGHT + and action.stage_index == next_action.stage_index + and action.microbatch_index == next_action.microbatch_index + ): + merged_actions.append( + _Action(action.stage_index, FULL_BACKWARD, action.microbatch_index) + ) + compute_actions.pop(0) + else: + merged_actions.append(action) + return merged_actions + + +def _add_send_recv( + compute_actions: dict[int, list[_Action]], + stage_to_rank: Callable[[int], int], + num_stages: int, +) -> dict[int, list[_Action]]: + """ + Transforms a compute-only schedule into a complete schedule with communication actions. + + For actions with sub-actions (OVERLAP_F_B) we ensure that all the subactions have been + computed and the communication is ready + """ + comm_actions: dict[int, list[_Action]] = {rank: [] for rank in compute_actions} + prev_actions: dict[int, set[_Action]] = {rank: set() for rank in compute_actions} + + def _has_comms(action: _Action) -> bool: + if action.computation_type == F: + return action.stage_index != num_stages - 1 and stage_to_rank( + action.stage_index + 1 + ) != stage_to_rank(action.stage_index) + elif action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD): + return action.stage_index != 0 and stage_to_rank( + action.stage_index - 1 + ) != stage_to_rank(action.stage_index) + return False + + def _get_comms(action: _Action) -> tuple[_Action, _Action]: + assert _has_comms(action), f"{action} is not a valid comm action" + stage_idx = action.stage_index + ctype = action.computation_type + mb_idx = action.microbatch_index + send = _Action(stage_idx, SEND_F if ctype == F else SEND_B, mb_idx) + recv_stage_idx = stage_idx + 1 if ctype == F else stage_idx - 1 + recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx) + return send, recv + + def _ready_to_schedule(action: _Action | None, prev_actions: set[_Action]) -> bool: + """We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place. + This helps ensure a sane (non-hanging) ordering of sends and recvs. + But it also means we might not be able to schedule our next compute action yet. + """ + if action is None: + return True + elif action.computation_type == F and action.stage_index != 0: + if ( + _Action(action.stage_index, RECV_F, action.microbatch_index) + in prev_actions + ): + return True + elif ( + _Action(action.stage_index - 1, F, action.microbatch_index) + in prev_actions + ): + return True + return False + elif ( + action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD) + and action.stage_index != num_stages - 1 + ): + if ( + _Action(action.stage_index, RECV_B, action.microbatch_index) + in prev_actions + ): + return True + elif ( + _Action(action.stage_index + 1, BACKWARD_INPUT, action.microbatch_index) + in prev_actions + ): + return True + elif ( + _Action(action.stage_index + 1, FULL_BACKWARD, action.microbatch_index) + in prev_actions + ): + return True + return False + else: + return True + + while compute_actions: + progress = False + # go in order of ranks even if dict keys aren't ordered + for rank in sorted(compute_actions): + assert len(compute_actions[rank]) > 0, ( + f"{rank=}, {len(compute_actions[rank])=}" + ) + action = compute_actions[rank][0] + # handle case where parent action (e.g. OVERLAP_F_B) can be comprised of subactions + if action is not None and action.sub_actions is not None: + all_actions = action.sub_actions + else: + all_actions = (action,) + + if not all(_ready_to_schedule(a, prev_actions[rank]) for a in all_actions): + continue + + # The action's dependencies are satisfied, so add to schedule + if action is not None: + comm_actions[rank].append(action) + for a in all_actions: + prev_actions[rank].add(a) + if _has_comms(a): + send, recv = _get_comms(a) + # TODO we can avoid send/recv if the 2 stages are on the same rank. + # should we avoid that in the runtime or here? + comm_actions[rank].append(send) + prev_actions[rank].add(send) + comm_actions[stage_to_rank(recv.stage_index)].append(recv) + prev_actions[stage_to_rank(recv.stage_index)].add(recv) + + compute_actions[rank].pop(0) + if len(compute_actions[rank]) == 0: + del compute_actions[rank] + progress = True + assert progress, "Malformed compute schedule, can't schedule sends/recvs" + return comm_actions + + +def _validate_schedule( + actions: dict[int, list[_Action | None]], + pp_group_size: int, + num_stages: int, + num_microbatches: int, +) -> dict[int, int]: + assert len(actions) == pp_group_size, ( + f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}" + ) + for rank in range(pp_group_size): + assert rank in actions, f"Schedule is missing actions for rank {rank}" + + # We will count all the actions per stage and ensure they happen in a valid order + # (e.g. F before (B, I) before W for a given microbatch) + stage_actions: dict[int, dict[_ComputationType, set]] = { + stage_id: { + F: set(), + B: set(), + I: set(), + W: set(), + } + for stage_id in range(num_stages) + } + stage_index_to_rank_mapping = {} + + def _process_action(action: _Action, rank: int, step: int): + """Process a single action and update stage_actions and stage_index_to_rank_mapping""" + s_id = action.stage_index + ctype = action.computation_type + mb_id = action.microbatch_index + + if ctype == F: + stage_actions[s_id][F].add(mb_id) + elif ctype == B: + if mb_id not in stage_actions[s_id][F]: + error_msg = ( + f"Rank {rank}, step {step}: Running Full Backward for stage {s_id}, " + f"microbatch {mb_id} without first running Forward" + ) + formatted_schedule = _format_pipeline_order( + actions, error_step_number=step + ) + full_error_msg = ( + f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}" + ) + raise AssertionError(full_error_msg) + stage_actions[s_id][B].add(mb_id) + elif ctype == I: + if mb_id not in stage_actions[s_id][F]: + error_msg = ( + f"Rank {rank}, step {step}: Running Backward Input for stage {s_id}, " + f"microbatch {mb_id} without first running Forward" + ) + formatted_schedule = _format_pipeline_order( + actions, error_step_number=step + ) + full_error_msg = ( + f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}" + ) + raise AssertionError(full_error_msg) + stage_actions[s_id][I].add(mb_id) + elif ctype == W: + if mb_id not in stage_actions[s_id][I]: + error_msg = ( + f"Rank {rank}, step {step}: Running Backward Weight for stage {s_id}, " + f"microbatch {mb_id} without first running Backward Input" + ) + formatted_schedule = _format_pipeline_order( + actions, error_step_number=step + ) + full_error_msg = ( + f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}" + ) + raise AssertionError(full_error_msg) + stage_actions[s_id][W].add(mb_id) + + if s_id not in stage_index_to_rank_mapping: + stage_index_to_rank_mapping[s_id] = rank + else: + existing_rank = stage_index_to_rank_mapping[s_id] + assert rank == existing_rank, ( + f"Rank {rank}, step {step}: Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}" + ) + + for rank in actions: + for step, action in enumerate(actions[rank]): + if action is None: + continue + assert isinstance(action, _Action), ( + f"Rank {rank}, step {step}: Got an invalid action: {action}, expected instance of _Action" + ) + + # Check if action has sub_actions + if action.sub_actions is not None: + # Process each sub_action instead of the main action + for sub_action in action.sub_actions: + _process_action(sub_action, rank, step) + else: + # Process the main action normally + _process_action(action, rank, step) + + for s_id in stage_actions: + f_mb = len(stage_actions[s_id][F]) + b_mb = len(stage_actions[s_id][B]) + i_mb = len(stage_actions[s_id][I]) + w_mb = len(stage_actions[s_id][W]) + + assert f_mb == num_microbatches, ( + f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}" + ) + + assert i_mb == w_mb, ( + f"Invalid backward microbatches for stage {s_id}: I and W must have equal counts, \ + but got I={i_mb}, W={w_mb}" + ) + + assert b_mb + (i_mb + w_mb) // 2 == num_microbatches, ( + f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \ + but got B={b_mb}, I={i_mb}, W={w_mb}" + ) + return stage_index_to_rank_mapping + + +class PipelineScheduleMulti(_PipelineSchedule): + """ + Base class for multi-stage schedules. + Implements the `step` method. + + Gradients are scaled by num_microbatches depending on the `scale_grads` argument, defaulting to True. This setting + should match the configuration of your loss_fn, which may either average losses (scale_grads=True) + or sum losses (scale_grads=False). + """ + + def __init__( + self, + stages: list[_PipelineStageBase], + n_microbatches: int, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + use_full_backward: bool | None = None, + scale_grads: bool = True, + backward_requires_autograd: bool = True, + ): + # Init parent + super().__init__( + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + ) + # Self attributes + self._stages = stages + self._num_stages = stages[0].num_stages + self.pp_group_size = stages[0].group_size + self.rank = stages[0].group_rank + # Set the pipeline stage states + self.stage_index_to_group_rank = generate_stage_to_rank_mapping( + self.pp_group_size, self._num_stages + ) + for stage in self._stages: + stage.stage_index_to_group_rank = self.stage_index_to_group_rank + + self._stages_forward_initialized = False + self._stages_backward_initialized = False + + # avoid putting a reference to 'self' inside the lambda, it creates a ref cycle + has_loss: bool = self._loss_fn is not None + self._should_compute_loss = lambda stage: stage.is_last and has_loss + + # This will be set during init of derived schedules + self.pipeline_order: dict[int, list[_Action | None]] = {} + + # When using a custom backward function, we may or may not need autograd to be used + # for the backward pass. This flag is used to determine whether or torch.is_grad_enabled() + # check should be performed before the step function. + self._backward_requires_autograd = backward_requires_autograd + + if use_full_backward is not None: + logger.warning( + "Deprecation warning: 'use_full_backward' is no longer supported. " + "Simply stop passing it, and everything should still work fine." + ) + + def _initialize_stages(self, args: tuple[Any, ...], kwargs): + if not self._stages_forward_initialized: + # Prepare the communication needed for the pipeline schedule execution + # This is needed because during execution we always perform a series of batch P2P ops + # The first call of the batched P2P needs to involve the global group + all_ops: list[dist.P2POp] = [] + for stage in self._stages: + all_ops.extend(stage._get_init_p2p_neighbors_ops()) + _wait_batch_p2p(_batch_p2p(all_ops)) + + # may be 'none' value (if this stage sends its output shapes to the next stage via P2P) + # or real value (if this stage and next stage are on the same device) + next_stage_args: tuple[Any, ...] = tuple() + for stage in self._stages: + if stage.is_first: + next_stage_args = stage._prepare_forward_infra( + self._n_microbatches, args, kwargs + ) + else: + next_stage_args = stage._prepare_forward_infra( + self._n_microbatches, next_stage_args, kwargs + ) + self._stages_forward_initialized = True + + if self._has_backward and not self._stages_backward_initialized: + for stage in self._stages: + stage._prepare_backward_infra(self._n_microbatches) + self._stages_backward_initialized = True + + def _validate_and_set_stage_mapping( + self, actions: dict[int, list[_Action | None]] + ) -> None: + """ + Allocates the stage index to rank mapping which is needed for communication + """ + self.stage_index_to_group_rank = _validate_schedule( + actions, + self.pp_group_size, + self._num_stages, + self._n_microbatches, + ) + for stage in self._stages: + stage.stage_index_to_group_rank = self.stage_index_to_group_rank + + def _dump_csv(self, filename): + """Dump a CSV representation of the schedule into a file with the provided filename.""" + with open(filename, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + for rank in self.pipeline_order: + writer.writerow(self.pipeline_order[rank]) + + def _load_csv(self, filename, format="compute_only"): + """Load a CSV representation of the schedule from a file with the provided filename. + This API will most likely get renamed/refactored so is marked as internal for now. + + format must be "compute_only" for PipelineScheduleMulti. + """ + assert format == "compute_only" + with open(filename, newline="") as csvfile: + reader = csv.reader(csvfile) + for rank, row in enumerate(reader): + self.pipeline_order[rank] = [_Action.from_str(s) for s in row] + + # Validates the order of the pipeline actions and infers the stage_to_rank_mapping. + # This will overwrite the default stage_to_rank_mapping created in the constructor + self._validate_and_set_stage_mapping(self.pipeline_order) + + def step( + self, + *args, + target=None, + losses: list | None = None, + return_outputs: bool = True, + **kwargs, + ): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + return_outputs: whether to return the outputs from the last stage. + """ + if ( + self._has_backward + and self._backward_requires_autograd + and not torch.is_grad_enabled() + ): + raise RuntimeError( + "step() requires gradients to be enabled for backward computation; " + "it should not be used under torch.no_grad() context. " + "Please call eval() instead." + ) + + # Set the same has_backward flag for stage object + for stage in self._stages: + stage.has_backward = self._has_backward + + # Clean per iteration + for stage in self._stages: + stage.clear_runtime_states() + + # Split inputs into microbatches + args_split, kwargs_split = self._split_inputs(args, kwargs) + + # Split target into microbatches + if target is not None: + targets_split = list(torch.tensor_split(target, self._n_microbatches)) + else: + targets_split = None + + # Run microbatches + self._step_microbatches( + args_split, kwargs_split, targets_split, losses, return_outputs + ) + + # Return merged results per original format + for stage in self._stages: + if stage.is_last and return_outputs: + return self._merge_outputs(stage.output_chunks) + # Does not contain the last stage or we do not return output chunks + return None + + def _step_microbatches( + self, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, + return_outputs: bool = True, + ): + """ + Operate on the microbatches for looped schedules (multiple stages on each rank). + + TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does + not support models with skip connections. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + + self._initialize_stages(arg_mbs[0], kwarg_mbs[0]) + + # Based on the plan in Step 1 created in __init__: + # 2. Perform communication based on the pipeline_order + stage_index_to_stage: dict[int, _PipelineStageBase] = { + stage.stage_index: stage for stage in self._stages + } + + # determine prev_rank and next_rank based on which ranks are next to + # the stages in the pipeline_order + all_prev_ranks: set[int] = set() + all_next_ranks: set[int] = set() + for stage_index in stage_index_to_stage: + # TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections) + if stage_index > 0: + all_prev_ranks.add(self.stage_index_to_group_rank[stage_index - 1]) + if stage_index < self._num_stages - 1: + all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1]) + # count either full_backward or backward_weight together, to determine when to sync DP grads + backward_counter: Counter[int] = Counter() + for time_step, action in enumerate(self.pipeline_order[self.rank]): + try: + ops: list[dist.P2POp] = [] + if action is not None: + computation_type = action.computation_type + mb_index = action.microbatch_index + stage_index = action.stage_index + assert mb_index is not None, ( + "All currently supported action types require valid microbatch_index" + ) + if computation_type == _ComputationType.FORWARD: + # perform forward computation + stage = stage_index_to_stage[stage_index] + output = stage.forward_one_chunk( + mb_index, + arg_mbs[mb_index], + kwarg_mbs[mb_index], + save_forward_output=return_outputs, + ) + self._maybe_compute_loss(stage, output, target_mbs, mb_index) + ops.extend(stage.get_fwd_send_ops(mb_index)) + elif computation_type == _ComputationType.FULL_BACKWARD: + # perform backward computation + stage = stage_index_to_stage[stage_index] + loss = self._maybe_get_loss(stage, mb_index) + backward_counter[stage_index] += 1 + last_backward = ( + backward_counter[stage_index] == self._n_microbatches + ) + grad_scale_factor = ( + self._n_microbatches if self.scale_grads else 1 + ) + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=True, + last_backward=last_backward, + ) + if last_backward: + stage.scale_grads(grad_scale_factor) + + ops.extend(stage.get_bwd_send_ops(mb_index)) + elif computation_type == _ComputationType.BACKWARD_INPUT: + # perform backward computation + stage = stage_index_to_stage[stage_index] + loss = self._maybe_get_loss(stage, mb_index) + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=False, + last_backward=False, + ) + ops.extend(stage.get_bwd_send_ops(mb_index)) + elif computation_type == _ComputationType.BACKWARD_WEIGHT: + # perform weight update + stage = stage_index_to_stage[stage_index] + backward_counter[stage_index] += 1 + last_backward = ( + backward_counter[stage_index] == self._n_microbatches + ) + grad_scale_factor = ( + self._n_microbatches if self.scale_grads else 1 + ) + stage.backward_weight_one_chunk( + mb_index, + last_backward=last_backward, + ) + if last_backward: + stage.scale_grads(grad_scale_factor) + else: + raise ValueError(f"Unknown computation type {computation_type}") + + # Look at the neighboring ranks for this current timestep and determine whether + # this current rank needs to do any recv communication + for prev_rank in all_prev_ranks: + prev_rank_ops = self.pipeline_order[prev_rank] + prev_rank_action = None + if time_step < len(prev_rank_ops): + prev_rank_action = prev_rank_ops[time_step] + if prev_rank_action is not None: + computation_type = prev_rank_action.computation_type + mb_index = prev_rank_action.microbatch_index + stage_index = prev_rank_action.stage_index + assert mb_index is not None, ( + "All currently supported action types require valid microbatch_index" + ) + # Only handle sends for the forward from a previous rank + if computation_type == _ComputationType.FORWARD: + # If not the last stage, then receive fwd activations + if stage_index + 1 in stage_index_to_stage: + # TODO: We are assuming that stage will always receive from stage-1 + # however that is not necessarily true of get_fwd_recv_ops + stage = stage_index_to_stage[stage_index + 1] + ops.extend(stage.get_fwd_recv_ops(mb_index)) + elif computation_type in ( + FULL_BACKWARD, + BACKWARD_INPUT, + BACKWARD_WEIGHT, + ): + # Previous rank doing backward has no influence for the current rank forward recv + pass + else: + raise ValueError( + f"Unknown computation type {computation_type}" + ) + for next_rank in all_next_ranks: + next_rank_ops = self.pipeline_order[next_rank] + next_rank_action = None + if time_step < len(next_rank_ops): + next_rank_action = next_rank_ops[time_step] + if next_rank_action is not None: + computation_type = next_rank_action.computation_type + mb_index = next_rank_action.microbatch_index + stage_index = next_rank_action.stage_index + assert mb_index is not None, ( + "All currently supported action types require valid microbatch_index" + ) + # Only handle receives for the backwards from a next rank + if computation_type in (FORWARD, BACKWARD_WEIGHT): + # Next rank doing forward or weight update has no influence for the current rank backward recv + pass + elif computation_type in (BACKWARD_INPUT, FULL_BACKWARD): + # If not the first stage, then receive bwd gradients + if stage_index - 1 in stage_index_to_stage: + # TODO: We are assuming that stage will always receive from stage+1 + # however that is not necessarily true of get_bwd_recv_ops + stage = stage_index_to_stage[stage_index - 1] + ops.extend(stage.get_bwd_recv_ops(mb_index)) + else: + raise ValueError( + f"Unknown computation type {computation_type}" + ) + + # do the communication + _wait_batch_p2p(_batch_p2p(ops)) + except Exception as e: + logger.error( # noqa: G200 + "[Rank %s] pipeline schedule %s caught the following exception '%s' \ +at time_step %s when running action %s", + self.rank, + self.__class__.__name__, + str(e), + time_step, + action, + ) + logger.error( + "%s", + _format_pipeline_order( + self.pipeline_order, error_step_number=time_step + ), + ) + raise e + # Return losses if there is a container passed in + self._update_losses(self._stages, losses) + + +class _PipelineContext: + def __init__( + self, + schedule_ref: _PipelineSchedule, + arg_mbs: list[tuple] | None = None, + kwarg_mbs: list[dict] | None = None, + target_mbs: list | None = None, + losses: list | None = None, + ): + self.schedule_ref = schedule_ref + self.arg_mbs = arg_mbs + self.kwarg_mbs = kwarg_mbs + self.target_mbs = target_mbs + self.losses = losses + + +class _CustomFunctionProtocol(Protocol): + def __call__(self, action: _Action, ctx: _PipelineContext) -> None: ... + + +class _PipelineScheduleRuntime(PipelineScheduleMulti): + """ + Provides a simple runtime that requires a 'schedule IR' including specified communication operations. + + Can be instantiated directly by creating _PipelineScheduleRuntime and calling load_csv, or can be + subclassed and the subclass can be responsible for creating a schedule IR. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Action to custom function mapping + self._comp_type_to_function_map: dict[_ComputationType, Callable] = {} + # count either full_backward or backward_weight together, to determine when to sync DP grads + self.backward_counter: Counter[int] = Counter() + + # recv ops indexed by (stage_idx, mb_idx) need to be waited on before use + self.bwd_recv_ops: dict[tuple[int, int], list[dist.Work]] = {} + self.fwd_recv_ops: dict[tuple[int, int], list[dist.Work]] = {} + + # we track which stages are 'active' when used with FSDP, and wait on unshard ops before computing on stages + self.unshard_ops: dict[int, list[UnshardHandle]] = defaultdict(list) + self.unsharded_stages = set() + + def register_custom_function( + self, + computation_type: _ComputationType, + custom_function: _CustomFunctionProtocol, + ) -> None: + """ + Register a custom function to be executed for a specific computation type. + + Args: + computation_type: The computation type for which to register the custom function + custom_function: The function to execute when this computation type is encountered. + Must have signature: (action: _Action, ctx: _PipelineContext) -> None + """ + # Ensure that the computation type is valid + if computation_type not in ( + FORWARD, + FULL_BACKWARD, + BACKWARD_INPUT, + BACKWARD_WEIGHT, + OVERLAP_F_B, + UNSHARD, + RESHARD, + REDUCE_GRAD, + ): + raise ValueError( + f"Invalid computation type {computation_type}. Only FORWARD, FULL_BACKWARD, \ + BACKWARD_INPUT, BACKWARD_WEIGHT, OVERLAP_F_B, UNSHARD, RESHARD and REDUCE_GRAD are supported." + ) + + # Check if computation_type is already registered + if computation_type in self._comp_type_to_function_map: + logger.warning( + "Computation type %s is already registered. " + "Overwriting the existing custom function.", + computation_type, + ) + + self._comp_type_to_function_map[computation_type] = custom_function + + def _prepare_schedule_with_comms( + self, + actions: dict[int, list[_Action | None]], + format: str = "compute_only", + ): + """ + Given an in-memory representation for a simple compute-only schedule, lower it to a complex schedule including + communication actions. Stores the schedule in self, and must be called before running step_mo() + """ + # validate the provided actions are valid and overrides the default stage_index_to_group_rank + super()._validate_and_set_stage_mapping(actions) + + self.pipeline_order_with_comms: dict[int, list[_Action]] = {} + if format == "compute_comms": + for rank in actions: + self.pipeline_order_with_comms[rank] = [] + for action in actions[rank]: + assert action is not None + self.pipeline_order_with_comms[rank].append(action) + # TODO what level of validation should we offer for compute+comms schedule? + elif format == "compute_only": + # Validate that the schedule does not have comms already added to it + for rank, action_list in actions.items(): + for i, action in enumerate(action_list): + if action is not None and not action.is_compute_op: + raise ValueError( + f"Expected compute-only schedule but found communication action " + f"'{action}' at rank {rank}, position {i}. " + f"Communication actions (e.g. SEND_F, RECV_F, etc.) " + f"should not be present when format='compute_only'." + ) + + # Perform schedule lowering + for rank in actions: + self.pipeline_order_with_comms[rank] = _add_unshard_reshard( + actions[rank] + ) + self.pipeline_order_with_comms[rank] = _add_reduce_grad( # type: ignore[assignment] + self.pipeline_order_with_comms[rank], # type: ignore[arg-type] + self._n_microbatches, + ) + + self.pipeline_order_with_comms = _add_send_recv( + self.pipeline_order_with_comms, + stage_to_rank=lambda s: self.stage_index_to_group_rank[s], + num_stages=self._num_stages, + ) + else: + raise NotImplementedError(f"{format=} is not implemented") + + def _load_csv(self, filename: str, format: str = "compute_only"): + """Loads a csv in simple format and then lowers it to include communication actions + + format must be either "compute_only" or "compute_comms". If compute_only, the lowering passes + will automatically be run to generate a compute_comms schedule. + """ + if format == "compute_only": + # this will populate self.pipeline_order + super()._load_csv(filename) + # this will populate self.pipeline_order_with_comms + self._prepare_schedule_with_comms(self.pipeline_order) + elif format == "compute_comms": + actions = {} + with open(filename, newline="") as csvfile: + reader = csv.reader(csvfile) + for rank, row in enumerate(reader): + actions[rank] = [_Action.from_str(s) for s in row] + self._prepare_schedule_with_comms(actions, format=format) + else: + raise NotImplementedError(f"{format=} is not implemented") + + def _dump_csv(self, filename: str, format: str = "compute_comms"): + """Dump a CSV representation of the schedule into a file with the provided filename.""" + if format == "compute_only": + assert self.pipeline_order is not None, ( + "Compute only schedule must be available" + ) + with open(filename, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + for rank in self.pipeline_order: + writer.writerow(self.pipeline_order[rank]) + elif format == "compute_comms": + assert self.pipeline_order_with_comms is not None, ( + "Must initialize compute_comms schedule before dump_csv" + ) + with open(filename, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + for rank in self.pipeline_order_with_comms: + writer.writerow(self.pipeline_order_with_comms[rank]) + + def _simulate(self): + return _simulate_comms_compute( + self.pipeline_order_with_comms, + lambda s: self.stage_index_to_group_rank[s], + self._num_stages, + ) + + def _assert_unsharded(self, stage: _PipelineStageBase): + """If an unshard is active for `stage_idx`, wait() it and mark `stage_idx` unshared.""" + stage_uses_fsdp = isinstance(stage.submod, FSDPModule) + if stage_uses_fsdp: + stage_idx = stage.stage_index + if stage_idx in self.unshard_ops: + for op in self.unshard_ops[stage_idx]: + op.wait() + del self.unshard_ops[stage_idx] + self.unsharded_stages.add(stage_idx) + assert stage_idx in self.unsharded_stages, ( + f"Attempted to compute on sharded {stage_idx=}" + ) + + def _step_microbatches( + self, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, + return_outputs: bool = True, + ): + """ + Operate on the microbatches for looped schedules (multiple stages on each rank). + + TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does + not support models with skip connections. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + self._initialize_stages(arg_mbs[0], kwarg_mbs[0]) + + # Based on the plan in Step 1 created in __init__: + # 2. Perform communication based on the pipeline_order + stage_index_to_stage: dict[int, _PipelineStageBase] = { + stage.stage_index: stage for stage in self._stages + } + + assert self.pipeline_order_with_comms is not None, ( + "Must call _prepare_schedule_with_comms() before calling _step_microbatches()" + ) + + # send ops should be waited on before step() exists, mainly for hygiene + send_ops: list[list[dist.Work]] = [] + + def _perform_action(action: _Action) -> None: + comp_type = action.computation_type + mb_index: int = ( + action.microbatch_index if action.microbatch_index is not None else -1 + ) + assert mb_index >= 0 or comp_type in ( + UNSHARD, + RESHARD, + REDUCE_GRAD, + ), f"{action=} missing mb_index" + stage_idx = action.stage_index + stage = stage_index_to_stage[stage_idx] + stage_uses_fsdp = isinstance(stage.submod, FSDPModule) + # see [Note: V-schedule special case] + is_next_stage_on_this_rank = stage_idx + 1 in stage_index_to_stage + is_prev_stage_on_this_rank = stage_idx - 1 in stage_index_to_stage + + # TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections, + # since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be + # safe to use instead. + # However, I was wondering if I should avoid calling batched operators at all in the case that there is + # only one operator per batch. I could iterate through the 'fwd_send_ops' one by one and run them. + if comp_type == SEND_F: + send_ops.append(_batch_p2p(stage.get_fwd_send_ops(mb_index))) + elif comp_type == SEND_B: + send_ops.append(_batch_p2p(stage.get_bwd_send_ops(mb_index))) + elif comp_type == RECV_F: + assert ( + stage_idx, + mb_index, + ) not in self.fwd_recv_ops, ( + f"Recv twice for {stage_idx=} {mb_index=} without executing forward" + ) + self.fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p( + stage.get_fwd_recv_ops(mb_index) + ) + elif comp_type == RECV_B: + assert ( + stage_idx, + mb_index, + ) not in self.bwd_recv_ops, ( + f"Recv twice for {stage_idx=} {mb_index=} without executing backward" + ) + self.bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p( + stage.get_bwd_recv_ops(mb_index) + ) + elif comp_type == UNSHARD: + if stage_uses_fsdp: + assert ( + stage_idx not in self.unsharded_stages + and stage_idx not in self.unshard_ops + ), f"Unsharding the same {stage_idx=} twice" + for submodule in stage.submod.modules(): + if not isinstance(submodule, FSDPModule): + continue + handle = cast(UnshardHandle, submodule.unshard(async_op=True)) + self.unshard_ops[stage_idx].append(handle) + elif comp_type == RESHARD: + if stage_uses_fsdp: + assert stage_idx in self.unsharded_stages, ( + f"Resharding {stage_idx=} without unsharding" + ) + assert stage_idx not in self.unshard_ops, ( + f"Resharding {stage_idx=} before finishing unshard" + ) + for submodule in stage.submod.modules(): + if not isinstance(submodule, FSDPModule): + continue + submodule.reshard() + self.unsharded_stages.remove(stage_idx) + elif comp_type == FORWARD: + self._assert_unsharded(stage) + + if ( + not stage.is_first + # no recv op expected for V-schedule special case (see [Note: V-schedule special case]) + and not is_prev_stage_on_this_rank + ): + assert ( + stage_idx, + mb_index, + ) in self.fwd_recv_ops, ( + f"Computing {action=} before receiving input" + ) + _wait_batch_p2p(self.fwd_recv_ops.pop((stage_idx, mb_index))) + + output = stage.forward_one_chunk( + mb_index, + arg_mbs[mb_index], # type: ignore[index] + kwarg_mbs[mb_index], # type: ignore[index] + save_forward_output=return_outputs, + ) + self._maybe_compute_loss(stage, output, target_mbs, mb_index) + + # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank + # see [Note: V-schedule special case] + if is_next_stage_on_this_rank: + stage_index_to_stage[stage_idx + 1].set_local_fwd_input( + output, mb_index + ) + + elif comp_type == FULL_BACKWARD: + self._assert_unsharded(stage) + + if ( + not stage.is_last + # no recv op expected for V-schedule special case (see [Note: V-schedule special case]) + and not is_next_stage_on_this_rank + ): + assert ( + stage_idx, + mb_index, + ) in self.bwd_recv_ops, ( + f"Attempted to run compute {action=} before receiving input" + ) + _wait_batch_p2p(self.bwd_recv_ops.pop((stage_idx, mb_index))) + loss = self._maybe_get_loss(stage, mb_index) + self.backward_counter[stage_idx] += 1 + last_backward = self.backward_counter[stage_idx] == self._n_microbatches + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=True, + last_backward=last_backward, + ) + # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank + # see [Note: V-schedule special case] + if is_prev_stage_on_this_rank: + stage_index_to_stage[stage_idx - 1].set_local_bwd_input( + stage.get_local_bwd_output(mb_index), mb_index + ) + elif comp_type == BACKWARD_INPUT: + self._assert_unsharded(stage) + + if not stage.is_last and not is_next_stage_on_this_rank: + assert ( + stage_idx, + mb_index, + ) in self.bwd_recv_ops, ( + f"Attempted to run compute {action=} before receiving input" + ) + _wait_batch_p2p(self.bwd_recv_ops.pop((stage_idx, mb_index))) + loss = self._maybe_get_loss(stage, mb_index) + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=False, + last_backward=False, + ) + # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank + # see [Note: V-schedule special case] + if is_prev_stage_on_this_rank: + stage_index_to_stage[stage_idx - 1].set_local_bwd_input( + stage.get_local_bwd_output(mb_index), mb_index + ) + elif comp_type == BACKWARD_WEIGHT: + self._assert_unsharded(stage) + self.backward_counter[stage_idx] += 1 + last_backward = self.backward_counter[stage_idx] == self._n_microbatches + stage.backward_weight_one_chunk( + mb_index, + last_backward=last_backward, + ) + elif comp_type == REDUCE_GRAD: + grad_scale_factor = self._n_microbatches if self.scale_grads else 1 + stage.perform_reduce_grad(grad_scale_factor) + else: + raise ValueError(f"{action=} is unknown or unsupported") + + # count either full_backward or backward_weight together, to determine when to sync DP grads + self.backward_counter.clear() + for time_step, action in enumerate(self.pipeline_order_with_comms[self.rank]): + logger.debug( + "_PipelineScheduleRuntime running time_step %d, action %s", + time_step, + action, + ) + try: + with record_function(_get_profiler_function_name(action)): + if action.computation_type in self._comp_type_to_function_map: + ctx = _PipelineContext( + self, + arg_mbs, + kwarg_mbs, + target_mbs, + losses, + ) + self._comp_type_to_function_map[action.computation_type]( + action, ctx + ) + elif action.computation_type == OVERLAP_F_B: + assert action.sub_actions is not None, "sub_actions must be set" + for sub_a in action.sub_actions: + _perform_action(sub_a) + else: + _perform_action(action) + except Exception as e: + logger.error( + "_PipelineScheduleRuntime caught exception at step %s when running action %s. Full Schedule:", + time_step, + action, + ) + logger.error( + _format_pipeline_order( + self.pipeline_order_with_comms, # type: ignore[arg-type] + error_step_number=time_step, + ) + ) + raise e + + # Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them + while send_ops: + _wait_batch_p2p(send_ops.pop()) + + assert len(self.unshard_ops) == 0, "Unused unshard operations" + + # Return losses if there is a container passed in + self._update_losses(self._stages, losses) + + +class ScheduleLoopedBFS(_PipelineScheduleRuntime): + """ + Breadth-First Pipeline Parallelism. + See https://arxiv.org/abs/2211.05953 for details. + Similar to Interleaved 1F1B, Looped BFS supports multiple stages per rank. + What is different is that when microbatches are ready for multiple local + stages, Loops BFS will prioritizes the earlier stage, running all available + microbatches at once. + """ + + def __init__( + self, + stages: list[_PipelineStageBase], + n_microbatches: int, + loss_fn: Callable | _Loss | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + scale_grads: bool = True, + backward_requires_autograd: bool = True, + ): + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + backward_requires_autograd=backward_requires_autograd, + ) + + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: dict[int, list[_Action | None]] = {} + # ======================================================================== + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime + self._prepare_schedule_with_comms(self.pipeline_order) + + def _calculate_single_rank_operations(self, rank): + n_local_stages = len(self._stages) + stage_indices = range( + rank, self.pp_group_size * n_local_stages, self.pp_group_size + ) + + # Store the list of operations used for that rank + # Pre-padding, rank starts with no-ops based on the warmup. + rank_ops: list[_Action | None] = [None for _ in range(rank)] + + for stage_index in stage_indices: + rank_ops.extend( + _Action(stage_index, _ComputationType.FORWARD, mb_index) + for mb_index in range(self._n_microbatches) + ) + + # wait for the first backward to trickle up + # which is 2 for every hop away + post_warmup_ops = 2 * (self.pp_group_size - 1 - rank) + rank_ops.extend([None] * post_warmup_ops) + + for stage_index in reversed(stage_indices): + rank_ops.extend( + _Action(stage_index, _ComputationType.FULL_BACKWARD, mb_index) + for mb_index in reversed(range(self._n_microbatches)) + ) + return rank_ops + + +def _get_1f1b_rank_ops( + n_local_stages, + pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + num_1f1b_microbatches=0, + enable_zero_bubble=False, +): + # All stages start with handling microbatch 0 + fwd_stage_mb_index: dict[int, int] = defaultdict(int) + bwd_stage_mb_index: dict[int, int] = defaultdict(int) + weight_stage_mb_index: dict[int, int] = defaultdict(int) + + # Store the list of operations used for that rank + # Pre-padding, rank starts with no-ops based on the warmup. + rank_ops: list[_Action | None] = [None for _ in range(rank)] + # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup + # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks. + # Formula: + # pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward + # post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding) + # earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)] + # warmup_ops = calculated above + post_warmup_ops = ( + n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank) + ) - (warmup_ops + rank) + + if enable_zero_bubble: + post_warmup_ops = pp_group_size - rank - 1 + + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + + backward_op_ids = [] + weight_op_count = 0 + + FULL_BACKWARD_OR_BACKWARD_INPUT = ( + BACKWARD_INPUT if enable_zero_bubble else FULL_BACKWARD + ) + + for op in range(total_ops): + # Warmup phase + if op < warmup_ops: + fwd_stage_index = forward_stage_index(op) + # This will assign the current microbatch index and update it as well + fwd_stage_mb_index[fwd_stage_index] = ( + mb_index := fwd_stage_mb_index[fwd_stage_index] + ) + 1 + rank_ops.append( + _Action(fwd_stage_index, _ComputationType.FORWARD, mb_index) + ) + if op == warmup_ops - 1: + # This is the last step in the warmup phase, so we need to wait for the backward to trickle back up + rank_ops.extend([None] * post_warmup_ops) + # 1F1B Phase (forward and backward) + elif warmup_ops <= op < warmup_ops + fwd_bwd_ops: + fwd_stage_index = forward_stage_index(op) + fwd_stage_mb_index[fwd_stage_index] = ( + fwd_mb_index := fwd_stage_mb_index[fwd_stage_index] + ) + 1 + rank_ops.append( + _Action(fwd_stage_index, _ComputationType.FORWARD, fwd_mb_index) + ) + bwd_stage_index = backward_stage_index(op) + bwd_stage_mb_index[bwd_stage_index] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] + ) + 1 + rank_ops.append( + _Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index) + ) + backward_op_ids.append(op) + + if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches: + weight_stage_index = backward_stage_index( + backward_op_ids[weight_op_count] + ) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, + _ComputationType.BACKWARD_WEIGHT, + weight_mb_index, + ) + ) + weight_op_count += 1 + # Cooldown phase + else: + # During cooldown phase, we need steps to align with 1f1b happening in other ranks + # TODO: we don't need to always append, after all 1f1b are finished we can stop appending None + if not enable_zero_bubble: + rank_ops.append(None) + + bwd_stage_index = backward_stage_index(op) + bwd_stage_mb_index[bwd_stage_index] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] + ) + 1 + rank_ops.append( + _Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index) + ) + backward_op_ids.append(op) + + if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches: + weight_stage_index = backward_stage_index( + backward_op_ids[weight_op_count] + ) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, + _ComputationType.BACKWARD_WEIGHT, + weight_mb_index, + ) + ) + weight_op_count += 1 + + while enable_zero_bubble and weight_op_count < len(backward_op_ids): + weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count]) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, _ComputationType.BACKWARD_WEIGHT, weight_mb_index + ) + ) + weight_op_count += 1 + + return rank_ops + + +class ScheduleInterleaved1F1B(_PipelineScheduleRuntime): + """ + The Interleaved 1F1B schedule. + See https://arxiv.org/pdf/2104.04473 for details. + Will perform one forward and one backward on the microbatches in steady + state and supports multiple stages per rank. When microbatches are ready for + multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch + (also called "depth first"). + + This schedule is mostly similar to the original paper. + It differs by being relaxing the requirement of num_microbatch % pp_size == 0. + Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and + it works as long as n_microbatches % num_rounds is 0. As a few examples, support + + 1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0. + 2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0. + """ + + def __init__( + self, + stages: list[_PipelineStageBase], + n_microbatches: int, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + scale_grads: bool = True, + backward_requires_autograd: bool = True, + ): + self.pp_group_size = stages[0].group_size + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + backward_requires_autograd=backward_requires_autograd, + ) + self.n_local_stages = len(stages) + self.rank = stages[0].group_rank + self.number_of_rounds = max(1, n_microbatches // self.pp_group_size) + self.microbatches_per_round = n_microbatches // self.number_of_rounds + if n_microbatches % self.number_of_rounds != 0: + raise ValueError( + "Interleaved 1F1B requires the number of microbatches to be a " + f"multiple of the number of rounds ({self.number_of_rounds}), " + f"but got {n_microbatches}." + ) + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: dict[int, list[_Action | None]] = {} + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime + self._prepare_schedule_with_comms(self.pipeline_order) + + def _calculate_single_rank_operations(self, rank) -> list[_Action | None]: + def get_rank_warmup_ops(rank): + # Warms up operations for last stage + warmups_ops_last_stage = ( + self.n_local_stages - 1 + ) * self.microbatches_per_round + # Increment warmup operations by 2 for each hop away from the last stage + multiply_factor = 2 + warmup_ops = warmups_ops_last_stage + multiply_factor * ( + (self.pp_group_size - 1) - rank + ) + + # We cannot have more warmup operations than there are number of microbatches, so cap it there + return min(warmup_ops, self._n_microbatches * self.n_local_stages) + + warmup_ops = get_rank_warmup_ops(rank) + microbatch_ops = self.n_local_stages * self._n_microbatches + # fwd_bwd_ops should encompass the remaining forwards + fwd_bwd_ops = microbatch_ops - warmup_ops + # cooldown_ops should encompass the remaining backwards + cooldown_ops = microbatch_ops - fwd_bwd_ops + # total ops encompass both forward and backward ops + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2 + logger.debug( + "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s", + rank, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + total_ops, + ) + + # Calculates the stage index based on step and pp_group_size + def forward_stage_index(step): + # Get the local index from 0 to n_local_stages-1 + local_index = (step // self.microbatches_per_round) % self.n_local_stages + return (local_index * self.pp_group_size) + rank + + def backward_stage_index(step): + local_index = ( + self.n_local_stages + - 1 + - ((step - warmup_ops) // self.microbatches_per_round) + % self.n_local_stages + ) + return (local_index * self.pp_group_size) + rank + + return _get_1f1b_rank_ops( + self.n_local_stages, + self.pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + ) + + +class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime): + """ + The Interleaved Zero Bubble schedule. + See https://arxiv.org/pdf/2401.10241 for details. + Will perform one forward and one backward on inputs for the microbatches in steady + state and supports multiple stages per rank. Uses the backward for weights to fill in + the pipeline bubble. + + In particular this is implementing the ZB1P schedule in the paper. + """ + + def __init__( + self, + stages: list[_PipelineStageBase], + n_microbatches: int, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + scale_grads: bool = True, + backward_requires_autograd: bool = True, + ): + # TODO: we dont support input/weight backward split with torch.compile + _check_torch_compile_compatibility(stages, self.__class__.__name__) + self.pp_group_size = stages[0].group_size + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + backward_requires_autograd=backward_requires_autograd, + ) + self.n_local_stages = len(stages) + self.rank = stages[0].group_rank + self.number_of_rounds = max(1, n_microbatches // self.pp_group_size) + self.microbatches_per_round = n_microbatches // self.number_of_rounds + if n_microbatches % self.number_of_rounds != 0: + raise ValueError( + "Zero bubble requires the number of microbatches to be a " + f"multiple of the number of rounds ({self.number_of_rounds}), " + f"but got {n_microbatches}." + ) + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: dict[int, list[_Action | None]] = {} + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + # This function add bubbles to the generated schedule based on dependencies of actions + # Note that the ZB1P schedule will not require bubbles to be manually added and it is + # only useful when n_microbatches <= microbatches_per_round + self.pipeline_order = self._add_bubbles_to_actions( + self.n_local_stages * self.pp_group_size, + ) + + # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime + self._prepare_schedule_with_comms(self.pipeline_order) + + def _calculate_single_rank_operations(self, rank) -> list[_Action | None]: + def get_rank_warmup_ops(rank): + # Warms up operations for last stage + warmups_ops_last_stage = ( + self.n_local_stages - 1 + ) * self.microbatches_per_round + # Increment warmup operations by 2 for each hop away from the last stage + multiply_factor = 1 + warmup_ops = warmups_ops_last_stage + multiply_factor * ( + (self.pp_group_size - 1) - rank + ) + + # We cannot have more warmup operations than there are number of microbatches, so cap it there + return min(warmup_ops, self._n_microbatches * self.n_local_stages) + + warmup_ops = get_rank_warmup_ops(rank) + microbatch_ops = self.n_local_stages * self._n_microbatches + # fwd_bwd_ops should encompass the remaining forwards + fwd_bwd_ops = microbatch_ops - warmup_ops + # cooldown_ops should encompass the remaining backwards + cooldown_ops = microbatch_ops - fwd_bwd_ops + # total ops encompass both forward and backward ops + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2 + logger.debug( + "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s", + rank, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + total_ops, + ) + + # Calculates the stage index based on step and pp_group_size + + def forward_stage_index(step): + # Get the local index from 0 to n_local_stages-1 + local_index = (step // self.microbatches_per_round) % self.n_local_stages + return (local_index * self.pp_group_size) + rank + + def backward_stage_index(step): + local_index = ( + self.n_local_stages + - 1 + - ((step - warmup_ops) // self.microbatches_per_round) + % self.n_local_stages + ) + return (local_index * self.pp_group_size) + rank + + num_1f1b_microbatches = rank + + return _get_1f1b_rank_ops( + self.n_local_stages, + self.pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + num_1f1b_microbatches, + enable_zero_bubble=True, + ) + + def _add_bubbles_to_actions(self, num_stages_global): + actions = self.pipeline_order + + def need_bubble(stage, op, microbatch, num_stages_global, seen_ops): + if op == _ComputationType.FORWARD: + if stage != 0 and (stage - 1, op, microbatch) not in seen_ops: + return True + elif op == _ComputationType.FULL_BACKWARD: + if stage == num_stages_global - 1: + return (stage, _ComputationType.FORWARD, microbatch) not in seen_ops + return (stage + 1, op, microbatch) not in seen_ops + return False + + seen_ops: set[tuple[int, _ComputationType, int]] = set() + result: dict[int, list[_Action | None]] = {} + next_pointer: dict[int, int] = {} + bubbles_added: dict[int, int] = {} + total_bubbles_added = 0 + + for rank in range(self.pp_group_size): + result[rank] = [] + next_pointer[rank] = 0 + bubbles_added[rank] = 0 + + while True: + should_stop = True + + temp_seen_ops: set[tuple[int, _ComputationType, int]] = set() + + for rank in range(self.pp_group_size): + timestamp = next_pointer[rank] + if timestamp >= len(actions[rank]): + continue + + should_stop = False + + if actions[rank][timestamp] is not None: + temp_action = actions[rank][timestamp] + assert temp_action is not None + stage_index, op, microbatch, _ = temp_action + if not need_bubble( + stage_index, op, microbatch, num_stages_global, seen_ops + ): + result[rank].append(actions[rank][timestamp]) + if microbatch is not None: + temp_seen_ops.add((stage_index, op, microbatch)) + next_pointer[rank] += 1 + else: + result[rank].append(None) + bubbles_added[rank] += 1 + else: + next_pointer[rank] += 1 + result[rank].append(None) + + seen_ops.update(temp_seen_ops) + if should_stop: + break + + if total_bubbles_added > 0: + logger.warning( + "Non zero bubbles added: total_bubbles_added=%s bubbles_added=%s", + total_bubbles_added, + bubbles_added, + ) + return result + + +class ScheduleZBVZeroBubble(_PipelineScheduleRuntime): + """ + The Zero Bubble schedule (ZBV variant). + See https://arxiv.org/pdf/2401.10241 Section 6 for details. + + This schedules requires exactly two stages per rank. + + This schedule will perform one forward and one backward on inputs for the microbatches in steady + state and supports multiple stages per rank. Uses backward with respect to weights to fill in + the pipeline bubble. + + This ZB-V schedule would have the "zero bubble" property only if time forward == time backward input == time backward weights. + In practice, this is not likely true for real models so alternatively + a greedy scheduler could be implemented for unequal/unbalanced time. + """ + + def __init__( + self, + stages: list[_PipelineStageBase], + n_microbatches: int, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + scale_grads: bool = True, + backward_requires_autograd: bool = True, + ): + # TODO: we dont support input/weight backward split with torch.compile + _check_torch_compile_compatibility(stages, self.__class__.__name__) + self.pp_group_size = stages[0].group_size + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + backward_requires_autograd=backward_requires_autograd, + ) + self.stage_index_to_group_rank = generate_stage_to_rank_mapping( + self.pp_group_size, self._num_stages, style="v" + ) + for stage in self._stages: + stage.stage_index_to_group_rank = self.stage_index_to_group_rank + + self.n_local_stages = len(stages) + if self.n_local_stages != 2: + raise ValueError( + "ZBV requires exactly 2 stages per rank, but got " + f"{self.n_local_stages}." + ) + + self.rank = stages[0].group_rank + self.num_stages = stages[0].num_stages + + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: dict[int, list[_Action | None]] = {} + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime + self._prepare_schedule_with_comms(self.pipeline_order) + + def _calculate_single_rank_operations(self, rank) -> list[_Action | None]: + # max(2 * self.pp_group_size - 1, ...) ensure the number of microbatches is at least + # as large of the number of microbatches needed to fully utilize the pipeline + n_micro = max(2 * self.pp_group_size - 1, self._n_microbatches) + rank_ops: list[_Action | None] = [None for _ in range(rank)] + + # Forward and backward action counts for stage chunk 0 and chunk 1 + f0_cnt, f1_cnt, b0_cnt, b1_cnt = 0, 0, 0, 0 + # warm-up phase + warmup_n1 = 2 * (self.pp_group_size - rank) - 1 + stage_id_chunk0 = rank + stage_id_chunk1 = self.num_stages - 1 - rank + + for _ in range(warmup_n1): + rank_ops.append( + _Action(stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt) + ) + f0_cnt += 1 + warmup_n2 = rank + for _ in range(warmup_n2): + rank_ops.append( + _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt) + ) + f1_cnt += 1 + rank_ops.append( + _Action(stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt) + ) + f0_cnt += 1 + warmup_n3 = self.pp_group_size - rank + for _ in range(warmup_n3): + rank_ops.append( + _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt) + ) + f1_cnt += 1 + rank_ops.append( + _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt) + ) + rank_ops.append( + _Action(stage_id_chunk1, computation_type=W, microbatch_index=b1_cnt) + ) + b1_cnt += 1 + # stable phase + while f1_cnt < f0_cnt or f0_cnt < n_micro: + if f0_cnt < n_micro: + rank_ops.append( + _Action( + stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt + ) + ) + f0_cnt += 1 + rank_ops.append( + _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt) + ) + rank_ops.append( + _Action(stage_id_chunk0, computation_type=W, microbatch_index=b0_cnt) + ) + b0_cnt += 1 + + rank_ops.append( + _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt) + ) + f1_cnt += 1 + rank_ops.append( + _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt) + ) + rank_ops.append( + _Action(stage_id_chunk1, computation_type=W, microbatch_index=b1_cnt) + ) + b1_cnt += 1 + # cool-down phase + w0_cnt, w1_cnt = b0_cnt, b1_cnt + cooldown_n1 = rank + for _ in range(cooldown_n1): + rank_ops.append( + _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt) + ) + b0_cnt += 1 + rank_ops.append( + _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt) + ) + b1_cnt += 1 + cooldown_n2 = self.pp_group_size - rank + for _ in range(cooldown_n2): + rank_ops.append( + _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt) + ) + b0_cnt += 1 + rank_ops.append( + _Action(stage_id_chunk0, computation_type=W, microbatch_index=w0_cnt) + ) + w0_cnt += 1 + while w1_cnt < b1_cnt: + rank_ops.append( + _Action(stage_id_chunk1, computation_type=W, microbatch_index=w1_cnt) + ) + w1_cnt += 1 + while w0_cnt < b0_cnt: + rank_ops.append( + _Action(stage_id_chunk0, computation_type=W, microbatch_index=w0_cnt) + ) + w0_cnt += 1 + + assert w0_cnt == b0_cnt and b0_cnt == f0_cnt + assert w1_cnt == b1_cnt and b1_cnt == f1_cnt + # We use max() in the n_micro computation above, so we may need to + # remove redundant microbatches + rank_ops = [ + ( + action + if action is not None + and action.microbatch_index is not None + and action.microbatch_index < self._n_microbatches + else None + ) + for action in rank_ops + ] + return rank_ops + + +class ScheduleDualPipeV(_PipelineScheduleRuntime): + """ + The DualPipeV schedule. A more efficient schedule variant based on the + DualPipe schedule introduced by DeepSeek in https://arxiv.org/pdf/2412.19437 + + Based on the open sourced code from https://github.com/deepseek-ai/DualPipe + """ + + def __init__( + self, + stages: list[_PipelineStageBase], + n_microbatches: int, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + scale_grads: bool = True, + backward_requires_autograd: bool = True, + ): + # TODO: we dont support input/weight backward split with torch.compile + _check_torch_compile_compatibility(stages, self.__class__.__name__) + self.pp_group_size = stages[0].group_size + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + backward_requires_autograd=backward_requires_autograd, + ) + self.stage_index_to_group_rank = generate_stage_to_rank_mapping( + self.pp_group_size, self._num_stages, style="v" + ) + for stage in self._stages: + stage.stage_index_to_group_rank = self.stage_index_to_group_rank + + self.n_local_stages = len(stages) + if self.n_local_stages != 2: + raise ValueError( + "ZBV requires exactly 2 stages per rank, but got " + f"{self.n_local_stages}." + ) + if n_microbatches < self._num_stages: + raise ValueError( + "DualPipeV requires at least as many microbatches as stages, but got " + f"{n_microbatches} microbatches and {self._num_stages} stages." + ) + + self.rank = stages[0].group_rank + self.num_stages = stages[0].num_stages + + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: dict[int, list[_Action | None]] = {} + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime + self._prepare_schedule_with_comms(self.pipeline_order) + + def _calculate_single_rank_operations(self, rank) -> list[_Action | None]: + actions: list[_Action | None] = [] + counters: dict[ + tuple[int, _ComputationType], int + ] = {} # (stage_index, computation_type) -> mb_index + weight_queue = [] # Queue of (stage_index, mb_index) for pending weight actions + + num_ranks = self.pp_group_size + num_chunks = self._n_microbatches + + rank_to_stages = generate_rank_to_stage_mapping( + num_ranks, num_ranks * 2, style="v" + ) + stage0_index, stage1_index = rank_to_stages[rank] + + def increment_backward_counts(stage_index: int): + """Helper method to increment BACKWARD_INPUT and BACKWARD_WEIGHT counters when FULL_BACKWARD is used.""" + input_key = (stage_index, BACKWARD_INPUT) + weight_key = (stage_index, BACKWARD_WEIGHT) + counters[input_key] = counters.get(input_key, 0) + 1 + counters[weight_key] = counters.get(weight_key, 0) + 1 + + def add_overlap_f_b( + actions: list, + forward_stage: int, + backward_stage: int, + ): + """Helper method to add an overlapped forward+backward action which tracks microbatch index.""" + # Create new overlapped forward+backward action with sub_actions + forward_key = (forward_stage, FORWARD) + backward_key = (backward_stage, BACKWARD_INPUT) + + forward_mb = counters.get(forward_key, 0) + backward_mb = counters.get(backward_key, 0) + + sub_actions = ( + _Action(forward_stage, FORWARD, forward_mb), + _Action(backward_stage, FULL_BACKWARD, backward_mb), + ) + actions.append(_Action(-1, OVERLAP_F_B, None, sub_actions)) + + # Update counters for sub_actions + counters[forward_key] = forward_mb + 1 + increment_backward_counts(backward_stage) + + def add_action( + actions: list, + stage_index: int, + computation_type: _ComputationType, + ): + # Regular single action, for FULL_BACKWARD we only use the BACKWARD_INPUT counter + key = ( + (stage_index, computation_type) + if computation_type != FULL_BACKWARD + else (stage_index, BACKWARD_INPUT) + ) + mb_index = counters.get(key, 0) + actions.append(_Action(stage_index, computation_type, mb_index)) + + # If FULL_BACKWARD is used, just increment the separate BACKWARD_INPUT and BACKWARD_WEIGHT counters + if computation_type == FULL_BACKWARD: + increment_backward_counts(stage_index) + else: + # If BACKWARD_INPUT is updated, add corresponding weight action to queue + if computation_type == BACKWARD_INPUT: + # Add weight action to queue for later processing + weight_queue.append((stage_index, mb_index)) + counters[key] = mb_index + 1 + + def add_weight_action_if_pending(actions: list): + """Helper method to add a weight action from the queue.""" + if not weight_queue: + return # No pending weight actions, skip + # Pop the oldest weight action from the queue + actual_stage_index, weight_mb_index = weight_queue.pop(0) + actions.append( + _Action( + actual_stage_index, + BACKWARD_WEIGHT, + weight_mb_index, + ) + ) + # Update the counter for the actual stage that was processed + weight_key = (actual_stage_index, BACKWARD_WEIGHT) + counters[weight_key] = counters.get(weight_key, 0) + 1 + + # Step 1: F0 + step_1 = (num_ranks - rank - 1) * 2 + for _ in range(step_1): + add_action(actions, stage0_index, FORWARD) + + # Step 2: F0F1 + step_2 = rank + 1 + for _ in range(step_2): + add_action(actions, stage0_index, FORWARD) + add_action(actions, stage1_index, FORWARD) + + # Step 3: I1W1F1 (Use zero bubble) + step_3 = num_ranks - rank - 1 + for _ in range(step_3): + add_action(actions, stage1_index, BACKWARD_INPUT) + add_weight_action_if_pending(actions) + add_action(actions, stage1_index, FORWARD) + + # Step 4 (Main step): F0B1-F1B0 (combined, overlapped forward+backward) + step_4 = num_chunks - num_ranks * 2 + rank + 1 + for i in range(step_4): + if i == 0 and rank == num_ranks - 1: + # NOTE: We don't overlap these two chunks to further reduce bubble size. + add_action(actions, stage0_index, FORWARD) + add_action(actions, stage1_index, FULL_BACKWARD) + else: + add_overlap_f_b( + actions, + forward_stage=stage0_index, + backward_stage=stage1_index, + ) + add_overlap_f_b( + actions, + forward_stage=stage1_index, + backward_stage=stage0_index, + ) + + # Step 5: B1-F1B0 + step_5 = num_ranks - rank - 1 + for _ in range(step_5): + add_action(actions, stage1_index, FULL_BACKWARD) + add_overlap_f_b( + actions, + forward_stage=stage1_index, + backward_stage=stage0_index, + ) + + # Step 6: B1B0 (The second half of the chunks use zero bubble) + step_6 = rank + 1 + enable_zb = False + for i in range(step_6): + if i == step_6 // 2 and rank % 2 == 1: + enable_zb = True + comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD + add_action(actions, stage1_index, comp_type) + if i == step_6 // 2 and rank % 2 == 0: + enable_zb = True + comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD + add_action(actions, stage0_index, comp_type) + + # Step 7: W0B0 + step_7 = num_ranks - rank - 1 + for _ in range(step_7): + add_weight_action_if_pending(actions) + comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD + add_action(actions, stage0_index, comp_type) + + # Step 8: W0 + step_8 = rank + 1 + for _ in range(step_8): + add_weight_action_if_pending(actions) + + return actions + + +def get_schedule_class(schedule_name: str): + """ + Maps a schedule name (case insensitive) to its corresponding class object. + + Args: + schedule_name (str): The name of the schedule. + """ + schedule_map = { + "1F1B": Schedule1F1B, + "Interleaved1F1B": ScheduleInterleaved1F1B, + "GPipe": ScheduleGPipe, + "LoopedBFS": ScheduleLoopedBFS, + "InterleavedZeroBubble": ScheduleInterleavedZeroBubble, + "PipelineScheduleSingle": PipelineScheduleSingle, + "PipelineScheduleMulti": PipelineScheduleMulti, + "ZBVZeroBubble": ScheduleZBVZeroBubble, + "DualPipeV": ScheduleDualPipeV, + } + lowercase_keys = {k.lower(): k for k in schedule_map} + lowercase_schedule_name = schedule_name.lower() + if lowercase_schedule_name not in lowercase_keys: + raise ValueError( + f"Unknown schedule name '{schedule_name}'. The valid options are {list(schedule_map.keys())}" + ) + return schedule_map[lowercase_keys[lowercase_schedule_name]] + + +def _simulate_comms_compute( + pipeline_order, stage_to_rank: Callable[[int], int], num_stages: int +): + """This function dry-run simulates the actions in the schedule from the perspective of all ranks, and flags + any deadlocks caused by missing or misordered communications. It also simulates any bubbles in time where a rank + can not execute any action due to waiting for unmet dependencies. The total number of simulator steps can be used + as a metric for unit tests involving IR optimization passes as reordering and merging of IR can reduce the number + of simulated steps. + + The simulation is not high-fidelity and does not model overlapping of compute and communication, or cuda streams. + Future work may be to enhance this and model the compute time, comms overlap, and even memory. + """ + pipeline_order = { + rank: [a for a in pipeline_order[rank] if a is not None] + for rank in sorted(pipeline_order) + } + _schedule: dict[int, list[_Action | None]] = { + rank: [] for rank in sorted(pipeline_order) + } + + _prev_ops_rank: dict[int, set[_Action]] = {rank: set() for rank in _schedule} + + def add_to_schedule(rank: int, action: _Action | None): + _schedule[rank].append(action) + if action is not None: + _prev_ops_rank[rank].add(action) + + def _ready_to_schedule(action: _Action | None) -> bool: + if action is None: + return True + + stage_idx = action.stage_index + prev_ops = _prev_ops_rank[stage_to_rank(stage_idx)] + if action.computation_type == F: + if action.stage_index == 0: + return True + elif ( + _Action(action.stage_index, RECV_F, action.microbatch_index) in prev_ops + ): + return True + elif ( + _Action(action.stage_index - 1, F, action.microbatch_index) in prev_ops + ): + return True + return False + elif action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD): + if action.stage_index == num_stages - 1: + return True + if _Action(action.stage_index, RECV_B, action.microbatch_index) in prev_ops: + return True + if ( + _Action(action.stage_index + 1, BACKWARD_INPUT, action.microbatch_index) + in prev_ops + ): + return True + if ( + _Action(action.stage_index + 1, FULL_BACKWARD, action.microbatch_index) + in prev_ops + ): + return True + return False + elif action.computation_type == BACKWARD_WEIGHT: + return True + elif action.computation_type == SEND_F: + expected_f = _Action(action.stage_index, F, action.microbatch_index) + return expected_f in prev_ops + elif action.computation_type == RECV_F: + peer_stage_idx = stage_idx - 1 + expected_send = _Action(peer_stage_idx, SEND_F, action.microbatch_index) + return expected_send in _prev_ops_rank[stage_to_rank(peer_stage_idx)] + elif action.computation_type == SEND_B: + expected_b = _Action( + action.stage_index, BACKWARD_INPUT, action.microbatch_index + ) + expected_bw = _Action( + action.stage_index, FULL_BACKWARD, action.microbatch_index + ) + return expected_b in prev_ops or expected_bw in prev_ops + elif action.computation_type == RECV_B: + peer_stage_idx = stage_idx + 1 + expected_send = _Action(peer_stage_idx, SEND_B, action.microbatch_index) + return expected_send in _prev_ops_rank[stage_to_rank(peer_stage_idx)] + else: + raise ValueError(f"Unsupported action type {action}") + + while pipeline_order: + progress = False + for rank in sorted(pipeline_order): + if len(pipeline_order[rank]) == 0: + continue + + action = pipeline_order[rank][0] + if _ready_to_schedule(action): + if action is not None: + add_to_schedule(rank, action) + pipeline_order[rank].pop(0) + progress = True + else: + add_to_schedule(rank, None) + + for i in sorted(pipeline_order, reverse=True): + if len(pipeline_order[i]) == 0: + del pipeline_order[i] + + # hacky, but do a second pass to replace any 'none' at this timestep with a real action, if it got unblocked + # by one of the later ranks + for rank in sorted(pipeline_order): + if len(pipeline_order[rank]) == 0: + continue + + if _schedule[rank][-1] is not None: + continue + + action = pipeline_order[rank][0] + if _ready_to_schedule(action): + if action is not None: + _schedule[rank][-1] = action + _prev_ops_rank[rank].add(action) + pipeline_order[rank].pop(0) + + for i in sorted(pipeline_order, reverse=True): + if len(pipeline_order[i]) == 0: + del pipeline_order[i] + + if not progress: + print("WIP comms schedule:\n", _format_pipeline_order(_schedule)) + for rank in pipeline_order: + print(f"{rank=} next action= {pipeline_order[rank][0]}") + raise ValueError("Schedule is not progressing") + + return _schedule + + +def _dump_chrometrace(schedule, filename): + """ + This function dumps a schedule IR into a chrometrace format so it can be visualized. + + It is currently very basic and only serves as a graphical alternative to dumping the schedule IR as text. + + As future work we may extend this to include more accurate heuristics for durations, or let users input durations, + add 'flow events' to let the UI show the connection between sends and recvs, and model cuda streams for comm/compute + as separate streams on the chrometrace view. + """ + events = [] + for rank in sorted(schedule): + for timestep, action in enumerate(schedule[rank]): + if action is None: + continue + events.append( + { + "name": str(action), + "cat": ( + "computation" + if action.computation_type in (F, B, W) + else "communication" + ), + "ph": "X", + "pid": rank, + "tid": rank, + "ts": timestep, + "dur": 1, + } + ) + import json + + with open(filename, "w") as f: + json.dump({"traceEvents": events}, f) + + +def _check_torch_compile_compatibility( + stages: list[_PipelineStageBase], schedule_name: str +): + """ + Check if the schedule is compatible with torch.compile. + + Args: + stages: List of pipeline stages to check + schedule_name: Name of the schedule for error message + + Raises: + RuntimeError: If any stage uses torch.compile + """ + for stage in stages: + if not isinstance(stage.submod, torch.nn.Module): + continue + + for module in stage.submod.modules(): + if isinstance(module, OptimizedModule): + raise RuntimeError( + f"The {schedule_name} schedule is not supported with " + "stage modules that have used torch.compile. " + f"Found OptimizedModule in {type(module).__name__}" + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py new file mode 100644 index 0000000000000000000000000000000000000000..cc0d51020458bcfd45cdc34c45868dc374bc2564 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py @@ -0,0 +1,1588 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +import operator +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any, cast, Union + +import torch +import torch.distributed as dist +import torch.fx as fx +import torch.nn as nn +from torch._subclasses.fake_tensor import FakeTensor +from torch.distributed._composable.replicate_with_fsdp import replicate, ReplicateModule +from torch.distributed.fsdp import FSDPModule, fully_shard +from torch.fx.node import Argument, map_aggregate +from torch.nn.parallel import DistributedDataParallel +from torch.utils._pytree import tree_map_only + +from ._backward import stage_backward, stage_backward_input, stage_backward_weight +from ._debug import map_debug_info +from ._utils import flatten_args, PipeInfo, validate_tensors_metadata + + +__all__ = [ + "PipelineStage", + "build_stage", +] + +logger = logging.getLogger(__name__) + + +def _normalize_model_output_as_tuple(output: Any) -> tuple[Any]: + """[Note: pipeline model output type] + + The output of the model passed to pipelining can be any type, controlled by the user. + + However, there are 2 API surfaces that complicate this. + (1) the outputs of intermediate stages are passed via Send/Recv ops to subsequent stages. The implicit assumption + is that each element of the outputs is a tensor. Otherwise, Send/Recv would not be supported. The exception + is the last layer of the model, which can output anything any which won't be communicated via Send/Recv. + (2) the outputs of the last layer of the model are returned to the user, or, passed to the loss function. + The loss function can be written in any way, such that its inputs match the outputs of the model. + + It would be convenient if we could strictly type the output signature of the pipeline stage wrapping the model, + but we do not want to impose an unnecessary constraint on user provided models. + + Currently, we let user provided models return either a Tensor or a tuple of Tensors from each stage. Due to + torch.export tracing, compiled models may also return a list instead of a Tuple, which we will normalize back to a + tuple for consistency. + + TODO: should we be stricter about asserting that stage modules (intermediate and output) all return only Tensor + values? + """ + if type(output) is list: + # HACK: this is a hacky workaround for the fact that export creates + # output in list format + output = tuple(output) + + # Unify output form to tuple for easy correspondence with + # `act_send_info` + output_tuple = output if type(output) is tuple else (output,) + return output_tuple + + +class _RootArgPlaceholder: + """ + Placeholder for model-level inputs. + """ + + def __init__(self, tensor): + self.meta = tensor.to("meta") + + +class _RecvInfo: + """ + Represents a stage input. + """ + + def __init__( + self, + input_name: str, + source: int, + buffer: torch.Tensor, + ): + # Name of this input + self.input_name = input_name + # Stage index of the source of this input + self.source = source + # Buffer to receive the input into. + self.buffer = buffer + + def __repr__(self): + return f"_RecvInfo(input={self.input_name}, source={self.source}, shape={self.buffer.size()})" + + +# An input can be either a received activation or a model input +InputInfo = Union[_RecvInfo, _RootArgPlaceholder] + + +def _make_tensor_from_meta( + example: torch.Tensor | FakeTensor, + device: torch.device, +) -> torch.Tensor: + """ + Create a real tensor from a tensor. + """ + return torch.empty( + example.size(), + dtype=example.dtype, + layout=example.layout, + device=device, + ) + + +class _PipelineStageBase(ABC): + """ + Base class for pipeline stages. + Defines or implements common methods used by the `_PipelineStage` used by + the tracing frontend and `PipelineStage` used by manual frontend. + """ + + def __init__( + self, + submodule: torch.nn.Module, + stage_index: int, + num_stages: int, + device: torch.device, + group: dist.ProcessGroup | None = None, + dw_builder: Callable[[], Callable[..., None]] | None = None, + ): + """ + Args: + submodule (torch.nn.Module): The module to be executed in this stage. + stage_index (int): The index of this stage. + num_stages (int): The total number of stages in this pipeline. + device (torch.device): The device to run this stage on. + group (Optional[dist.ProcessGroup]): The process group to use for communication. + If `None`, the default process group will be used. + Default: `None`. + dw_builder (Optional[Callable[[], Callable[..., None]]): If provided, dw_builder is a builder function + that will build a new dw_runner function that will run parts of module backward that were intentionally + skipped during the module's actual backward pass. The builder must be invoked by stage after stage runs + model backwards, and stage should save the latest dw_runner to run during weight pas (W). + If not provided, a dw_runner will be generated automatically by traversing the autograd graph. + When used with schedules that only have F and B steps, the fresh dw_runner function will be called as + part of I (input backwards). When used with F,I,W schedules, the dw_runner function implements 'W'. + """ + super().__init__() + if stage_index >= num_stages: + raise ValueError( + f"Stage index {stage_index} is out of range of {num_stages}" + ) + + self.submod = submodule + self.stage_index = stage_index + self.num_stages = num_stages + # pyrefly: ignore [read-only] + self.device = device + self.group = group + + self.dw_builder = dw_builder + + # backward state + self.backward_state: dict[int, tuple[Any, ...]] = {} + + # store dw_runner per microbatch_id + self.dw_runner: dict[int, Callable[..., None]] = {} + + # `group_rank` is rank in process group `group`. + self.group_rank = dist.get_rank(self.group) + self.group_size = dist.get_world_size(self.group) + if self.group_size > self.num_stages: + raise RuntimeError( + f"Pipeline group size {self.group_size} cannot be larger than number of stages {self.num_stages}" + ) + + # Run time states + self._outputs_meta: tuple[torch.Tensor, ...] | None = None + # map microbatch ID to list of forward tensor args + self.fwd_cache: dict[int, tuple[Any, list[torch.Tensor]]] = {} + # map microbatch ID to list of backward grad tensor args + self.bwd_cache: dict[int, tuple[torch.Tensor | None, ...]] = {} + # Caching chunk outputs for final output merge or reduction + self.output_chunks: list[Any] = [] + + # Initialize has_backward to false; this will be set to true if loss + # function is passed to pipeline schedule + self.has_backward = False + # Log prefix + self.log_prefix = f"[Stage {self.stage_index}]" + + # Forward infra + self.args_recv_info: dict[int, tuple[InputInfo, ...]] = {} + self.act_send_info: dict[int, list] = {} + + # Backward infra will created lazily + self.grad_recv_info: dict = {} + self.grad_send_info: list | None = None + + # To be populated later by the Schedule + self.chunks: int | None = None + self.stage_index_to_group_rank: dict[int, int] = { + i: i % self.group_size for i in range(self.num_stages) + } + + @property + def has_backward(self) -> bool: + """ + Returns true if this stage has a backward pass. + """ + return self._has_backward + + @has_backward.setter + def has_backward(self, has_backward: bool): + self._has_backward = has_backward + + @property + def is_first(self): + """ + Returns true if this stage is the first stage in the pipeline. + """ + return self.stage_index == 0 + + @property + def is_last(self): + """ + Returns true if this stage is the last stage in the pipeline. + """ + return self.stage_index == self.num_stages - 1 + + def _check_chunk_id(self, chunk_id: int): + if self.chunks is None: + raise RuntimeError( + "Attempted to access chunk_id before chunks have been configured." + ) + if chunk_id >= self.chunks: + raise RuntimeError( + f"Chunk id {chunk_id} is out of range [0, {self.chunks})" + ) + + def _configure_outputs_meta(self, outputs_meta: tuple[torch.Tensor, ...]): + """ + Track the output shapes/dtype of this stage since they determine the send operation(s) which must match + recv operations of the next stage. The next stage _will_ be freezing its recv buffers based on its initial + configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches + which could show up as hangs, silent corruption, or other errors. + """ + assert self._outputs_meta is None, ( + "Attempting to reconfigure output_meta, which is not supported" + ) + self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment] + + def get_outputs_meta(self) -> tuple[torch.Tensor, ...]: + """Get the output metadata (meta tensors) representing the outputs of this stage""" + assert self._outputs_meta is not None, ( + "Attempted to get_outputs_meta() without configuring output meta" + ) + return self._outputs_meta + + def _create_grad_send_info( + self, + args_recv_info: tuple, + ) -> list[int | None]: + """ + Create a list of stage indices to send gradients to. + """ + grad_send_info: list[int | None] = [] + + def map_recv_to_send(a): + # Note: we send gradients back to previous stage as long as in + # forward it is a received input, regardless of whether it requires + # grad. It is up to the previous stage to discard this gradient. + if isinstance(a, _RecvInfo): + grad_send_info.append(a.source) + return a.source + else: + grad_send_info.append(None) + return None + + map_aggregate(args_recv_info, map_recv_to_send) + + logger.debug("%s Grad send info: %s", self.log_prefix, grad_send_info) + return grad_send_info + + @abstractmethod + def _prepare_forward_infra( + self, + num_microbatches: int, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + ) -> tuple[Any, ...]: + raise NotImplementedError + + def _prepare_backward_infra(self, num_microbatches: int): + # TODO: this is needed for backward_maybe_with_nosync + self.chunks = num_microbatches + + for mb_index in range(num_microbatches): + # `grad_recv_info` is a mirror of `act_send_info` + self.grad_recv_info[mb_index] = self._create_grad_recv_info( + self.act_send_info + ) + + @abstractmethod + def _create_grad_recv_info( + self, + act_send_info: dict, + ) -> tuple[_RecvInfo, ...]: + raise NotImplementedError + + def _get_recv_ops( + self, + recv_infos: tuple[InputInfo, ...], + ) -> list[dist.P2POp]: + """ + Helper function shared by `get_fwd_recv_ops` and `get_bwd_recv_ops`. + Returns a list of ops that correspond to the recv infos. + """ + ops: list[dist.P2POp] = [] + for info in recv_infos: + if not isinstance(info, _RecvInfo): + continue + + peer_rank = self.stage_index_to_group_rank[info.source] + peer_global_rank = ( + peer_rank + if self.group is None + else dist.get_global_rank(self.group, peer_rank) + ) + ops.append( + dist.P2POp(dist.irecv, info.buffer, peer_global_rank, self.group) + ) + + return ops + + """[Note: V-schedule special case] + + V-Schedules have a special case where 2 stages with adjacent stage_id are on the same rank. + + ex: 2 ranks, 4 stages forms a simple V: + rank0: stage 0 stage 3 + rank1: stage 1 stage 2 + + stage 0,1 and 2,3 communicate activations using send/recv as usual, but stage 1,2 do not need to + use communication ops. Instead, they should pass tensor data directly via function call. + + set_local_fwd_input and (get_local_bwd_output + set_local_bwd_input) facilitate this optimization, and + should be called at the appropriate time during the pipeline schedule (after forward or backward execution). + """ + + def set_local_fwd_input(self, prev_stage_outputs: Any, mb_index: int) -> None: + """ + Moves 'prev_stage_outputs' from another stage on the same rank into place as inputs for this stage. Avoids + copying tensor data or using send/recv op. Detaches original tensor and sets requires_grad so the + tensor can serve as a leaf for autograd and gradients can be collected from it during backward. + """ + recv_infos: tuple[InputInfo, ...] = self.args_recv_info[mb_index] + + # See [Note: pipeline model output type] + prev_stage_outputs = _normalize_model_output_as_tuple(prev_stage_outputs) + + for info, tensor in zip(recv_infos, prev_stage_outputs): + assert isinstance(tensor, torch.Tensor), ( + f"expected tensor values as outputs from prev stage, got {type(tensor)}" + ) + assert isinstance(info, _RecvInfo), ( + "set_local_Fwd_input should only be called on non-first stage, which should always have RecvInfo" + ) + + # We don't need to do a data copy here, since we can directly pass the activation tensor reference from + # one stage to the next. However, we do need to mark the activation as a leaf tensor since it will serve + # as the input tensor for a fresh autograd graph, not part of the previous stage's autograd graph. + # TODO: confirm, do we use this activation as the root of the backward call for the previous stage? does + # detach have any affect on that? + info.buffer = tensor.detach().requires_grad_(True) + + def get_local_bwd_output(self, mb_index): + """ + Returns the input grad tensors for this stage, which correspond to the stage inputs during forward. + """ + assert self.has_backward, ( + "can't steal_bwd_input if this stage doesn't have backward" + ) + assert not self.is_first, "can't get bwd output if this stage is first" + + self._check_chunk_id(mb_index) + return self.bwd_cache.pop(mb_index) + + def set_local_bwd_input( + self, next_stage_bwd_outputs: tuple[torch.Tensor | None, ...], mb_index: int + ) -> None: + """ + Moves 'grad input' tensors from the next stage to 'grad_output' on this stage, avoiding a copy or send/recv. + Does not detach or set '_requires_grad'. + """ + assert isinstance(next_stage_bwd_outputs, tuple), ( + f"Expected tuple, got {type(next_stage_bwd_outputs)}" + ) + + assert self.has_backward, ( + "can't set bwd input if this stage doesn't have backward" + ) + assert not self.is_last, "can't set bwd input if this stage is last" + recv_infos = self.grad_recv_info[mb_index] + for info, tensor in zip(recv_infos, next_stage_bwd_outputs): + assert isinstance(tensor, torch.Tensor), ( + f"expected tensor values as outputs from prev stage, got {type(tensor)}" + ) + assert isinstance(info, _RecvInfo), ( + f"Expected a recv info, got {type(info)}" + ) + info.buffer = tensor + + def get_fwd_recv_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]: + """ + Returns a list of ops that are needed to receive the input arguments + for this stage. + """ + recv_infos: tuple[InputInfo, ...] = self.args_recv_info[fwd_chunk_id] + + return self._get_recv_ops(recv_infos) + + def get_bwd_recv_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]: + """ + Returns a list of ops that are needed to receive the gradients + for this stage. + """ + if not self.has_backward or self.is_last: + return [] + + recv_infos = self.grad_recv_info[bwd_chunk_id] + return self._get_recv_ops(recv_infos) + + def get_fwd_send_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]: + """ + Get the activation send ops for current stage's forward. + """ + output_tuple, _ = self.fwd_cache[fwd_chunk_id] + + ops: list[dist.P2POp] = [] + + for idx, out in enumerate(output_tuple): + dst_stages = self.act_send_info[idx] + for dst in dst_stages: + if dst is None: + continue + logger.debug( + "%s Sending tensor to Stage %s: %s", + self.log_prefix, + dst, + out.size(), + ) + peer_rank = self.stage_index_to_group_rank[dst] + peer_global_rank = ( + peer_rank + if self.group is None + else dist.get_global_rank(self.group, peer_rank) + ) + ops.append(dist.P2POp(dist.isend, out, peer_global_rank, self.group)) + + return ops + + def get_bwd_send_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]: + """ + Get the gradient send ops for current stage's backward. + """ + if not self.has_backward or self.is_first: + return [] + + self._check_chunk_id(bwd_chunk_id) + # Create bwd send infra lazily + if self.grad_send_info is None: + # Send info for input grads during backward: + # List of destinations corresponding to input grads + # Can be None if an input has no grad + # `grad_send_info` is a mirror of `args_recv_info` + self.grad_send_info = self._create_grad_send_info(self.args_recv_info[0]) + + ops: list[dist.P2POp] = [] + grads_input = self.bwd_cache.pop(bwd_chunk_id) + for grad, grad_recv_stage in zip(grads_input, self.grad_send_info): + if isinstance(grad, torch.Tensor) and grad_recv_stage is not None: + logger.debug( + "%s Sending gradient to Stage %s: %s", + self.log_prefix, + grad_recv_stage, + grad.size(), + ) + peer_rank = self.stage_index_to_group_rank[grad_recv_stage] + peer_global_rank = ( + peer_rank + if self.group is None + else dist.get_global_rank(self.group, peer_rank) + ) + ops.append(dist.P2POp(dist.isend, grad, peer_global_rank, self.group)) + else: + if not (grad is None and grad_recv_stage is None): + raise RuntimeError( + f"[{self.stage_index}] for chunk {bwd_chunk_id} has gradients {grad} " + f"and is expecting to send gradients to stage {grad_recv_stage}" + ) + return ops + + def clear_runtime_states(self) -> None: + """ + Clear runtime states of the stage. + """ + # map microbatch ID to list of forward tensor args + self.fwd_cache.clear() + # Caching chunk outputs for final output merge or reduction + self.output_chunks.clear() + + # Clear grad of input buffers in between schedule steps. This is because + # `torch.autograd.backward()` will accumulate gradients into leaf + # tensors by default. For gradients to pass back to previous stages, we + # don't want such accumulation. + for recv_tuple in self.args_recv_info.values(): # iterate over all chunks + for a in recv_tuple: # iterate over all input args + if isinstance(a, _RecvInfo): + # Set to None is the newer and recommended way to clear grads, compared to `zero_()`. + # See https://github.com/pytorch/pytorch/pull/92731 + a.buffer.grad = None + + def _map_tensor_from_recv_info( + self, + recv_infos: tuple[InputInfo, ...], + ): + """ + Map tensors from recv infos to a list. + """ + + def get_recv_tensor(info): + if isinstance(info, _RecvInfo): + return info.buffer + else: + raise AssertionError(f"Expected _RecvInfo but got {type(info)}") + + return map_aggregate(cast(Argument, recv_infos), get_recv_tensor) + + def _retrieve_recv_activations(self, fwd_chunk_id: int): + """ + Retrieve the activations received for the current stage during forward. + """ + recv_infos = self.args_recv_info[fwd_chunk_id] + activations = self._map_tensor_from_recv_info(recv_infos) + return activations + + def _retrieve_recv_grads( + self, + bwd_chunk_id: int, + ): + """ + Retrieve the gradients received for the current stage during backward. + """ + recv_infos = self.grad_recv_info[bwd_chunk_id] + grads = self._map_tensor_from_recv_info(recv_infos) + return grads + + def forward_maybe_with_nosync(self, *args, **kwargs): + # If submod is wrapped with DDP, we use the `no_sync` context manager to + # avoid gradient all-reduce per microbatch + if isinstance(self.submod, DistributedDataParallel): + with self.submod.no_sync(): # type: ignore[operator] + out_val = self.submod(*args, **kwargs) + else: + out_val = self.submod(*args, **kwargs) + return out_val + + def scale_grads(self, grad_scale_factor: int) -> None: + """Scale gradients model gradients by `grad_scale_factor`, which should be specified in coordination with the + loss function used with pipelining. For loss functions which perform 'mean' loss reduction, `grad_scale_factor` + should be set to num_microbatches. For loss functions that use `sum` reduction, `grad_scale_factor` should + be set to 1. + + Should only be called once per pipeline schedule step, after all backwards passes have completed. + """ + + # PP scales only for its own contribution (microbatches), but relies on DP to scale further + # for DP degree. + if grad_scale_factor != 1: + for p in self.submod.parameters(): + if p.grad is not None: + p.grad.div_(grad_scale_factor) + + def backward_maybe_with_nosync( + self, + backward_type, + bwd_kwargs: dict, + last_backward: bool = False, + ) -> tuple[tuple[torch.Tensor | None, ...], list[dict[str, Any]] | None]: + """ + Whether using PP with FSDP, DDP, or replicate there are some runtime differences between the last backward step and the + other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but + there are additional state-variables and performance considerations depending on the data parallelism used. + This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries. + """ + + def perform_backward( + backward_type, + ) -> Callable[ + [], + tuple[tuple[torch.Tensor | None, ...], list[dict[str, Any]] | None], + ]: + if backward_type == "full": + return lambda: ( + stage_backward( + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + bwd_kwargs["input_values"], + ), + None, + ) + elif backward_type == "input": + return lambda: stage_backward_input( + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + bwd_kwargs["input_values"], + self.submod.parameters(), + ) + elif backward_type == "weight": + return lambda: ( + stage_backward_weight( + self.submod.parameters(), bwd_kwargs["param_groups"] + ), + None, + ) + else: + raise RuntimeError(f"Unknown backward type: {backward_type}") + + # If submod is wrapped by DDP + if isinstance(self.submod, DistributedDataParallel): + if last_backward: + # Last chunk, prepare for gradient reduction + # HACK: reaching into DDP implementation details here. Is there a better way? + self.submod.reducer.prepare_for_backward( # type: ignore[union-attr, operator] + list( + torch.nn.parallel.distributed._find_tensors( # type: ignore[attr-defined] + bwd_kwargs["stage_output"] + ) + ) + ) + result = perform_backward(backward_type)() + else: + with self.submod.no_sync(): # type: ignore[operator] + result = perform_backward(backward_type)() + + # If submod is a FSDP or replicate module + elif isinstance(self.submod, FSDPModule): + self.submod.set_is_last_backward(False) + self.submod.set_reshard_after_backward(False) + self.submod.set_requires_gradient_sync(False) + result = perform_backward(backward_type)() + + else: + # Non-DP submodule, regular backward + result = perform_backward(backward_type)() + + grads, param_groups = result + return grads, param_groups + + def forward_one_chunk( + self, + fwd_chunk_id: int, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + save_forward_output: bool = True, + ): + """ + Perform forward pass on the stage with one microbatch. + `args` and `kwargs` are the inputs from *external* to this stage. + As of Sept 2024: + - `args` applies to the first stage only, other stages receives args + through activation transmission. + - `kwargs` can be passed to all stages via respective `step` calls. + """ + + if self.is_first: + # First stage doesn't need to receive anything + composite_args = args + else: + # Receive activations for this chunk + # Activations only come in args form + composite_args = self._retrieve_recv_activations(fwd_chunk_id) + + composite_kwargs = kwargs or {} + + self._validate_fwd_input(args, kwargs) + + # Compute forward + try: + output = self.forward_maybe_with_nosync(*composite_args, **composite_kwargs) + + except Exception as e: + exc_msg = f""" + {self.log_prefix} failed to run forward: + args: {map_debug_info(composite_args)} + kwargs: {map_debug_info(composite_kwargs)} + """ + raise RuntimeError(exc_msg) from e + + # See [Note: pipeline model output type] + output_tuple = _normalize_model_output_as_tuple(output) + + # Prepare for final output merge or reduction + # Output chunks is only used for the last stage since we only merge the output of the last stage + if self.is_last and save_forward_output: + self.output_chunks.append(output) + # Save activations and inputs for backward + flat_args = flatten_args(composite_args) + flat_kwargs = flatten_args(composite_kwargs) + flatten_input_tensors = flat_args + flat_kwargs + self.fwd_cache[fwd_chunk_id] = ( + output_tuple, # stage_output + flatten_input_tensors, # input_values + ) + + logger.debug( + "%s Forwarded chunk %s, outputs: %s", + self.log_prefix, + fwd_chunk_id, + map_debug_info(output), + ) + self._validate_fwd_outputs(output_tuple) + + # We return the original user-provided output, not normalized to tuple. + # See [Note: pipeline model output type] + return output + + def backward_one_chunk( + self, + bwd_chunk_id: int, + loss=None, + full_backward: bool = True, + last_backward=False, + ): + """ + Perform backward pass on the module. + This should only be called once per microbatch. + + If full_backward is True (the default), the full backward pass including weight and input gradients will be run, + and it is an error to call `backward_weight_one_chunk` for this bwd_chunk_id. + + If full_backward is False, it is optional that `dw_runner` was provided to the PipelineStage at __init__ time, + and a subsequent call to `backward_weight_one_chunk` is required to invoke dw_runner and complete the backward. + + last_backward is controlled by the schedule and signals synchronization of gradients across DP groups + after the last backward. + """ + # skip backward computation if backward is not enabled + if not self.has_backward: + return + + self._check_chunk_id(bwd_chunk_id) + + ( + stage_output, + input_values, + ) = self.fwd_cache.pop(bwd_chunk_id) + + # Compute backward + if self.is_last: + # Last stage computes gradients from loss and has no gradients from + # next stage + bwd_kwargs = { + "stage_output": loss, + "output_grads": None, + "input_values": input_values, + } + else: + # Otherwise, receive gradients from next stage + grads_output = self._retrieve_recv_grads(bwd_chunk_id) + # If an input to the pipeline requires gradient, + # `torch.autograd.backward` will accumulate the gradient into the + # `.grad` field of such input + bwd_kwargs = { + "stage_output": stage_output, + "output_grads": grads_output, + "input_values": input_values, + } + + grads_input: tuple[torch.Tensor | None, ...] = () + + # Custom backward function + if self.dw_builder: + # TODO: We may want to change our semantics so we are allowed to ignore + # the 'dw_builder' and call full_backward directly when it is a full_backward op. + grads_input, _ = self.backward_maybe_with_nosync( + "full", + bwd_kwargs, + last_backward=last_backward, + ) + if full_backward: + self.dw_builder()() + else: + self.dw_runner[bwd_chunk_id] = self.dw_builder() + else: + if full_backward: + grads_input, _ = self.backward_maybe_with_nosync( + "full", bwd_kwargs, last_backward=last_backward + ) + else: + param_groups: list[dict[str, Any]] | None = None + # Skip the backward for the first stage since we will perform the weight update with + # autograd.backward in backward_weight_one_chunk + if not self.is_first: + if isinstance(bwd_kwargs["stage_output"], torch.Tensor): + bwd_kwargs["stage_output"] = (bwd_kwargs["stage_output"],) + + # perform the partial backwards for the inputs with a custom backward function + # when the "stage_ouput" is a loss, then it is a tensor, otherwise it is a tuple of tensors + grads_input, param_groups = self.backward_maybe_with_nosync( + "input", bwd_kwargs, last_backward=last_backward + ) + + # TODO: we dont need to save this, add to dw_runner? + self.backward_state[bwd_chunk_id] = ( + bwd_kwargs["input_values"], + param_groups, + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + ) + # Save a placeholder for the dw_runner + self.dw_runner[bwd_chunk_id] = lambda: None + + self.bwd_cache[bwd_chunk_id] = grads_input + + if self.is_last and not self.is_first: + # Autograd dependencies: + # rest_of_autograd_graph -> stage_output -> loss + # stage_output is no longer used in the last stage for backward and only needed + # to return to the user in merge_output_chunks, therefore + # this should be detached to release autograd graph context and free memory earlier + for t in stage_output: + if not t._is_view(): # views are not detachable in-place + t.detach_() + + logger.debug("%s Backwarded chunk %s", self.log_prefix, bwd_chunk_id) + + def backward_weight_one_chunk(self, bwd_chunk_id: int, last_backward=False): + # skip backward computation if backward is not enabled + if not self.has_backward: + return + + assert bwd_chunk_id in self.dw_runner, ( + f"{self.log_prefix} Attempted to run backward_weight_one_chunk for chunk {bwd_chunk_id}" + " without first calling `backward_one_chunk(full_backward=False)`" + ) + + if self.dw_builder is not None: + self.dw_runner.pop(bwd_chunk_id)() + else: + ( + input_values, + param_groups, + stage_output, + output_grads, + ) = self.backward_state.pop(bwd_chunk_id) + + if self.stage_index != 0: + bwd_kwargs = { + "stage_output": stage_output, + "param_groups": param_groups, + } + self.backward_maybe_with_nosync( + "weight", bwd_kwargs, last_backward=last_backward + ) + else: + # TODO: figure out a better way to do this: + # if inputs does not require gradient, + # then the parameter group will not be fully captured during stage_backward_input + # in this case, we need call grad directly on the parameters + # To solve: make input fn do the intersect compute and then finish it off during W + bwd_kwargs = { + "stage_output": stage_output, + "output_grads": output_grads, + "input_values": input_values, + } + self.backward_maybe_with_nosync( + "full", bwd_kwargs, last_backward=last_backward + ) + + def _validate_fwd_input(self, args, kwargs): + """Raises a RuntimeError if shapes of input args/kwargs do not match the shapes configured for this stage.""" + + if self.is_first: + # TODO why is there a separate recv_info for each pipeline chunk? + # kwen2501: to avoid passing a `fwd_chunk_id` to this function, we + # check all chunks against args_recv_info[0] + expected_args = self.args_recv_info[0] + else: + # We don't check inputs for non-0 stages assuming they don't accept + # user inputs in canonical pipeline scenarios + return + + if len(kwargs): + # TODO- need a mapping of kwarg to position in self.args_recv_info + # Without it, we are not 100% sure how to match the args and + # expected_args. + return + + # TODO- need a mapping of kwarg to position in self.args_recv_info + # maybe it's impossible to tell whether the len mismatches because + # (a) the user passed an extra arg or missed an arg + # (b) the user did not pass a kwarg, which has a default value baked into expected_args + expected_tensors_meta = [ + e.meta if isinstance(e, _RootArgPlaceholder) else e.buffer + for e in expected_args + ] + validate_tensors_metadata( + f"Stage {self.stage_index} forward inputs", expected_tensors_meta, args + ) + + def _validate_fwd_outputs(self, outputs: tuple[torch.Tensor, ...]): + """Raises a RuntimeError if this stage produces an output of unexpected shape/dtype. + Most likely, this could be cause either by incorrect user specification of output shapes, or because + shape inference was done on the original model but then at runtime the model is wrapped with something like + mixed precision which changes output dtype. + """ + expected_tensors_meta = self.get_outputs_meta() + validate_tensors_metadata( + f"Stage {self.stage_index} forward outputs", expected_tensors_meta, outputs + ) + + def _get_init_p2p_neighbors_ops(self) -> list[dist.P2POp]: + """ + Get the operations to initialize the p2p communicators between previous and next stages. + This is done so by creating a dummy tensor and sending it to the next stage and receiving + from the previous stage. + """ + ops: list[dist.P2POp] = [] + next_stage_peer_rank = self.stage_index_to_group_rank.get(self.stage_index + 1) + prev_stage_peer_rank = self.stage_index_to_group_rank.get(self.stage_index - 1) + + recv_tensor = torch.zeros(1, device=self.device, dtype=torch.float32) + send_tensor = torch.tensor( + self.stage_index, device=self.device, dtype=torch.float32 + ) + # forward + if not self.is_first: + ops.append( + dist.P2POp( + dist.irecv, + recv_tensor, + group_peer=prev_stage_peer_rank, + group=self.group, + ) + ) + if not self.is_last: + ops.append( + dist.P2POp( + dist.isend, + send_tensor, + group_peer=next_stage_peer_rank, + group=self.group, + ) + ) + + # backward + if not self.is_first: + ops.append( + dist.P2POp( + dist.isend, + send_tensor, + group_peer=prev_stage_peer_rank, + group=self.group, + ) + ) + if not self.is_last: + ops.append( + dist.P2POp( + dist.irecv, + recv_tensor, + group_peer=next_stage_peer_rank, + group=self.group, + ) + ) + + return ops + + def perform_reduce_grad(self, grad_scale_factor: int): + """ + Called as a part of schedule IR. + REDUCE_GRAD action is scheduled after all microbatches W, B actions. + + Currently contains "post_backward" functionality for FSDP. + We can try to extract post_backward in a separate IR action in future. + """ + # Manually call post backward for FSDP + if isinstance(self.submod, FSDPModule): + fsdp_module = self.submod + fsdp_module.set_is_last_backward(True) + fsdp_module.set_reshard_after_backward(True) + fsdp_module.set_requires_gradient_sync(True) + + if isinstance(fsdp_module, ReplicateModule): + distributed_state = replicate.state(fsdp_module) # type: ignore[arg-type] + else: + distributed_state = fully_shard.state(fsdp_module) # type: ignore[attr-defined] + + for state in distributed_state._state_ctx.all_states: + if state._fsdp_param_group: + state._fsdp_param_group.post_backward() + + # it would be much better if pipelining backward invoked .backward so autograd hooks + # worked and modules like DDP/FSDP behaved as expected. Working around this for the time being, + # we need to call this too to ensure FSDP syncs its grad reduction ops back to the default stream. + distributed_state._root_post_backward_final_callback() + # Call gradient scaling at the end of the backward pass + # NOTE: this must happen after FSDP post_backward is FSDP is enabled + if grad_scale_factor != 1: + self.scale_grads(grad_scale_factor) + + +class _PipelineStage(_PipelineStageBase): + def __init__( + self, + stage_module: torch.nn.Module, + stage_index: int, + pipe_info: PipeInfo, + device: torch.device, + group: dist.ProcessGroup | None = None, + ): + """ + Create a pipeline stage given a stage_module to be wrapped by this stage + and a `pipe_info` describing the stage relationship of the pipeline. + + Args: + stage_module (torch.nn.Module): the module to be wrapped by this stage + stage_index (int): the index of this stage in the pipeline + pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()` + device (torch.device): the device to be used by this stage + group (Optional[dist.ProcessGroup]): the process group to be used by this stage + """ + _PipelineStageBase.__init__( + self, + stage_module, + stage_index, + pipe_info.num_stages, + device, + group, + ) + self.pipe_info = pipe_info + + # Find stage nodes in graph + submod_nodes = [ + node for node in pipe_info.graph.nodes if node.op == "call_module" + ] + if len(submod_nodes) != self.num_stages: + raise AssertionError( + f"Number of submodules in pipe graph {len(submod_nodes)} does not match number of stages {self.num_stages}" + ) + + # Find my stage node in graph + self.node = submod_nodes[self.stage_index] + self.name = self.node.name + logger.info( + "[%s] Creating PipelineStage %s for %s", + self.group_rank, + stage_index, + self.name, + ) + + # Create mapping from stage name to stage index + self.submod_to_stage_index: dict[str, int] = {} + for i, node in enumerate(submod_nodes): + self.submod_to_stage_index.setdefault(node.name, i) + + # Cast submodule to device + self._move_submod_to_device() + + def _move_submod_to_device(self): + # Move submodule to indicated device if possible + # Note: we cannot move meta module to real devices because meta tensors + # do not support to() method. One needs to do an in-place tensor swap in + # that case. + has_meta_param = any( + isinstance(p, FakeTensor) or p.is_meta for p in self.submod.parameters() + ) + if has_meta_param: + logger.debug("%s Found meta parameters!", self.log_prefix) + else: + self.submod.to(self.device) + + def _prepare_forward_infra( + self, + num_microbatches: int, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + ) -> tuple[Any, ...]: + """ + Create send/recv infrastructures for activations (during forward) + """ + # TODO(whc) + # this method should be deleted once lazy buffer allocation is implemented + # for now, it ignores args/kwargs because it should not need to do shape inference + for chunk in range(num_microbatches): + self.args_recv_info[chunk] = self._create_act_recv_info() + + # Send info during forward for each activation + self.act_send_info = self._create_act_send_info() + return tuple() + + def get_stage_index_of_submod( + self, + submod_name: str, + ): + """ + Given a submodule name, return the stage index of the submodule. + """ + if submod_name not in self.submod_to_stage_index: + raise AssertionError(f"Stage id of {submod_name} not found") + + return self.submod_to_stage_index[submod_name] + + def _create_act_recv_info( + self, + ): + """ + Create a tuple of `_RecvInfo` for inputs to the stage. + """ + + def create_recv_tensor(placeholder, arg_node): + """ + Create a receive buffer for a placeholder. + """ + example_value = placeholder.meta["val"] + if arg_node.op == "placeholder": + # This is a root level placeholder, thus an input argument to the entire model. + # We are likely at stage 0, hence no need to create a receive buffer. + return _RootArgPlaceholder(example_value) + + # Figure out the source stage of this input + while arg_node.target is operator.getitem: + # If the input is a getitem, we need to go deeper + arg_node = arg_node.args[0] + + assert arg_node.op == "call_module", ( + f"Expecting call_module, got {arg_node.op}" + ) + src_stage = self.get_stage_index_of_submod(arg_node.name) + + # Create a receive buffer for this placeholder + logger.debug( + "%s Creating recv buffer for input '%s' : %s, %s", + self.log_prefix, + placeholder.name, + example_value.shape, + example_value.dtype, + ) + buffer = _make_tensor_from_meta(example_value, self.device) + # In case there is backward pass, set requires_grad for receive buffers + # before first forward + if self.has_backward: + buffer.requires_grad_(True) + + return _RecvInfo( + arg_node.name, + src_stage, + buffer, + ) + + args_recv_info: list[InputInfo] = [] + # Filter out placeholder nodes from `self.submod` (a GraphModule) + placeholders = filter( # type: ignore[var-annotated] + lambda node: node.op == "placeholder", # type: ignore[arg-type] + self.submod.graph.nodes, # type: ignore[arg-type,union-attr] + ) + # `placeholders` are nodes internal to submod. + # `self.node.args` are dependency nodes in the outer graph. + # The two are 1:1. + for placeholder, arg_node in zip(placeholders, self.node.args): + # Create a receive buffer for this placeholder + recv_info = create_recv_tensor(placeholder, arg_node) + args_recv_info.append(recv_info) + + logger.debug( + "%s Activation recv / args info: %s", self.log_prefix, args_recv_info + ) + # `args` is a Tuple, hence we will return a Tuple[InputInfo] + return tuple(args_recv_info) + + def find_dst_rank( + self, + user: fx.Node, + ) -> int | None: + """ + Find the destination rank of a `user` node. + If the `user` is not a submod, `None` may be returned. + """ + if user.op == "call_module": + # User is a stage (`call_module`) + return self.get_stage_index_of_submod(user.name) + else: + # - If user.op == "output": + # No need to send back to rank 0 + # - If user.target is stage_backward: + # No need to send assuming submod output is stored locally or + # should be re-calculated in case of activation checkpointing + return None + + def _create_act_send_info(self): + """ + Create a dict of send info for activations. + The dict is of the form: + { + output_index: [dst_rank_0, dst_rank_1, ...], + ... + } + where the list of `dst_rank`s covers the case where an output value may + be consumed by multiple stages. + """ + # Output index: List of receiver ranks + act_send_info: dict[int, list] = {} + out_idx = 0 + + for user in self.node.users: + if user.target is operator.getitem: + # Recursively find the real destination + gi_dsts = act_send_info.setdefault(out_idx, []) + for gi_user in user.users: + dst_rank = self.find_dst_rank(gi_user) + if dst_rank is not None: + gi_dsts.append(dst_rank) + # Next `getitem` will point to the next output index + out_idx += 1 + else: + # In case of single output value, `out_idx` will not increase + dsts = act_send_info.setdefault(out_idx, []) + dst_rank = self.find_dst_rank(user) + if dst_rank is not None: + dsts.append(dst_rank) + + output_node = self._get_output_node() + output_vals: tuple[torch.Tensor] = tuple( + v.meta["val"] for v in flatten_args(output_node.args) + ) + self._configure_outputs_meta(output_vals) + + logger.debug("%s Send info: %s", self.log_prefix, act_send_info) + return act_send_info + + def _get_output_node(self): + output_nodes = [node for node in self.submod.graph.nodes if node.op == "output"] # type: ignore[union-attr] + assert len(output_nodes) == 1 + output_node = output_nodes[0] + return output_node + + def _create_grad_recv_info( + self, + act_send_info: dict, + ) -> tuple[_RecvInfo, ...]: + """ + Create a tuple of `_RecvInfo` for gradients. + """ + # Dict[output_index, _RecvInfo] + grad_recv_info: dict[int, _RecvInfo] = {} + output_node = self._get_output_node() + + # The output node may take multiple args, meaning the submod having multiple output values. + output_vals = flatten_args(output_node.args) + + for out_idx, dst_list in act_send_info.items(): + if not dst_list: + # No actual receiver for activation so no grad coming back + continue + + output = output_vals[out_idx] + example_value = output.meta["val"] + logger.debug( + f"{self.log_prefix} Creating grad recv buffer for output {output.name} " # noqa: G004 + f": {example_value.shape}, {example_value.dtype}" + ) + + # TODO: otherwise needs grad accumulation + assert len(dst_list) == 1, "Backward of skip connections not supported yet" + grad_src = dst_list[0] + grad_recv_info[out_idx] = _RecvInfo( + f"{grad_src}", # noqa: G004 + grad_src, + _make_tensor_from_meta(example_value, self.device), + ) + + # Convert to tuple for convenience in get_ops and retrieve tensor + grad_recv_info_tuple = tuple(grad_recv_info.values()) + logger.debug("%s Grad recv info: %s", self.log_prefix, grad_recv_info_tuple) + return grad_recv_info_tuple + + +# A helper function to create a pipeline stage based on traced pipeline information +def build_stage( + stage_module: torch.nn.Module, + stage_index: int, + pipe_info: PipeInfo, + device: torch.device, + group: dist.ProcessGroup | None = None, +) -> _PipelineStage: + """ + Create a pipeline stage given a stage_module to be wrapped by this stage + and pipeline information. + + Args: + stage_module (torch.nn.Module): the module to be wrapped by this stage + stage_index (int): the index of this stage in the pipeline + pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()` + device (torch.device): the device to be used by this stage + group (Optional[dist.ProcessGroup]): the process group to be used by this stage + + Returns: + _PipelineStage: a pipeline stage that can run with `PipelineSchedules`. + """ + return _PipelineStage( + stage_module, + stage_index, + pipe_info, + device, + group, + ) + + +class PipelineStage(_PipelineStageBase): + """ + A class representing a pipeline stage in a pipeline parallelism setup. + + PipelineStage assumes sequential partitioning of the model, i.e. the model is split into chunks where outputs from + one chunk feed into inputs of the next chunk, with no skip connections. + + PipelineStage performs runtime shape/dtype inference automatically by propagating the outputs from stage0 to + stage1 and so forth, in linear order. To bypass shape inference, pass the `input_args` and `output_args` to each + PipelineStage instance. + + Args: + submodule (nn.Module): The PyTorch module wrapped by this stage. + stage_index (int): The ID of this stage. + num_stages (int): The total number of stages. + device (torch.device): The device where this stage is located. + input_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The input arguments for the submodule. + output_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The output arguments for the submodule. + group (dist.ProcessGroup, optional): The process group for distributed training. If None, default group. + dw_builder (Optional[Callable[[], Callable[..., None]]): If provided, dw_builder will build a new dw_runner function + that will the W action (input weights) for F, I, W (Fwd, Input, Weight) zero bubble schedules. + """ + + def __init__( + self, + submodule: nn.Module, + stage_index: int, + num_stages: int, + device: torch.device, + input_args: torch.Tensor | tuple[torch.Tensor, ...] | None = None, + output_args: torch.Tensor | tuple[torch.Tensor, ...] | None = None, + group: dist.ProcessGroup | None = None, + dw_builder: Callable[[], Callable[..., None]] | None = None, + ): + super().__init__(submodule, stage_index, num_stages, device, group, dw_builder) + self.inputs: list[torch.Tensor] | None = None + self.inputs_meta: tuple[torch.Tensor, ...] | None = None + # Note: inputs and submod should ideally be on meta device. We decided not to assert this (yet) because it + # might be breaking for existing users. + if input_args is None: + assert output_args is None, ( + "If specifying output_args, input_args must also be specified. " + "Otherwise, shape inference will be performed at runtime" + ) + else: + self.inputs_meta = ( + (input_args,) if isinstance(input_args, torch.Tensor) else input_args + ) + if output_args is None: + logger.warning( + "Deprecation warning: passing input_args and performing init-time shape inference is deprecated. " + "PipelineStage now supports runtime shape inference using the real inputs provided to schedule step(). " + "Either delete `input_args` arg to `PipelineStage` to opt-into runtime shape inference, " + "or additionally pass `output_args` to `PipelineStage` to fully override shape inference. " + ) + try: + with torch.no_grad(): + output_args = submodule(*self.inputs_meta) + output_args = tree_map_only( + torch.Tensor, lambda x: x.to("meta"), output_args + ) + except Exception as e: + raise RuntimeError( + "Failed to perform pipeline shape inference- are your inputs on the same device as your module?" + ) from e + assert output_args is not None, ( + "If passing input_args, also pass output_args to override shape inference" + ) + self._configure_outputs_meta( + (output_args,) if isinstance(output_args, torch.Tensor) else output_args + ) + + # these are the buffers used in backwards send/recv, they are allocated later + self.outputs_grad: list[torch.Tensor] = [] + + dbg_str = ( + f"Finished pipeline stage init, {self.stage_index=}, {self.is_first=}, " # noqa: G004 + f"{self.is_last=}, {self.num_stages=}, " + ) + if self.inputs_meta is not None: + dbg_str += ( + f"inputs: {[inp.shape for inp in self.inputs_meta]}, " + f"output: {[output.shape for output in self.get_outputs_meta()]}" + ) + else: + dbg_str += " running shape-inference at runtime" + + logger.debug(dbg_str) + + def _shape_inference( + self, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + ): + if kwargs is None: + kwargs = {} + assert args is not None, "Args may be an empty tuple but not None" + + # We skip recv communication if we're the first stage, but also if the previous stage is on the same rank + # and can pass its output shapes in as args instead of using send/recv. + if ( + self.is_first + # if not first stage, then check if prev stage is on the same rank + or self.stage_index_to_group_rank[self.stage_index - 1] == self.group_rank + ): + logger.debug( + "Shape inference: stage %s skipping recv, because shape info passed in via `args`", + self.stage_index, + ) + args = tree_map_only(torch.Tensor, lambda x: x.to("meta"), args) + else: + assert len(args) == 0, ( + "Can't supply input args for shape inference on non-first stage" + ) + objects = [None] + logger.debug( + "Shape inference: stage %s receiving from stage %s", + self.stage_index, + self.stage_index - 1, + ) + dist.recv_object_list( + objects, + src=dist.get_global_rank( + self.group or dist.distributed_c10d._get_default_group(), + self.stage_index_to_group_rank[self.stage_index - 1], + ), + group=self.group, + device=self.device, + use_batch=True, + ) + recv_args = objects[0] + assert isinstance(recv_args, tuple), type(recv_args) + args = recv_args + + # cache input shapes for use during recv buffer allocation + self.inputs_meta = args + args = tree_map_only( + torch.Tensor, lambda x: torch.zeros_like(x, device=self.device), args + ) + + # set attributes needed for forward + with torch.no_grad(): + outputs = self.submod(*args, **kwargs) + + # if single tensor, convert so it is always a list + if isinstance(outputs, torch.Tensor): + outputs = [outputs] + + # communicate meta outputs not real outputs for two reasons + # 1 - its faster (esp. since obj coll pickles tensor data!) + # 2 - avoid activating a cuda context for the src rank when unpickling on the recv end! + outputs_meta = tuple( + tree_map_only(torch.Tensor, lambda x: x.to("meta"), outputs) + ) + logger.debug( + "Shape inference: stage %s inputs %s, outputs %s", + self.stage_index, + self.inputs_meta, + outputs_meta, + ) + self._configure_outputs_meta(outputs_meta) + + # Passing outputs to the next stage: + # two cases- + # 1. Usually: use send/recv communication to pass the output + # 2. Special case: for V-schedules, 2 'adjacent' stages (e.g. stage 3, 4 in an 8-stage 4-rank V) + # pass their shape info via return value and function args rather than send/recv. + if ( + self.is_last + # if not last stage, then check if next stage is on the same rank + or self.stage_index_to_group_rank[self.stage_index + 1] == self.group_rank + ): + # Case (2) above: pass shape info via return value and caller passes it as args to next stage's + # _shape_inference call + logger.debug( + "Shape inference: stage %s skipping send to next stage", + self.stage_index, + ) + + else: + # Case (1): send shapes via send operation, and ensure not to return it to the caller + logger.debug( + "Shape inference: stage %s sending to stage %s", + self.stage_index, + self.stage_index + 1, + ) + dist.send_object_list( + [outputs_meta], + dst=dist.get_global_rank( + self.group or dist.distributed_c10d._get_default_group(), + self.stage_index_to_group_rank[self.stage_index + 1], + ), + group=self.group, + device=self.device, + use_batch=True, + ) + outputs_meta = tuple() + + return outputs_meta + + def _prepare_forward_infra( + self, + num_microbatches: int, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + ) -> tuple[Any, ...]: + # TODO move self.device to an argument from step API (from its input tensors)? + assert num_microbatches is not None, "TODO fix num_microbatches" + + outputs: tuple[Any, ...] = tuple() + if self.inputs_meta is None: + outputs = self._shape_inference(args, kwargs) + + assert self.inputs_meta is not None + # Receive info during forward + # TODO: create args_recv_info lazily? (same needed for PipelineStage) + for chunk_id in range(num_microbatches): + if not self.is_first: + # We assume that we always receive from stage - 1 + recv_infos = tuple( + _RecvInfo( + f"recv_for_{self.stage_index}_from_{self.stage_index - 1}", + self.stage_index - 1, + _make_tensor_from_meta(inp, self.device), + ) + for inp in self.inputs_meta + ) + # In case there is backward pass, set requires_grad for receive buffers + if self.has_backward: + for r in recv_infos: + r.buffer.requires_grad_(True) + + self.args_recv_info[chunk_id] = recv_infos + else: + self.args_recv_info[chunk_id] = tuple( + _RootArgPlaceholder(i) for i in self.inputs_meta + ) + + # Send info during forward for each activation + # only need the rank that is being sent to + self.act_send_info: dict[int, list] = {} + + for idx in range(len(self.get_outputs_meta())): + # We assume we always send to stage + 1 + if not self.is_last: + self.act_send_info[idx] = [self.stage_index + 1] + else: + self.act_send_info[idx] = [] + + return outputs + + def _create_grad_recv_info( + self, + act_send_info: dict, + ) -> tuple[_RecvInfo, ...]: + grad_recv_info: tuple[_RecvInfo, ...] = () + if not self.is_last: + # Receiving gradients from multiple sources is not supported + # hence we only take the first destination + grad_recv_info = tuple( + _RecvInfo( + f"recv_grad_for_{self.stage_index}_from_{dst_list[0]}", + dst_list[0], + _make_tensor_from_meta(self.get_outputs_meta()[idx], self.device), + ) + for idx, dst_list in act_send_info.items() + ) + return grad_recv_info diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..adf901d6b6e3e693f69464e5c64d58a857ae6014 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__init__.py @@ -0,0 +1,257 @@ +# mypy: allow-untyped-defs +import logging +import os +import threading +import warnings +from collections.abc import Generator +from datetime import timedelta +from urllib.parse import urlparse + +import torch +import torch.distributed as dist + + +__all__ = ["is_available"] + + +logger = logging.getLogger(__name__) + + +_init_counter = 0 +_init_counter_lock = threading.Lock() + + +def is_available() -> bool: + return hasattr(torch._C, "_rpc_init") + + +if is_available() and not torch._C._rpc_init(): + raise RuntimeError("Failed to initialize torch.distributed.rpc") + + +if is_available(): + _is_tensorpipe_available = hasattr( + torch._C._distributed_rpc, "_TensorPipeRpcBackendOptionsBase" + ) + + import numbers + + import torch.distributed.autograd as dist_autograd + from torch._C._distributed_c10d import Store + from torch._C._distributed_rpc import ( # noqa: F401 + _cleanup_python_rpc_handler, + _DEFAULT_INIT_METHOD, + _DEFAULT_RPC_TIMEOUT_SEC, + _delete_all_user_and_unforked_owner_rrefs, + _destroy_rref_context, + _disable_jit_rref_pickle, + _disable_server_process_global_profiler, + _enable_jit_rref_pickle, + _enable_server_process_global_profiler, + _get_current_rpc_agent, + _invoke_remote_builtin, + _invoke_remote_python_udf, + _invoke_remote_torchscript, + _invoke_rpc_builtin, + _invoke_rpc_python_udf, + _invoke_rpc_torchscript, + _is_current_rpc_agent_set, + _reset_current_rpc_agent, + _rref_context_get_debug_info, + _set_and_start_rpc_agent, + _set_profiler_node_id, + _set_rpc_timeout, + _UNSET_RPC_TIMEOUT, + enable_gil_profiling, + get_rpc_timeout, + PyRRef, + RemoteProfilerManager, + RpcAgent, + RpcBackendOptions, + WorkerInfo, + ) + + if _is_tensorpipe_available: + from torch._C._distributed_rpc import ( # noqa: F401 + _DEFAULT_NUM_WORKER_THREADS, + _TensorPipeRpcBackendOptionsBase, + TensorPipeAgent, + ) + + from . import api, backend_registry, functions + from .api import * # noqa: F401,F403 + from .backend_registry import BackendType + from .options import TensorPipeRpcBackendOptions # noqa: F401 + from .server_process_global_profiler import _server_process_global_profile + + rendezvous_iterator: Generator[tuple[Store, int, int], None, None] + + __all__ += ["init_rpc", "BackendType", "TensorPipeRpcBackendOptions"] + __all__ = __all__ + api.__all__ + backend_registry.__all__ # noqa: PLE0605 + + def init_rpc( + name, + backend=None, + rank=-1, + world_size=None, + rpc_backend_options=None, + ): + r""" + Initializes RPC primitives such as the local RPC agent + and distributed autograd, which immediately makes the current + process ready to send and receive RPCs. + + Args: + name (str): a globally unique name of this node. (e.g., + ``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``) + Name can only contain number, alphabet, underscore, colon, + and/or dash, and must be shorter than 128 characters. + backend (BackendType, optional): The type of RPC backend + implementation. Supported values is + ``BackendType.TENSORPIPE`` (the default). + See :ref:`rpc-backends` for more information. + rank (int): a globally unique id/rank of this node. + world_size (int): The number of workers in the group. + rpc_backend_options (RpcBackendOptions, optional): The options + passed to the RpcAgent constructor. It must be an agent-specific + subclass of :class:`~torch.distributed.rpc.RpcBackendOptions` + and contains agent-specific initialization configurations. By + default, for all agents, it sets the default timeout to 60 + seconds and performs the rendezvous with an underlying process + group initialized using ``init_method = "env://"``, + meaning that environment variables ``MASTER_ADDR`` and + ``MASTER_PORT`` need to be set properly. See + :ref:`rpc-backends` for more information and find which options + are available. + """ + torch._C._log_api_usage_once("torch.distributed.init_rpc") + if backend is not None and not isinstance( + backend, backend_registry.BackendType + ): + raise TypeError("Argument backend must be a member of BackendType") + + if rpc_backend_options is not None and not isinstance( + rpc_backend_options, RpcBackendOptions + ): + raise TypeError( + "Argument rpc_backend_options must be an instance of RpcBackendOptions" + ) + + # Try to detect the backend from the options + if backend is None and rpc_backend_options is not None: + for candidate_backend in BackendType: + if isinstance( + rpc_backend_options, + type( + backend_registry.construct_rpc_backend_options( + candidate_backend + ) + ), + ): + backend = candidate_backend + break + else: + raise TypeError( + f"Could not infer backend for options {rpc_backend_options}" + ) + # Ignore type error because mypy doesn't handle dynamically generated type objects (#4865) + if backend != BackendType.TENSORPIPE: # type: ignore[attr-defined] + logger.warning( + "RPC was initialized with no explicit backend but with options " # type: ignore[attr-defined] + "corresponding to %(backend)s, hence that backend will be used " + "instead of the default BackendType.TENSORPIPE. To silence this " + "warning pass `backend=%(backend)s` explicitly.", + {"backend": backend}, + ) + + if backend is None: + backend = BackendType.TENSORPIPE # type: ignore[attr-defined] + + if rpc_backend_options is None: + # default construct a set of RPC backend options. + rpc_backend_options = backend_registry.construct_rpc_backend_options( + backend + ) + + # Create store, performs rendezvous for static RPC group. + if not world_size: + # If world_size is not set in construction and also not set in environment variables + # The store will be created for the dynamic group setting + store = dist._create_store_from_options(rpc_backend_options, rank) + else: + # This rendezvous state sometimes is destroyed before all processes + # finishing handshaking. To avoid that issue, we make it global to + # keep it alive. + global rendezvous_iterator + rendezvous_iterator = dist.rendezvous( + rpc_backend_options.init_method, rank=rank, world_size=world_size + ) + store, _, _ = next(rendezvous_iterator) + # Use same timeout as RPC. + store.set_timeout(timedelta(seconds=rpc_backend_options.rpc_timeout)) + + # Use a PrefixStore to distinguish multiple invocations. + with _init_counter_lock: + global _init_counter + store = dist.PrefixStore(str(f"rpc_prefix_{_init_counter}"), store) + _init_counter += 1 + + # Initialize autograd before RPC since _init_rpc_backend guarantees all + # processes sync via the store. If we initialize autograd after RPC, + # there could be a race where some nodes might have initialized autograd + # and others might not have. As a result, a node calling + # torch.distributed.autograd.backward() would run into errors since + # other nodes might not have been initialized. + dist_autograd._init(rank) + + _set_profiler_node_id(rank) + # Initialize RPC. + _init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options) + + def _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options): + type_mapping = { + backend: backend_registry.BackendType, + store: dist.Store, + name: str, + rank: numbers.Integral, + # world_size can be None for a dynamic group + world_size: (numbers.Integral, type(None)), + rpc_backend_options: RpcBackendOptions, + } + for arg, arg_type in type_mapping.items(): + if not isinstance(arg, arg_type): # type: ignore[arg-type] + raise RuntimeError( + f"Argument {arg} must be of type {arg_type} but got type {type(arg)}" + ) + + def _init_rpc_backend( + backend=BackendType.TENSORPIPE, # type: ignore[attr-defined] + store=None, + name=None, + rank=-1, + world_size=None, + rpc_backend_options=None, + ): + _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options) + + if _is_current_rpc_agent_set(): + raise RuntimeError("RPC is already initialized") + + # Initialize RPC. + rpc_agent = backend_registry.init_backend( + backend, + store=store, + name=name, + rank=rank, + world_size=world_size, + rpc_backend_options=rpc_backend_options, + ) + + api._init_rpc_states(rpc_agent) + + @api._require_initialized + def _get_debug_info(): + info = _rref_context_get_debug_info() + info.update(api._get_current_rpc_agent().get_debug_info()) + info.update(dist_autograd._get_debug_info()) + return info diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40c5d8d8ec20d3bf46dc4f9a77da712b002da8f2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67b7fdf07ed7e7a829ad00c061192c0119fcb791 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/api.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee144f4e3155e32d1184c6f1368a288c05d1295a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/api.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/backend_registry.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/backend_registry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af9b266a48b9b2e72e45209e76a94a10c76bc559 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/backend_registry.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/constants.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/constants.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc50d80700f5dde44e032f63521d8c30008ba7b8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/constants.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/functions.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b6128b2ba581a487bc5091bdeeecf72dce6e78b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/functions.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/internal.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/internal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd3a39f55ce85fe30b86f9e3be2f365a32d7028d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/internal.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/options.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/options.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dbc3756c353692e3dcc15b5f8edfbaca0609249 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/options.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/rref_proxy.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/rref_proxy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a1cb9d594e06a3eaf0387b753dbf661644cf42d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/rref_proxy.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/server_process_global_profiler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/server_process_global_profiler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..240f65245c2164a45259526fafe9653383b147ff Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/server_process_global_profiler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/_testing/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/_testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0abd737becafbae33b0b63799c1eb43c913e1998 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/_testing/__init__.py @@ -0,0 +1,18 @@ +import torch + + +def is_available() -> bool: + return hasattr(torch._C, "_faulty_agent_init") + + +if is_available() and not torch._C._faulty_agent_init(): + raise RuntimeError("Failed to initialize torch.distributed.rpc._testing") + +if is_available(): + # Registers FAULTY_TENSORPIPE RPC backend. + from torch._C._distributed_rpc_testing import ( + FaultyTensorPipeAgent, + FaultyTensorPipeRpcBackendOptions, + ) + + from . import faulty_agent_backend_registry diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/_testing/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/_testing/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e057ae3fca4ec4bd8d811544914ffb6051cf50f3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/_testing/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/_testing/__pycache__/faulty_agent_backend_registry.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/_testing/__pycache__/faulty_agent_backend_registry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39be0ca053011312aea3ae1ca1ceb36c3669698a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/_testing/__pycache__/faulty_agent_backend_registry.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..d04882e16e79a94f74ddc1350e94f547ef625611 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +import torch.distributed as dist +import torch.distributed.rpc as rpc + + +def _faulty_tensorpipe_construct_rpc_backend_options_handler( + rpc_timeout, + init_method, + num_worker_threads, + messages_to_fail, + messages_to_delay, + num_fail_sends, + **kwargs, +): + from . import FaultyTensorPipeRpcBackendOptions + + return FaultyTensorPipeRpcBackendOptions( + num_worker_threads=num_worker_threads, + rpc_timeout=rpc_timeout, + init_method=init_method, + messages_to_fail=messages_to_fail, + messages_to_delay=messages_to_delay, + num_fail_sends=num_fail_sends, + ) + + +def _faulty_tensorpipe_init_backend_handler( + store, name, rank, world_size, rpc_backend_options +): + from torch.distributed.rpc import api + + from . import FaultyTensorPipeAgent, FaultyTensorPipeRpcBackendOptions + + if not isinstance(store, dist.Store): + raise TypeError(f"`store` must be a c10d::Store. {store}") + + if not isinstance(rpc_backend_options, FaultyTensorPipeRpcBackendOptions): + raise TypeError( + f"`rpc_backend_options` must be a `FaultyTensorPipeRpcBackendOptions`. {rpc_backend_options}" + ) + + agent = FaultyTensorPipeAgent( + store, + name, + rank, + world_size, + rpc_backend_options, + {}, # reverse_device_map + [], # devices + ) + api._init_rpc_states(agent) + + return agent + + +rpc.backend_registry.register_backend( + "FAULTY_TENSORPIPE", + _faulty_tensorpipe_construct_rpc_backend_options_handler, + _faulty_tensorpipe_init_backend_handler, +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a0021ff1e43d8653df457cb99e7ea3637a508851 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/_utils.py @@ -0,0 +1,47 @@ +# mypy: allow-untyped-defs +import logging +from contextlib import contextmanager +from typing import cast + + +logger = logging.getLogger(__name__) + + +@contextmanager +def _group_membership_management(store, name, is_join): + token_key = "RpcGroupManagementToken" + join_or_leave = "join" if is_join else "leave" + my_token = f"Token_for_{name}_{join_or_leave}" + while True: + # Retrieve token from store to signal start of rank join/leave critical section + returned = store.compare_set(token_key, "", my_token).decode() + if returned == my_token: + # Yield to the function this context manager wraps + yield + # Finished, now exit and release token + # Update from store to signal end of rank join/leave critical section + store.set(token_key, "") + # Other will wait for this token to be set before they execute + store.set(my_token, "Done") + break + else: + # Store will wait for the token to be released + try: + store.wait([returned]) + except RuntimeError: + logger.error( + "Group membership token %s timed out waiting for %s to be released.", + my_token, + returned, + ) + raise + + +def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join): + from . import api, TensorPipeAgent + + agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) + ret = agent._update_group_membership( + worker_info, my_devices, reverse_device_map, is_join + ) + return ret diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/api.py new file mode 100644 index 0000000000000000000000000000000000000000..845ce0b7faf6c4cb1390c4d7089f745a1861f335 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/api.py @@ -0,0 +1,965 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs + +import collections +import contextlib +import functools +import inspect +import logging +import threading +from typing import Any, Generic, TYPE_CHECKING, TypeVar + +import torch +from torch._C._distributed_rpc import ( + _cleanup_python_rpc_handler, + _delete_all_user_and_unforked_owner_rrefs, + _destroy_rref_context, + _get_current_rpc_agent, + _invoke_remote_builtin, + _invoke_remote_python_udf, + _invoke_remote_torchscript, + _invoke_rpc_builtin, + _invoke_rpc_python_udf, + _invoke_rpc_torchscript, + _is_current_rpc_agent_set, + _reset_current_rpc_agent, + _set_and_start_rpc_agent, + get_rpc_timeout, + PyRRef, + RemoteProfilerManager, + WorkerInfo, +) +from torch.futures import Future + +from ._utils import _group_membership_management, _update_group_membership +from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT +from .internal import ( + _build_rpc_profiling_key, + _internal_rpc_pickler, + PythonUDF, + RPCExecMode, +) + + +__all__ = [ + "shutdown", + "get_worker_info", + "remote", + "rpc_sync", + "rpc_async", + "RRef", + "AllGatherStates", + "method_factory", + "new_method", +] + + +logger = logging.getLogger(__name__) + +# NB: Ignoring RRef leaks during shutdown. Without this, applications have to +# make sure there is no references to any RRef in the application code and +# Python GC has done its job to delete those RRefs. This is could result in bad +# debugging experiences especially when for large applications. Therefore, by +# default, we are going to ignore RRef leaks during shutdown. This is usually +# fine as shutdown means applications have done training and no longer care +# about states. +# +# To enable RRef leak checking, set this _ignore_rref_leak to False +_ignore_rref_leak = True +_default_pickler = _internal_rpc_pickler + + +@contextlib.contextmanager +def _use_rpc_pickler(rpc_pickler): + r""" + rpc_pickler: (.internal._InternalRPCPickler) Overrides the default RPC pickler + """ + global _default_pickler + _default_pickler = rpc_pickler + try: + yield + finally: + _default_pickler = _internal_rpc_pickler + + +def _require_initialized(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not _is_current_rpc_agent_set(): + raise RuntimeError( + "RPC has not been initialized. Call " + "torch.distributed.rpc.init_rpc first." + ) + return func(*args, **kwargs) + + return wrapper + + +class AllGatherStates: + def __init__(self): + # Each `gathered_objects` is an empty dict at beginning. + # The leader worker is elected as the first worker in a sorted worker + # name list. Whenever there is a worker entering `_all_gather()`, it + # runs `_gather_to_leader()` on the leader to add its own name and + # data obj to this dict. The leader also adds itself's name to the dict + # on calling `_all_gather()`. + # Once `set(gathered_objects.keys()) == _ALL_WORKER_NAMES`, the leader + # will broadcast the gathered dict to all follower workers and set their + # `gathered_objects` field and the `proceed_signal` field. + self.gathered_objects = {} + # All workers wait on this signal until it receives all gathered + # objects. + self.proceed_signal = threading.Event() + + +# States used by `def _all_gather()`. +# `_ALL_WORKER_NAMES` is initialized on initializing RPC layer. +_ALL_WORKER_NAMES: set[Any] = set() +_all_gather_dict_lock = threading.RLock() +_all_gather_sequence_id: dict[str, int] = {} +_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict( + AllGatherStates +) + + +def _init_rpc_states(agent): + worker_infos = agent.get_worker_infos() + global _ALL_WORKER_NAMES + _ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos} + + # NB: backend implementation might have already set the rpc_agent. + if not _is_current_rpc_agent_set(): + _set_and_start_rpc_agent(agent) + + +def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None): + with _all_gather_dict_lock: + if not worker_names: + worker_names = _ALL_WORKER_NAMES + assert worker_name in worker_names, ( + f"{worker_name} is not expected by leader." + ) + states = _all_gather_sequence_id_to_states[sequence_id] + assert worker_name not in states.gathered_objects, ( + f"{worker_name} reported intent sequence id {sequence_id} twice. " + ) + states.gathered_objects[worker_name] = obj + if worker_names == set(states.gathered_objects.keys()): + states.proceed_signal.set() + + +def _broadcast_to_followers(sequence_id, objects_map): + with _all_gather_dict_lock: + states = _all_gather_sequence_id_to_states[sequence_id] + + assert not states.proceed_signal.is_set(), ( + f"Termination signal sequence id {sequence_id} got set twice." + ) + states.gathered_objects = objects_map + states.proceed_signal.set() + + +_thread_local_var = threading.local() + + +@contextlib.contextmanager +def _wait_all(): + r""" + A context manager that collects all futures returned by ``rpc_async`` and + waits them on the context manager's exit; relieving the user of needing + to explicitly call wait. + + + Example:: + >>> # xdoctest: +SKIP("distributed") + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> with rpc._wait_all(): + >>> fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) + >>> fut_2 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) + >>> #fut_1 and fut_2 are waited on + """ + _thread_local_var.future_list = [] + try: + yield + finally: + try: + torch.futures.wait_all(_thread_local_var.future_list) + finally: + del _thread_local_var.future_list + + +@_require_initialized +def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT): + r""" + This is similar to torch.distributed.all_gather(), but is using RPC. It + picks the worker with the smallest name (alphabetic order) as the leader. + Then all followers send their data ``obj`` to the leader. After the leader + has received all, it will broadcast the results back to all followers. This + function blocks until all workers have received the gathered results. + """ + if not worker_names: + assert _ALL_WORKER_NAMES is not None, ( + "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`." + ) + worker_names = _ALL_WORKER_NAMES + leader_name = min(worker_names) + + self_name = _get_current_rpc_agent().get_worker_info().name + + with _all_gather_dict_lock: + concat_names = "".join(sorted(worker_names)) + sequence_num = _all_gather_sequence_id.get(concat_names, 0) + _all_gather_sequence_id[concat_names] = sequence_num + 1 + sequence_id = concat_names + str(sequence_num) + + is_leader = leader_name == self_name + + if timeout == UNSET_RPC_TIMEOUT: + # Timeout is specified by agent for RPC calls + rpc_timeout = get_rpc_timeout() + # No timeout for signal + signal_timeout = None + elif timeout == DEFAULT_SHUTDOWN_TIMEOUT: + # No timeout for RPC + rpc_timeout = timeout + # No timeout for signal + signal_timeout = None + else: + # Signal and RPC timeout use the same timeout + signal_timeout = rpc_timeout = timeout + + # Phase 1: Followers send it's object to the leader + if is_leader: + _gather_to_leader(sequence_id, self_name, obj, worker_names) + else: + rpc_sync( + leader_name, + _gather_to_leader, + args=(sequence_id, self_name, obj, worker_names), + timeout=rpc_timeout, + ) + + with _all_gather_dict_lock: + states = _all_gather_sequence_id_to_states[sequence_id] + + # Timeout is either set by function parameter or None (which is indefinite) + states.proceed_signal.wait(timeout=signal_timeout) + + # Phase 2: Leader broadcast gathered results to all followers + # Leader's signal is the first to be unblocked, after receiving all + # followers' data objects. + if is_leader: + worker_name_to_response_future_dict = {} + for follower_name in worker_names - {leader_name}: + fut = rpc_async( + follower_name, + _broadcast_to_followers, + args=(sequence_id, states.gathered_objects), + timeout=rpc_timeout, + ) + worker_name_to_response_future_dict[follower_name] = fut + + errors = [] + for follower_name, fut in worker_name_to_response_future_dict.items(): + try: + fut.wait() + except RuntimeError as ex: + errors.append((follower_name, ex)) + + if errors: + raise RuntimeError( + f"Followers {[e[0] for e in errors]} timed out in _all_gather " + f"after {rpc_timeout:.2f} seconds. The first exception is {errors[0][1]}" + ) + + # Clean up for the states using the sequence_id + with _all_gather_dict_lock: + states = _all_gather_sequence_id_to_states.pop(sequence_id) + return states.gathered_objects + + +@_require_initialized +def _barrier(worker_names): + r""" + Synchronizes local and remote RPC processes. + + This will block until all local and remote RPC processes specified under worker_names + reach this method to wait for all outstanding work to complete. + + Args: + worker_names (List[str]): The set of workers to synchronize. + + """ + try: + _all_gather(None, set(worker_names)) + except RuntimeError: + logger.exception("Failed to complete barrier") + + +@_require_initialized +def _wait_all_workers(timeout=DEFAULT_SHUTDOWN_TIMEOUT): + r""" + Block until all local and remote RPC processes reach this method and wait + for all outstanding work to complete. Every RPC process must call this + method before exit to perform a graceful shutdown. This should be used to + terminate the RPC framework, and there is no guarantee that the RPC + framework will work after this method returns. + """ + try: + _all_gather(None, timeout=timeout) + except RuntimeError as ex: + logger.exception("Failed to respond to 'Shutdown Proceed' in time") + raise ex + + +@_require_initialized +def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT): + r""" + Perform a shutdown of the RPC agent, and then destroy the RPC agent. This + stops the local agent from accepting outstanding requests, and shuts + down the RPC framework by terminating all RPC threads. If ``graceful=True``, + this will block until all local and remote RPC processes reach this method + and wait for all outstanding work to complete. Otherwise, if + ``graceful=False``, this is a local shutdown, and it does not wait for other + RPC processes to reach this method. + + .. warning:: + For :class:`~torch.futures.Future` objects returned by + :meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not + be called after ``shutdown()``. + + Args: + graceful (bool): Whether to do a graceful shutdown or not. If True, + this will 1) wait until there is no pending system + messages for ``UserRRefs`` and delete them; 2) block + until all local and remote RPC processes have reached + this method and wait for all outstanding work to + complete. + + Example:: + Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly + on both workers. Refer to :meth:`~torch.distributed.init_process_group` + API for more details. For example, + + export MASTER_ADDR=localhost + export MASTER_PORT=5678 + + Then run the following code in two different processes: + + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> # do some work + >>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1)) + >>> # ready to shutdown + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> # wait for worker 0 to finish work, and then shutdown. + >>> rpc.shutdown() + """ + if graceful: + try: + agent = _get_current_rpc_agent() + from torch._C._distributed_rpc import TensorPipeAgent + + if not isinstance(agent, TensorPipeAgent) or agent.is_static_group: + _wait_all_workers(timeout) + _delete_all_user_and_unforked_owner_rrefs() + agent.join(shutdown=True, timeout=timeout) + else: + # This is a dynamic group so we need to grab the token for the operation + my_worker_info = agent.get_worker_info() + my_name = my_worker_info.name + with _group_membership_management(agent.store, my_name, False): + all_worker_infos = agent.get_worker_infos() + for worker in all_worker_infos: + if worker.name != my_name: + rpc_sync( + worker.name, + _update_group_membership, + args=(my_worker_info, [], {}, False), + ) + agent.join(shutdown=True, timeout=timeout) + finally: + # In case of errors, continue to complete the local shutdown. + _finalize_shutdown() + else: + _finalize_shutdown() + + +def _finalize_shutdown(): + try: + # This raises a `TORCH_CHECK()` exception on RRef leak detected. + _destroy_rref_context(_ignore_rref_leak) + finally: + _get_current_rpc_agent().shutdown() + # clean up python rpc handler in shutdown(), see comments in + # PythonRpcHandler::cleanup(), call it in python API because the + # cleanup() function has python dependency, it assumes python + # interpreter exists. + # No matter if RRef leak exception is raised, this clean-up code + # must run to avoid destruction segfault in Python 3.5. + # + # future.wait() should not be called after shutdown(). + # pythonRpcHandler is cleaned up in shutdown(), after + # shutdown(), python objects returned from rpc python call can not be + # resolved. + _cleanup_python_rpc_handler() + _reset_current_rpc_agent() + + +@_require_initialized +def get_worker_info(worker_name=None): + r""" + Get :class:`~torch.distributed.rpc.WorkerInfo` of a given worker name. + Use this :class:`~torch.distributed.rpc.WorkerInfo` to avoid passing an + expensive string on every invocation. + + Args: + worker_name (str): the string name of a worker. If ``None``, return the + the id of the current worker. (default ``None``) + + Returns: + :class:`~torch.distributed.rpc.WorkerInfo` instance for the given + ``worker_name`` or :class:`~torch.distributed.rpc.WorkerInfo` of the + current worker if ``worker_name`` is ``None``. + """ + if worker_name is not None: + return _get_current_rpc_agent().get_worker_info(worker_name) + else: + return _get_current_rpc_agent().get_worker_info() + + +def _to_worker_info(to): + if isinstance(to, WorkerInfo): + return to + elif isinstance(to, (str, int)): + return get_worker_info(to) + else: + raise ValueError(f"Cannot get WorkerInfo from name {to}") + + +def _rref_typeof_on_owner(rref, blocking: bool = True): + rref_type = type(rref.local_value()) + if blocking: + return rref_type + else: + # Wrap result into a completed Future. This is so that if blocking=`False` + # is specified, we return a future regardless of if this call is on user + # or owner. + future = Future[type]() + future.set_result(rref_type) + return future + + +def _rref_typeof_on_user( + rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True +): + fut = rpc_async(rref.owner(), _rref_typeof_on_owner, args=(rref,), timeout=timeout) + if blocking: + return fut.wait() + else: + return fut + + +T = TypeVar("T") +# pyrefly: ignore [invalid-annotation] +GenericWithOneTypeVar = Generic[T] + + +if TYPE_CHECKING: + + class RRef(PyRRef[T], Generic[T]): + pass + +else: + try: + # Combine the implementation class and the type class. + class RRef(PyRRef, Generic[T]): + pass + + except TypeError: + # TypeError: metaclass conflict: the metaclass of a derived class + # must be a (non-strict) subclass of the metaclasses of all its bases + # Mypy doesn't understand __class__ (mypy bug #4177) + class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore[name-defined, misc, valid-type] + pass + + # Combine the implementation class and the type class. + # Types for classes expecting a certain generic parameter (mypy bug #7791) + class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore[misc, no-redef, valid-type] + pass + + +# Install docstrings from `PyRRef` to `RRef`. +# +# This is for the fact that pybind11 generates the parameter +# `self` as type `rpc.PyRRef`, so a `:inherited-members:` +# under `.. autoclass:: RRef` does not work. +# we have to do the following process to replace `rpc.PyRRef` with `rpc.RRef`. +# +def method_factory(method_name, docstring): + def method(self, *args, **kwargs): + return getattr(super(RRef, self), method_name)(*args, **kwargs) + + if method.__doc__: + method.__doc__ = docstring + return method + + +for method_name, method in inspect.getmembers(PyRRef): + # Ignore magic methods, except "__str__". + if method_name.startswith("_") and method_name != "__str__": + continue + + # Get pybind11 generated docstring. + # It's like, + """ + to_here(self: torch.distributed.rpc.PyRRef, timeout: float=-1.0) -> object + + Blocking call that copies the value of the RRef from the owner + to the local node and returns it. If the current node is the + owner, returns a reference to the local value. + """ + docstring = getattr(method, "__doc__", None) + assert docstring is not None, "RRef user-facing methods should all have docstrings." + + # Do surgery on pybind11 generated docstrings. + docstring = docstring.replace( + "torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef" + ) + + # Attach user-facing RRef method with modified docstring. + new_method = method_factory(method_name, docstring) + setattr(RRef, method_name, new_method) + + +@_require_initialized +def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): + r""" + Make a remote call to run ``func`` on worker ``to`` and return an + :class:`~torch.distributed.rpc.RRef` to the result value immediately. + Worker ``to`` will be the owner of the returned + :class:`~torch.distributed.rpc.RRef`, and the worker calling ``remote`` is + a user. The owner manages the global reference count of its + :class:`~torch.distributed.rpc.RRef`, and the owner + :class:`~torch.distributed.rpc.RRef` is only destructed when globally there + are no living references to it. + + Args: + to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. + func (Callable): a callable function, such as Python callables, builtin + operators (e.g. :meth:`~torch.add`) and annotated + TorchScript functions. + args (tuple): the argument tuple for the ``func`` invocation. + kwargs (dict): is a dictionary of keyword arguments for the ``func`` + invocation. + + timeout (float, optional): timeout in seconds for this remote call. If the + creation of this + :class:`~torch.distributed.rpc.RRef` on worker + ``to`` is not successfully processed on this + worker within this timeout, then the next time + there is an attempt to use the RRef (such as + ``to_here()``), a timeout will be raised + indicating this failure. A value of 0 indicates + an infinite timeout, i.e. a timeout error will + never be raised. If not provided, the default + value set during initialization or with + ``_set_rpc_timeout`` is used. + + Returns: + A user :class:`~torch.distributed.rpc.RRef` instance to the result + value. Use the blocking API :meth:`torch.distributed.rpc.RRef.to_here` + to retrieve the result value locally. + + .. warning :: + The ``remote`` API does not copy storages of argument tensors until + sending them over the wire, which could be done by a different thread + depending on the RPC backend type. The caller should make sure that the + contents of those tensors stay intact until the returned RRef is + confirmed by the owner, which can be checked using the + :meth:`torch.distributed.rpc.RRef.confirmed_by_owner` API. + + .. warning :: + Errors such as timeouts for the ``remote`` API are handled on a + best-effort basis. This means that when remote calls initiated by + ``remote`` fail, such as with a timeout error, we take a best-effort + approach to error handling. This means that errors are handled and set + on the resulting RRef on an asynchronous basis. If the RRef has not been + used by the application before this handling (such as ``to_here`` or + fork call), then future uses of the ``RRef`` will appropriately raise + errors. However, it is possible that the user application will use the + ``RRef`` before the errors are handled. In this case, errors may not be + raised as they have not yet been handled. + + Example:: + + Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly + on both workers. Refer to :meth:`~torch.distributed.init_process_group` + API for more details. For example, + + export MASTER_ADDR=localhost + export MASTER_PORT=5678 + + Then run the following code in two different processes: + + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) + >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) + >>> x = rref1.to_here() + rref2.to_here() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + Below is an example of running a TorchScript function using RPC. + + >>> # On both workers: + >>> @torch.jit.script + >>> def my_script_add(tensor: torch.Tensor, scalar: int): + >>> return torch.add(tensor, scalar) + + >>> # On worker 0: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> rref = rpc.remote("worker1", my_script_add, args=(torch.ones(2), 3)) + >>> rref.to_here() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + """ + torch._C._log_api_usage_once("torch.distributed.rpc_remote") + qualified_name = torch.jit._builtins._find_builtin(func) + dst_worker_info = _to_worker_info(to) + should_profile = _get_should_profile() + + ctx_manager = _enable_rpc_profiler( + should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info + ) + + with ctx_manager as rf: + args = args if args else () + kwargs = kwargs if kwargs else {} + + is_async_exec = hasattr(func, "_wrapped_async_rpc_function") + + if is_async_exec: + wrapped = func._wrapped_async_rpc_function + if isinstance(wrapped, torch.jit.ScriptFunction): + func = wrapped + + if qualified_name is not None: + rref = _invoke_remote_builtin( + dst_worker_info, qualified_name, timeout, *args, **kwargs + ) + elif isinstance(func, torch.jit.ScriptFunction): + rref = _invoke_remote_torchscript( + dst_worker_info.name, + torch._jit_internal._qualified_name(func), + timeout, + is_async_exec, + *args, + **kwargs, + ) + else: + (pickled_python_udf, tensors) = _default_pickler.serialize( + PythonUDF(func, args, kwargs) + ) + rref = _invoke_remote_python_udf( + dst_worker_info, pickled_python_udf, tensors, timeout, is_async_exec + ) + # attach profiling information + if should_profile: + assert torch.autograd._profiler_enabled() + assert rf is not None + fut = rf._call_end_callbacks_on_future(rref._get_future()) + rref._set_profiling_future(fut) + + return rref + + +def _invoke_rpc( + to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT +): + if not callable(func): + raise TypeError("function should be callable.") + + qualified_name = torch.jit._builtins._find_builtin(func) + dst_worker_info = _to_worker_info(to) + + should_profile = _get_should_profile() + + ctx_manager = _enable_rpc_profiler( + should_profile, qualified_name, func, rpc_type, dst_worker_info + ) + + with ctx_manager as rf: + args = args if args else () + kwargs = kwargs if kwargs else {} + + is_async_exec = hasattr(func, "_wrapped_async_rpc_function") + + if is_async_exec: + # pyrefly: ignore [missing-attribute] + wrapped = func._wrapped_async_rpc_function + if isinstance(wrapped, torch.jit.ScriptFunction): + func = wrapped + + if qualified_name is not None: + fut = _invoke_rpc_builtin( + dst_worker_info, qualified_name, rpc_timeout, *args, **kwargs + ) + elif isinstance(func, torch.jit.ScriptFunction): + fut = _invoke_rpc_torchscript( + dst_worker_info.name, + torch._jit_internal._qualified_name(func), + args, + kwargs, + rpc_timeout, + is_async_exec, + ) + else: + (pickled_python_udf, tensors) = _default_pickler.serialize( + PythonUDF(func, args, kwargs) + ) + fut = _invoke_rpc_python_udf( + dst_worker_info, pickled_python_udf, tensors, rpc_timeout, is_async_exec + ) + if should_profile: + assert torch.autograd._profiler_enabled() + assert rf is not None + # Schedule profiling callbacks to run when the future completes. + # This returns a future that is completed when the original future + # completes and the profiling callbacks have been completed as well, + # to guarantee that fut.wait() completes the profiling. This new + # future will contain the same value as the original future. + fut = rf._call_end_callbacks_on_future(fut) + return fut + + +@_require_initialized +def rpc_sync(to, func, args=None, kwargs=None, timeout: float = UNSET_RPC_TIMEOUT): + r""" + Make a blocking RPC call to run function ``func`` on worker ``to``. RPC + messages are sent and received in parallel to execution of Python code. This + method is thread-safe. + + Args: + to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. + func (Callable): a callable function, such as Python callables, builtin + operators (e.g. :meth:`~torch.add`) and annotated + TorchScript functions. + args (tuple): the argument tuple for the ``func`` invocation. + kwargs (dict): is a dictionary of keyword arguments for the ``func`` + invocation. + timeout (float, optional): timeout in seconds to use for this RPC. If + the RPC does not complete in this amount of + time, an exception indicating it has + timed out will be raised. A value of 0 + indicates an infinite timeout, i.e. a timeout + error will never be raised. If not provided, + the default value set during initialization + or with ``_set_rpc_timeout`` is used. + + Returns: + Returns the result of running ``func`` with ``args`` and ``kwargs``. + + Example:: + Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly + on both workers. Refer to :meth:`~torch.distributed.init_process_group` + API for more details. For example, + + export MASTER_ADDR=localhost + export MASTER_PORT=5678 + + Then run the following code in two different processes: + + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3)) + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + Below is an example of running a TorchScript function using RPC. + + >>> # On both workers: + >>> @torch.jit.script + >>> def my_script_add(tensor: torch.Tensor, scalar: int): + >>> return torch.add(tensor, scalar) + + >>> # On worker 0: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> ret = rpc.rpc_sync("worker1", my_script_add, args=(torch.ones(2), 3)) + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + """ + torch._C._log_api_usage_once("torch.distributed.rpc_sync") + fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs, timeout) + return fut.wait() + + +@_require_initialized +def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): + r""" + Make a non-blocking RPC call to run function ``func`` on worker ``to``. RPC + messages are sent and received in parallel to execution of Python code. This + method is thread-safe. This method will immediately return a + :class:`~torch.futures.Future` that can be awaited on. + + Args: + to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. + func (Callable): a callable function, such as Python callables, builtin + operators (e.g. :meth:`~torch.add`) and annotated + TorchScript functions. + args (tuple): the argument tuple for the ``func`` invocation. + kwargs (dict): is a dictionary of keyword arguments for the ``func`` + invocation. + timeout (float, optional): timeout in seconds to use for this RPC. If + the RPC does not complete in this amount of + time, an exception indicating it has + timed out will be raised. A value of 0 + indicates an infinite timeout, i.e. a timeout + error will never be raised. If not provided, + the default value set during initialization + or with ``_set_rpc_timeout`` is used. + + + Returns: + Returns a :class:`~torch.futures.Future` object that can be waited + on. When completed, the return value of ``func`` on ``args`` and + ``kwargs`` can be retrieved from the :class:`~torch.futures.Future` + object. + + .. warning :: + Using GPU tensors as arguments or return values of ``func`` is not + supported since we don't support sending GPU tensors over the wire. You + need to explicitly copy GPU tensors to CPU before using them as + arguments or return values of ``func``. + + .. warning :: + The ``rpc_async`` API does not copy storages of argument tensors until + sending them over the wire, which could be done by a different thread + depending on the RPC backend type. The caller should make sure that the + contents of those tensors stay intact until the returned + :class:`~torch.futures.Future` completes. + + Example:: + Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly + on both workers. Refer to :meth:`~torch.distributed.init_process_group` + API for more details. For example, + + export MASTER_ADDR=localhost + export MASTER_PORT=5678 + + Then run the following code in two different processes: + + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3)) + >>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2)) + >>> result = fut1.wait() + fut2.wait() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + Below is an example of running a TorchScript function using RPC. + + >>> # On both workers: + >>> @torch.jit.script + >>> def my_script_add(tensor: torch.Tensor, scalar: int): + >>> return torch.add(tensor, scalar) + + >>> # On worker 0: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> fut = rpc.rpc_async("worker1", my_script_add, args=(torch.ones(2), 3)) + >>> ret = fut.wait() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + """ + torch._C._log_api_usage_once("torch.distributed.rpc_async") + fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout) + if hasattr(_thread_local_var, "future_list"): + _thread_local_var.future_list.append(fut) + return fut + + +def _get_should_profile(): + # Legacy profiler should be enabled. RPC profiling is not supported with + # Kineto profiler. + ActiveProfilerType = torch._C._profiler.ActiveProfilerType + return ( + torch.autograd._profiler_enabled() + and torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined] + ) + + +def _enable_rpc_profiler( + should_profile, qualified_name, func, rpc_type, dst_worker_info +): + ctx_manager = contextlib.nullcontext() + + if should_profile: + # Create appropriate string representation based on type of func + # (builtin, script, python) + if qualified_name is None: + func_name = ( + torch._jit_internal._qualified_name(func) + if isinstance(func, torch.jit.ScriptFunction) + else func.__qualname__ + ) + else: + func_name = qualified_name + # Build RPC profiling key. + rpc_profiling_key = _build_rpc_profiling_key( + rpc_type, + func_name, + get_worker_info().name, + dst_worker_info.name, + ) + RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key) + # Mypy doesn't support re-def of a variable not in the same block (#1174) + ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment] + + return ctx_manager diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/backend_registry.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/backend_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..3f30252bd825665280a9b4cf96613bd6a676d190 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/backend_registry.py @@ -0,0 +1,431 @@ +# mypy: allow-untyped-defs + + +import collections +import enum +from typing import cast + +import torch +import torch.distributed as dist + +from . import api, constants as rpc_constants +from ._utils import _group_membership_management, _update_group_membership + + +__all__ = [ + "backend_registered", + "register_backend", + "construct_rpc_backend_options", + "init_backend", + "BackendValue", + "BackendType", +] + +BackendValue = collections.namedtuple( + "BackendValue", ["construct_rpc_backend_options_handler", "init_backend_handler"] +) + + +def _backend_type_repr(self): + return "BackendType." + self.name + + +_backend_type_doc = """ + An enum class of available backends. + + PyTorch ships with a builtin ``BackendType.TENSORPIPE`` backend. + Additional ones can be registered using the + :func:`~torch.distributed.rpc.backend_registry.register_backend` function. +""" + +# Create an enum type, `BackendType`, with empty members. +# Can't handle Function Enum API (mypy bug #9079) +BackendType = enum.Enum(value="BackendType", names={}) # type: ignore[misc] +# Unable to assign a function a method (mypy bug #2427) +BackendType.__repr__ = _backend_type_repr # type: ignore[assignment] + +if BackendType.__doc__: + BackendType.__doc__ = _backend_type_doc + + +def backend_registered(backend_name): + """ + Checks if backend_name is registered as an RPC backend. + + Args: + backend_name (str): string to identify the RPC backend. + Returns: + True if the backend has been registered with ``register_backend``, else + False. + """ + return backend_name in BackendType.__members__ + + +def register_backend( + backend_name, construct_rpc_backend_options_handler, init_backend_handler +): + """Registers a new RPC backend. + + Args: + backend_name (str): backend string to identify the handler. + construct_rpc_backend_options_handler (function): + Handler that is invoked when + rpc_backend.construct_rpc_backend_options(**dict) is called. + init_backend_handler (function): Handler that is invoked when the + `_init_rpc_backend()` function is called with a backend. + This returns the agent. + """ + global BackendType + if backend_registered(backend_name): + raise RuntimeError(f"RPC backend {backend_name}: already registered") + # Create a new enum type, `BackendType`, with extended members. + existing_enum_dict = {member.name: member.value for member in BackendType} + extended_enum_dict = dict( + { + backend_name: BackendValue( + construct_rpc_backend_options_handler=construct_rpc_backend_options_handler, + init_backend_handler=init_backend_handler, + ) + }, + **existing_enum_dict, + ) + # Can't handle Function Enum API (mypy bug #9079) + BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) # type: ignore[misc] + # Unable to assign a function a method (mypy bug #2427) + BackendType.__repr__ = _backend_type_repr # type: ignore[assignment] + if BackendType.__doc__: + BackendType.__doc__ = _backend_type_doc + # pyrefly: ignore [unsupported-operation] + return BackendType[backend_name] + + +def construct_rpc_backend_options( + backend, + rpc_timeout=rpc_constants.DEFAULT_RPC_TIMEOUT_SEC, + init_method=rpc_constants.DEFAULT_INIT_METHOD, + **kwargs, +): + return backend.value.construct_rpc_backend_options_handler( + rpc_timeout, init_method, **kwargs + ) + + +def init_backend(backend, *args, **kwargs): + return backend.value.init_backend_handler(*args, **kwargs) + + +def _init_process_group(store, rank, world_size): + # Initialize ProcessGroup. + process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT + + # We're using a bunch of private APIs here since `new_group` requires the + # default group to be initialized. + group = dist.ProcessGroupGloo(store, rank, world_size, process_group_timeout) + + assert group is not None, "Failed to initialize default ProcessGroup." + + if (rank != -1) and (rank != group.rank()): + raise RuntimeError(f"rank argument {rank} doesn't match pg rank {group.rank()}") + if (world_size != -1) and (world_size != group.size()): + raise RuntimeError( + f"world_size argument {world_size} doesn't match pg size {group.size()}" + ) + return group + + +def _tensorpipe_construct_rpc_backend_options_handler( + rpc_timeout, + init_method, + num_worker_threads=rpc_constants.DEFAULT_NUM_WORKER_THREADS, + _transports=None, + _channels=None, + **kwargs, +): + from . import TensorPipeRpcBackendOptions + + return TensorPipeRpcBackendOptions( + rpc_timeout=rpc_timeout, + init_method=init_method, + num_worker_threads=num_worker_threads, + _transports=_transports, + _channels=_channels, + ) + + +def _tensorpipe_validate_devices(devices, device_count): + return all( + d.type == "cpu" or (d.type == "cuda" and 0 <= d.index < device_count) + for d in devices + ) + + +# detect if any worker has invalid device_map configurations, and return +# reverse device maps +def _tensorpipe_exchange_and_check_all_device_maps( + my_name, my_device_count, my_device_maps, my_devices, group +): + gathered: list[ + tuple[str, int, dict[str, dict[torch.device, torch.device]], list[torch.device]] + ] = [("", 0, {}, []) for _ in range(group.size())] + dist.all_gather_object( + gathered, (my_name, my_device_count, my_device_maps, my_devices), group + ) + all_names = [name for name, _, _, _ in gathered] + all_device_counts = {name: count for name, count, _, _ in gathered} + all_device_maps = {name: map_ for name, _, map_, _ in gathered} + all_devices = {name: devices for name, _, _, devices in gathered} + + _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices) + + # passed all checked, construct reverse mapping and get list of devices handled by this agent + reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps) + my_devices = _create_device_list(my_devices, my_device_maps, reverse_device_maps) + return reverse_device_maps, my_devices + + +def _validate_device_maps( + all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True +): + for node in all_names: + devices = all_devices[node] + if len(set(devices)) != len(devices): + raise ValueError(f"Node {node} has duplicated devices\ndevices = {devices}") + if not _tensorpipe_validate_devices(devices, all_device_counts[node]): + raise ValueError( + f"Node {node} has devices with invalid indices\n" + f"devices = {devices}\n" + f"device count = {all_device_counts[node]}" + ) + + for source_node in all_names: + # For dynamic group (non-static) do not check the target node name since it may not have joined yet + if is_static_group and not set(all_device_maps[source_node].keys()).issubset( + all_names + ): + raise ValueError( + f"Node {source_node} has invalid target node names in its device maps\n" + f"device maps = {all_device_maps[source_node].keys()}\n" + f"node names = {all_names}" + ) + for target_node, map_ in all_device_maps[source_node].items(): + if len(set(map_.values())) != len(map_): + raise ValueError( + f"Node {source_node} has duplicated target devices " + f"in its device map for {target_node}\n" + f"device map = {map_}" + ) + if all_devices[source_node]: + if not set(map_.keys()).issubset(all_devices[source_node]): + raise ValueError( + f"Node {source_node} has unexpected source devices " + f"in its device map for {target_node}\n" + f"device map = {map_}\n" + f"devices = {all_devices[source_node]}" + ) + elif not _tensorpipe_validate_devices( + map_.keys(), all_device_counts[source_node] + ): + raise ValueError( + f"Node {source_node} has source devices with invalid indices " + f"in its device map for {target_node}\n" + f"device map = {map_}\n" + f"device count = {all_device_counts[source_node]}" + ) + if all_devices.get(target_node, []): + if not set(map_.values()).issubset(all_devices[target_node]): + raise ValueError( + f"Node {source_node} has unexpected target devices " + f"in its device map for {target_node}\n" + f"device map = {map_}\n" + f"devices = {all_devices[target_node]}" + ) + elif target_node in all_device_counts and not _tensorpipe_validate_devices( + map_.values(), all_device_counts[target_node] + ): + raise ValueError( + f"Node {source_node} has target devices with invalid indices " + f"in its device map for {target_node}\n" + f"device map = {map_}\n" + f"device count = {all_device_counts[target_node]}" + ) + + +def _create_device_list(my_devices, my_device_maps, reverse_device_maps): + if not my_devices: + devices_set: set[torch.device] = set() + for map_ in my_device_maps.values(): + devices_set.update(map_.keys()) + for map_ in reverse_device_maps.values(): + devices_set.update(map_.keys()) + devices_set.discard(torch.device("cpu")) + my_devices = list(devices_set) + my_devices = sorted(my_devices, key=lambda d: d.index) + return my_devices + + +def _create_reverse_mapping(my_name, all_names, all_device_maps): + reverse_device_maps: dict[str, dict[torch.device, torch.device]] = {} + for node in all_names: + if my_name in all_device_maps[node]: + reverse_device_maps[node] = { + v: k for k, v in all_device_maps[node][my_name].items() + } + return reverse_device_maps + + +def _get_device_infos(): + from . import TensorPipeAgent + + agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) + opts = agent._get_backend_options() + device_count = torch.cuda.device_count() + if torch.cuda.is_available() and opts.devices: + torch.cuda.init() + return device_count, opts.device_maps, opts.devices + + +def _set_devices_and_reverse_device_map(agent): + from . import TensorPipeAgent + + agent = cast(TensorPipeAgent, agent) + # Group state is retrieved from local agent + # On initialization, tensorpipe agent retrieves information from all existing workers, so group state is valid + my_worker_info = agent.get_worker_info() + my_name = my_worker_info.name + all_worker_infos = agent.get_worker_infos() + # One round to get device_maps of all workers and construct reverse device maps + all_device_counts, all_device_maps, all_devices, all_names = {}, {}, {}, [] + for worker_info in all_worker_infos: + worker_name = worker_info.name + if worker_name != my_name: + # TODO: make async? + device_count, device_map, devices = api.rpc_sync( + worker_name, _get_device_infos + ) + else: + opts = agent._get_backend_options() + device_count, device_map, devices = ( + torch.cuda.device_count(), + opts.device_maps, + opts.devices, + ) + all_device_counts[worker_name] = device_count + all_device_maps[worker_name] = device_map + all_devices[worker_name] = devices + all_names.append(worker_name) + + _validate_device_maps( + all_names, + all_device_counts, + all_device_maps, + all_devices, + is_static_group=False, + ) + reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps) + + # Perform RPC call to all workers, including itself, to include newly joined worker information and device maps + for worker_name in all_names: + # Set device list for each worker + all_devices[worker_name] = _create_device_list( + all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps + ) + api.rpc_sync( + worker_name, + _update_group_membership, + args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True), + ) + + +def _tensorpipe_init_backend_handler( + store, name, rank, world_size, rpc_backend_options +): + from . import TensorPipeAgent, TensorPipeRpcBackendOptions + + if not isinstance(store, dist.Store): + raise TypeError(f"`store` must be a c10d::Store. {store}") + + if not isinstance(rpc_backend_options, TensorPipeRpcBackendOptions): + raise TypeError( + f"`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`. {rpc_backend_options}" + ) + + device_count = torch.cuda.device_count() + + is_static_group = bool(world_size) + # world_size is specified so this is a static group (ranks cannot join and leave) + if is_static_group: + # The agent's join method is required to behave like a barrier and perform + # collective operations, for which it relies on a process group, instead of + # re-implementing this on top of RPCs. + group = _init_process_group(store, rank, world_size) + + reverse_device_maps, devices = _tensorpipe_exchange_and_check_all_device_maps( + name, + device_count, + rpc_backend_options.device_maps, + rpc_backend_options.devices, + group, + ) + + if torch.cuda.is_available() and devices: + # It's necessary to initialize PyTorch CUDA states here (e.g., + # CUDACachingAllocator). If this is missing, we could hit errors like + # "allocator not initialized", because other processes might send + # CUDA-related RPC request to this process before user code in this + # process initializes its PyTorch CUDA states. + torch.cuda.init() + + # TODO: add try-except and destroy _agent in all processes if any fails. + agent = TensorPipeAgent( + store, + name, + rank, + world_size, + rpc_backend_options, + reverse_device_maps, + devices, + ) + + api._init_rpc_states(agent) + + # Run one dummy round of RPC to initialize channels/transports. Without + # this, it's easy to hit timeout in rpc.shutdown() if there is no other RPC + # on that process before rpc.shutdown(), as the agent initialization can + # take longer than 5s. + api._all_gather(None, timeout=rpc_backend_options.rpc_timeout) + # Need a barrier here to make sure no peers leave before the rank0 finishes + # _all_gather + group.barrier().wait() + + return agent + # initialization for dynamic rpc (ranks can join and leave) + else: + with _group_membership_management(store, name, True): + # Construct TPAgent with empty reverse_device_map and devices + # these properties will be updated after initialization + agent = TensorPipeAgent( + store, + name, + rank, + world_size, + rpc_backend_options, + {}, + [], + ) + api._init_rpc_states(agent) + + try: + # Notify all workers in group this rank has joined and set devices and reverse_device_map + # This is a synchronous operation that completes once all existing ranks are updated + _set_devices_and_reverse_device_map(agent) + except Exception: + api.shutdown() + raise + return agent + + +register_backend( + "TENSORPIPE", + _tensorpipe_construct_rpc_backend_options_handler, + _tensorpipe_init_backend_handler, +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/constants.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..f0eaf92b8aef56dc96700c1ddb42bfb988542650 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/constants.py @@ -0,0 +1,24 @@ +from datetime import timedelta + +from torch._C._distributed_rpc import ( + _DEFAULT_INIT_METHOD, + _DEFAULT_NUM_WORKER_THREADS, + _DEFAULT_RPC_TIMEOUT_SEC, + _UNSET_RPC_TIMEOUT, +) + + +# For any RpcAgent. +DEFAULT_RPC_TIMEOUT_SEC: float = _DEFAULT_RPC_TIMEOUT_SEC +DEFAULT_INIT_METHOD: str = _DEFAULT_INIT_METHOD +DEFAULT_SHUTDOWN_TIMEOUT: float = 0 + +# For TensorPipeAgent. +DEFAULT_NUM_WORKER_THREADS: int = _DEFAULT_NUM_WORKER_THREADS +# Ensure that we don't time out when there are long periods of time without +# any operations against the underlying ProcessGroup. +DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2**31 - 1) +# Value indicating that timeout is not set for RPC call, and the default should be used. +UNSET_RPC_TIMEOUT: float = _UNSET_RPC_TIMEOUT + +__all__: list[str] = [] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/functions.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..e48ea8cc534ab87838965c947bbd0ed76d4d64c7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/functions.py @@ -0,0 +1,169 @@ +# mypy: allow-untyped-defs +import functools + + +def async_execution(fn): + r""" + A decorator for a function indicating that the return value of the function + is guaranteed to be a :class:`~torch.futures.Future` object and this + function can run asynchronously on the RPC callee. More specifically, the + callee extracts the :class:`~torch.futures.Future` returned by the wrapped + function and installs subsequent processing steps as a callback to that + :class:`~torch.futures.Future`. The installed callback will read the value + from the :class:`~torch.futures.Future` when completed and send the + value back as the RPC response. That also means the returned + :class:`~torch.futures.Future` only exists on the callee side and is never + sent through RPC. This decorator is useful when the wrapped function's + (``fn``) execution needs to pause and resume due to, e.g., containing + :meth:`~torch.distributed.rpc.rpc_async` or waiting for other signals. + + .. note:: To enable asynchronous execution, applications must pass the + function object returned by this decorator to RPC APIs. If RPC detected + attributes installed by this decorator, it knows that this function + returns a ``Future`` object and will handle that accordingly. + However, this does not mean this decorator has to be outmost one when + defining a function. For example, when combined with ``@staticmethod`` + or ``@classmethod``, ``@rpc.functions.async_execution`` needs to be the + inner decorator to allow the target function be recognized as a static + or class function. This target function can still execute asynchronously + because, when accessed, the static or class method preserves attributes + installed by ``@rpc.functions.async_execution``. + + + Example:: + The returned :class:`~torch.futures.Future` object can come from + :meth:`~torch.distributed.rpc.rpc_async`, + :meth:`~torch.futures.Future.then`, or :class:`~torch.futures.Future` + constructor. The example below shows directly using the + :class:`~torch.futures.Future` returned by + :meth:`~torch.futures.Future.then`. + + >>> from torch.distributed import rpc + >>> + >>> # omitting setup and shutdown RPC + >>> + >>> # On all workers + >>> @rpc.functions.async_execution + >>> def async_add_chained(to, x, y, z): + >>> # This function runs on "worker1" and returns immediately when + >>> # the callback is installed through the `then(cb)` API. In the + >>> # mean time, the `rpc_async` to "worker2" can run concurrently. + >>> # When the return value of that `rpc_async` arrives at + >>> # "worker1", "worker1" will run the lambda function accordingly + >>> # and set the value for the previously returned `Future`, which + >>> # will then trigger RPC to send the result back to "worker0". + >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then( + >>> lambda fut: fut.wait() + z + >>> ) + >>> + >>> # On worker0 + >>> # xdoctest: +SKIP + >>> ret = rpc.rpc_sync( + >>> "worker1", + >>> async_add_chained, + >>> args=("worker2", torch.ones(2), 1, 1) + >>> ) + >>> print(ret) # prints tensor([3., 3.]) + + When combined with TorchScript decorators, this decorator must be the + outmost one. + + >>> from torch import Tensor + >>> from torch.futures import Future + >>> from torch.distributed import rpc + >>> + >>> # omitting setup and shutdown RPC + >>> + >>> # On all workers + >>> @torch.jit.script + >>> def script_add(x: Tensor, y: Tensor) -> Tensor: + >>> return x + y + >>> + >>> @rpc.functions.async_execution + >>> @torch.jit.script + >>> def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]: + >>> return rpc.rpc_async(to, script_add, (x, y)) + >>> + >>> # On worker0 + >>> ret = rpc.rpc_sync( + >>> "worker1", + >>> async_add, + >>> args=("worker2", torch.ones(2), 1) + >>> ) + >>> print(ret) # prints tensor([2., 2.]) + + When combined with static or class method, this decorator must be the + inner one. + + >>> from torch.distributed import rpc + >>> + >>> # omitting setup and shutdown RPC + >>> + >>> # On all workers + >>> class AsyncExecutionClass: + >>> + >>> @staticmethod + >>> @rpc.functions.async_execution + >>> def static_async_add(to, x, y, z): + >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then( + >>> lambda fut: fut.wait() + z + >>> ) + >>> + >>> @classmethod + >>> @rpc.functions.async_execution + >>> def class_async_add(cls, to, x, y, z): + >>> ret_fut = torch.futures.Future() + >>> rpc.rpc_async(to, torch.add, args=(x, y)).then( + >>> lambda fut: ret_fut.set_result(fut.wait() + z) + >>> ) + >>> return ret_fut + >>> + >>> @rpc.functions.async_execution + >>> def bound_async_add(self, to, x, y, z): + >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then( + >>> lambda fut: fut.wait() + z + >>> ) + >>> + >>> # On worker0 + >>> ret = rpc.rpc_sync( + >>> "worker1", + >>> AsyncExecutionClass.static_async_add, + >>> args=("worker2", torch.ones(2), 1, 2) + >>> ) + >>> print(ret) # prints tensor([4., 4.]) + >>> + >>> ret = rpc.rpc_sync( + >>> "worker1", + >>> AsyncExecutionClass.class_async_add, + >>> args=("worker2", torch.ones(2), 1, 2) + >>> ) + >>> print(ret) # prints tensor([4., 4.]) + + This decorator also works with RRef helpers, i.e., . + :meth:`torch.distributed.rpc.RRef.rpc_sync`, + :meth:`torch.distributed.rpc.RRef.rpc_async`, and + :meth:`torch.distributed.rpc.RRef.remote`. + + >>> from torch.distributed import rpc + >>> + >>> # reuse the AsyncExecutionClass class above + >>> rref = rpc.remote("worker1", AsyncExecutionClass) + >>> ret = rref.rpc_sync().static_async_add("worker2", torch.ones(2), 1, 2) + >>> print(ret) # prints tensor([4., 4.]) + >>> + >>> rref = rpc.remote("worker1", AsyncExecutionClass) + >>> ret = rref.rpc_async().static_async_add("worker2", torch.ones(2), 1, 2).wait() + >>> print(ret) # prints tensor([4., 4.]) + >>> + >>> rref = rpc.remote("worker1", AsyncExecutionClass) + >>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here() + >>> print(ret) # prints tensor([4., 4.]) + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + # Can't declare and use attributes of function objects (mypy#2087) + wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined] + return wrapper diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/internal.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/internal.py new file mode 100644 index 0000000000000000000000000000000000000000..faef8afddfc2caac25c8360c216509aed5acf8c1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/internal.py @@ -0,0 +1,285 @@ +# mypy: allow-untyped-defs +import collections +import copyreg +import io +import pickle +import sys +import threading +import traceback +from enum import Enum + +import torch +import torch.distributed as dist +from torch._C._distributed_rpc import _get_current_rpc_agent + + +__all__ = ["RPCExecMode", "serialize", "deserialize", "PythonUDF", "RemoteException"] + +# Thread local tensor tables to store tensors while pickling torch.Tensor +# objects +_thread_local_tensor_tables = threading.local() +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler + + +class RPCExecMode(Enum): + SYNC = "sync" + ASYNC = "async" + ASYNC_JIT = "async_jit" + REMOTE = "remote" + + +class _InternalRPCPickler: + r""" + This class provides serialize() and deserialize() interfaces to serialize + data to be "binary string + tensor table" format + So for RPC python UDF function and args, non tensor data will be serialized + into regular binary string, tensor data will be put into thread local tensor + tables, this serialization format is consistent with builtin operator and args + using JIT pickler. This format will make tensor handling in C++ much easier, + e.g. attach tensor to distributed autograd graph in C++ + """ + + def __init__(self): + # Ignore type error because dispatch_table is defined in third-party package + self._dispatch_table = copyreg.dispatch_table.copy() # type: ignore[attr-defined] + self._dispatch_table[torch.Tensor] = self._tensor_reducer + # Used for registering customized picklers. + self._class_reducer_dict = {} + + def _register_reducer(self, obj_class, reducer): + # For the same class, only register the reducer once. + if obj_class not in self._class_reducer_dict: + self._class_reducer_dict[obj_class] = reducer + + @classmethod + def _tensor_receiver(cls, tensor_index): + global _thread_local_tensor_tables + return _thread_local_tensor_tables.recv_tables[tensor_index] + + def _tensor_reducer(self, tensor): + global _thread_local_tensor_tables + _thread_local_tensor_tables.send_tables.append(tensor) + tensor_index = len(_thread_local_tensor_tables.send_tables) - 1 + return (_InternalRPCPickler._tensor_receiver, (tensor_index,)) + + @classmethod + def _py_rref_receiver(cls, rref_fork_data): + return dist.rpc.PyRRef._deserialize(rref_fork_data) + + def _py_rref_reducer(self, py_rref): + rref_fork_data = py_rref._serialize() + return (_InternalRPCPickler._py_rref_receiver, (rref_fork_data,)) + + def _rref_reducer(self, rref): + return self._py_rref_reducer(rref) + + @classmethod + def _script_module_receiver(cls, script_module_serialized): + """ + Given a serialized representation of a ScriptModule created with torch.jit.save, + loads and returns the ScriptModule. + """ + f = io.BytesIO(script_module_serialized) + m = torch.jit.load(f) + return m + + def _script_module_reducer(self, script_module): + """ + Serializes a ScriptModule. + """ + f = io.BytesIO() + torch.jit.save(script_module, f) + return (_InternalRPCPickler._script_module_receiver, (f.getvalue(),)) + + def serialize(self, obj): + r""" + Serialize non tensor data into binary string, tensor data into + tensor table + """ + f = io.BytesIO() + p = _pickler(f) + p.dispatch_table = self._dispatch_table + + # rpc api could accept user picklers inheriting from _InternalRPCPickler to serialize rref, + # user picklers could have different initialization function from _InternalRPCPickler, + # but all the user picklers should call serialize() and use _rref_reducer to pickle rref + # in python. also, when _internal_rpc_pickler is imported to rpc/api.py, rpc.RRef is not + # compiled yet, it is not good place to access rpc.RRef inside _InternalRPCPickler constructor, + # so putting rref's dispatch table here + # + # The return value of a `rpc.remote(..)` call is type of `rpc.PyRRef`. + # The deserialized RRef object on an RPC receiver side is type of `rpc.PyRRef`. + # Ignore type error because dispatch_table is defined in third-party package + p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer # type: ignore[index] + # An RRef created locally by RRef Python constructor is type of `rpc.RRef`. + # Ignore type error because dispatch_table is defined in third-party package + p.dispatch_table[dist.rpc.RRef] = self._rref_reducer # type: ignore[index] + + # Add dispatch pickling for ScriptModule or its subclass. + if isinstance(obj, torch.jit.ScriptModule): + # Ignore type error because dispatch_table is defined in third-party package + p.dispatch_table[obj.__class__] = self._script_module_reducer # type: ignore[index] + + # Install customized picklers. + for class_name in self._class_reducer_dict: + p.dispatch_table[class_name] = self._class_reducer_dict[class_name] # type: ignore[index] + + # save _thread_local_tensor_tables.send_tables if it is in nested call + global _thread_local_tensor_tables + if hasattr(_thread_local_tensor_tables, "send_tables"): + old_send_tables = _thread_local_tensor_tables.send_tables + else: + old_send_tables = None + _thread_local_tensor_tables.send_tables = [] + + p.dump(obj) + + # restore _thread_local_tensor_tables.send_tables if return + # from nested call, otherwise clean up the table + tensors = _thread_local_tensor_tables.send_tables + if old_send_tables is not None: + _thread_local_tensor_tables.send_tables = old_send_tables + else: + del _thread_local_tensor_tables.send_tables + + return (f.getvalue(), tensors) + + def deserialize(self, binary_data, tensor_table): + r""" + Deserialize binary string + tensor table to original obj + """ + # save _thread_local_tensor_tables.recv_tables if it is in nested call + global _thread_local_tensor_tables + if hasattr(_thread_local_tensor_tables, "recv_tables"): + old_recv_tables = _thread_local_tensor_tables.recv_tables + else: + old_recv_tables = None + _thread_local_tensor_tables.recv_tables = tensor_table + + try: + unpickler = _unpickler(io.BytesIO(binary_data)) + ret = unpickler.load() + except AttributeError as e: + # Occurs when function is not found on module/class during + # unpickling. + except_str = ( + str(e) + + """ Default RPC pickler does not serialize + function code. Ensure that UDFs are defined on both caller and + callee modules.""" + ) + ret = AttributeError(except_str) + # Ensure the stack trace gets preserved + ret.__cause__ = e + + # restore _thread_local_tensor_tables.recv_tables if return + # from nested call, otherwise clean up the table + if old_recv_tables is not None: + _thread_local_tensor_tables.recv_tables = old_recv_tables + else: + del _thread_local_tensor_tables.recv_tables + + return ret + + +# Create _internal_rpc_pickler only once to initialize _dispatch_table only once +_internal_rpc_pickler = _InternalRPCPickler() + + +def serialize(obj): + return _internal_rpc_pickler.serialize(obj) + + +def deserialize(binary_data, tensor_table): + return _internal_rpc_pickler.deserialize(binary_data, tensor_table) + + +def _run_function(python_udf): + r""" + This function is exclusively called from C++. + See ``torch/csrc/distributed/rpc/python_rpc_handler.cpp``. + + Runs a Python UDF and returns its return value. + Wraps any exception in ``RemoteException`` if the function raises. + """ + try: + if isinstance(python_udf, AttributeError): + raise python_udf + result = python_udf.func(*python_udf.args, **python_udf.kwargs) + except Exception as e: + # except str = exception info + traceback string + except_str = ( + f"On {_get_current_rpc_agent().get_worker_info()}:\n" + f"{repr(e)}\n{traceback.format_exc()}" + ) + print(except_str, file=sys.stderr) + result = RemoteException(except_str, type(e)) + return result + + +def _handle_exception(result): + if isinstance(result, RemoteException): + exception_msg = result.msg.encode("utf-8").decode("unicode_escape") + # We wrap exception re-creation here in case some exception classes + # cannot be constructed directly from a string. + exc = None + try: + exc = result.exception_type(exception_msg) + except BaseException as e: # noqa: B036 + raise RuntimeError( # noqa: B904 + f"Failed to create original exception type. Error msg was {str(e)}" + f" Original exception on remote side was {exception_msg}" + ) from e + + if exc is not None: + raise exc + + +def _build_rpc_profiling_key( + exec_type, func_name, current_worker_name, dst_worker_name +): + """ + Builds the key that RPC calls are profiled with using the autograd profiler. + This will be the name of the corresponding Event recorded in the profiler. + + Args: + exec_type (RPCExecMode): Type of RPC/RRef call + func_name (str): Name of function being profiled. + current_worker_name (str): Name of current worker. + dst_worker_name (str): Name of the destination worker. + + Returns: + String representing profiling key + """ + profile_key = ( + f"rpc_{exec_type.value}#{func_name}({current_worker_name} -> {dst_worker_name})" + ) + return profile_key + + +def _start_record_function(exec_type, func_name, current_worker_name, dest_worker_name): + """ + This function should be called from RPC/RRef functions to create a + RecordFunction object for profiling. This function also runs the before + callbacks that start the profiling, though the user is responsible for + running the appropriate callbacks when the function to be profiled finishes. + + Args: + exec_type (RPCExecMode): Type of RPC/RRef call + func_name (str): Name of function being profiled. + current_worker_name (str): Name of current worker. + dest_worker_name (str): Name of the destination worker. + + Returns: + An instance of `torch.autograd._RecordFunction`. + """ + assert torch.autograd._profiler_enabled(), "Autograd profiler should be enabled." + profile_key = f"rpc_{exec_type.value}#{str(func_name)}({current_worker_name} -> {dest_worker_name})" + rf = torch.autograd._RecordFunction() # type: ignore[attr-defined] + torch.autograd._run_before_callbacks(rf, profile_key) # type: ignore[attr-defined] + return rf + + +PythonUDF = collections.namedtuple("PythonUDF", ["func", "args", "kwargs"]) +RemoteException = collections.namedtuple("RemoteException", ["msg", "exception_type"]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/options.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/options.py new file mode 100644 index 0000000000000000000000000000000000000000..c58a2bf923910039502ed98f1fd742b827800f20 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/options.py @@ -0,0 +1,181 @@ +# mypy: allow-untyped-defs +from typing import Union + +import torch + +from . import _is_tensorpipe_available, constants as rpc_contants + + +DeviceType = Union[int, str, torch.device] + +__all__ = ["TensorPipeRpcBackendOptions"] + + +def _to_device(device: DeviceType) -> torch.device: + device = torch.device(device) + if device.type != "cuda": + raise ValueError( + "`set_devices` expect a list of CUDA devices, but got " + f"device type {device.type}." + ) + return device + + +def _to_device_map( + device_map: dict[DeviceType, DeviceType], +) -> dict[torch.device, torch.device]: + full_device_map: dict[torch.device, torch.device] = {} + reverse_map: dict[torch.device, torch.device] = {} + for k, v in device_map.items(): + k, v = torch.device(k), torch.device(v) + if v in reverse_map: + raise ValueError( + "`device_map` only supports 1-to-1 mapping, " + f"trying to map {k} and {reverse_map[v]} to {v}" + ) + full_device_map[k] = v + reverse_map[v] = k + return full_device_map + + +def _to_device_list(devices: list[DeviceType]) -> list[torch.device]: + return list(map(_to_device, devices)) + + +if _is_tensorpipe_available: # type: ignore[has-type] + from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase +else: + _TensorPipeRpcBackendOptionsBase = object # type: ignore[assignment, misc] + + +# pyrefly: ignore [invalid-inheritance] +class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase): + r""" + The backend options for + :class:`~torch.distributed.rpc.TensorPipeAgent`, derived from + :class:`~torch.distributed.rpc.RpcBackendOptions`. + + Args: + num_worker_threads (int, optional): The number of threads in the + thread-pool used by + :class:`~torch.distributed.rpc.TensorPipeAgent` to execute + requests (default: 16). + rpc_timeout (float, optional): The default timeout, in seconds, + for RPC requests (default: 60 seconds). If the RPC has not + completed in this timeframe, an exception indicating so will + be raised. Callers can override this timeout for individual + RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and + :meth:`~torch.distributed.rpc.rpc_async` if necessary. + init_method (str, optional): The URL to initialize the distributed + store used for rendezvous. It takes any value accepted for the + same argument of :meth:`~torch.distributed.init_process_group` + (default: ``env://``). + device_maps (Dict[str, Dict], optional): Device placement mappings from + this worker to the callee. Key is the callee worker name and value + the dictionary (``Dict`` of ``int``, ``str``, or ``torch.device``) + that maps this worker's devices to the callee worker's devices. + (default: ``None``) + devices (List[int, str, or ``torch.device``], optional): all local + CUDA devices used by RPC agent. By Default, it will be initialized + to all local devices from its own ``device_maps`` and corresponding + devices from its peers' ``device_maps``. When processing CUDA RPC + requests, the agent will properly synchronize CUDA streams for + all devices in this ``List``. + """ + + def __init__( + self, + *, + num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS, + rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC, + init_method: str = rpc_contants.DEFAULT_INIT_METHOD, + device_maps: dict[str, dict[DeviceType, DeviceType]] | None = None, + devices: list[DeviceType] | None = None, + _transports: list | None = None, + _channels: list | None = None, + ): + full_device_maps = ( + {} + if device_maps is None + else {k: _to_device_map(v) for k, v in device_maps.items()} + ) + full_device_list = [] if devices is None else _to_device_list(devices) + super().__init__( + num_worker_threads, + _transports, + _channels, + rpc_timeout, + init_method, + full_device_maps, + full_device_list, + ) + + def set_device_map(self, to: str, device_map: dict[DeviceType, DeviceType]): + r""" + Set device mapping between each RPC caller and callee pair. This + function can be called multiple times to incrementally add + device placement configurations. + + Args: + to (str): Callee name. + device_map (Dict of int, str, or torch.device): Device placement + mappings from this worker to the callee. This map must be + invertible. + + Example: + >>> # xdoctest: +SKIP("distributed") + >>> # both workers + >>> def add(x, y): + >>> print(x) # tensor([1., 1.], device='cuda:1') + >>> return x + y, (x + y).to(2) + >>> + >>> # on worker 0 + >>> options = TensorPipeRpcBackendOptions( + >>> num_worker_threads=8, + >>> device_maps={"worker1": {0: 1}} + >>> # maps worker0's cuda:0 to worker1's cuda:1 + >>> ) + >>> options.set_device_map("worker1", {1: 2}) + >>> # maps worker0's cuda:1 to worker1's cuda:2 + >>> + >>> rpc.init_rpc( + >>> "worker0", + >>> rank=0, + >>> world_size=2, + >>> backend=rpc.BackendType.TENSORPIPE, + >>> rpc_backend_options=options + >>> ) + >>> + >>> x = torch.ones(2) + >>> rets = rpc.rpc_sync("worker1", add, args=(x.to(0), 1)) + >>> # The first argument will be moved to cuda:1 on worker1. When + >>> # sending the return value back, it will follow the invert of + >>> # the device map, and hence will be moved back to cuda:0 and + >>> # cuda:1 on worker0 + >>> print(rets[0]) # tensor([2., 2.], device='cuda:0') + >>> print(rets[1]) # tensor([2., 2.], device='cuda:1') + """ + full_device_map = _to_device_map(device_map) + curr_device_maps = super().device_maps + + if to in curr_device_maps: + for k, v in full_device_map.items(): + if k in curr_device_maps[to] and v != curr_device_maps[to][k]: + raise ValueError( + "`set_device_map` only supports 1-to-1 mapping, trying" + f" to map {k} to {v} and {curr_device_maps[to][k]}" + ) + + super()._set_device_map(to, full_device_map) + + def set_devices(self, devices: list[DeviceType]): + r""" + Set local devices used by the TensorPipe RPC agent. When processing + CUDA RPC requests, the TensorPipe RPC agent will properly synchronize + CUDA streams for all devices in this ``List``. + + Args: + devices (List of int, str, or torch.device): local devices used by + the TensorPipe RPC agent. + """ + self.devices = _to_device_list(devices) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/rref_proxy.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/rref_proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..46eecf19e22c9bcb11a475963f9be0461261b0a4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/rref_proxy.py @@ -0,0 +1,80 @@ +# mypy: allow-untyped-defs +from functools import partial + +import torch +from torch.futures import Future + +from . import functions, rpc_async +from .constants import UNSET_RPC_TIMEOUT + + +def _local_invoke(rref, func_name, args, kwargs): + return getattr(rref.local_value(), func_name)(*args, **kwargs) + + +@functions.async_execution +def _local_invoke_async_execution(rref, func_name, args, kwargs): + return getattr(rref.local_value(), func_name)(*args, **kwargs) + + +def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs): + def _rref_type_cont(rref_fut): + rref_type = rref_fut.value() + + _invoke_func = _local_invoke + # Bypass ScriptModules when checking for async function attribute. + bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass( + rref_type, torch._C.ScriptModule + ) + if not bypass_type: + func = getattr(rref_type, func_name) + if hasattr(func, "_wrapped_async_rpc_function"): + _invoke_func = _local_invoke_async_execution + + return rpc_api( + rref.owner(), + _invoke_func, + args=(rref, func_name, args, kwargs), + timeout=timeout, + ) + + rref_fut = rref._get_type(timeout=timeout, blocking=False) + + if rpc_api is not rpc_async: + rref_fut.wait() + return _rref_type_cont(rref_fut) + else: + # A little explanation on this. + # rpc_async returns a Future pointing to the return value of `func_name`, it returns a `Future[T]` + # Calling _rref_type_cont from the `then` lambda causes Future wrapping. IOW, `then` returns a `Future[Future[T]]` + # To address that, we return a Future that is completed with the result of the async call. + result: Future = Future() + + def _wrap_rref_type_cont(fut): + try: + _rref_type_cont(fut).then(_complete_op) + except BaseException as ex: # noqa: B036 + result.set_exception(ex) + + def _complete_op(fut): + try: + result.set_result(fut.value()) + except BaseException as ex: # noqa: B036 + result.set_exception(ex) + + rref_fut.then(_wrap_rref_type_cont) + return result + + +# This class manages proxied RPC API calls for RRefs. It is entirely used from +# C++ (see python_rpc_handler.cpp). +class RRefProxy: + def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT): + self.rref = rref + self.rpc_api = rpc_api + self.rpc_timeout = timeout + + def __getattr__(self, func_name): + return partial( + _invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/server_process_global_profiler.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/server_process_global_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..29a916772d330b555673645a3e38308788b31535 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/rpc/server_process_global_profiler.py @@ -0,0 +1,190 @@ +#!/usr/bin/python3 +# mypy: allow-untyped-defs + +import itertools + +import torch + +# pyrefly: ignore [deprecated] +from torch.autograd.profiler_legacy import profile + +from . import ( + _disable_server_process_global_profiler, + _enable_server_process_global_profiler, +) + + +__all__: list[str] = [] + + +class _server_process_global_profile(profile): + """ + It has the same API as ``torch.autograd.profiler.profile`` class, + except that it enables profiling on all threads running RPC server request callbacks. + + Context manager that manages autograd profiler state and holds a summary of results. + Under the hood it just records events of functions being executed in C++ and + exposes those events to Python. You can wrap any code into it and it will + only report runtime of PyTorch functions. + Note: profiler is thread local and is automatically propagated into the async tasks + + Args: + enabled (bool, optional): Setting this to False makes this context manager a no-op. + Default: ``True``. + + use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API. + Adds approximately 4us of overhead to each tensor operation. + Default: ``False`` + + record_shapes (bool, optional): If shapes recording is set, information + about input dimensions will be collected. This allows one to see which + dimensions have been used under the hood and further group by them + using prof.key_averages(group_by_input_shape=True). Please note that + shape recording might skew your profiling data. It is recommended to + use separate runs with and without shape recording to validate the timing. + Most likely the skew will be negligible for bottom most events (in a case + of nested function calls). But for higher level functions the total + self cpu time might be artificially increased because of the shape + collection. + + profile_memory (bool, optional): Whether to report memory usage, default: ``False`` + + .. warning:: + Enabling memory profiling incurs additional profiler overhead + + .. warning:: + Due to some CUDA multiprocessing limitations (see :ref:`multiprocessing-cuda-note`), + one cannot use the profiler with ``use_cuda = True`` to benchmark + DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading, + please use ``use_cuda = False`` or ``num_workers = 0``. + + Example: + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> x, y = torch.tensor(1), torch.tensor(2) + >>> outer_profile_rref = rpc.remote( + ... dst_worker_name, rpc._server_process_global_profile + ... ) + >>> outer_profile_rref.rpc_sync().__enter__() + >>> rpc.rpc_sync(dst_worker_name, torch.add, (x, y)) + >>> inner_profile_rref = rpc.remote( + ... dst_worker_name, rpc._server_process_global_profile + ... ) + >>> inner_profile_rref.rpc_sync().__enter__() + >>> rpc.rpc_sync(dst_worker_name, torch.sub, (x, y)) + >>> inner_profile_rref.rpc_sync().__exit__(None, None, None) + >>> outer_profile_rref.rpc_sync().__exit__(None, None, None) + >>> print(inner_profile_rref.rpc_sync().key_averages()) + --------- --------------- --------------- --------------- --------------- --------------- --------------- + Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls + --------- --------------- --------------- --------------- --------------- --------------- --------------- + sub 85.06% 76.275us 100.00% 89.667us 89.667us 1 + empty 14.94% 13.392us 14.94% 13.392us 13.392us 1 + --------- --------------- --------------- --------------- --------------- --------------- --------------- + Self CPU time total: 89.667us + >>> print(outer_profile_rref.rpc_sync().key_averages()) + --------- --------------- --------------- --------------- --------------- --------------- --------------- + Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls + --------- --------------- --------------- --------------- --------------- --------------- --------------- + sub 35.65% 76.275us 41.91% 89.667us 89.667us 1 + empty 12.67% 27.101us 12.67% 27.101us 13.551us 2 + add 51.68% 110.550us 58.09% 124.259us 124.259us 1 + --------- --------------- --------------- --------------- --------------- --------------- --------------- + Self CPU time total: 213.926us + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> # wait for worker 0 to finish work, and then shutdown. + >>> rpc.shutdown() + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __enter__(self): + """ + Turn on server-side process-global profiling. + This enables thread-local profiler on all RPC threads running server-side request callbacks. + """ + if not self.enabled: + return + + if self.entered: # type: ignore[has-type] + raise RuntimeError("autograd profiler traces are not reentrant") + self.entered = True + + profiler_kind = ( + torch.autograd.ProfilerState.CUDA + if self.use_cuda + else torch.autograd.ProfilerState.CPU + ) + profiler_config = torch.autograd.ProfilerConfig( + profiler_kind, + self.record_shapes, + self.profile_memory, + False, + False, + False, + torch.profiler._ExperimentalConfig(), + ) + _enable_server_process_global_profiler(profiler_config) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Turn off server-side process-global profiling. + Aggregate all profiling events recorded by RPC threads. + + These attributes are assigned on exiting context. + + Attributes: + function_events (torch.autograd.profiler.EventList). It's a list that has helper + methods, like 1) show record items in a pretty-print table. + 2) do averaging by grouping on keys. 3) and more. + + process_global_function_events (List[torch.autograd.profiler.FunctionEvent]). + It's a list of ``FunctionEvent`` elements. Every element is a profiling result + of an RPC request handling within the profiling range. + """ + if not self.enabled: + return + + process_global_events = _disable_server_process_global_profiler() + + # Every element in this list is a thread profiling result from an RPC request handling. + process_global_function_events = [] + for thread_local_events in process_global_events: + # Parse from ``Event``s to ``FunctionEvent``s. + thread_local_function_events = ( + torch.autograd.profiler_legacy._parse_legacy_records( + thread_local_events + ) + ) + thread_local_function_events.sort( + key=lambda function_event: [ + function_event.time_range.start, + -(function_event.time_range.end), + ] + ) + process_global_function_events.append(thread_local_function_events) + + flattened_function_events = list( + itertools.chain.from_iterable(process_global_function_events) + ) + # pyrefly: ignore [bad-assignment] + self.function_events = torch.autograd.profiler_util.EventList( + flattened_function_events, + use_device="cuda" if self.use_cuda else None, + profile_memory=self.profile_memory, + ) + # pyrefly: ignore [missing-attribute] + self.function_events._build_tree() + + self.process_global_function_events = process_global_function_events + + return False diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..067d4c0917e9de33b516c7ed47c678be2ac6c692 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/__init__.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import torch +import torch.distributed.tensor._ops # force import all built-in dtensor ops +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh # noqa: F401 +from torch.distributed.tensor._api import ( + distribute_module, + distribute_tensor, + DTensor, + empty, + full, + ones, + rand, + randn, + zeros, +) +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) +from torch.optim.optimizer import ( + _foreach_supported_types as _optim_foreach_supported_types, +) +from torch.utils._foreach_utils import ( + _foreach_supported_types as _util_foreach_supported_types, +) + + +# All public APIs from dtensor package +__all__ = [ + "DTensor", + "distribute_tensor", + "distribute_module", + "Shard", + "Replicate", + "Partial", + "Placement", + "ones", + "empty", + "full", + "rand", + "randn", + "zeros", +] + +# For weights_only torch.load +from ._dtensor_spec import ( + DTensorSpec as _DTensorSpec, + ShardOrderEntry as _ShardOrderEntry, + TensorMeta as _TensorMeta, +) + + +torch.serialization.add_safe_globals( + [ + DeviceMesh, + _DTensorSpec, + _TensorMeta, + _ShardOrderEntry, + DTensor, + Partial, + Replicate, + Shard, + ] +) + + +# Append DTensor to the list of supported types for foreach implementation for optimizer +# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA. +if DTensor not in _optim_foreach_supported_types: + _optim_foreach_supported_types.append(DTensor) + +if DTensor not in _util_foreach_supported_types: + _util_foreach_supported_types.append(DTensor) # type: ignore[arg-type] + + +# Set namespace for exposed private names +DTensor.__module__ = "torch.distributed.tensor" +distribute_tensor.__module__ = "torch.distributed.tensor" +distribute_module.__module__ = "torch.distributed.tensor" +ones.__module__ = "torch.distributed.tensor" +empty.__module__ = "torch.distributed.tensor" +full.__module__ = "torch.distributed.tensor" +rand.__module__ = "torch.distributed.tensor" +randn.__module__ = "torch.distributed.tensor" +zeros.__module__ = "torch.distributed.tensor" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_api.py new file mode 100644 index 0000000000000000000000000000000000000000..78e00d5137ea075fdcda11d3e97f2ed7ed7f3f0a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_api.py @@ -0,0 +1,1385 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import copy +import inspect +import warnings +from collections.abc import Callable, Sequence +from typing import Any +from typing_extensions import deprecated + +import torch +import torch.distributed.tensor._dispatch as op_dispatch +import torch.distributed.tensor._random as random +import torch.nn as nn +from torch._export.wrappers import mark_subclass_constructor_exportable_experimental +from torch.distributed.device_mesh import _mesh_resources, DeviceMesh +from torch.distributed.tensor._collective_utils import check_tensor_meta, mesh_broadcast +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._redistribute import ( + Redistribute, + redistribute_local_tensor, +) +from torch.distributed.tensor._utils import ( + compute_global_tensor_info, + compute_local_shape_and_global_offset, + normalize_to_torch_size, +) +from torch.distributed.tensor.placement_types import ( + _StridedShard, + Partial, + Placement, + Replicate, + Shard, +) + + +__all__ = [ + "DTensor", + "distribute_tensor", + "distribute_module", + "ones", + "empty", + "full", + "rand", + "randn", + "zeros", +] + +aten = torch.ops.aten + + +# NOTE [Autograd interaction between torch.Tensor] +# +# The autograd functions defined below are being used by the public +# facing APIs (i.e. from_local, to_local) to ensure DTensor to work +# together with torch.Tensor within the autograd engine. This +# allows DTensor to only exist on part of the module hierarchy. +# +# As an example, we have the a module that consists of submodules +# A, B, and C, the execution flow would be like: +# input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor) +# +# Suppose I only want to make Module B be a sharded module with +# DTensor params, the following forward/backward should work: +# +# input(torch.Tensor) -> Module A +# -> DTensor input (from_local) -> Sharded Module B -> DTensor output +# -> torch.Tensor output (to_local) -> Module C +# +# So from_local/to_local must be Autograd functions. +# +class _ToTorchTensor(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, + input: "DTensor", + grad_placements: Sequence[Placement] | None, + ): + ctx.dtensor_spec = input._spec + ctx.grad_placements = grad_placements + local_tensor = input._local_tensor + + # We need to return a fresh Tensor object there as autograd metadata + # will be inplaced into it. So we don't want to pollute the Tensor + # object stored in the _local_tensor of this DTensor. + return local_tensor.view_as(local_tensor) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): # type: ignore[override] + dtensor_spec = ctx.dtensor_spec + mesh = dtensor_spec.mesh + grad_placements = ctx.grad_placements + dtensor_meta = dtensor_spec.tensor_meta + + _, tensor_stride = compute_global_tensor_info( + grad_output, mesh, dtensor_spec.placements + ) + tensor_stride = tuple(tensor_stride) + grad_placements = grad_placements or dtensor_spec.placements + if ( + tensor_stride == dtensor_meta.stride + and grad_placements == dtensor_spec.placements + ): + # Avoid actual sharing of specs in case they're modified during (e.g.) + # sharding propagation. + grad_spec = copy.copy(dtensor_spec) + else: + grad_spec = DTensorSpec( + mesh, + grad_placements, + tensor_meta=TensorMeta( + shape=dtensor_meta.shape, + stride=tensor_stride, + dtype=dtensor_meta.dtype, + ), + ) + return ( + # pyrefly: ignore [bad-argument-type] + DTensor( + # pyrefly: ignore [bad-argument-count] + grad_output, + grad_spec, + # pyrefly: ignore [unexpected-keyword] + requires_grad=grad_output.requires_grad, + ), + None, + ) + + +class _FromTorchTensor(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, # pyre-ignore[2]: Parameter must be annotated. + input: torch.Tensor, + device_mesh: DeviceMesh, + placements: tuple[Placement, ...], + run_check: bool, + shape: torch.Size | None = None, + stride: tuple[int, ...] | None = None, + ) -> "DTensor": + ctx.previous_placement = placements + ctx.previous_device_mesh = device_mesh + + if shape and stride: + tensor_shape, tensor_stride = shape, stride + elif not shape and not stride: + # if it's not by default run_check, we assume user is certain that each + # rank has the same tensor shape, and we just use that to calculate the + # global shape + global_shape, global_stride = compute_global_tensor_info( + input, device_mesh, placements + ) + tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride) + else: + raise RuntimeError( + f"Found shape:{shape}, stride:{stride}.", + "Please pass both shape and stride at the same time.", + ) + + if device_mesh.get_coordinate() is None: + # if the global rank is not participating in the device mesh, we + # simply set the local tensor to an empty tensor + input = input.new_empty(0, requires_grad=input.requires_grad) + elif run_check: + # TODO: support uneven sharding when global shape/stride not passed, by + # building the global TensorMeta during check_tensor_meta + check_shape_stride = not shape and not stride + check_tensor_meta(input, check_shape_stride=check_shape_stride) + # TODO: See if we need to make this run_check logic + # have a corresponding backward. + for idx, placement in enumerate(placements): + if placement.is_replicate(): + # broadcast rank 0 tensor to all ranks + # only broadcast if run_check is True + input = input.contiguous() + mesh_broadcast(input, device_mesh, mesh_dim=idx) + + dist_spec = DTensorSpec( + device_mesh, + placements, + tensor_meta=TensorMeta( + tensor_shape, + tensor_stride, + input.dtype, + ), + ) + + # We want a fresh Tensor object that shares memory with the input tensor + # pyrefly: ignore [bad-argument-type] + dist_tensor = DTensor( + # pyrefly: ignore [bad-argument-count] + input.view_as(input), + dist_spec, + # requires_grad of the dist tensor depends on if input + # requires_grad or not + # pyrefly: ignore [unexpected-keyword] + requires_grad=input.requires_grad, + ) + return dist_tensor + + @staticmethod + def backward(ctx, grad_output: "DTensor"): # type: ignore[override] + previous_placement = ctx.previous_placement + previous_device_mesh = ctx.previous_device_mesh + + # reshard to the placement when creating DistributedTensor + # so that the gradient layout matches, and we could return + # local gradients directly + if grad_output.placements != previous_placement: + current_spec = grad_output._spec + target_spec = DTensorSpec( + previous_device_mesh, + previous_placement, + tensor_meta=grad_output._spec.tensor_meta, + ) + local_tensor = grad_output._local_tensor + output = redistribute_local_tensor( + local_tensor, current_spec, target_spec, is_backward=True + ) + # TODO: return the redistributed local tensor directly without + # differentiable backward. see if this make sense for all cases. + return output, None, None, None, None, None + + # TODO: backward is also differentiable now, add a test + # to test higher level gradients. + return grad_output.to_local(), None, None, None, None, None + + +class DTensor(torch.Tensor): + """ + ``DTensor`` (Distributed Tensor) is a subclass of ``torch.Tensor`` that provides single-device like + abstraction to program with multi-device ``torch.Tensor``. It describes the distributed tensor sharding + layout (DTensor Layout) through the :class:`DeviceMesh` and following types of :class:`Placement`: + + * :class:`Shard`: Tensor sharded on the tensor dimension ``dim`` on the devices of the ``DeviceMesh`` dimension + * :class:`Replicate`: Tensor replicated on the devices of the ``DeviceMesh`` dimension + * :class:`Partial`: Tensor is pending reduction on the devices of the ``DeviceMesh`` dimension + + When calling PyTorch operators, ``DTensor`` overrides the PyTorch operators to perform sharded computation and issue + communications whenever necessary. Along with the operator computation, ``DTensor`` will transform or propagate the + placements (DTensor Layout) properly (based on the operator semantic itself) and generate new ``DTensor`` outputs. + + To ensure numerical correctness of the ``DTensor`` sharded computation when calling PyTorch operators, ``DTensor`` + requires every Tensor argument of the operator be DTensor. + + .. note:: Directly using the Tensor subclass constructor here is not the recommended way to create a ``DTensor`` + (i.e. it does not handle autograd correctly hence is not the public API). Please refer to the `create_dtensor`_ + section to see how to create a ``DTensor``. + """ + + _local_tensor: torch.Tensor + _spec: DTensorSpec + __slots__ = ["_local_tensor", "_spec"] + + # _op_dispatcher instance as a class attribute to handle runtime dispatching logic + _op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher() + + # This implementation is just to convince mypy _spec and _local_tensor are + # initialized; it is immediately overridden below. + def __new__( + cls, + local_tensor: torch.Tensor, + spec: DTensorSpec, + *, + requires_grad: bool, + ) -> "DTensor": + r = torch.Tensor._dtensor__new__( + cls, local_tensor, spec, requires_grad=requires_grad + ) + r._spec = spec + r._local_tensor = local_tensor + return r + + __new__ = torch.Tensor._dtensor__new__ # type: ignore[assignment] # noqa: F811 + + @torch._disable_dynamo + @mark_subclass_constructor_exportable_experimental + def __init__(self, *args, **kwargs): + """ + Construct a DTensor from a local tensor, device mesh, and placement and + other tensor properties (i.e. shape, requires_grad, strides, etc). + .. note:: This is not a public API and it's only supposed to be used by the + operator implementations and internals. If you want to construct a + DTensor from a local tensor, consider using ``DTensor.from_local``, if + you want to construct a DTensor from a "global" tensor (where you + already have tensor initialized and want to shard this tensor), + consider using ``distribute_tensor``. + """ + super().__init__() + + # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently. + # pyre-fixme[3]: Return type must be annotated. + def __repr__(self): # type: ignore[override] + # TODO: consider all_gather the local tensors for better debugging + return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" + + def __tensor_flatten__(self): + """ + protocol to inform how to flatten a DTensor to local tensor + for PT2 tracing + """ + return ["_local_tensor"], (self._spec, self.requires_grad) + + @staticmethod + def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): + assert flatten_spec is not None, ( + "Expecting spec to be not None from `__tensor_flatten__` return value!" + ) + local_tensor = inner_tensors["_local_tensor"] + spec, requires_grad = flatten_spec + unflatten_tensor_meta = TensorMeta( + shape=outer_size, + stride=outer_stride, + dtype=spec.tensor_meta.dtype, + ) + unflatten_spec = DTensorSpec( + spec.mesh, + spec.placements, + tensor_meta=unflatten_tensor_meta, + ) + # pyrefly: ignore [bad-argument-type] + return DTensor( + # pyrefly: ignore [bad-argument-count] + local_tensor, + unflatten_spec, + # pyrefly: ignore [unexpected-keyword] + requires_grad=requires_grad, + ) + + def __coerce_tangent_metadata__(self): + if not any(isinstance(p, Partial) for p in self.placements): + return self + placements = [ + Replicate() if isinstance(p, Partial) else p for p in self.placements + ] + return self.redistribute(device_mesh=self.device_mesh, placements=placements) + + def __coerce_same_metadata_as_tangent__(self, flatten_spec, expected_type=None): + if expected_type is not None: + return None + + (spec, _) = flatten_spec # Result of tensor_flatten() + return self.redistribute( + device_mesh=self.device_mesh, + placements=spec.placements, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] + # We just need to have an implementation here; the __torch_dispatch__ machinery + # calls into a specific C++ fast path that doesn't call here. + # See #167051 for details + # python_arg_parser.cpp: dispatch_on_subclass() + # -> python_variable.cpp: dispatchDTensorOp() + raise NotImplementedError( + "DTensor.__torch_dispatch__ should not actually get called" + ) + + @staticmethod + def from_local( + local_tensor: torch.Tensor, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, + *, + run_check: bool = False, + shape: torch.Size | None = None, + stride: tuple[int, ...] | None = None, + ) -> "DTensor": + """ + Create a :class:`DTensor` from a local torch.Tensor on each rank + according to the ``device_mesh`` and ``placements`` specified. + + Args: + local_tensor (torch.Tensor): local torch.Tensor on each rank. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + tensor, if not specified, must be called under a DeviceMesh + context manager, default: None + placements (List[:class:`Placement`], optional): the placements that + describes how to place the local torch.Tensor on DeviceMesh, must + have the same number of elements as ``device_mesh.ndim``. + + Keyword args: + run_check (bool, optional): at a cost of extra communications, perform + sanity check across ranks to check each local tensor's meta information + to ensure correctness. If have :class:`Replicate` in ``placements``, the + data on first rank of the device mesh dimension will be broadcasted + to other ranks. default: False + shape (torch.Size, optional): A List of int which specifies the size of + DTensor which build on top of `local_tensor`. Note this needs to be + provided if the shape of ``local_tensor`` are different across the ranks. + If not provided, ``shape`` will be computed assuming the given distributed + tensor is evenly sharded across ranks. default: None + stride (tuple, optional): A List of int which specifies the stride of DTensor. + If not provided, ``stride`` will be computed assuming the given distributed + tensor is evenly sharded across ranks. default: None + + Returns: + A :class:`DTensor` object + + .. note:: When ``run_check=False``, it is the user's responsibility to ensure the + local tensor passed in is correct across ranks (i.e. the tensor is sharded for + the ``Shard(dim)`` placement or replicated for the ``Replicate()`` placement). + If not, the behavior of the created DTensor is undefined. + + .. note:: ``from_local`` is differentiable, the `requires_grad` of the created + `DTensor` object will depend on if `local_tensor` requires_grad or not. + """ + # `local_tensor` argument cannot be DTensor + if isinstance(local_tensor, DTensor): + raise RuntimeError( + f"the local_tensor argument only accepts torch.Tensor but got {type(local_tensor)} value." + ) + + # if same shape/dtype, no need to run_check, if not, must allgather + # the metadatas to check the size/dtype across ranks + # There should be no data communication unless there's replication + # strategy, where we broadcast the replication from the first rank + # in the mesh dimension + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + + # convert the local tensor to desired device base on device mesh's device_type + if device_type != local_tensor.device.type and not local_tensor.is_meta: + local_tensor = local_tensor.to(device_type) + + # set default placements to replicated if not specified + if placements is None: + placements = [Replicate() for _ in range(device_mesh.ndim)] + else: + placements = list(placements) + for idx, placement in enumerate(placements): + # normalize shard dim to be positive + if isinstance(placement, Shard | _StridedShard): + if placement.dim < 0: + normalized_dim = placement.dim + local_tensor.ndim + if type(placement) is _StridedShard: + placements[idx] = _StridedShard( + normalized_dim, split_factor=placement.split_factor + ) + elif type(placement) is Shard: + placements[idx] = Shard(normalized_dim) + + # `from_local` is differentiable, and the gradient of the dist tensor this function + # created should flow back the gradients to the local_tensor, so we call an autograd + # function to construct the dist tensor instead. + return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func + local_tensor, + device_mesh, + tuple(placements), + run_check, + shape, + stride, + ) + + def to_local( + self, *, grad_placements: Sequence[Placement] | None = None + ) -> torch.Tensor: + """ + Get the local tensor of this DTensor on its current rank. For sharding it returns + a local shard of the logical tensor view, for replication it returns the replica on + its current rank. + + Keyword args: + grad_placements (List[:class:`Placement`], optional): the placements describes + the future layout of any gradient layout of the Tensor returned from this + function. + `to_local` converts DTensor to local tensor and the returned local tensor + might not be used as the original DTensor layout later in the code. This + argument is the hint that user can give to autograd in case the gradient + layout of the returned tensor does not match the original DTensor layout. + If not specified, we will assume the gradient layout remains the same + as the original DTensor and use that for gradient computation. + + Returns: + A :class:`torch.Tensor` or ``AsyncCollectiveTensor`` object. it represents the + local tensor on its current rank. When an ``AsyncCollectiveTensor`` object is returned, + it means the local tensor is not ready yet (i.e. communication is not finished). In this + case, user needs to call ``wait`` to wait the local tensor to be ready. + + .. note:: ``to_local`` is differentiable, the ``requires_grad`` of the local tensor returned + will depend on if the `DTensor` requires_grad or not. + """ + if not torch.is_grad_enabled(): + return self._local_tensor + + if grad_placements is not None and not isinstance(grad_placements, tuple): + grad_placements = tuple(grad_placements) + return _ToTorchTensor.apply( + self, grad_placements + ) # pyre-ignore[16]: autograd func + + def redistribute( + self, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, + *, + async_op: bool = False, + forward_dtype: torch.dtype | None = None, + backward_dtype: torch.dtype | None = None, + ) -> "DTensor": + """ + ``redistribute`` performs necessary collective operations that redistribute the current + DTensor from its current placements to a new placements, or from its current DeviceMesh + to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by + specifying a Replicate placement for each dimension of the DeviceMesh. + + When redistributing from current to the new placements on one device mesh dimension, we + will perform the following operations including communication collective or local operation: + + 1. ``Shard(dim)`` -> ``Replicate()``: ``all_gather`` + 2. ``Shard(src_dim)`` -> ``Shard(dst_dim)``: ``all_to_all`` + 3. ``Replicate()`` -> ``Shard(dim)``: local chunking (i.e. ``torch.chunk``) + 4. ``Partial()`` -> ``Replicate()``: ``all_reduce`` + 5. ``Partial()`` -> ``Shard(dim)``: ``reduce_scatter`` + + + ``redistribute`` would correctly figure out the necessary redistribute steps for DTensors + that are created either on 1-D or N-D DeviceMesh. + + Args: + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + DTensor. If not specified, it would use the current DTensor's DeviceMesh. + default: None + placements (List[:class:`Placement`], optional): the new placements that + describes how to place the DTensor into the DeviceMesh, must + have the same number of elements as ``device_mesh.ndim``. + default: replicate on all mesh dimensions + + Keyword args: + async_op (bool, optional): whether to perform the DTensor redistribute operation + asynchronously or not. Default: False + forward_dtype (torch.dtype, optional): the local tensor datatype can be converted to + ``forward_dtype`` before redistributing the local tensor in its forward. + The result DTensor will be in ``forward_dtype`` Default: None. + backward_dtype (torch.dtype, optional): the local tensor datatype can be converted to + ``backward_dtype`` before redistributing the local tensor in its backward. + The result DTensor gradient would be converted back to the current DTensor dtype. Default: None + + Returns: + A :class:`DTensor` object + + .. note:: ``redistribute`` is differentiable, which means user do not need to worry about + the backward formula of the redistribute operation. + + .. note:: ``redistribute`` currently only supports redistributing DTensor on the same DeviceMesh, + Please file an issue if you need to redistribute DTensor to different DeviceMesh. + """ + # NOTE: This redistribute API currently only supports out + # of place redistribution, i.e. it always create a new + # DTensor object and leave the original one unchanged. + + # if device_mesh is not specified, use the current device_mesh + device_mesh = device_mesh or self.device_mesh + # raise error if new placements not specified + if placements is None: + raise RuntimeError("placements is needed for redistribute!") + + placements = list(placements) + for i, placement in enumerate(placements): + if placement.is_partial() and self.placements[i] != placement: + raise RuntimeError( + f"Can not redistribute from {self.placements[i]} to {placement}, " + "redistributing to Partial is for internal use only!" + ) + elif isinstance(placement, Shard) and placement.dim < 0: + # normalize shard dim to be positive + placements[i] = Shard(placement.dim + self.ndim) + elif isinstance(placement, _StridedShard) and placement.dim < 0: + placements[i] = _StridedShard( + placement.dim + self.ndim, split_factor=placement.split_factor + ) + placements = tuple(placements) + + # pyre-fixme[16]: `Redistribute` has no attribute `apply`. + return Redistribute.apply( + self, device_mesh, placements, async_op, forward_dtype, backward_dtype + ) + + def full_tensor( + self, *, grad_placements: Sequence[Placement] | None = None + ) -> torch.Tensor: + """ + Return the full tensor of this DTensor. It will perform necessary collectives + to gather the local tensors from other ranks in its DeviceMesh and concatenate + them together. It's a syntactic sugar of the following code: + + ``dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()`` + + Keyword args: + grad_placements (List[:class:`Placement`], optional): the placements describes + the future layout of any gradient layout of the full Tensor returned from this + function. + `full_tensor` converts DTensor to a full torch.Tensor and the returned torch.tensor + might not be used as the original replicated DTensor layout later in the code. This + argument is the hint that user can give to autograd in case the gradient + layout of the returned tensor does not match the original replicated DTensor layout. + If not specified, we will assume the gradient layout of the full tensor be replicated. + + Returns: + A :class:`torch.Tensor` object that represents the full tensor of this DTensor. + + .. note:: ``full_tensor`` is differentiable. + """ + + redist_res = self.redistribute( + placements=[Replicate()] * self.device_mesh.ndim, async_op=False + ) + return _ToTorchTensor.apply(redist_res, grad_placements) + + @property + def device_mesh(self) -> DeviceMesh: + """ + The :class:`DeviceMesh` attribute that associates with this DTensor object. + + .. note:: ``device_mesh`` is a read-only property, it can not be set. + """ + return self._spec.mesh + + @property + def placements(self) -> tuple[Placement, ...]: + """ + The placements attribute of this DTensor that describes the layout of this + DTensor on the its DeviceMesh. + + .. note:: ``placements`` is a read-only property, it can not be set. + """ + return self._spec.placements + + def _raise_if_contains_partial_placements(self) -> None: + """ + Raise an error if the DTensor contains partial placements. + """ + for placement in self._spec.placements: + if not isinstance(placement, Partial): + continue + + raise ValueError( + "Any checkpointing related operations are not supported for " + "DTensor with partial placements!" + ) + + def __create_write_items__(self, fqn: str, object: Any): + self._raise_if_contains_partial_placements() + from torch.distributed.checkpoint.planner_helpers import ( + _create_write_items_for_dtensor, + ) + + if hasattr(self._local_tensor, "__create_write_items__"): + return self._local_tensor.__create_write_items__(fqn, object) # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return [_create_write_items_for_dtensor(fqn, object)] + else: + raise RuntimeError("Unsupported tensor type!") + + def __create_chunk_list__(self): + """ + Return a list of ChunkStorageMetadata, which is a dataclass that describes the size/offset of the local shard/replica + on current rank. For DTensor, each rank will have a single local shard/replica, so the returned list usually only + has one element. + + This dunder method is primariy used for distributed checkpoint purpose. + + Returns: + A List[:class:`ChunkStorageMetadata`] object that represents the shard size/offset on the current rank. + """ + self._raise_if_contains_partial_placements() + from torch.distributed.checkpoint.planner_helpers import ( + _create_chunk_from_dtensor, + ) + + if hasattr(self._local_tensor, "__create_chunk_list__"): + return self._local_tensor.__create_chunk_list__() # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return [_create_chunk_from_dtensor(self)] + else: + raise RuntimeError("Unsupported tensor type!") + + def __get_tensor_shard__(self, index): + self._raise_if_contains_partial_placements() + if hasattr(self._local_tensor, "__get_tensor_shard__"): + return self._local_tensor.__get_tensor_shard__(index) # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return self.to_local() + else: + raise RuntimeError("Unsupported tensor type!") + + @classmethod + def __metadata_guard__( + cls, orig: tuple[DTensorSpec, bool], other: tuple[DTensorSpec, bool] + ) -> bool: + # TODO - delete this - This is now unused after the PR - + # https://github.com/pytorch/pytorch/pull/165824 + orig_spec, orig_requires_grad = orig + other_spec, other_requires_grad = other + return ( + orig_spec._check_equals(other_spec, skip_shapes=True) + and orig_requires_grad == other_requires_grad + ) + + +def distribute_tensor( + tensor: torch.Tensor, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, + *, + src_data_rank: int | None = 0, +) -> DTensor: + """ + Distribute a leaf ``torch.Tensor`` (i.e. nn.Parameter/buffers) to the ``device_mesh`` according + to the ``placements`` specified. The rank of ``device_mesh`` and ``placements`` must be the + same. The ``tensor`` to distribute is the logical or "global" tensor, and the API would use + the ``tensor`` from first rank of the DeviceMesh dimension as the source of truth to preserve + the single-device semantic. If you want to construct a DTensor in the middle of the Autograd + computation, please use :meth:`DTensor.from_local` instead. + + Args: + tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you + want to shard a tensor on a dimension that is not evenly divisible by + the number of devices in that mesh dimension, we use ``torch.chunk`` + semantic to shard the tensor and scatter the shards. The uneven sharding + behavior is experimental and subject to change. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the + tensor, if not specified, must be called under a DeviceMesh context + manager, default: None + placements (List[:class:`Placement`], optional): the placements that + describes how to place the tensor on DeviceMesh, must have the same + number of elements as ``device_mesh.ndim``. If not specified, we will + by default replicate the tensor across the ``device_mesh`` from the + first rank of each dimension of the `device_mesh`. + + Keyword args: + src_data_rank (int, optional): the rank of the source data for the logical/global tensor, it is + used by :meth:`distribute_tensor` to scatter/broadcast the shards/replicas to other ranks. + By default, we use ``group_rank=0`` on each DeviceMesh dimension as the source data to preserve + the single-device semantic. If passing ``None`` explicitly, :meth:`distribute_tensor` simply uses + its local data instead of trying to preserve the single-device semantic via scatter/broadcast. + Default: 0 + + Returns: + A :class:`DTensor` or ``XLAShardedTensor`` object. + + .. note:: + When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_tensor`` + return `XLAShardedTensor` instead. see `this issue `__ + for more details. The XLA integration is experimental and subject to change. + """ + + torch._C._log_api_usage_once("torch.dtensor.distribute_tensor") + + # get default device mesh if there's nothing specified + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + if device_type == "xla": + try: + # call PyTorch/XLA SPMD for `xla` backend type device mesh. + # This returns XLAShardedTensor + from torch_xla.distributed.spmd import ( # type:ignore[import] + xla_distribute_tensor, + ) + + return xla_distribute_tensor(tensor, device_mesh, placements) # type:ignore[return-value] + except ImportError as e: + msg = "To use DTensor API with xla, you must install the torch_xla package!" + raise ImportError(msg) from e + + if not tensor.is_leaf: + raise RuntimeError( + "`distribute_tensor` should be used to distribute leaf tensors! but found non-leaf tensor!" + ) + + # convert tensor to the corresponding device type if it's not in that device type + if device_type != tensor.device.type and not tensor.is_meta: + tensor = tensor.to(device_type) + + # set default placements to replicated if not specified + if placements is None: + placements = [Replicate() for _ in range(device_mesh.ndim)] + + if len(placements) != device_mesh.ndim: + raise ValueError( + f"`placements` must have the same length as `device_mesh.ndim`! " + f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}." + ) + if isinstance(tensor, DTensor): + # if the tensor is already a DTensor, we need to check: + # 1. if the we can further shard this DTensor if the two device mesh belong to + # the same parenet mesh and further sharding is possible. + # 2. check if device mesh and placements are the same + if tensor.device_mesh != device_mesh: + raise ValueError( + f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} " + f"to a different device mesh {device_mesh}." + ) + if tensor.placements != tuple(placements): + raise ValueError( + f"Cannot distribute a DTensor with placements {tensor.placements} " + f"to a different placements {placements}. do you want to call " + f"`redistribute` instead?" + ) + return tensor + + local_tensor = tensor.detach() + + # TODO(xilun): address sharding order + # distribute the tensor according to the placements. + placements = list(placements) + for idx, placement in enumerate(placements): + if isinstance(placement, Shard | _StridedShard): + placement_dim = ( + placement.dim + tensor.ndim if placement.dim < 0 else placement.dim + ) + if isinstance(placement, Shard): + local_tensor = Shard._make_shard_tensor( + placement_dim, local_tensor, device_mesh, idx, src_data_rank + ) + placements[idx] = Shard(placement_dim) + else: + local_tensor = _StridedShard._make_shard_tensor( + placement_dim, + local_tensor, + device_mesh, + idx, + src_data_rank, + split_factor=placement.split_factor, + ) + placements[idx] = _StridedShard( + placement_dim, split_factor=placement.split_factor + ) + elif isinstance(placement, Replicate): + local_tensor = Replicate._make_replicate_tensor( + local_tensor, device_mesh, idx, src_data_rank + ) + elif isinstance(placement, Partial): + local_tensor = Replicate._make_replicate_tensor( + local_tensor, device_mesh, idx, src_data_rank + ) + local_tensor = placement._partition_value(local_tensor, device_mesh, idx) + else: + raise RuntimeError( + f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!" + ) + placements = tuple(placements) + + assert local_tensor is not None, "distributing a tensor should not be None" + # detach the local tensor passed to DTensor since after the construction + # of DTensor, autograd would work on top of DTensor instead of local tensor + spec = DTensorSpec( + mesh=device_mesh, + placements=placements, + tensor_meta=TensorMeta( + shape=tensor.size(), + stride=tensor.stride(), + dtype=tensor.dtype, + ), + ) + # pyrefly: ignore [bad-argument-type] + return DTensor( + # pyrefly: ignore [bad-argument-count] + local_tensor.requires_grad_(tensor.requires_grad), + spec, + # pyrefly: ignore [unexpected-keyword] + requires_grad=tensor.requires_grad, + ) + + +@deprecated("Please use `distribute_tensor` with `src_data_rank=None` instead.") +def _shard_tensor( + full_tensor: torch.Tensor, + placements: Sequence[Shard], + device_mesh: DeviceMesh | None = None, +) -> "DTensor": + """ + Locally shards a full tensor based on indicated sharding arrangement, and + returns a DTensor containing the local shard. + + .. warning:: This is a private API that is subject to change. It skips the + communication otherwise required by `distribute_tensor`. It is only + applicable to cases where all ranks have the same `full_tensor`. For + example, in distributed inference all ranks load from the same + checkpoint. This API will not check for data equality between ranks, it + is thus user's responsibility to ensure the `full_tensor` is the same + across ranks. + + Args: + full_tensor (torch.Tensor): the full tensor to be sharded. + placements (Sequence[:class:`Shard`]): the placements that + describes how to place the local tensor on DeviceMesh. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + DTensor. Must have same dimension as the number of placements. + If not specified, would be retrieve from current context. + + Returns: + A :class:`DTensor` object with the shard as its local tensor. + + Examples: + >>> # xdoctest: +SKIP("need world_size and rank") + >>> device_mesh = dist.init_device_mesh("cuda", (world_size,)) + >>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}") + >>> dtensor = _shard_tensor(full_tensor, [Shard(1)], device_mesh) + """ + return distribute_tensor(full_tensor, device_mesh, placements, src_data_rank=None) + + +def distribute_module( + module: nn.Module, + device_mesh: DeviceMesh | None = None, + partition_fn: Callable[[str, nn.Module, DeviceMesh], None] | None = None, + input_fn: Callable[[nn.Module, Any, DeviceMesh], None] | None = None, + output_fn: Callable[[nn.Module, Any, DeviceMesh], None] | None = None, +) -> nn.Module: + """ + This function expose three functions to control the parameters/inputs/outputs of the module: + + 1. To perform sharding on the module before runtime execution by specifying the + ``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor` + parameters according to the `partition_fn` specified). + 2. To control the inputs or outputs of the module during runtime execution by + specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to + :class:`DTensor`, convert the output back to ``torch.Tensor``) + + Args: + module (:class:`nn.Module`): user module to be partitioned. + device_mesh (:class:`DeviceMesh`): the device mesh to place the module. + partition_fn (Callable): the function to partition parameters (i.e. shard certain + parameters across the ``device_mesh``). If ``partition_fn`` is not specified, + by default we replicate all module parameters of ``module`` across the mesh. + input_fn (Callable): specify the input distribution, i.e. could control how the + input of the module is sharded. ``input_fn`` will be installed as a module + ``forward_pre_hook`` (pre forward hook). + output_fn (Callable): specify the output distribution, i.e. could control how the + output is sharded, or convert it back to torch.Tensor. ``output_fn`` will be + installed as a module ``forward_hook`` (post forward hook). + + Returns: + A module that contains parameters/buffers that are all ``DTensor`` s. + + .. note:: + When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_module`` + return nn.Module with PyTorch/XLA SPMD annotated parameters. See + `this issue `__ + for more details. The XLA integration is experimental and subject to change. + + """ + + torch._C._log_api_usage_once("torch.dtensor.distribute_module") + + already_distributed = getattr(module, "_distribute_module_applied", False) + if already_distributed: + raise RuntimeError( + "distribute_module should only be called once on a module, " + "but it has already been called on this module!" + ) + + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + if device_type == "xla": + try: + # This function annotates all module parameters for auto-partitioning with + # PyTorch/XLA SPMD or explicitly partition to :class:`XLAShardedTensor` parameters + # according to the `partition_fn` specified. + from torch_xla.distributed.spmd import ( # type:ignore[import] + xla_distribute_module, + ) + + return xla_distribute_module( + module, device_mesh, partition_fn, input_fn, output_fn + ) # type:ignore[return-value] + except ImportError as e: + msg = "To use DTensor API with xla, you must install the torch_xla package!" + raise ImportError(msg) from e + + def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: + # This function loop over the immediate module parameters and + # buffers, replicate all non DTensor params/buffers to DTensor + # parameters/buffers, if they have not been partitioned in the + # partition_fn, we can't easily use `module._apply` here + # because we don't know what happened inside partition_fn as + # user could do anything, i.e. install hooks, and we want to + # preserve those. + full_replicate = [Replicate()] * mesh.ndim + for key, param in m._parameters.items(): + if param is not None and not isinstance(param, DTensor): + m.register_parameter( + key, + nn.Parameter(distribute_tensor(param.data, mesh, full_replicate)), + ) + for key, buffer in m._buffers.items(): + if buffer is not None and not isinstance(buffer, DTensor): + m._buffers[key] = distribute_tensor(buffer, mesh, full_replicate) + + if partition_fn is None: + # if partition_fn not specified, we by default replicate + # all module params/buffers + for submod in module.modules(): + replicate_module_params_buffers(submod, device_mesh) + else: + # apply partition_fun to submodules + for name, submod in module.named_modules(): + partition_fn(name, submod, device_mesh) + replicate_module_params_buffers(submod, device_mesh) + + # register input_fn as module forward pre hook + if input_fn is not None: + # check the input_fn signature + num_args = len(inspect.signature(input_fn).parameters) + if num_args == 2: + # input_fn only takes in inputs and device mesh + warnings.warn( + "Deprecating input_fn that takes two arguments (inputs, device_mesh), " + "please use input_fn that takes in (module, inputs, device_mesh) instead!", + FutureWarning, + stacklevel=2, + ) + module.register_forward_pre_hook( + lambda _, inputs: input_fn(inputs, device_mesh) # type: ignore[call-arg] + ) + elif num_args == 3: + # input_fn takes in module, inputs, device mesh + module.register_forward_pre_hook( + lambda mod, inputs: input_fn(mod, inputs, device_mesh) + ) + else: + raise ValueError( + f"input_fn should take in 3 arguments, but got {num_args} arguments!" + ) + # register output_fn as module forward hook + if output_fn is not None: + num_args = len(inspect.signature(output_fn).parameters) + if num_args == 2: + # output_fn only takes in outputs and device mesh + warnings.warn( + "Deprecating output_fn that takes two arguments (inputs, device_mesh), " + "please use output_fn that takes in (module, inputs, device_mesh) instead!", + FutureWarning, + stacklevel=2, + ) + module.register_forward_hook( + lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg] + ) + elif num_args == 3: + module.register_forward_hook( + lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh) + ) + else: + raise ValueError( + f"output_fn should take in 3 arguments, but got {num_args} arguments!" + ) + + module._distribute_module_applied = True # type: ignore[assignment] + return module + + +# Below are tensor factory function APIs, which are used to create a DTensor directly. We need +# to make separate factory function APIs because tensor subclass could not override the tensor +# factory methods, and we need user to call the factory functions with user intended device_mesh +# and placements to create a proper DTensor. + + +def _dtensor_init_helper( # type: ignore[no-untyped-def] + init_op, + size: torch.Size, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, + **kwargs, +) -> DTensor: + # if device_mesh is None, use the one from mesh resources + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + kwargs["device"] = device_mesh.device_type + + # set default placements to replicated if not specified + placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim)) + + # check device_mesh against placements + assert device_mesh.ndim == len(placements), ( + "mesh dimension does not match the length of placements" + ) + + assert kwargs["layout"] == torch.strided, "layout value not supported!" + torch_stride = torch._prims_common.make_contiguous_strides_for(size) + + # get local tensor shape + local_shape, _ = compute_local_shape_and_global_offset( + size, device_mesh, placements, skip_offset=True + ) + + # initialize the local tensor + if init_op is torch.full: + fill_value = kwargs.pop("fill_value", 0) + local_tensor = init_op(local_shape, fill_value, **kwargs) + elif init_op is torch.rand or init_op is torch.randn: + # this tensor meta is not used except `shape` + dtype = kwargs.get("dtype", torch.get_default_dtype()) + + tensor_meta = TensorMeta(size, (0,), dtype) + spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=tensor_meta) + + if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker: + random._rng_tracker = random.OffsetBasedRNGTracker(device_mesh) + + assert random._rng_tracker is not None + with random._rng_tracker._distribute_region(spec): + local_tensor = init_op(local_shape, **kwargs) + else: + local_tensor = init_op(local_shape, **kwargs) + + spec = DTensorSpec( + device_mesh, + tuple(placements), + tensor_meta=TensorMeta( + size, + torch_stride, + local_tensor.dtype, + ), + ) + + # pyrefly: ignore [bad-argument-type] + return DTensor( + # pyrefly: ignore [bad-argument-count] + local_tensor, + spec, + # pyrefly: ignore [unexpected-keyword] + requires_grad=kwargs["requires_grad"], + ) + + +def ones( # type: ignore[no-untyped-def] + *size, + dtype: torch.dtype | None = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined + by the variable argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.ones, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def empty( # type: ignore[no-untyped-def] + *size, + dtype: torch.dtype | None = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor` + is defined by the variable argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\ + layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.empty, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def full( # type: ignore[no-untyped-def] + size, + fill_value, + *, + dtype: torch.dtype | None = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with ``fill_value`` according to ``device_mesh`` and + ``placements``, with the shape defined by the argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + fill_value(Scalar): the value to fill the output tensor with. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.full, + torch_size, + fill_value=fill_value, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def rand( # type: ignore[no-untyped-def] + *size, + requires_grad: bool = False, + dtype: torch.dtype | None = None, + layout: torch.layout = torch.strided, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with random numbers from a uniform distribution + on the interval ``[0, 1)``. The shape of the tensor is defined by the variable + argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.rand, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def randn( # type: ignore[no-untyped-def] + *size, + requires_grad: bool = False, + dtype: torch.dtype | None = None, + layout: torch.layout = torch.strided, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with random numbers from a normal distribution + with mean 0 and variance 1. The shape of the tensor is defined by the variable + argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.randn, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def zeros( # type: ignore[no-untyped-def] + *size, + requires_grad: bool = False, + dtype: torch.dtype | None = None, + layout: torch.layout = torch.strided, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with the scalar value 0. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..)) + Keyword args: + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. + Default: ``torch.strided``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.zeros, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_argmin_argmax.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_argmin_argmax.py new file mode 100644 index 0000000000000000000000000000000000000000..730291a7926b3130e23a0d1b98b6d6170fca03c3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_argmin_argmax.py @@ -0,0 +1,120 @@ +import operator +from functools import reduce + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.tensor._api as dtensor +from torch.distributed.tensor._utils import compute_local_shape_and_global_offset +from torch.distributed.tensor.placement_types import Partial, Replicate, Shard + + +_REDUCTION_OPS = { + torch.ops.aten.argmax.default: torch.max, + torch.ops.aten.argmin.default: torch.min, +} + + +def argmin_argmax_handler( + op_call: torch._ops.OpOverload, + args: tuple["dtensor.DTensor", int] | tuple["dtensor.DTensor", int, bool], + kwargs: dict[str, object], +): + """ + Handles reduces on sharded dimensions locally to limit calls to replicate. + """ + op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + assert output_sharding is not None, "output sharding should not be None" + if op_call not in _REDUCTION_OPS: + raise NotImplementedError(f"Unsupported reduction op: {op_call}") + val_op = _REDUCTION_OPS[op_call] + + input_dtensor = args[0] + if not isinstance(input_dtensor, dtensor.DTensor): + raise NotImplementedError + + dim: int | None = args[1] if len(args) > 1 else None # type: ignore[assignment] + keepdim = args[2] if len(args) > 2 else False + + placements = input_dtensor.placements + + # check for partial placements and handle it as replicate. + if any(isinstance(p, Partial) for p in placements): + target_placements = [ + Replicate() if isinstance(p, Partial) else p for p in placements + ] + input_dtensor = input_dtensor.redistribute( + device_mesh=input_dtensor.device_mesh, placements=target_placements + ) + placements = input_dtensor.placements + local_tensor = input_dtensor.to_local() + + input_shape = list(local_tensor.shape) + if dim is None: + expected_shape = ( + torch.Size([1] * len(input_shape)) if keepdim else torch.Size([]) + ) + elif keepdim: + if input_shape: + input_shape[dim] = 1 + expected_shape = torch.Size(input_shape) + else: + if input_shape: + input_shape.pop(dim) + expected_shape = torch.Size(input_shape) + + shard_mesh_dims = [] + for mesh_dim, p in enumerate(placements): + if isinstance(p, Shard): + if dim is None or p.dim == (dim if dim >= 0 else local_tensor.ndim + dim): + shard_mesh_dims.append(mesh_dim) + + device_mesh = input_dtensor.device_mesh + + if dim is None: + local_idx = op_call(local_tensor) + local_max = local_tensor.flatten()[local_idx] + else: + local_max, local_idx = val_op(local_tensor, dim=dim, keepdim=True) + + if not shard_mesh_dims: + return dtensor.DTensor._op_dispatcher.wrap( + local_idx.reshape(expected_shape), output_sharding.output_spec + ) + + # find the correct offset for sharded dim + global_shape = input_dtensor.shape + _, global_offset = compute_local_shape_and_global_offset( + global_shape, device_mesh, placements + ) + gathered_maxes = local_max + if dim is None: + local_coord = torch.unravel_index(local_idx, local_tensor.shape) + global_coord = torch.stack(local_coord) + gather_dim = 0 + for i, offset in enumerate(global_offset): + global_coord[i] += offset + # compute with proper striding + gathered_idxs = torch.tensor(0, device=local_tensor.device, dtype=torch.long) + for i, coord in enumerate(global_coord): + gathered_idxs += coord * reduce(operator.mul, global_shape[i + 1 :], 1) + else: + gather_dim = dim + gathered_idxs = local_idx + global_offset[dim] + + for mesh_dim in shard_mesh_dims: + gathered_maxes = funcol.all_gather_tensor( + gathered_maxes, gather_dim=gather_dim, group=(device_mesh, mesh_dim) + ) + gathered_idxs = funcol.all_gather_tensor( + gathered_idxs, gather_dim=gather_dim, group=(device_mesh, mesh_dim) + ) + + rank_winner = op_call(gathered_maxes, dim, True) + + final_idx = torch.gather(gathered_idxs, dim=gather_dim, index=rank_winner) + + return dtensor.DTensor._op_dispatcher.wrap( + final_idx.reshape(expected_shape), output_sharding.output_spec + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_collective_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_collective_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..766b030ad9524d7c3e8dc185ac3ff3056e4bcc77 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_collective_utils.py @@ -0,0 +1,396 @@ +# mypy: allow-untyped-defs +import logging +import math +from dataclasses import dataclass +from functools import lru_cache +from typing import Optional + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.tensor._dtensor_spec as dtensor_spec +from torch._C._distributed_c10d import _resolve_process_group +from torch._logging import warning_once +from torch.distributed._local_tensor import ( + local_tensor_mode, + maybe_run_for_local_tensor, +) +from torch.distributed.device_mesh import _mesh_resources, DeviceMesh +from torch.distributed.distributed_c10d import ( + _get_group_size_by_name, + broadcast, + get_group_rank, + get_rank, + ProcessGroup, + scatter, + Work, +) + + +logger = logging.getLogger(__name__) + + +@torch.library.register_fake("_dtensor::shard_dim_alltoall") +def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name): + group_size = _get_group_size_by_name(group_name) + stacked_list = [torch.empty_like(input) for _ in range(group_size)] + group = _resolve_process_group(group_name) + group_rank = get_group_rank(group, get_rank()) + + return ( + torch.cat(stacked_list, dim=gather_dim) + .chunk(group_size, dim=shard_dim)[group_rank] + .contiguous() + ) + + +def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim): + if mesh.device_type == "cpu" and local_tensor_mode() is None: + # Gloo does not support alltoall, so falling back to allgather + chunk + warning_once( + logger, + "CPU process group does not support alltoall yet, falling back with allgather + chunk!", + ) + out = funcol.all_gather_tensor(input, gather_dim, (mesh, mesh_dim)) + if isinstance(out, funcol.AsyncCollectiveTensor): + # stick to the same behavior for the alltoall case, remove this once we enable alltoall async + out = out.wait() + out = torch.chunk(out, mesh.size(mesh_dim), dim=shard_dim)[ + mesh.get_local_rank(mesh_dim) + ] + return out.contiguous() + + group_name = funcol._resolve_group_name((mesh, mesh_dim)) + # TODO: enable async op for shard_dim_alltoall + return torch.ops._dtensor.shard_dim_alltoall( + input, gather_dim, shard_dim, group_name + ) + + +def mesh_scatter( + output: torch.Tensor, + scatter_list: list[torch.Tensor], + mesh: DeviceMesh, + mesh_dim: int = 0, + async_op: bool = False, + *, + group_src: int = 0, +) -> Work | None: + """ + scatter a list of tensors to a device mesh dimension. We by default + use the first rank of the mesh dimension as the source of truth, i.e + for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will + scatter the tensor list on rank 0 to rank 0/1, and tensor list on rank + 2 to rank 2/3. + + Args: + output (torch.Tensor): the tensor to receive the scattered list. + scatter_list (List[torch.Tensor]): the tensor list to be scattered. + mesh_dim (int, optional): indicate which mesh dimension we want + to scatter on, we by default choose the first rank on the + mesh dimension as source of truth. + + Keyword args: + group_src (int, optional): the group rank of the source data for the + logical/global tensor, on the specific mesh dimension. By default, we + use ``group_rank=0`` on each DeviceMesh dimension as the source data + to preserve the single-device semantic. If passing ``None`` explicitly, + this method simply uses its local data with no communication. + + Returns: + A :class:`Work` object + """ + # TODO: Ideally we should use the meta tensor way + # (to register a meta kernel for the collective op) + # so that it would avoid the communication. Need to + # remove the check below once that is done. + if output.is_meta: + return None + dim_group = mesh.get_group(mesh_dim) + assert isinstance(dim_group, ProcessGroup) + + if group_src == get_rank(dim_group): + fut = scatter( + output, + scatter_list=scatter_list, + group=dim_group, + async_op=async_op, + group_src=group_src, + ) + else: + fut = scatter( + output, + scatter_list=None, + group=dim_group, + async_op=async_op, + group_src=group_src, + ) + + return fut + + +def mesh_broadcast( + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int = 0, + async_op: bool = False, + *, + group_src: int = 0, +) -> Work | None: + """ + broadcast the tensor to a device mesh dimension. We by default + use the first rank of the mesh dimension as the source of truth, i.e + for a 2d mesh [[0, 1], [2, 3]], if we broadcast on mesh_dim = 1, we will + broadcast the tensor on rank 0 to rank 0/1, and tensor on rank 2 + to rank 2/3. + + Args: + tensor (torch.Tensor): tensor to broadcast. + mesh_dim (int, optional): indicate which mesh dimension we want + to scatter on, we by default choose the first rank on the + mesh dimension as source of truth. + + Keyword args: + group_src (int, optional): the group rank of the source data for the + logical/global tensor, on the specific mesh dimension. By default, we + use ``group_rank=0`` on each DeviceMesh dimension as the source data + to preserve the single-device semantic. If passing ``None`` explicitly, + this method simply uses its local data with no communication. + + Returns: + A :class:`Work` object + """ + # TODO: Ideally we should use the meta tensor way + # (to register a meta kernel for the collective op) + # so that it would avoid the communication. Need to + # remove the check below once that is done. + if tensor.is_meta: + return None + dim_group = mesh.get_group(mesh_dim) + assert isinstance(dim_group, ProcessGroup) + + return broadcast(tensor, group=dim_group, async_op=async_op, group_src=group_src) + + +@maybe_run_for_local_tensor +def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor: + if pad_size == 0: + return tensor + pad = [0, 0] * (tensor.ndim - pad_dim) + pad[-1] = pad_size + return torch.nn.functional.pad(tensor, pad) + + +@maybe_run_for_local_tensor +def unpad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor: + if pad_size == 0: + return tensor + return tensor.narrow( + pad_dim, + start=0, + length=tensor.size(pad_dim) - pad_size, + ) + + +def fill_empty_tensor_to_shards( + shards: list[torch.Tensor], shard_dim: int, num_empty_tensors: int +) -> list[torch.Tensor]: + if num_empty_tensors == 0: + return shards + tensor_size = list(shards[0].size()) + tensor_size[shard_dim] = 0 + tensor = shards[0].new_zeros(tensor_size) + shards.extend(tensor for _ in range(num_empty_tensors)) + return shards + + +def check_tensor_meta( + local_tensor, check_shape_stride=False +) -> Optional["dtensor_spec.TensorMeta"]: + local_metadata = { + "dtype": local_tensor.dtype, + "requires_grad": local_tensor.requires_grad, + } + + if check_shape_stride: + local_metadata.update( + {"shape": local_tensor.shape, "stride": local_tensor.stride()} + ) + + gathered_metadata = [None for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather_object(gathered_metadata, local_metadata) + + # Check if metadata is consistent across ranks + if not all(meta == local_metadata for meta in gathered_metadata): + raise ValueError( + "Inconsistent tensor metadata (including shape and stride) across ranks." + ) + return None + + +def spec_to_bytes(spec: "dtensor_spec.DTensorSpec") -> int: + assert spec.tensor_meta is not None, "spec should have tensor meta defined!" + return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape) + + +@dataclass +class MeshTopoInfo: + """ + Mesh information for collective cost estimation + """ + + mesh: DeviceMesh + mesh_dim_devices: list[int] + mesh_dim_bandwidth: list[float] + mesh_dim_latency: list[float] + + @staticmethod + @lru_cache(None) + def build_from_mesh(mesh: DeviceMesh) -> "MeshTopoInfo": + # Generate mesh topology info for intra-host/inter-host communication pattern + # Note that we made bunch of assumptions for simplicity: + # 1. we assume the mesh is homogeneous, and it's gpu/nccl model + # 2. we assume gpu arch is Ampere or Hopper + # 3. we assume collectives are all ring base algo for now + num_devices_per_host = _mesh_resources.num_devices_per_host(mesh.device_type) + # the base bw number (intra-node), GB/s + base_bw = 87.7 + mesh_dim_bandwidth = [base_bw] * mesh.ndim + # the latency in terms of us (intra-node, nv-link) + mesh_dim_latency = [0.6] * mesh.ndim + mesh_dim_devices = [1] * mesh.ndim + + total_num_devices = 1 + for mesh_dim in reversed(range(mesh.ndim)): + num_devices = mesh.size(mesh_dim) + mesh_dim_devices[mesh_dim] = num_devices + total_num_devices *= num_devices + if total_num_devices > num_devices_per_host: + # magic number for inter-host communication bandwidth/latency factor + # This number assumes latest GPU arch, i.e. Ampere or Hopper + # TODO: see if we need to tweak this or offer a way for user + # to specify the bandwidths/latency + mesh_dim_bandwidth[mesh_dim] *= 0.22 + # set to ethernet latency for inter-host + mesh_dim_latency[mesh_dim] = 2.7 + + return MeshTopoInfo( + mesh, mesh_dim_devices, mesh_dim_bandwidth, mesh_dim_latency + ) + + +def allgather_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float: + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] + mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim] + num_hops = num_devices_on_mesh_dim - 1 + # base latency + comm latency + latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] # us + bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth # s + return latency + bw * 1e6 # rescale to us + + +def allreduce_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float: + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] + mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim] + # allreduce have almost 2x comm bytes compare to allgather/reduce_scatter + num_hops = 2 * (num_devices_on_mesh_dim - 1) + + latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] + bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth + return latency + bw * 1e6 + + +def reduce_scatter_cost( + bytes_gb: float, + mesh_topo: MeshTopoInfo, + mesh_dim: int, +) -> float: + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] + mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim] + num_hops = num_devices_on_mesh_dim - 1 + # base latency + comm latency + latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] + bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth + return latency + bw * 1e6 + + +def redistribute_cost( + current_spec: "dtensor_spec.DTensorSpec", + target_spec: "dtensor_spec.DTensorSpec", +) -> float: + """ + This function returns the cost of redistribute from current to target DTensorSpec. + + NOTE: + 1. Only consider communication cost here, since computation costs for redistribute + are quite trivial (i.e. we only need to narrow or simple division) + 2. Only consider redistribute cost on same mesh, cross mesh communication cost is + not quite needed for operator strategy estimation/selection. + """ + if current_spec.mesh != target_spec.mesh: + # make infinite cost if meshes are not same + # TODO: see if we want to support this once there's cross mesh communication + return float("inf") + + if current_spec.is_replicated(): + # short-cut: + # comm cost is 0 if current spec is already full replication + return 0.0 + + mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh) + cost = 0.0 + comm_bytes_gb = ( + spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024 + ) + # Transformation that considered for redistribute cost: + # 1. allgather 2. alltoall + # 3. allreduce 4. reduce_scatter + from torch.distributed._functional_collectives import _are_we_tracing + from torch.distributed.tensor._redistribute import ( + _gen_transform_infos, + _gen_transform_infos_non_cached, + ) + + # No redistribution needed when placements are already identical. + # This also prevents potential failures in _gen_transform_infos for certain configurations + # (e.g., sub-meshes) where finding a transform path between identical states may error out. + # TODO(zpcore): test placements with _StridedShard. + if current_spec.placements == target_spec.placements: + return cost + if _are_we_tracing(): + transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec) + else: + transform_infos = _gen_transform_infos(current_spec, target_spec) + for transform_info in transform_infos: + assert current_spec.tensor_meta is not None, ( + "spec should have tensor meta defined!" + ) + current = transform_info.src_dst_placements[0] + target = transform_info.src_dst_placements[1] + if current == target: + continue + mesh_dim = transform_info.mesh_dim + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] + if current.is_shard() and target.is_replicate(): + # allgather gives larger comm bytes + comm_bytes_gb *= num_devices_on_mesh_dim + # add up allgather comm cost + cost += allgather_cost(comm_bytes_gb, mesh_topo, mesh_dim) + elif current.is_shard() and target.is_shard(): + # should be alltoall comm, since we haven't implement it yet, add 1.0 as penalty + # to favor allgather instead + # TODO: add alltoall_cost + cost += allgather_cost(comm_bytes_gb, mesh_topo, mesh_dim) + 1.0 + elif current.is_partial() and target.is_replicate(): + # add up allreduce comm cost + cost += allreduce_cost(comm_bytes_gb, mesh_topo, mesh_dim) + elif current.is_partial() and target.is_shard(): + # add up reduce_scatter comm cost + cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, mesh_dim) + # after reduce_scatter the comm bytes for further collectives halved. + comm_bytes_gb /= num_devices_on_mesh_dim + elif current.is_shard() and target.is_partial(): + # ban shard -> partial as it does not make sense to perform + # this redistribute + return float("inf") + + return cost diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..54c0cf63440b947587eca96781371c98ffa58407 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py @@ -0,0 +1,653 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import contextlib +import logging +import warnings +from collections.abc import Sequence +from typing import cast + +import torch +import torch.distributed as dist +import torch.distributed.tensor._api as dtensor +import torch.distributed.tensor._random as random +from torch._library.utils import fill_defaults +from torch.distributed._functional_collectives import _are_we_tracing +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._argmin_argmax import argmin_argmax_handler +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( + OpInfo, + OpSchema, + OutputSharding, + OutputSpecType, +) +from torch.distributed.tensor._random import is_rng_supported_mesh +from torch.distributed.tensor._redistribute import redistribute_local_tensor +from torch.distributed.tensor._sharding_prop import ShardingPropagator +from torch.distributed.tensor._tp_conv import ( + convolution_backward_handler, + convolution_handler, +) +from torch.distributed.tensor._utils import ( + ExplicitRedistributionContext, + try_find_mesh_from_args, +) +from torch.distributed.tensor.placement_types import Partial, Placement, Replicate +from torch.utils._debug_mode import get_active_debug_mode +from torch.utils._python_dispatch import return_and_correct_aliasing + + +try: + from torch.utils import _cxx_pytree as pytree +except ImportError: + from torch.utils import _pytree as pytree # type: ignore[no-redef] + +aten = torch.ops.aten +logger = logging.getLogger(__name__) + + +def as_strided_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +): + args, kwargs = fill_defaults(op_call._schema, args, kwargs) + assert not kwargs + tensor, size, stride, storage_offset = args + if ( + tensor.size() == tuple(size) + and tensor.stride() == tuple(stride) + and (storage_offset is None or tensor.storage_offset() == storage_offset) + ): + return torch.ops.aten.alias.default(tensor) + raise RuntimeError("as_strided not supported with DTensor") + + +def is_same_size_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> bool: + lhs = cast(torch.Tensor, args[0]) + rhs = cast(torch.Tensor, args[1]) + return lhs.shape == rhs.shape + + +def found_inf_reduce_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> None: + op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + local_tensor_args = pytree.tree_unflatten( + cast(list[object], op_info.local_args), + op_info.args_tree_spec, # type: ignore[arg-type] + ) + local_tensor_args = cast(tuple[object, ...], local_tensor_args) + op_call(*local_tensor_args, **op_info.local_kwargs) + + grad_dtensor = cast(list[dtensor.DTensor], args[0])[0] + grad_placements = grad_dtensor.placements + mesh = grad_dtensor.device_mesh + + found_inf_placements: list[Placement] = [] + for placement in grad_placements: + if isinstance(placement, Replicate): + found_inf_placements.append(placement) + else: + found_inf_placements.append(Partial("max")) + + target_tensor = cast(torch.Tensor, args[1]) + spec = DTensorSpec( + mesh=mesh, + placements=tuple(found_inf_placements), + tensor_meta=TensorMeta( + shape=target_tensor.size(), + stride=target_tensor.stride(), + dtype=target_tensor.dtype, + ), + ) + # pyrefly: ignore [bad-argument-type] + found_inf_dtensor = dtensor.DTensor( + local_tensor=target_tensor, # pyrefly: ignore [unexpected-keyword] + spec=spec, # pyrefly: ignore [unexpected-keyword] + requires_grad=False, # pyrefly: ignore [unexpected-keyword] + ) + found_inf = found_inf_dtensor.full_tensor() + target_tensor.copy_(found_inf) + + +class OpDispatcher: + """ + Op dispatching class instance to handle args/kwargs pre-processing (un-wrapping), sharding + propagation, redistribute local args, local compute, and post-processing (re-wrapping). It + also handles any op specific logic if necessary. + + NOTE: Given the runtime overhead of Tensor subclass (__torch_dispatch__), the OpDispatcher + is designed to minimize the CPU overhead by using the tricks of proper unflattening, faster + pytree if needed, and leveraging various caching mechanisms implemented in the sharding + propagation and redistribute modules. The CPU overhead is critical to eager mode performance, + one need to carefully measure the CPU overhead when making significant changes to the + OpDispatcher and ShardingPropagator. + """ + + def __init__(self) -> None: + self.sharding_propagator = ShardingPropagator() + # NOTE: must stay in sync with is_random_op in + # torch/csrc/autograd/python_variable.cpp + self._random_ops = { + aten.native_dropout.default, + aten.normal_.default, + aten.rand.default, + aten.rand_like.default, + aten.randn.default, + aten.randn_like.default, + aten.randint_like.default, + aten.randint_like.low_dtype, + aten.randint_like.low_dtype_out, + aten.uniform_.default, + aten.bernoulli.default, + aten.bernoulli_.float, + } + self._custom_op_handlers = { + aten.is_same_size.default: is_same_size_handler, + aten.convolution.default: convolution_handler, + aten.convolution_backward.default: convolution_backward_handler, + aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler, + aten.as_strided.default: as_strided_handler, + aten.argmin.default: argmin_argmax_handler, + aten.argmax.default: argmin_argmax_handler, + } + + # ******************************************************************************************** + # def dispatch(...) + # + # NOTE: this class no longer contains the top-level dispatch entrypoint! + # See #167051 for details + # + # The entrypoint has been moved to C++, and it handles common cases and then calls back into + # OpDispatcher python to handle corner cases. + # See dispatchDTensorOp() defined in python_variable.cpp and called from python_arg_parser.cpp + # ******************************************************************************************** + + # This flag is used internally to control whether we treat the torch.Tensor(non-DTensor) + # as implicitly replicated or we throw error to user. + # NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave + # it as False by default. + @property + def _allow_implicit_replication(self) -> bool: + return torch._C._get_dtensor_allow_implicit_replication() + + @_allow_implicit_replication.setter + def _allow_implicit_replication(self, value: bool) -> None: + return torch._C._set_dtensor_allow_implicit_replication(value) + + def _propagate_op_sharding_dispatch_slow_path( + self, + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], + op_info: OpInfo, + # The logic here is a bit messy. There are several reasons why the + # C++ fastpath may have bailed out. If we just cache missed, we will + # come here because we need to actually calculate the real thing. + # There's no need to have a SECOND Python cache lookup; the C++ native + # cache completely subsumes it. But sometimes, we will have failed + # to compute the cache key in C++ entirely. In this case, we DO need + # to do a cache lookup in Python, as the missing cache key in C++ + # means we don't have access to it all. Furthermore, without duping + # this function, we need to do the try_cache test inside of the + # try-except block so that either case hits the inference mode / + # exception rewrapping case. + # + # This should be cleaned up. First, ensuring the C++ codepath can + # always compute a key will be a big help. Second, we should properly + # fastpath inference mode composite implicit autograd so that you + # don't have to throw an exception even in "fastpath". + try_cache: bool, + ) -> object: + try: + # We have basically inlined propagate() here, but WITHOUT the + # output_sharding assignment + if try_cache and not _are_we_tracing(): + return self.sharding_propagator.propagate_op_sharding(op_info.schema) + else: + return self.sharding_propagator.propagate_op_sharding_non_cached( + op_info.schema + ) + except NotImplementedError: + if torch._C._dispatch_has_kernel_for_dispatch_key( + op_call.name(), torch._C.DispatchKey.CompositeImplicitAutograd + ): + # When running under inference mode, CompositeImplicitAutograd ops show up in __torch_dispatch__, + # so we manually decompose them, here + out = op_call.decompose(*args, **kwargs) + assert out is not NotImplemented + return out + else: + raise + except Exception as e: + raise RuntimeError( + f"{e}\n\nSharding propagation failed for {op_info.schema}" + ) from e + + def _dispatch_get_local_results_slow_path( + self, + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + op_info: OpInfo, + ) -> object: + output_sharding = op_info.output_sharding + assert output_sharding is not None, "output sharding should not be None" + + mesh = op_info.compute_mesh + participating = mesh.get_coordinate() is not None + local_results = None + if participating: + # computation that happens in the current rank of the mesh, normal case + if output_sharding.needs_redistribute: + # If sharding propagation decision needs redistribute, perform redistribute + # on args first, which could potentially modify args (i.e. allgather certain arg) + assert output_sharding.redistribute_schema is not None + self.redistribute_local_args( + op_info, + output_sharding.redistribute_schema, + output_sharding.use_val_from_redistribute_schema, + ) + + local_tensor_args = ( + pytree.tree_unflatten( + cast(list[object], op_info.local_args), + # pyrefly: ignore [bad-argument-type] + op_info.args_tree_spec, + ) + if op_info.args_tree_spec + else op_info.local_args + ) + + # run local op computation with potentially modified args/kwargs + local_tensor_args = cast(tuple[object, ...], local_tensor_args) + if op_call in self._random_ops: + if not random._rng_tracker and is_rng_supported_mesh(mesh): + # Default to `OffsetBasedRNGTracker` if the parallelism API + # did not already construct one + random._rng_tracker = random.OffsetBasedRNGTracker(mesh) + + first_arg, first_local_arg = ( + cast(dtensor.DTensor, args[0]), + cast(torch.Tensor, local_tensor_args[0]), + ) + + # If the user provided a generator, we hook it up to our RNG manager, but we also pop it from kwargs + # so the op_call does not directly use it (we want op_call to fall back to the 'default' which is + # our RNG manager) + maybe_user_generator = op_info.local_kwargs.pop("generator", None) + assert maybe_user_generator is None or isinstance( + maybe_user_generator, torch.Generator + ) + # maybe_user_generator = None + rng_context = ( + random._rng_tracker._distribute_region( + first_arg._spec, generator=maybe_user_generator + ) + if random._rng_tracker and not first_local_arg.is_meta + else contextlib.nullcontext() + ) + # For DTensor random operator, run it within a RNGTracker context to + # ensure the random number generator is properly distributed. + with rng_context: + local_results = op_call(*local_tensor_args, **op_info.local_kwargs) + else: + # normal case, run local sharded op computation + local_results = op_call(*local_tensor_args, **op_info.local_kwargs) + + else: + # For a non-participating device (happens on rank that does not belong to + # the device mesh), we do: + # 1. if the return type is scalar, set the local result to None. + # 2. if the return type is Tensor or List[Tensor], return empty + # tensor(s) with correct dtype. + spec = output_sharding.output_spec + ret_list = op_call._schema.returns + + if spec is None: + # For a scalar return type, the non-participating device has None + # as its local result + local_results = None + else: + + def default_tensor(spec: DTensorSpec) -> torch.Tensor: + if spec.tensor_meta is not None: + shape = spec.tensor_meta.shape + dtype = spec.tensor_meta.dtype + if len(shape) == 0: + # scalar tensor + return torch.zeros((), dtype=dtype) + else: + # non-scalar tensor + return torch.tensor([], dtype=dtype) + else: + raise RuntimeError(f"{spec} has no tensor metadata.") + + if isinstance(spec, DTensorSpec): + # return a Tensor value + local_results = default_tensor(spec) + elif isinstance(spec, Sequence): + # return a List[Tensor] value + local_results = [ + default_tensor(s) if s is not None else None for s in spec + ] + assert isinstance(local_results, list) + if None in local_results: + ret_type = str(ret_list[0].type) + raise NotImplementedError( + f"return type {ret_type} in DTensor op is not supported" + ) + return local_results + + def _dispatch_fast_path_python_tail( + self, + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], + compute_mesh: DeviceMesh, + output_sharding: OutputSharding, + local_results: object, + participating: bool, + is_inplace_op: bool, + is_out_variant_op: bool, + ) -> object: + """ + Tail of main dispatching logic, called from C++ fast path. + """ + + if output_sharding.output_spec is None: + if op_call == aten.equal.default: + # The output of the equal op is a bool, by converting it into a + # a single value tensor, we can use all-reduce with min reduce op + # to simulate logical and. + assert local_results is None or isinstance(local_results, bool) + r = torch.tensor( + int(local_results) if local_results is not None else 1, + device=compute_mesh.device_type, + ) + dist.all_reduce(r, op=dist.ReduceOp.MIN) + local_results = bool(r.item()) + + if is_inplace_op: + # inplace op should return self instead of re-wrapping + if output_sharding.output_spec is not None: + output_spec = output_sharding.output_spec + assert isinstance(output_spec, DTensorSpec) + assert isinstance(args[0], dtensor.DTensor) + + # NOTE: aten.squeeze_.dim is an inplace op but it also may change + # the inplace argument's tensor meta. Here we choose to special case + # this op because as far as I know this is the only inplace op that + # has such as behavior. We can extend this special case if necessary. + if op_call == aten.squeeze_.dim: + # update the spec to handle tensor meta changes + args[0]._spec = output_spec + # use return_and_correct_aliasing to match the outer and the inner + # aliasing. See https://github.com/pytorch/pytorch/pull/158954 + return return_and_correct_aliasing(op_call, args, kwargs, args[0]) + else: + # For all other inplace ops, check if placement changes are required + # Inplace operations that change placement are not supported because + # they would require redistribution, which breaks aliasing semantics. + # If there are views into the tensor, the views would not be updated. + if args[0]._spec.placements != output_spec.placements: + raise RuntimeError( + f"{op_call}: in-place operations that require placement changes " + f"are not supported. The operation would change placement from " + f"{args[0]._spec.placements} to {output_spec.placements}, " + f"which requires redistribution and breaks aliasing semantics. " + f"Please use the out-of-place version of this operation instead." + ) + # Most inplace ops don't change tensor meta, so no spec update needed + return args[0] + else: + return None + elif is_out_variant_op: + # out variant could possibly have multiple out args (i.e. lu_unpack.out) + output_specs = ( + (output_sharding.output_spec,) + if not isinstance(output_sharding.output_spec, tuple) + else output_sharding.output_spec + ) + out_dts = [] + spec_idx = 0 + for argument in op_call._schema.arguments: + if argument.is_out: + out_dt = cast(dtensor.DTensor, kwargs[argument.name]) + out_dt._spec = cast(DTensorSpec, output_specs[spec_idx]) + out_dts.append(out_dt) + spec_idx += 1 + + assert len(out_dts) >= 1, "out variant should have at least one out arg" + return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] + else: + assert op_call == aten.equal.default, op_call + ret = self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] + if participating and op_call._schema._is_view_op(): + return return_and_correct_aliasing(op_call, args, kwargs, ret) + else: + return ret + + @staticmethod + def redistribute_local_args( + op_info: OpInfo, + suggested_input_schema: OpSchema, + use_val_from_redistribute_schema: bool, + ) -> None: + debug_mode = get_active_debug_mode() + + # NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it + if op_info.args_tree_spec is not None: + flatten_args_schema_to_reshard = tuple( + pytree.tree_leaves(suggested_input_schema.args_schema) + ) + else: + flatten_args_schema_to_reshard = suggested_input_schema.args_schema + + new_local_args: list[object] = [] + for i, arg_spec in enumerate(op_info.flat_args_schema): + reshard_arg_spec = flatten_args_schema_to_reshard[i] + if isinstance(arg_spec, DTensorSpec): + local_tensor = cast(torch.Tensor, op_info.local_args[i]) + if arg_spec != reshard_arg_spec: + redistribute_context = ( + debug_mode.record_redistribute_calls( # type: ignore[union-attr] + i, arg_spec, reshard_arg_spec + ) + if debug_mode is not None + else contextlib.nullcontext() + ) + ExplicitRedistributionContext.observe_redistribution( + arg_spec, + # pyrefly: ignore [bad-argument-type] + reshard_arg_spec, + message=f"Implicit redistribution occurred for {op_info.schema} " + "while ExplicitRedistributionContext was active", + ) + with redistribute_context: + resharded_local_tensor = redistribute_local_tensor( + local_tensor, + arg_spec, + # pyrefly: ignore [bad-argument-type] + reshard_arg_spec, + ) + new_local_args.append(resharded_local_tensor) + else: + new_local_args.append(local_tensor) + else: + if use_val_from_redistribute_schema: + # args can be updated for view related ops, we refer to the + # update in redistribute_schema. + new_local_args.append(reshard_arg_spec) + else: + new_local_args.append(arg_spec) + + op_info.local_args = tuple(new_local_args) + + def unwrap_to_op_info( + self, + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> OpInfo: + return self._unwrap_to_op_info_impl(op_call, args, kwargs, True) + + def _unwrap_to_op_info_impl( + self, + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], + create_schema: bool, + ) -> OpInfo: + # get runtime schema info to determine whether to use pytree to flatten inputs + runtime_schema_info = self.sharding_propagator.op_to_schema_info.get( + op_call, None + ) + + if runtime_schema_info is not None and runtime_schema_info.needs_pytree: + # flatten args/kwargs when op says necessary + tree_args, args_spec = pytree.tree_flatten(args) + args_list: Sequence[object] = tree_args + else: + args_list, args_spec = args, None + + args_schema: list[object] = [] + kwargs_schema: dict[str, object] = {} + local_args: list[object] = [] + local_kwargs: dict[str, object] = {} + compute_mesh: DeviceMesh | None = None + + for arg in args_list: + if isinstance(arg, dtensor.DTensor): + local_args.append(arg._local_tensor) + args_schema.append(arg._spec) + if compute_mesh is None: + # record the first compute device mesh from args + compute_mesh = arg.device_mesh + elif isinstance(arg, torch.Tensor): + compute_mesh = compute_mesh or try_find_mesh_from_args( + op_call, args_list + ) + args_schema.append( + self._try_replicate_spec_for_scalar_tensor( + op_call, arg, compute_mesh + ) + ) + local_args.append(arg) + else: + # non DTensor/Tensor args (i.e. int/float/bool), just add to args_schema/local_args + args_schema.append(arg) + local_args.append(arg) + + for k, v in kwargs.items(): + if isinstance(v, dtensor.DTensor): + local_kwargs[k] = v._local_tensor + kwargs_schema[k] = v._spec + elif isinstance(v, torch.Tensor): + compute_mesh = compute_mesh or try_find_mesh_from_args( + op_call, args_list + ) + kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor( + op_call, + v, + # pyrefly: ignore [bad-argument-type] + compute_mesh, + ) + local_kwargs[k] = v + else: + # non DTensor/Tensor args (i.e. int/float/bool), just add to args_schema/local_args + kwargs_schema[k] = v + local_kwargs[k] = v + + assert compute_mesh is not None, ( + f"found no DeviceMesh from dtensor args for {op_call}!" + ) + op_info = OpInfo( + compute_mesh, + OpSchema( + op_call, + ( + # pyrefly: ignore [bad-argument-type] + pytree.tree_unflatten(args_schema, args_spec) + if args_spec + else tuple(args_schema) + ), + kwargs_schema, + schema_info=runtime_schema_info, + ) + if create_schema + else None, # type: ignore[arg-type] + args_schema, + tuple(local_args), + local_kwargs, + args_spec, + ) + return op_info + + @staticmethod + def wrap(res: object, spec: OutputSpecType) -> object: + if isinstance(res, torch.Tensor): + if spec is not None: + assert isinstance(spec, DTensorSpec), ( + f"output spec does not match with output! Expected DTensorSpec, got {spec}." + ) + # pyrefly: ignore [bad-argument-type, bad-argument-count, unexpected-keyword] + return dtensor.DTensor(res, spec, requires_grad=res.requires_grad) + else: + # if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor + assert res.ndim == 0, "output tensor should be scalar!" + return res + elif isinstance(res, (list, tuple)): + assert spec is not None and isinstance(spec, (list, tuple)), ( + f"output spec does not match with output! Expected list/tuple, got {spec}." + ) + res_list = [] + for e, s in zip(res, spec): + res_list.append(OpDispatcher.wrap(e, s)) + + return tuple(res_list) if isinstance(res, tuple) else res_list + else: + # if the res contains only non tensor values (i.e. int/float/none), we simply return it + # without rewrapping to DTensor. + return res + + def _try_replicate_spec_for_scalar_tensor( + self, + op_call: torch._ops.OpOverload, + tensor_arg: torch.Tensor, + compute_mesh: DeviceMesh, + ) -> DTensorSpec: + # util function to produce a replicate spec for a scalar tensor arg/kwarg + if tensor_arg.numel() == 1 and tensor_arg.ndim == 1: + warnings.warn( + "Found a non-scalar tensor with numel=1 and ndim!=0, " + "we are implicitly creating a replicated DTensor for it. " + "However, please consider changing it to a scalar tensor " + "or explicitly create a DTensor under distributed environment.", + stacklevel=2, + ) + + if tensor_arg.numel() == 1 or self._allow_implicit_replication: + # scalar tensor can be safely treated as replicated + replication_spec = DTensorSpec( + compute_mesh, + (Replicate(),) * compute_mesh.ndim, + tensor_meta=TensorMeta( + shape=tensor_arg.shape, + stride=tensor_arg.stride(), + dtype=tensor_arg.dtype, + ), + ) + else: + raise RuntimeError( + f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all" + " torch.Tensor to DTensor before calling distributed operators!" + " Please see https://docs.pytorch.org/docs/main/distributed.tensor.html#mixed-tensor-and-dtensor-operations" + " for more details." + ) + return replication_spec diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_dtensor_spec.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_dtensor_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..629bf104e11632beee286256d6f0f77609a289eb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_dtensor_spec.py @@ -0,0 +1,710 @@ +import itertools +import math +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, cast, NamedTuple, Optional + +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import ( + _StridedShard, + MaskPartial, + Partial, + Placement, + Replicate, + Shard, +) +from torch.utils._debug_mode import _stringify_shape +from torch.utils._dtype_abbrs import dtype_abbrs + + +class ShardOrderEntry(NamedTuple): + """ + Represents how a single tensor dimension is sharded across mesh dimensions. + + Attributes: + tensor_dim: The tensor dimension being sharded (e.g., 0, 1, 2 for a 3D tensor). + mesh_dims: Tuple of mesh dimensions across which this tensor dimension is sharded, + in execution order. The first mesh dim is applied first, second is applied + second, etc. This tuple is guaranteed to be non-empty. + + Examples: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DISTRIBUTED) + >>> # Tensor dim 1 sharded across mesh dim 2, then mesh dim 0 + >>> ShardOrderEntry(tensor_dim=1, mesh_dims=(2, 0)) + + >>> # Tensor dim 0 sharded only on mesh dim 1 + >>> ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)) + """ + + tensor_dim: int + mesh_dims: tuple[int, ...] # guaranteed to be non-empty + + +# Type alias for the complete shard order specification +# A tuple of ShardOrderEntry, one per sharded tensor dimension +# +# Example: +# shard_order = ( +# ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), +# ShardOrderEntry(tensor_dim=2, mesh_dims=(0, 3)), +# ) +# This means: +# - Tensor dimension 0 is sharded on mesh dimension 1 +# - Tensor dimension 2 is sharded on mesh dimension 0 first, then mesh dimension 3 +ShardOrder = tuple[ShardOrderEntry, ...] + + +class TensorMeta(NamedTuple): + # simple named tuple to represent tensor metadata + # intentionally to stay simple only for sharding + # propagation purposes. + shape: torch.Size + stride: tuple[int, ...] + dtype: torch.dtype + + +# used internally to propagate the placements +@dataclass +class DTensorSpec: + mesh: DeviceMesh + placements: tuple[Placement, ...] + + # tensor meta will only be set during sharding propagation + tensor_meta: TensorMeta | None = None + + # When a tensor dimension is sharded across multiple mesh axes, + # `shard_order` specifies the sequence in which these shardings are applied. + # This order determines how tensor shards are mapped and distributed across + # devices. + # + # Example: + # For a tensor of shape [8, 16] and a 3D device mesh, if dim 0 is sharded over + # mesh dim 1, and dim 1 is sharded over mesh dim 0 and then mesh dim 2, + # the shard_order would be: + # shard_order = ( + # ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), + # ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2)), + # ) + shard_order: ShardOrder = None # type: ignore[assignment] + + def __post_init__(self) -> None: + if not isinstance(self.placements, tuple): + self.placements = tuple(self.placements) + if self.shard_order is None: + # pyrefly: ignore [bad-assignment] + + _, self.shard_order = self._normalize_placements_into_shard_order( + self.placements, self.mesh + ) + self._hash: int | None = None + + @staticmethod + def _normalize_placements_into_shard_order( + placements: tuple[Placement, ...], mesh: DeviceMesh + ) -> tuple[tuple[Placement, ...], Optional[ShardOrder]]: + # If the returned shard_order is None, it means the StridedShard/Shard + # combinations can't be interpreted as shard order. + # If no _StridedShard in placements, we create default order. + if not any(isinstance(p, _StridedShard) for p in placements): + return placements, DTensorSpec.compute_default_shard_order(placements) + # _StridedShard in placements, try check if it can be decoded as shard order + shard_order = DTensorSpec._maybe_convert_StridedShard_to_shard_order( + placements, mesh + ) + if shard_order is not None: + normalized_placements = tuple( + [ + p if not isinstance(p, _StridedShard) else Shard(p.dim) + for p in placements + ] + ) + return normalized_placements, shard_order + # unable to decode placements to shard order(e.g., the _StridedShard is + # also used by `view` op shard propagation). + return placements, None + + @staticmethod + def compute_default_shard_order( + placements: tuple[Placement, ...], + ) -> ShardOrder: + """ + Compute the default shard order from placements. + + Returns a ShardOrder where each ShardOrderEntry maps a tensor dimension + to the mesh dimensions it's sharded on, in left-to-right order. + """ + # follow default left-to-right device order if shard_order is not specified + tensor_dim_to_mesh_dims: defaultdict[int, list[int]] = defaultdict(list) + mesh_ndim = len(placements) + for mesh_dim in range(mesh_ndim): + # shard_order doesn't work with _StridedShard + if isinstance(placements[mesh_dim], _StridedShard): + return () + if isinstance(placements[mesh_dim], Shard): + placement = cast(Shard, placements[mesh_dim]) + shard_dim = placement.dim + assert shard_dim >= 0, ( + f"Shard dim {shard_dim} in placements {placements} must be normalized" + ) + tensor_dim_to_mesh_dims[shard_dim].append(mesh_dim) + + # Convert dict into ShardOrderEntry tuples + default_shard_order = tuple( + ShardOrderEntry(tensor_dim=key, mesh_dims=tuple(value)) + for key, value in sorted(tensor_dim_to_mesh_dims.items()) + if value + ) + return default_shard_order + + @staticmethod + def _convert_shard_order_to_StridedShard( + shard_order: ShardOrder, placements: tuple[Placement, ...], mesh: DeviceMesh + ) -> tuple[Placement, ...]: + """ + Convert ShardOrder to placements with _StridedShard. + + This function converts a ShardOrder specification into a tuple of Placement objects, + using _StridedShard when a tensor dimension is sharded across multiple mesh dimensions + in a non-default order. The split_factor of each _StridedShard is determined by the + product of mesh dimension sizes that appear earlier in the shard order but later in + the placement tuple. + + Args: + shard_order: ShardOrder specification indicating which tensor dimensions are + sharded on which mesh dimensions and in what execution order. + placements: Tuple of Placement objects that does not contain _StridedShard. + mesh: DeviceMesh containing the size information for each mesh dimension. + + Returns: + Updated tuple of Placement objects with Shard or _StridedShard placements. + + Algorithm: + For each ShardOrderEntry in shard_order: + - For each mesh dimension in the entry's mesh_dims (in order): + - Calculate split_factor as the product of mesh sizes for all mesh dimensions + that appear: + 1. Earlier in the shard order (lower index in mesh_dims), and + 2. Later in the placement tuple (higher mesh dimension index) + - If split_factor == 1: use normal Shard + - Otherwise: use _StridedShard with the calculated split_factor + + Example: + >>> # xdoctest: +SKIP("Requires DeviceMesh") + >>> # Tensor dimension 0 sharded on mesh dims [2, 0, 1] in that order + >>> # mesh = DeviceMesh([4, 3, 2]) # sizes: mesh[0]=4, mesh[1]=3, mesh[2]=2 + >>> shard_order = (ShardOrderEntry(tensor_dim=0, mesh_dims=(2, 0, 1)),) + >>> placements = (Shard(0), Shard(0), Shard(0)) + >>> # For mesh_dim=2 (index 0 in mesh_dims): no earlier dims, split_factor=1 + >>> # -> placements[2] = Shard(0) + >>> # For mesh_dim=0 (index 1 in mesh_dims): mesh_dim=2 is earlier and has index 2>0 + >>> # -> split_factor = mesh.size(2) = 2 + >>> # -> placements[0] = _StridedShard(0, split_factor=2) + >>> # For mesh_dim=1 (index 2 in mesh_dims): mesh_dim=2 is earlier and has index 2>1 + >>> # -> split_factor = mesh.size(2) = 2 + >>> # -> placements[1] = _StridedShard(0, split_factor=2) + >>> # Result: (_StridedShard(0, sf=2), _StridedShard(0, sf=2), Shard(0)) + """ + placements_list = list(placements) + for entry in shard_order: + tensor_dim = entry.tensor_dim + mesh_dims = entry.mesh_dims + for idx in range(len(mesh_dims)): + # TODO(zpcore): split_factor from `view` and `shard order` + # should be able to be multiplied into one. Need to loosen the + # condition here. + mesh_dim = mesh_dims[idx] + if type(placements[mesh_dim]) is not Shard: + raise ValueError( + f"Only Shard placement can be converted to _StridedShard, " + f"found {placements[mesh_dim]} in {placements=}." + ) + split_factor = math.prod( + mesh.size(i) for i in mesh_dims[:idx] if i > mesh_dim + ) + if split_factor == 1: + # use normal Shard + placements_list[mesh_dim] = Shard(tensor_dim) + else: + placements_list[mesh_dim] = _StridedShard( + tensor_dim, split_factor=split_factor + ) + return tuple(placements_list) + + @staticmethod + def _maybe_convert_StridedShard_to_shard_order( + placements: tuple[Placement, ...], mesh: DeviceMesh + ) -> ShardOrder | None: + """ + Try to convert _StridedShard placements to ShardOrder. + + This is the inverse of `_convert_shard_order_to_StridedShard`. It reconstructs the shard + order by examining the split_factor of each _StridedShard and determining its position + in the execution order. If the _StridedShard configuration cannot be represented as a + valid ShardOrder (i.e., there's no shard order that produces the observed split_factors), + this function returns None. + + Args: + placements: Tuple of Placement objects that may contain _StridedShard. + mesh: DeviceMesh containing the size information for each mesh dimension. + + Returns: + ShardOrder if conversion is possible, None otherwise. For placements without + _StridedShard, returns the default shard order. + + Algorithm: + 1. If no _StridedShard in placements, return default shard order + 2. Create an empty list for each tensor dimension to represent mesh dim ordering + 3. Iterate through placements in reverse order (right to left): + - For each Shard/_StridedShard on a tensor dimension: + - Extract its split_factor (1 for Shard, split_factor for _StridedShard) + - Find the position in mesh_dims_order where accumulated_sf equals split_factor + - accumulated_sf is the product of mesh sizes of mesh dimensions that appear + earlier in mesh_dims_order (lower indices) + - Insert mesh_dim at the found position + 4. If no valid position found for any split_factor, return None (unable to convert) + 5. Construct ShardOrderEntry for each tensor dimension from mesh_dims_order + + Example: + >>> # xdoctest: +SKIP("Requires DeviceMesh") + >>> # mesh = DeviceMesh([4, 3, 2]) # sizes: mesh[0]=4, mesh[1]=3, mesh[2]=2 + >>> # placements = (_StridedShard(0, sf=2), _StridedShard(0, sf=2), Shard(0)) + >>> # Process tensor_dim=0 from right to left: + >>> # - mesh_dim=2: Shard(0) with sf=1 + >>> # Try position 0: accumulated_sf=1, matches! Insert at position 0 + >>> # Current mesh_dims_order order: [2] + >>> # - mesh_dim=1: _StridedShard(0, sf=2) with sf=2 + >>> # Try position 0: accumulated_sf=1, no match + >>> # Try position 1: accumulated_sf=1*mesh.size(2)=2, matches! Insert at position 1 + >>> # Current mesh_dims_order order: [2, 1] + >>> # - mesh_dim=0: _StridedShard(0, sf=2) with sf=2 + >>> # Try position 0: accumulated_sf=1, no match + >>> # Try position 1: accumulated_sf=1*mesh.size(2)=2, matches! Insert at position 1 + >>> # Final mesh_dims_order order: [2, 0, 1] + >>> # Result: ShardOrder((ShardOrderEntry(tensor_dim=0, mesh_dims=(2, 0, 1)),)) + >>> # This means: first shard on mesh_dim=2, then mesh_dim=0, then mesh_dim=1 + + Note: + This function validates that _StridedShard can be represented as a ShardOrder. + Not all _StridedShard configurations are valid - the split_factor must match + the product of mesh sizes in some execution order. + """ + if not any(isinstance(p, _StridedShard) for p in placements): + return DTensorSpec.compute_default_shard_order(placements) + max_tensor_dim = ( + max([i.dim for i in placements if isinstance(i, Shard | _StridedShard)]) + 1 + ) + shard_order = [] + + tensor_dim_to_mesh_dims_order: list[list[int]] = [ + [] for i in range(max_tensor_dim) + ] + for mesh_dim in reversed(range(len(placements))): + cur_placement = placements[mesh_dim] + # _StridedShard may not be a subclass of Shard in the future, so write in this way: + if isinstance(cur_placement, Shard | _StridedShard): + tensor_dim = cur_placement.dim + mesh_dims_order = tensor_dim_to_mesh_dims_order[tensor_dim] + cur_sf = 1 + if isinstance(cur_placement, _StridedShard): + cur_sf = cur_placement.split_factor + accumulated_sf = 1 + find_order = False + for i in range(len(mesh_dims_order) + 1): + if accumulated_sf == cur_sf: + mesh_dims_order.insert(i, mesh_dim) + find_order = True + break + if i < len(mesh_dims_order): + accumulated_sf *= mesh.size(mesh_dims_order[i]) + if not find_order: + # _StridedShard is not convertible to ShardOrder + return None + else: + if not isinstance(cur_placement, Replicate | Partial | MaskPartial): + raise ValueError( + f"Unsupported placement type {type(cur_placement)} encountered in " + f"{placements}; expected Replicate, Partial, or MaskPartial." + ) + for tensor_dim in range(max_tensor_dim): + if len(tensor_dim_to_mesh_dims_order[tensor_dim]) > 0: + shard_order.append( + ShardOrderEntry( + tensor_dim=tensor_dim, + mesh_dims=tuple(tensor_dim_to_mesh_dims_order[tensor_dim]), + ) + ) + return tuple(shard_order) + + def _verify_shard_order(self, shard_order: ShardOrder) -> None: + """Verify that the shard_order is valid and matches the placements.""" + total_shard = 0 + if any(isinstance(p, _StridedShard) for p in self.placements): + return + prev_tensor_dim = -1 + for entry in shard_order: + tensor_dim = entry.tensor_dim + mesh_dims = entry.mesh_dims + assert len(mesh_dims) > 0, f"shard_order {shard_order} has empty mesh dim" + assert tensor_dim >= 0, ( + f"shard_order {shard_order} has invalid tensor dim {tensor_dim}" + ) + assert tensor_dim > prev_tensor_dim, ( + "tensor dim should be sorted in shard_order" + ) + prev_tensor_dim = tensor_dim + total_shard += len(mesh_dims) + for mesh_dim in mesh_dims: + assert 0 <= mesh_dim < len(self.placements), ( + f"shard_order {shard_order} has invalid mesh dim {mesh_dims}" + ) + assert self.placements[mesh_dim] == Shard(tensor_dim), ( + f"placement[{mesh_dim}] doesn't have a matching shard in shard_order" + ) + assert total_shard == sum(1 for p in self.placements if isinstance(p, Shard)) + + def __setattr__(self, attr: str, value: Any) -> None: + if attr == "shard_order" and value is not None: + self._verify_shard_order(value) + super().__setattr__(attr, value) + # Make sure to recompute the hash in case any of the hashed attributes + # change (though we do not expect `mesh`, `placements` or `shard_order` + # to change) + if hasattr(self, "_hash") and attr in ( + "mesh", + "placements", + "tensor_meta", + "shard_order", + ): + self._hash = None + # This assert was triggered by buggy handling for dict outputs in some + # FX passes, where you accidentally iterate over a dict and try to put + # keys into TensorMeta. See https://github.com/pytorch/pytorch/issues/157919 + if attr == "tensor_meta" and value is not None: + from torch.fx.passes.shape_prop import TensorMetadata + + # TODO: the TensorMetadata arises from + # test/distributed/tensor/experimental/test_tp_transform.py::TensorParallelTest::test_tp_transform_e2e + # but I actually can't reproduce it, maybe it is also a bug! + assert isinstance(value, TensorMeta | TensorMetadata), value + + def _hash_impl(self) -> int: + # hashing and equality check for DTensorSpec are used to cache the sharding + # propagation results. We only need to consider the mesh, placements, shape + # dtype and stride. + # Caveat: we need to keep this in mind and sync hash and eq if we add more + # fields to them. + if self.tensor_meta is not None: + return hash( + ( + self.mesh, + self.placements, + self.shard_order, + self.tensor_meta.shape, + self.tensor_meta.stride, + self.tensor_meta.dtype, + ) + ) + return hash((self.mesh, self.placements, self.shard_order)) + + def __hash__(self) -> int: + # We lazily cache the spec to avoid recomputing the hash upon each + # use, where we make sure to update the hash when the `tensor_meta` + # changes by overriding `__setattr__`. This must be lazy so that Dynamo + # does not try to hash non-singleton `SymInt`s for the stride. + if self._hash is None: + self._hash = self._hash_impl() + return self._hash + + def _check_equals(self, other: object, skip_shapes: bool = False) -> bool: + if not ( + isinstance(other, DTensorSpec) + and self.mesh == other.mesh + and self.placements == other.placements + and self.shard_order == other.shard_order + ): + return False + if self.tensor_meta is None or other.tensor_meta is None: + return self.tensor_meta == other.tensor_meta + + if skip_shapes: + return self.tensor_meta.dtype == other.tensor_meta.dtype + return ( + self.tensor_meta.shape == other.tensor_meta.shape # type: ignore[union-attr] + and self.tensor_meta.stride == other.tensor_meta.stride # type: ignore[union-attr] + and self.tensor_meta.dtype == other.tensor_meta.dtype # type: ignore[union-attr] + ) + + def __eq__(self, other: object, /) -> bool: + return self._check_equals(other) + + def __str__(self) -> str: + """ + human readable representation of the DTensorSpec + """ + placement_str = self.format_shard_order_str(self.placements, self.shard_order) + if self.tensor_meta is not None: + tensor_shape = _stringify_shape(self.tensor_meta.shape) + tensor_dtype = dtype_abbrs[self.tensor_meta.dtype] + else: + tensor_shape = "unknown shape" + tensor_dtype = "unknown dtype" + + return f"Spec({tensor_dtype}{tensor_shape}({placement_str}))" + + @staticmethod + def is_default_device_order(shard_order: ShardOrder) -> bool: + """ + Check if the device order is the default left-to-right order. + """ + for entry in shard_order: + mesh_dims = entry.mesh_dims + is_increasing = all( + prev < nxt for prev, nxt in itertools.pairwise(mesh_dims) + ) + if not is_increasing: + return False + return True + + @staticmethod + def format_shard_order_str( + placements: tuple[Placement, ...], + shard_order: ShardOrder | None = None, + ) -> str: + """ + Format DTensor sharding information as a human-readable string. + + This method formats the sharding pattern in mesh-centric order, showing the placement + for each mesh dimension sequentially. When a tensor dimension is sharded across multiple + mesh dimensions, the order index indicates the execution sequence of the sharding operations. + + Args: + placements: Tuple of placement objects for each mesh dimension. + shard_order: Optional ShardOrder specifying the sharding order. + + Returns: + String representation of the sharding pattern in mesh-centric format. + + Example: + For a 3D tensor on a 2x2x2x2 mesh (16 devices) with:: + + placements = [Partial(), Shard(1), Shard(1), Replicate()] + shard_order = (ShardOrderEntry(tensor_dim=1, mesh_dims=(2, 1)),) + + Mesh configuration: + - mesh_dim_0: Partial reduction (sum) + - mesh_dim_1: Shard tensor dimension 1 (executed second, order index 1) + - mesh_dim_2: Shard tensor dimension 1 (executed first, order index 0) + - mesh_dim_3: Replicate + + Output: ``"PS(1)[1]S(1)[0]R"`` + + Explanation: + - ``P``: mesh dimension 0 has partial reduction + - ``S(1)[1]``: mesh dimension 1 shards tensor dimension 1 (order index 1 means second) + - ``S(1)[0]``: mesh dimension 2 shards tensor dimension 1 (order index 0 means first) + - ``R``: mesh dimension 3 replicates + + The format follows mesh dimension order (0, 1, 2, 3), and when a tensor dimension + is sharded across multiple mesh dimensions, the bracketed index shows the execution + order: ``[0]`` is executed first, ``[1]`` is executed second, etc. + """ + out_str = "" + # native dtensor-style sharding representation: map from mesh + # dim to tensor dim + for mesh_dim, placement in enumerate(placements): + if isinstance(placement, Shard): + if shard_order is not None: + for entry in shard_order: + tensor_dim = entry.tensor_dim + mesh_dims = entry.mesh_dims + + if placement.dim == tensor_dim: + assert mesh_dim in mesh_dims + if len(mesh_dims) > 1: + out_str += f"{placement}[{mesh_dims.index(mesh_dim)}]" + else: + # no need to show device order if the tensor dim is + # only sharded in one mesh dim + out_str += str(placement) + break + else: + out_str += str(placement) + else: + out_str += str(placement) + return out_str + + @property + def shape(self) -> torch.Size: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return self.tensor_meta.shape + + @property + def stride(self) -> tuple[int, ...]: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return self.tensor_meta.stride + + @property + def ndim(self) -> int: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return len(self.tensor_meta.shape) + + @property + def num_shards(self) -> int: + num_shards = 1 + for i, placement in enumerate(self.placements): + if placement.is_shard(): + num_shards *= self.mesh.size(i) + return num_shards + + @property + def device_mesh(self) -> DeviceMesh: + # simple aliasing for the mesh field, make some + # checks that mixes DTensor/DTensorSpec easier + return self.mesh + + @property + def dim_map(self) -> list[int]: + """ + dim_map is a property we derive from `placements` of + the distributed tensor. It simply return a list of ints + where dim_map[i] denotes the sharding mapping to the mesh + dimension, and len(dim_map) == dist_tensor.ndim + dim_map[i] = -1: means tensor dim i replicate on mesh + dim_map[i] = j: means tensor dim i shard on mesh dim j + + For example, we have a dist tensor that have the shape of + [18, 20, 30], and device_mesh([0, 1, 2, 3]), placements: + [Shard(1)], the dim_map of this placement would be: + [-1, 0, -1]. This representation is pretty helpful during + sharding propagation where we could know exactly each + tensor dimension is sharded or not. + + Note that if placements contains `_Partial`, we have to + explicitly deal with it, so that when we create a DTensorSpec + with dim_map, we could properly record the pending sums. + """ + # dims mapping of dist tensor sharding + # return size of tensor ndim, -1 represent replicate + # and int >=0 represent shard on that device mesh dim + r = [-1] * self.ndim + for i, placement in enumerate(self.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + if r[shard_dim] > -1: + raise ValueError( + f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]}," + " DTensor operator implementation does not support things like hybrid" + " sharding strategies yet (i.e. [Shard(0), Shard(0)])" + ) + r[shard_dim] = i + return r + + @property + def num_shards_map(self) -> list[int]: + """ + dim_map is a property we derive from `placements` of + the distributed tensor. Unlike `dim_map`, `num_shards_map` + denotes how many shards each tensor dim has. Like `dim_map`: + len(num_shards_map) == dist_tensor.ndim + num_shards_map[i] = 1: means tensor dim i is not sharded + num_shards_map[i] = j: means tensor dim i has j shards in total + + For example, we have a dist tensor of shape [18, 20, 30], + a device_mesh ([[0, 1, 2, 3], [4, 5, 6, 7]]), and placements + ([Shard(1), Shard(0)]), the num_shards_map of this distributed tensor + would be: [4, 2, 1]. + """ + r = [1] * self.ndim + for i, placement in enumerate(self.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + r[shard_dim] *= self.mesh.size(i) + + return r + + @property + def sums(self) -> list[int]: + """ + sums is a property we derive from `placements` of the + distributed tensor. It simply return a list of ints where + sums[i] denotes the pending sum (partial) on mesh dim i + """ + return [ + idx + for idx, placement in enumerate(self.placements) + if placement.is_partial() + ] + + @classmethod + def from_dim_map( + cls, + mesh: DeviceMesh, + dim_map: list[int], + sums: list[int], + tensor_meta: TensorMeta | None = None, + ) -> "DTensorSpec": + """ + Construct a DTensorSpec from dim_map list and pending sum. + + Args: + mesh (class:`DeviceMesh`): device mesh to be used in the DTensorSpec + dim_map (List[int]): a list of integer that represents sharding on each + tensor dimension, see `dim_map` property doc for details + sums (List[int]): a list of integer that represents the dist tensor have + pending sum on which device mesh dimension. + tensor meta (TensorMeta): DTensor metadata + + Return: + a class:`DTensorSpec` object + """ + # by default replicate on device mesh dims + placements: list[Placement] = [Replicate() for _ in range(mesh.ndim)] + + # find all mesh dims that need pending reductions + for s in sums: + placements[s] = Partial() + + for i, m in enumerate(dim_map): + if m >= 0: + placement = placements[m] + if placement.is_shard(): + placement = cast(Shard, placement) + raise RuntimeError( + f"DeviceMesh dimension can't be mapped to two dimension of the same tensor: {i} and {placement.dim}" + ) + elif placement.is_partial(): + raise RuntimeError( + f"DeviceMesh dimension {m} cannot be both shard and partial!" + ) + placements[m] = Shard(i) + + return cls(mesh, tuple(placements), tensor_meta=tensor_meta) + + def is_replicated(self) -> bool: + """ + return True if the current DTensorSpec replicates on all mesh dims (devices) + """ + return all(placement.is_replicate() for placement in self.placements) + + def is_sharded(self) -> bool: + """ + return True if the current DTensorSpec uses Shard() placement on any mesh dims (devices) + """ + return any(placement.is_shard() for placement in self.placements) + + def shallow_copy_with_tensor_meta( + self, tensor_meta: TensorMeta | None + ) -> "DTensorSpec": + """ + Shallow copy the DTensorSpec with a new tensor_meta. + """ + assert tensor_meta is not None, "shallow copy with no tensor_meta!" + return DTensorSpec( + self.mesh, + self.placements, + tensor_meta=tensor_meta, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_op_schema.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_op_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..4fec0293554ac1c0bb4031b91953386d3dc6d541 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_op_schema.py @@ -0,0 +1,612 @@ +# mypy: allow-untyped-defs +""" +DTensor operator schema definitions and utilities. + +This module defines the core data structures and utilities for describing and managing +distributed tensor operations in PyTorch's DTensor system. It provides the foundational +schema types used for sharding propagation, operator strategy selection, and distributed +execution planning. + +Key components: +- OpSpec: Describes acceptable sharding placements for operations +- OpStrategy: Represents the possible sharding strategies for an operator +- TupleStrategy: Container for multiple strategies when ops have tuple/list of tensors input +- OpSchema: Describes operator input/output schemas with DTensorSpecs +- OutputSharding: Manages output sharding specifications and redistribution +- RuntimeSchemaInfo: Runtime execution metadata for operators +- OpInfo: Complete runtime operator execution information + +These schema definitions enable the DTensor system to: +1. Propagate tensor sharding information to the operator outputs +2. Greedily select sharding strategies for distributed operations +3. Plan and execute tensor redistributions when needed +4. Cache sharding decisions for performance optimization +""" + +from collections.abc import Sequence +from dataclasses import dataclass +from functools import cached_property +from typing import Any +from typing_extensions import deprecated + +import torch +from torch._C import ( + _DTensor_OpSchema_post_init, + _DTensor_OpSchema_recompute_comparison_key, +) +from torch._ops import OpOverload +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import Placement + + +try: + from torch.utils._cxx_pytree import ( + register_pytree_node, + tree_leaves, + tree_map_only, + TreeSpec, + ) +except ImportError: + from torch.utils._pytree import ( # type: ignore[no-redef, assignment] + register_pytree_node, + tree_leaves, + tree_map_only, + TreeSpec, + ) + + +# Common type aliases +ArgsType = tuple[object, ...] +KwargsType = dict[str, object] + +PlacementList = list[Placement | None] + +# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type should +# be the same set of possibilities. +OutputSpecType = DTensorSpec | Sequence[DTensorSpec | None] | None + + +def _rebuild_tensor_from_dtensor_meta(arg) -> object: + """ + This is used to propagate tensor metadata, must be under fake mode + """ + assert arg.tensor_meta is not None, "DTensorSpec does not contain tensor_meta." + return torch.empty_strided( + arg.tensor_meta.shape, + arg.tensor_meta.stride, + dtype=arg.tensor_meta.dtype, + ) + + +def _pretty_print_spec(spec: object) -> str: + if spec is None: + return "None" + elif isinstance(spec, DTensorSpec): + return "".join([str(p) for p in spec.placements]) + elif isinstance(spec, Sequence): + return "(" + ", ".join([_pretty_print_spec(s) for s in spec]) + ")" + else: + raise RuntimeError(f"Unknown spec type to print: spec={spec}") + + +@dataclass +class OpSpec: + """ + An OpSpec describes an acceptable sharding placements of an operation, with the + specified DTensorSpecs for both the output and the inputs. + + note: when the op return value is a single DTensor object, output_specs is + DTensorSpec; when the return value is a tuple of Optional[DTensor], + output_specs is a tuple of Optional[DTensorSpec]. + + note: we MUST produce an DTensorSpec for every output that is a Tensor. None + entries only occur for non-Tensor outputs (e.g., operators that return Optional[Tensor], + or non-Tensor outputs.) + + invariant: the DeviceMesh on all DTensorSpec must be the same + """ + + # output_specs and input_specs are related: for this op, given these input_specs, + # this is the way the output would look + output_specs: DTensorSpec | tuple[DTensorSpec | None, ...] + input_specs: Sequence[DTensorSpec] | None = None + + """ + redistribute_cost tells how expensive it is to redistribute a given input into the + placement specified in this OpSpec. + + outer list: one entry (list) per (tensor) input in the op's arg schema + inner list: one entry (cost value) per possible sharding spec for that input + + Example: + ------- + another_op() -> tensor_a # another_op produces the output that becomes our first input + my_op(tensor_a) + + Let's assume this OpSpec's input_specs are [Replicate()], + but another_op() supports 2 strategies (OpSpecs) which produce outputs of + Replicate() + Shard(0) + + In this example, redistribute_costs would look like this + [ + # one row representing "my_op's first input" (tensor_a) + [ + # two entries, one for each strategies supported by another_op + 0.0, # cost of redistributing tensor_a from 'Replicate()' + K, # cost of redistributing tensor_a from 'Shard(0)' + ], + """ + redistribute_cost: list[list[float]] | None = None + + @cached_property + def output_spec(self) -> DTensorSpec: + """ + This function requires that the strategy have exactly one DTensorSpec as the + output spec. If the output_specs is a tuple, we throw an exception. + """ + if isinstance(self.output_specs, DTensorSpec): + return self.output_specs + else: + raise ValueError( + f"function output_spec expects a single DTensorSpec but got: {self.output_specs}" + ) + + @cached_property + def mesh(self): + if isinstance(self.output_specs, DTensorSpec): + return self.output_specs.mesh + elif isinstance(self.output_specs, tuple): + out_spec = self.output_specs[0] + assert isinstance(out_spec, DTensorSpec) + return out_spec.mesh + else: + raise ValueError( + f"function output_spec expects a single DTensorSpec or a tuple of DTensorSpec but got: {self.output_specs}" + ) + + def input_spec(self, index: int = 0) -> DTensorSpec: + assert self.input_specs is not None, "input_specs of OpSpec is None!" + assert len(self.input_specs) > index, ( + f"Invalid index {index} for input_specs of length " + f"{len(self.input_specs)}: {self.input_specs}" + ) + return self.input_specs[index] + + def __str__(self) -> str: + if self.input_specs is not None: + input_specs_str = f"{_pretty_print_spec(self.input_specs)} -> " + else: + input_specs_str = "" + output_spec_str = _pretty_print_spec(self.output_specs) + return f"{input_specs_str}{output_spec_str}" + + +class StrategyType: + """ + Base class type for op strategy, We have two StrategyType: + OpStrategy and TupleStrategy + """ + + +class OpStrategy(StrategyType): + """ + OpStrategy that consists of a list of sharding strategies associated with the op, + where each strategy is an OpSpec that describes the acceptable input/output sharding. + + invariant: the DeviceMesh on all OpSpec must be the same + """ + + def __init__(self, strategies: list[OpSpec]) -> None: + super().__init__() + self.strategies: list[OpSpec] = strategies + + def __str__(self) -> str: + strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies]) + mesh_shape = self.mesh_shape + return f"OpStrategy[{strategy_list_str}] @ mesh: {mesh_shape}" + + def max_num_shards(self) -> int: + """ + Returns the max number of shards across all OpSpecs + """ + return max(strategy.output_spec.num_shards for strategy in self.strategies) + + @property + def mesh(self): + return self.strategies[0].mesh + + @property + def mesh_shape(self): + return self.strategies[0].mesh.shape + + @property + def ndim(self): + return self.strategies[0].output_spec.ndim + + @property + def shape(self): + return self.strategies[0].output_spec.shape + + +class TupleStrategy(StrategyType): + """ + TupleStrategy is a special case for operators that are fundamentally compound or batched such that some subset + of the inputs and outputs are completely unrelated to some other subset. + + Generally, foreach_* ops are the most common use-case for TupleStrategy, because they accept lists of inputs, + but operate independently on each input or tuple of zipped inputs. + + For example, [out_a, out_b] = torch.foreach_add([a, b], scalar): input a's sharding only affects out_a's sharding, + independent of b and out_b. + + An example of an operator that should NOT use TupleStrategy is torch.split. It produces a List[Tensor] + as its output, but the sharding decision of one output is bound together with the decision + of each other output and the common input. + """ + + def __init__( + self, + children: Sequence[StrategyType], + ) -> None: + super().__init__() + self.children: Sequence[StrategyType] = children + + @property + @deprecated( + "TupleStrategy.childs is deprecated, use TupleStrategy.children instead.", # codespell:ignore childs + category=FutureWarning, + ) + def childs(self) -> Sequence[StrategyType]: # codespell:ignore childs + """ + Alias for children, to maintain backward compatibility. + """ + return self.children + + def child_mesh(self, index: int) -> DeviceMesh: + op_strategy = self.children[index] + assert isinstance(op_strategy, OpStrategy) + return op_strategy.mesh + + def __str__(self) -> str: + child_strategies_str = ", ".join( + [f"{str(strat)}" for idx, strat in enumerate(self.children)] + ) + return f"TupleStrategy({child_strategies_str})" + + +try: + register_pytree_node( + TupleStrategy, + lambda node: (node.children, None), + lambda children, _: TupleStrategy(tuple(children)), + ) +except ValueError: + # already registered TupleStrategy, skip + pass + + +@dataclass +class RuntimeSchemaInfo: + """ + RuntimeSchemaInfo stores the operator schema related information for runtime (eager) + execution. This is mainly used for two ways: 1. to generate hash for args to determine + whether to re-run sharding prop or not 2. to determine if we need pytree + """ + + # This static_argnum records static arg "starting index" for ops that have non-tensor + # args/kwargs which would affect sharding propagation results. All args starting from + # this index would be hashed to our sharding cache. + # Note that only a few ops need this information, e.g. view, transpose, var.dim, etc. + static_argnum: int = 100 + # This static_kwargkey records static kwarg names which would affect sharding prop + static_kwargkey: list[str] | None = None + # each op can decide if it wants to use pytree flatten/unflatten during operator + # eager execution, by default we don't need to do flatten/unflatten, only if the + # op indicate it needs to, this is to accelerate eager performance. + needs_pytree: bool = False + + +@dataclass +class OpSchema: + """ + OpSchema is a data class that describes an operator input schemas, it includes + DTensorSpecs/OpStrategies (instead of DTensor) and non-tensor args/kwargs (positional + order preserved). It is mainly used by the DTensor's dispatching logic to perform various + actions (i.e. sharding propagation, caching sharding decisions, redistribute, etc.) + + NOTE: this must be used as a read only data class + TODO: make this a frozen dataclass + + Args: + op: the operator overload we are intercepting + args_schema: contains args except that the DTensor args have been replaced + with its DTensorSpec or OpStrategy + kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced + with its DTensorSpec or OpStrategy + """ + + op: OpOverload + args_schema: ArgsType + kwargs_schema: KwargsType + + schema_info: RuntimeSchemaInfo | None = None + + _comparison_key: tuple[object, ...] | None = None + + @property + def args_spec(self) -> tuple[DTensorSpec, ...]: + """ + args_spec: Tuple[DTensorSpec, ...]: contains a clean list of args spec list + with NO non-DTensor positional arguments (i.e. int/float/tuple, etc) + mainly used by sharding propagation to propagate the output spec + """ + args = ( + tree_leaves(self.args_schema) + if self.schema_info is not None and self.schema_info.needs_pytree + else self.args_schema + ) + return tuple(item for item in args if isinstance(item, DTensorSpec)) + + @property + def args_strategy(self) -> tuple[OpStrategy, ...]: + # filter out non-relevant values from args schema to get a clean OpStrategy list + # separate with args_spec for the ease of type annotation + # TODO: see if we should merge this with args_spec + args = ( + tree_leaves(self.args_schema) + if self.schema_info is not None and self.schema_info.needs_pytree + else self.args_schema + ) + return tuple(item for item in args if isinstance(item, OpStrategy)) + + @property + def kwargs_strategy(self) -> tuple[OpStrategy, ...]: + # returns OpStrategy items from kwargs_schema. + kwargs_vals = ( + tree_leaves(self.kwargs_schema) + if self.schema_info is not None and self.schema_info.needs_pytree + else self.kwargs_schema.values() + ) + return tuple(item for item in kwargs_vals if isinstance(item, OpStrategy)) + + def __repr__(self) -> str: + args_schema = ", ".join([str(arg_schema) for arg_schema in self.args_schema]) + return ( + f"OpSchema(op={self.op}," + f" args_schema=({args_schema})," + f" kwargs_schema={self.kwargs_schema})" + ) + + def __str__(self) -> str: + args_schema: list[str] = [] + device_mesh = None + + for arg in self.args_schema: + if isinstance(arg, DTensorSpec): + args_schema.append(str(arg)) + device_mesh = arg.mesh + elif isinstance(arg, OpStrategy): + assert len(arg.strategies) == 1 + args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs)) + device_mesh = arg.mesh + elif isinstance(arg, TupleStrategy): + first_op_strategy = arg.children[0] + assert isinstance(first_op_strategy, OpStrategy) + device_mesh = first_op_strategy.mesh + args_schema.append(str(arg)) + else: + args_schema.append(str(arg)) + + return f"{self.op}({', '.join(args_schema)}) on {device_mesh})" + + def __post_init__(self) -> None: + _DTensor_OpSchema_post_init(self) + + def arg_type_tensor_or_tensor_list_like(self, arg: object) -> bool: + is_tensor = isinstance(arg, DTensorSpec) + if is_tensor: + return True + + if not isinstance(arg, list): + return False + + return all(isinstance(e, DTensorSpec) or e is None for e in arg) + + def return_type_tuple_tensor_like(self) -> bool: + # all dispatch ops could only return Tuple[Tensor] or have None/ints/floats + # in the tuple, but the first element must be a Tensor, so this check is enough + return_types = self.op._schema.returns + return len(return_types) > 1 and isinstance( + return_types[0].type, torch.TensorType + ) + + def return_type_list_tensor_like(self) -> bool: + # returns True if the return type is a List + return_types = self.op._schema.returns + return len(return_types) == 1 and isinstance( + return_types[0].type, torch.ListType + ) + + def return_type_tensor(self) -> bool: + return_types = self.op._schema.returns + # all dispatch ops only return Tensor or Tuple[Tensor] for tensor like + # return types, so this check is enough for tensor like types + return isinstance(return_types[0].type, torch.TensorType) + + def get_mesh_from_args(self, validate: bool = True) -> DeviceMesh: + """ + This util can be used to get a mesh from the OpSchema that contains multiple + DTensors as arguments. When `validate` is True, it will try to validate that all the + arguments have the same mesh to avoid unexpected cross mesh errors. + + NOTE: this util currently does not handle TupleStrategy when `validate=True`, + this is because for TupleStrategy there could be different types of checks, i.e.: + - for stack and cat like op, we need to check within a TupleStrategy is every + input is on the same mesh + - for foreach like ops we need to check "zipped" inputs are on the same mesh + for each index. + """ + first_arg = self.args_schema[0] + if isinstance(first_arg, (DTensorSpec, OpStrategy)): + mesh = first_arg.mesh + elif isinstance(first_arg, (list, tuple, TupleStrategy)): + first_elem = ( + first_arg.children[0] + if isinstance(first_arg, TupleStrategy) + else first_arg[0] + ) + assert isinstance(first_elem, (DTensorSpec, OpStrategy)) + mesh = first_elem.mesh + else: + raise ValueError(f"Cannot find device mesh from args for op : {self.op}.") + + if validate: + for arg in self.args_schema[1:]: + if isinstance(arg, (DTensorSpec, OpStrategy)) and arg.mesh != mesh: + raise RuntimeError( + f"DTensor does not support cross-mesh operation on {self.op}! " + f"Got meshes: {mesh} {arg.mesh}. " + f"Please make sure all the arguments have the same DeviceMesh." + ) + + return mesh + + def is_inplace_op(self) -> bool: + # simple analysis of function schema to determine + # if this is an inplace variant, it might not + # be entirely correct, but it's good enough for now. + return self.op._schema.name[-1] == "_" + + def is_out_variant_op(self) -> bool: + # simple analysis of function schema to determine + # if this is an out variant, it might not + # be entirely correct, but it's good enough for now. + return "out" in self.op._schema.overload_name + + def is_view_op(self) -> bool: + return self.op._schema._is_view_op() + + def _recompute_comparison_key(self) -> None: + _DTensor_OpSchema_recompute_comparison_key(self) + + def __hash__(self) -> int: + return hash(self._comparison_key) + + def __eq__(self, other: object) -> bool: + # early return checks + if not isinstance(other, OpSchema): + return False + + if self.op != other.op: + return False + + if len(self.args_schema) != len(other.args_schema): + return False + + return self._comparison_key == other._comparison_key + + def gen_fake_args(self) -> ArgsType: + """ + gen_fake_args: generate fake args for the operator, this is mainly used + by sharding propagation rules to generate fake args for the operator + to run the local tensor operator and get the output spec. + """ + return tree_map_only( + DTensorSpec, + _rebuild_tensor_from_dtensor_meta, + self.args_schema, + is_leaf=lambda x: isinstance(x, DTensorSpec), + ) + + def gen_fake_kwargs(self) -> KwargsType: + """ + gen_fake_kwargs: generate fake kwargs for the operator, this is mainly used + by sharding propagation rules to generate fake kwargs for the operator + to run the local tensor operator and get the output spec. + """ + return tree_map_only( + DTensorSpec, + _rebuild_tensor_from_dtensor_meta, + self.kwargs_schema, + is_leaf=lambda x: isinstance(x, DTensorSpec), + ) + + def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None: + suggestion_args_spec = self.args_spec + new_arg_schema: list[object] = [] + idx_of_args_spec = 0 + if ( + origin_schema.schema_info is not None + and origin_schema.schema_info.needs_pytree + ): + args_schema: Sequence[Any] = tree_leaves(origin_schema.args_schema) + else: + args_schema = origin_schema.args_schema + for arg in args_schema: + if isinstance(arg, DTensorSpec): + new_arg_schema.append(suggestion_args_spec[idx_of_args_spec]) + idx_of_args_spec += 1 + else: + new_arg_schema.append(arg) + self.args_schema = tuple(new_arg_schema) + self.kwargs_schema = origin_schema.kwargs_schema + self._recompute_comparison_key() + + +@dataclass +class OutputSharding: + """ + OutputSharding is a data class that is used by the sharding propagation, + it could set the output_spec upon successful propagation. If needs_redistribute + is set to True, a redistribute_schema would be returned together to indicate + the input arguments needs to be redistributed before the op execution. + + NOTE: the redistribute_schema generated by sharding propagation should be + exactly the same as the operator OpSchema, except the DTensorSpecs + """ + + # specifies the output sharding pattern + output_spec: OutputSpecType + # schema for redistribution if needed + redistribute_schema: OpSchema | None = None + # flag indicating if inputs need redistribution + needs_redistribute: bool = False + # flag to use values from `redistribute_schema` + use_val_from_redistribute_schema: bool = False + + @cached_property + def mesh(self): + if isinstance(self.output_spec, DTensorSpec): + return self.output_spec.mesh + elif isinstance(self.output_spec, tuple): + out_spec = self.output_spec[0] + if isinstance(out_spec, DTensorSpec): + return out_spec.mesh + else: + raise ValueError(f"Unknown output spec type: {type(out_spec)}") + else: + raise ValueError(f"Unknown output spec type: {type(self.output_spec)}") + + +@dataclass +class OpInfo: + """ + All Runtime Op execution info are packed here + """ + + # The first compute device mesh recorded from args + # NOTE: one op could have multiple meshes from its args. We just record the first + # mesh here to check if current rank should participate in computation or not. + compute_mesh: DeviceMesh + + # compete runtime operator infos + schema: OpSchema + flat_args_schema: list[object] + local_args: Sequence[object] + local_kwargs: dict[str, object] + args_tree_spec: TreeSpec | None = None + + # the output sharding info + output_sharding: OutputSharding | None = None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7cfaa668a18373df8576804a8cb730d8e030ad46 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from ._conv_ops import * # noqa: F403 +from ._embedding_ops import * # noqa: F403 +from ._math_ops import * # noqa: F403 +from ._matrix_ops import * # noqa: F403 +from ._pointwise_ops import * # noqa: F403 +from ._random_ops import * # noqa: F403 +from ._tensor_ops import * # noqa: F403 +from ._view_ops import * # noqa: F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_conv_ops.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_conv_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..1f456d505c12789f82e6c16aabf2692e871d3dfc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_conv_ops.py @@ -0,0 +1,127 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import OpSchema, OutputSharding +from torch.distributed.tensor._ops.registration import register_prop_rule + + +aten = torch.ops.aten + + +@register_prop_rule(aten.convolution.default) +def convolution_rules(op_schema: OpSchema) -> OutputSharding: + ( + input_spec, + weight_spec, + bias_spec, + stride, + padding, + dilation, + _transposed, + _output_padding, + _groups, + ) = op_schema.args_schema + + assert isinstance(input_spec, DTensorSpec) + assert isinstance(weight_spec, DTensorSpec) + # bias_spec can be None (optional parameter in aten.convolution schema) + if bias_spec is not None: + assert isinstance(bias_spec, DTensorSpec) + assert input_spec.tensor_meta is not None + assert weight_spec.tensor_meta is not None + in_shape = input_spec.tensor_meta.shape + weight_shape = weight_spec.tensor_meta.shape + assert isinstance(stride, list), f"stride must be list, got {type(stride)}" + assert isinstance(padding, list), f"padding must be list, got {type(padding)}" + assert isinstance(dilation, list), f"dilation must be list, got {type(dilation)}" + # weight_shape might not be torch.Size in all cases (e.g., SymIntArrayRef during tracing) + # so we don't assert its type, just use it + out_conv_shape = [ + (d + 2 * padding[i] - dilation[i] * (weight_shape[i + 1] - 1) - 1) // stride[i] + + 1 + for (i, d) in enumerate(in_shape[2:]) + ] + output_shape = [in_shape[0], weight_shape[0]] + out_conv_shape + output_stride = [1] + for i in range(1, len(output_shape)): + output_stride.insert(0, output_stride[0] * output_shape[-i]) + output_dim_map = input_spec.dim_map + pending_sums = input_spec.sums + + tensor_meta = TensorMeta( + torch.Size(output_shape), + tuple(output_stride), + input_spec.tensor_meta.dtype, + ) + return OutputSharding( + DTensorSpec.from_dim_map( + input_spec.mesh, + output_dim_map, + pending_sums, + tensor_meta=tensor_meta, + ) + ) + + +@register_prop_rule(aten.convolution_backward.default) +def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding: + input_spec = op_schema.args_schema[0] + ( + grad_output_spec, + input_spec, + weight_spec, + bias_shape_opt, + _stride, + _padding, + _dilation, + _transposed, + _output_padding, + _groups, + _output_mask, + ) = op_schema.args_schema + + assert isinstance(grad_output_spec, DTensorSpec) + assert isinstance(input_spec, DTensorSpec) + assert isinstance(weight_spec, DTensorSpec) + # bias_shape_opt can be None (optional parameter in aten.convolution_backward schema) + if bias_shape_opt is not None: + assert isinstance(bias_shape_opt, list) + assert input_spec.tensor_meta is not None + weight_tensor_meta = weight_spec.tensor_meta + + # Only create bias_tensor_meta if bias_shape_opt is not None + if bias_shape_opt is not None: + bias_tensor_meta = TensorMeta( + torch.Size(bias_shape_opt), + (1,), + input_spec.tensor_meta.dtype, + ) + else: + bias_tensor_meta = None + + grad_input_spec = input_spec + grad_weight_spec = DTensorSpec.from_dim_map( + input_spec.mesh, + [-1, -1, -1, -1], + [0], + tensor_meta=weight_tensor_meta, + ) + + # Only create grad_bias_spec if we have bias_tensor_meta + if bias_tensor_meta is not None: + grad_bias_spec = DTensorSpec.from_dim_map( + input_spec.mesh, + [-1], + [0], + tensor_meta=bias_tensor_meta, + ) + else: + grad_bias_spec = None + + # TODO: actually the output_mask is not respected here, we should + # set the corresponding spec to `None` if the output_mask is not `False` + # for a certain output Tensor. This also applies to the conv handler + # in torch/distributed/tensor/_tp_conv.py + return OutputSharding([grad_input_spec, grad_weight_spec, grad_bias_spec]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_einsum_strategy.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_einsum_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..9d46ede21f97bdf8539e73e14eab3a5697402d8e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_einsum_strategy.py @@ -0,0 +1,186 @@ +import itertools +from dataclasses import dataclass + +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import OpSpec, OpStrategy +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +@dataclass +class EinsumDims: + contracting_dims: list[str] + batch_dims: list[str] + lhs_out_only_dims: list[str] + rhs_out_only_dims: list[str] + + @classmethod + def parse_equation(cls, equation: str) -> tuple[list[str], str]: + # parse einop equation and extract arg specs + """ + Parse the einsum equation str to input dim chars and output dim char + """ + inputs, outputs = equation.split("->") + input_dims, output_dims = inputs.split(","), outputs.split(",") + + # NOTE: only support at most two inputs, and single output + # extend to support more inputs if needed in future + assert len(input_dims) <= 2, "Only support at most two inputs" + assert len(output_dims) == 1, "Only support single output" + output_dim = output_dims[0] + return input_dims, output_dim + + @classmethod + def parse_dims(cls, input_dims: list[str], output_dim: str) -> "EinsumDims": + """ + Parse the dims and extract the contracting, batch, and free dimensions + for the left and right hand sides. + """ + dim_char_set: set[str] = set() + for input_dim in input_dims: + dim_char_set.update(input_dim) + + # get a deterministic order of all dim chars + all_dim_chars = sorted(dim_char_set) + + # parse input and output dimensions + lhs_out_only_dims, rhs_out_only_dims = [], [] + batch_dims, contracting_dims = [], [] + + for dim_char in all_dim_chars: + if dim_char not in output_dim: + contracting_dims.append(dim_char) + else: + is_batch_dim = True + for input_dim in input_dims: + is_batch_dim = is_batch_dim and dim_char in input_dim + + if is_batch_dim: + batch_dims.append(dim_char) + else: + assert len(input_dims) == 2, ( + "free dimension only supported for two inputs!" + ) + lhs, rhs = input_dims + if dim_char in lhs: + lhs_out_only_dims.append(dim_char) + elif dim_char in rhs: + rhs_out_only_dims.append(dim_char) + else: + raise RuntimeError("Invalid dimension character") + + return cls( + contracting_dims=contracting_dims, + batch_dims=batch_dims, + lhs_out_only_dims=lhs_out_only_dims, + rhs_out_only_dims=rhs_out_only_dims, + ) + + +def gen_einsum_strategies( + equation: str, + mesh: DeviceMesh, + *, + linearity: bool = False, +) -> OpStrategy: + """ + Generate a strategy list for the ops that follow einsum style notation. + + In principle, each mesh dim is independent of other device mesh dim when we + generate strategies. So we generate strategy over each device mesh dim and + do product combination on all mesh dims. We basically follow the below rule + for each device mesh dim: + + 1. Shard on contracting dim: When both inputs shard on contracting dim over + the same device dim. The result will be Partial over that device dim. + + 2. Shard on noncontracting dim: + 2.1: Shard on batch dim: output, both inputs all should shard on batch + dim. + 2.2: Shard on lhs only dim or rhs only dim: both output and lhs or rhs + input should shard on this free dim. + + 3. Linearity (Partial): If enabled, set Partial on output and inputs over + the same device mesh dim. + """ + # parse einop equation and extract dims + input_dims, output_dim = EinsumDims.parse_equation(equation) + edims = EinsumDims.parse_dims(input_dims, output_dim) + all_mesh_dim_strategies = [] + + # generate strategies for each mesh dim and do cartesian product for final strategy. E.g., for a 2D mesh, we can have [P(),R,R] + strategies_over_one_mesh_dim = [] + + # placement list stores placements of [output, input1, input2, ...] + # first we always have replicate all for inputs and output + placement_list: list[Placement] = [Replicate()] * (len(input_dims) + 1) + strategies_over_one_mesh_dim.append(placement_list) + + # split batch dim + for batch_dim in edims.batch_dims: + output_batch_dim = output_dim.index(batch_dim) + placement_list = [Shard(output_batch_dim)] + for input_dim in input_dims: + input_batch_dim = input_dim.index(batch_dim) + placement_list.append(Shard(input_batch_dim)) + + strategies_over_one_mesh_dim.append(placement_list) + + # split contracting dim + for contracting_dim in edims.contracting_dims: + # Contracting dim can shard on same device axis for both inputs. This + # results in the output being Partial on that device axis. For example: + # bmk_{x},k_{x}n -> bmn{Ux} (becomes partial over device axis x) + placement_list = [Partial()] + for input_dim in input_dims: + input_contracting_dim = input_dim.index(contracting_dim) + placement_list.append(Shard(input_contracting_dim)) + + strategies_over_one_mesh_dim.append(placement_list) + + # split lhs free dim + for lhs_dim in edims.lhs_out_only_dims: + lhs_free_dim_output = output_dim.index(lhs_dim) + lhs_free_dim_input = input_dims[0].index(lhs_dim) + # this means split the lhs input and output + # i.e. S(0), R -> S(0) + lhs_placement_list: list[Placement] = [ + Shard(lhs_free_dim_output), + Shard(lhs_free_dim_input), + Replicate(), + ] + strategies_over_one_mesh_dim.append(lhs_placement_list) + + # split rhs free dim + for rhs_dim in edims.rhs_out_only_dims: + rhs_free_dim_output = output_dim.index(rhs_dim) + rhs_free_dim_input = input_dims[1].index(rhs_dim) + rhs_placement_list: list[Placement] = [ + Shard(rhs_free_dim_output), + Replicate(), + Shard(rhs_free_dim_input), + ] + strategies_over_one_mesh_dim.append(rhs_placement_list) + + # linearity strategy + if linearity: + linearity_placement_list: list[Placement] = [Partial()] + for _ in input_dims: + linearity_placement_list.append(Partial()) + strategies_over_one_mesh_dim.append(linearity_placement_list) + + # generate strategies for entire mesh + all_mesh_dim_strategies = [strategies_over_one_mesh_dim] * mesh.ndim + strategy_combs = itertools.product(*all_mesh_dim_strategies) + all_strategies = [] + for strategy_comb in strategy_combs: + spec_list = [DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)] + strat = OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:]) + all_strategies.append(strat) + + return OpStrategy(all_strategies) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_embedding_ops.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_embedding_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..b7c4abf353be5430c3ff827077b0c6226840bb40 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_embedding_ops.py @@ -0,0 +1,111 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor +from typing import cast + +import torch +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementList, + StrategyType, +) +from torch.distributed.tensor._ops.registration import register_op_strategy +from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy +from torch.distributed.tensor.placement_types import ( + MaskPartial, + Partial, + Replicate, + Shard, +) + + +aten = torch.ops.aten + + +@register_op_strategy(aten.embedding.default) +def embedding_strategy(op_schema: OpSchema) -> StrategyType: + """ + This strategy handles embedding op. We have two possible embedding shardings: + rowwise and colwise + """ + weight_strategy = cast(OpStrategy, op_schema.args_schema[0]) + indices_strategy = cast(OpStrategy, op_schema.args_schema[1]) + mesh = op_schema.get_mesh_from_args() + + weight_shape = weight_strategy.shape + indices_shape = indices_strategy.shape + output_emd_dim = len(indices_shape) + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, weight, input_indices] + # first we always have replicate all for inputs and output + all_replicate: PlacementList = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # colwise sharding, output shard on last dim, weight shard on dim 1, input replicate + colwise_sharding: PlacementList = [Shard(output_emd_dim), Shard(1), Replicate()] + single_mesh_dim_strategies.append(colwise_sharding) + + # rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial + embedding_partial_placement = MaskPartial(offset_shape=weight_shape, offset_dim=0) + + # NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates + # from the input indices and use it for output reduction + rowwise_sharding: PlacementList = [ + embedding_partial_placement, + Shard(0), + embedding_partial_placement, + ] + single_mesh_dim_strategies.append(rowwise_sharding) + + # batch dim sharding, weight replicated, input can shard on any dim, output follows input + for input_dim in range(len(indices_shape)): + batch_sharding: PlacementList = [ + Shard(input_dim), + Replicate(), + Shard(input_dim), + ] + single_mesh_dim_strategies.append(batch_sharding) + + return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies) + + +@register_op_strategy(aten.embedding_dense_backward.default) +def embedding_dense_backward_strategy(op_schema: OpSchema) -> StrategyType: + """ + This strategy handles embedding op. We have two possible embedding shardings: + rowwise and colwise + """ + grad_out_strategy = cast(OpStrategy, op_schema.args_schema[0]) + indices_strategy = cast(OpStrategy, op_schema.args_schema[1]) + mesh = op_schema.get_mesh_from_args() + + grad_out_shape = grad_out_strategy.shape + indices_shape = indices_strategy.shape + grad_out_ndim = len(grad_out_shape) + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, weight, input_indices] + # first we always have replicate all for inputs and output + all_replicate: PlacementList = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # colwise sharding backward, grad_out shard on last dim, input replicate, + # weight grad shard colwise + colwise_sharding: PlacementList = [Shard(1), Shard(grad_out_ndim - 1), Replicate()] + single_mesh_dim_strategies.append(colwise_sharding) + + # batch dim sharding, weight replicated, grad_out/input have same sharding + # that can shard on any dim, weight grad partial + for input_dim in range(len(indices_shape)): + batch_sharding: PlacementList = [Partial(), Shard(input_dim), Shard(input_dim)] + single_mesh_dim_strategies.append(batch_sharding) + + # grad_out partial, input replicate, weight grad keep partial + partial_sharding: PlacementList = [Partial(), Partial(), Replicate()] + single_mesh_dim_strategies.append(partial_sharding) + + return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_math_ops.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_math_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..23fbd92bf99e4c4e1e9cb10ebde39a1918dfc12b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_math_ops.py @@ -0,0 +1,1406 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import math +from collections.abc import Sequence +from dataclasses import dataclass +from enum import Enum +from typing import cast, Union + +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpSpec, + OpStrategy, + PlacementList, + RuntimeSchemaInfo, + TupleStrategy, +) +from torch.distributed.tensor._ops.registration import register_op_strategy +from torch.distributed.tensor._ops.utils import ( + as_list, + expand_to_full_mesh_op_strategy, + generate_redistribute_costs, + is_tensor_evenly_shardable, + is_tensor_evenly_shardable_on_dim, + normalize_dim, + normalize_dims, +) +from torch.distributed.tensor._utils import normalize_to_torch_size +from torch.distributed.tensor.placement_types import ( + _StridedShard, + Partial, + Placement, + Replicate, + Shard, +) + + +aten = torch.ops.aten + + +class Reduction(Enum): + NONE = 0 + MEAN = 1 + SUM = 2 + + +@dataclass(frozen=True) +class NormReduction: + norm_type: int | float | str + + +ReductionOpType = Union[NormReduction, str] + + +@dataclass(frozen=True) +class _NormPartial(Partial): + """ + This placement is used for partial vector norm. + + For p-norms (where p not inf or -inf), the p-norm over n elements computes + (sum_i x_i^p)^(1/p) + where the sum is from i=1 to n. The reduction op is the p-norm itself. + For example, consider 2 ranks, a (4,) tensor sharded on dim-0, and 2-norm: + Rank 0: [t1, t2] | Rank 1: [t3, t4] + After computing 2-norm per gradient (partial placement): + Rank 0: [sqrt(t1^2 + t2^2)] | Rank 1: [sqrt(t3^2 + t4^2)] + Converting from partial to replicate wants to ultimately get: + Rank 0/1: [sqrt(t1^2 + t2^2 + t3^2 + t4^2)] + This can be achieved by computing 2-norm on each rank's result. This holds + similarly for inf and -inf norm. For 0-norm, the reduction op is sum. + """ + + norm_type: int | float | str = 2 + + def __init__(self, norm_type: int | float | str = 2): + reduce_op = None + if norm_type in (float("inf"), "inf"): + reduce_op = "max" + elif norm_type in (float("-inf"), "-inf"): + reduce_op = "min" + elif isinstance(norm_type, (int, float)): + reduce_op = "sum" + else: + raise NotImplementedError(f"Unsupported norm type: {norm_type}") + + super().__init__(reduce_op) + object.__setattr__(self, "norm_type", norm_type) + + def _partition_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + """ + For example, consider 4 ranks, a (3,) replicated tensor, and 2-norm: + Ranks 0 and 1: sqrt(t1^2 + t2^2 + t3^3) + To convert from replicated to partial, we want f(x) such that + sqrt(t1^2 + t2^2 + t3^3) = sqrt(4f(t1)^2 + 4f(t2)^2 + 4f(t3)^2) + = sqrt(4) sqrt(f(t1)^2 + f(t2)^2 + f(t3)^2). + One such f(x) is f(x) = x / sqrt(4). This generalizes to d ranks and + p-norm as f(x) = x / d^(1/p). + """ + if self.reduce_op in ("max", "min"): + return tensor + elif self.reduce_op == "sum": + if self.norm_type == 0: + raise NotImplementedError(f"Unsupported norm type:: {self.norm_type}") + elif self.norm_type == 1: + return tensor / mesh.size(mesh_dim) + if not isinstance(self.norm_type, (int, float)): + raise AssertionError( + f"Expected int or float, got {type(self.norm_type)}" + ) + return tensor / math.pow(mesh.size(mesh_dim), 1 / self.norm_type) + raise NotImplementedError(self.reduce_op) + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + if not isinstance(shard_spec, Shard): + raise AssertionError(f"Expected Shard, got {type(shard_spec)}") + tensor = self._pre_reduce_transform(tensor) + reduced_tensor = super()._reduce_shard_value(tensor, mesh, mesh_dim, shard_spec) + return self._post_reduce_transform(reduced_tensor) + + def _reduce_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + tensor = self._pre_reduce_transform(tensor) + reduced_tensor = super()._reduce_value(tensor, mesh, mesh_dim) + return self._post_reduce_transform(reduced_tensor) + + def _pre_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor: + if self.reduce_op == "sum": + if not isinstance(self.norm_type, (int, float)): + raise AssertionError( + f"Expected int or float, got {type(self.norm_type)}" + ) + if self.norm_type != 0 and self.norm_type != 1: + # pyrefly: ignore [unsupported-operation] + return tensor**self.norm_type + return tensor + + def _post_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor: + if self.reduce_op == "sum": + if not isinstance(self.norm_type, (int, float)): + raise AssertionError( + f"Expected int or float, got {type(self.norm_type)}" + ) + if self.norm_type != 0 and self.norm_type != 1: + # pyrefly: ignore [unsupported-operation] + return tensor ** (1.0 / self.norm_type) + return tensor + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _NormPartial): + return False + return self.norm_type == other.norm_type + + def __hash__(self) -> int: + return 1 + hash(self.norm_type) + + def __repr__(self) -> str: + """ + machine readable representation of the _NormPartial placement + """ + return f"_NormPartial(reduce_op={self.reduce_op}, norm_type={self.norm_type})" + + def __str__(self) -> str: + """human readable representation of the _NormPartial placement""" + return f"_NormP({self.reduce_op}, {self.norm_type})" + + +def _infer_reduction_dims(dims_arg: object, ndim: int) -> list[int] | None: + if dims_arg is None: + return None + dims = cast(list[int], as_list(dims_arg)) + dims = cast(list[int], normalize_dims(dims, ndim)) + empty_dims = [[0], [-1], []] + if ndim == 0 and dims_arg in empty_dims: + return None + return dims + + +def _infer_reduce_dims_map( + reduction_dims: list[int], input_ndim: int, keep_dim=False +) -> list[int]: + reduction_dims_map = [] + new_dim_count = 0 + for input_dim in range(input_ndim): + if input_dim in reduction_dims and not keep_dim: + # if input dim in reduction dims, mark it as -1 + reduction_dims_map.append(-1) + else: + # otherwise mark it as the new dim + reduction_dims_map.append(new_dim_count) + new_dim_count += 1 + + return reduction_dims_map + + +def _replicate_dims_start_at( + placements: Sequence[Placement], start_dim: int = 0 +) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +# return new_placements which align with placements but skip the skipped_dim +def _skip_dim( + placements: tuple[Placement, ...], skipped_dim: int +) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if isinstance(p, Shard) and p.dim >= skipped_dim: + new_placements.append(Shard(p.dim - 1)) + else: + new_placements.append(p) + return tuple(new_placements) + + +def replicate_reduction_dims( + placements: tuple[Placement, ...], reduction_dims: list[int] +) -> tuple[Placement, ...]: + # replicate the reduction dims if not reduction_linear + new_placements: list[Placement] = [] + + for p in placements: + if p.is_partial(): + new_placements.append(Replicate()) + elif isinstance(p, Shard) and p.dim in reduction_dims: + new_placements.append(Replicate()) + else: + new_placements.append(p) + + return tuple(new_placements) + + +def map_placements_after_reduction( + placements: tuple[Placement, ...], + reduction_dims: list[int], + reduction_dims_map: list[int], + reduction_op: ReductionOpType, +) -> tuple[Placement, ...]: + """ + Map each placement based on the output shape after reduction. + """ + new_placements: list[Placement] = [] + for placement in placements: + if isinstance(placement, (Replicate, Partial)): + new_placements.append(placement) + else: + if not isinstance(placement, Shard | _StridedShard): + raise AssertionError( + f"Expected Shard/_StridedShard, got {type(placement)}" + ) + shard_dim = placement.dim + new_shard_dim = reduction_dims_map[shard_dim] + if new_shard_dim == -1 or shard_dim in reduction_dims: + # if new_shard_dim collapsed or its in the reduction dims + # (i.e. for the case where keepdims=True), we generate partial + new_placements.append(get_placement_from_reduction_op(reduction_op)) + else: + if isinstance(placement, Shard): + new_placements.append(Shard(new_shard_dim)) + else: + new_placements.append( + _StridedShard( + new_shard_dim, split_factor=placement.split_factor + ) + ) + return tuple(new_placements) + + +def get_placement_from_reduction_op(reduction_op: ReductionOpType) -> Placement: + if isinstance(reduction_op, NormReduction): + return _NormPartial(norm_type=reduction_op.norm_type) + return Partial(reduction_op) + + +def common_reduction_strategy( + input_strategy: OpStrategy, + reduce_dims: list[int], + keep_dim: bool = False, + reduction_linear: bool = True, + reduction_op: ReductionOpType = "sum", +) -> OpStrategy: + """ + reduction_linear means that the reduction `f` follows this rule: + f([f(a), f(b)]) = f([a, b]) + + reduction linear should be super set of linearity. + """ + # by default follow reduction input strategy + reduction_strategy = OpStrategy([]) + + for op_spec in input_strategy.strategies: + if reduction_op == "avg": + output_spec = op_spec.output_spec + local_shape = list(output_spec.tensor_meta.shape) # type:ignore[union-attr] + for dim in reduce_dims: + if not is_tensor_evenly_shardable_on_dim(local_shape, output_spec, dim): + # reduce(avg) is not linear for unevenly sharded tensors + reduction_linear = False + break + + for p in op_spec.output_spec.placements: + # when the partial reduction op matches the global reduction op, + # we can delay redistribution (i.e max, max) + if isinstance(p, Partial) and p.reduce_op != reduction_op: + reduction_linear = False + break + + if not reduction_linear: + # input placements for this strategy should clear out pending sum and sharding + # on the reduction dimension + input_placements = replicate_reduction_dims( + op_spec.output_spec.placements, reduce_dims + ) + else: + input_placements = op_spec.output_spec.placements + + input_spec = DTensorSpec( + mesh=input_strategy.mesh, + placements=input_placements, + tensor_meta=op_spec.output_spec.tensor_meta, + ) + + reduce_dims_map = _infer_reduce_dims_map(reduce_dims, input_spec.ndim, keep_dim) + out_placements = map_placements_after_reduction( + input_spec.placements, reduce_dims, reduce_dims_map, reduction_op + ) + redistribute_cost = [generate_redistribute_costs(input_strategy, input_spec)] + reduction_strategy.strategies.append( + OpSpec( + output_specs=DTensorSpec( + mesh=input_strategy.mesh, + placements=out_placements, + ), + input_specs=(input_spec,), + redistribute_cost=redistribute_cost, + ) + ) + + return reduction_strategy + + +LINEAR_REDUCTION_OP_MAP = { + aten.all.default: "product", + aten.all.dim: "product", + aten.sum.default: "sum", + aten.sum.dim_IntList: "sum", + aten.any.default: "sum", + aten.any.dim: "sum", + aten.any.out: "sum", + # These are only valid when there is no padding + aten.prod.default: "product", + aten.prod.dim_int: "product", + aten.prod.int_out: "product", + # avg is only linear when there is no padding + aten.mean.default: "avg", + aten.mean.dim: "avg", + aten.mean.out: "avg", + aten.max.default: "max", + aten.max.dim: "max", + aten.max.out: "max", + aten.min.default: "min", + aten.min.dim: "min", + aten.min.out: "min", + aten.amax.default: "max", + aten.amax.out: "max", + aten.amin.default: "min", + aten.amin.out: "min", + # argmax and argmin is using custom hanndler leveraging linear reduction of max and min + aten.argmax.default: "max", + aten.argmin.default: "min", +} + + +@register_op_strategy( + list(LINEAR_REDUCTION_OP_MAP.keys()), schema_info=RuntimeSchemaInfo(1) +) +def linear_reduction_strategy(op_schema: OpSchema) -> OpStrategy: + args_schema = op_schema.args_schema + input_strategy = args_schema[0] + if not isinstance(input_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}") + + dims = None + if len(op_schema.args_schema) > 1: + dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim) + + reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims + + keep_dim = len(op_schema.args_schema) > 2 and bool(op_schema.args_schema[2]) + reduction_op = LINEAR_REDUCTION_OP_MAP[op_schema.op] + return common_reduction_strategy( + input_strategy, + reduce_dims, + keep_dim=keep_dim, + reduction_linear=True, + reduction_op=reduction_op, + ) + + +@register_op_strategy(aten.cumsum.default, schema_info=RuntimeSchemaInfo(1)) +def cumsum_strategy(op_schema: OpSchema) -> OpStrategy: + args_schema = op_schema.args_schema + input_strategy = args_schema[0] + if not isinstance(input_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}") + dim = args_schema[1] + if not isinstance(dim, int): + raise AssertionError(f"Expected int, got {type(dim)}") + + return common_reduction_strategy( + input_strategy, [dim], keep_dim=True, reduction_linear=False + ) + + +@register_op_strategy( + [ + aten.std.correction, + aten.std.correction_out, + aten.var.correction, + aten.var.correction_out, + ], + schema_info=RuntimeSchemaInfo(1, ["keepdim"]), +) +def std_var_reduction_strategy(op_schema: OpSchema) -> OpStrategy: + args_schema = op_schema.args_schema + input_strategy = args_schema[0] + if not isinstance(input_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}") + dims = None + if len(op_schema.args_schema) > 1: + dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim) + + reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims + + keep_dim = cast(bool, op_schema.kwargs_schema.get("keepdim", False)) + return common_reduction_strategy( + input_strategy, reduce_dims, keep_dim=keep_dim, reduction_linear=False + ) + + +@register_op_strategy( + [aten.linalg_vector_norm.default], schema_info=RuntimeSchemaInfo(1) +) +def vector_norm_strategy(op_schema: OpSchema) -> OpStrategy: + args_schema = op_schema.args_schema + input_strategy = args_schema[0] + if not isinstance(input_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}") + + norm_type = args_schema[1] if len(args_schema) > 1 else 2 + if not isinstance(norm_type, (int, float, str)): + raise AssertionError(f"Expected int, float, or str, got {type(norm_type)}") + dim = args_schema[2] if len(args_schema) > 2 else None + keepdim = args_schema[3] if len(args_schema) > 3 else False + dims = _infer_reduction_dims(dim, input_strategy.ndim) + reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims + return common_reduction_strategy( + input_strategy, + reduce_dims, + keep_dim=cast(bool, keepdim), + reduction_linear=True, + reduction_op=NormReduction(norm_type), + ) + + +@register_op_strategy( + [aten._foreach_norm.Scalar], schema_info=RuntimeSchemaInfo(1, needs_pytree=True) +) +def foreach_norm_strategy(op_schema: OpSchema) -> TupleStrategy: + args_schema = op_schema.args_schema + input_tuple_strategy = args_schema[0] + if not isinstance(input_tuple_strategy, TupleStrategy): + raise AssertionError( + f"Expected TupleStrategy, got {type(input_tuple_strategy)}" + ) + norm_type = args_schema[1] if len(args_schema) > 1 else 2 + if not isinstance(norm_type, (int, float, str)): + raise AssertionError(f"Expected int, float, or str, got {type(norm_type)}") + output_tuple_strategy_children: list[OpStrategy] = [] + for op_strategy in input_tuple_strategy.children: + if not isinstance(op_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(op_strategy)}") + reduce_dims = list(range(op_strategy.ndim)) + output_strategy = common_reduction_strategy( + op_strategy, + reduce_dims, + reduction_linear=True, + reduction_op=NormReduction(norm_type), + ) + output_tuple_strategy_children.append(output_strategy) + return TupleStrategy(output_tuple_strategy_children) + + +@register_op_strategy( + [aten._foreach_max.default], schema_info=RuntimeSchemaInfo(1, needs_pytree=True) +) +def foreach_max_strategy(op_schema: OpSchema) -> TupleStrategy: + """ + Strategy for _foreach_max, which reduces each tensor in a list to its maximum value. + """ + args_schema = op_schema.args_schema + input_tuple_strategy = args_schema[0] + if not isinstance(input_tuple_strategy, TupleStrategy): + raise AssertionError( + f"Expected TupleStrategy, got {type(input_tuple_strategy)}" + ) + output_tuple_strategy_children: list[OpStrategy] = [] + for op_strategy in input_tuple_strategy.children: + if not isinstance(op_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(op_strategy)}") + # Reduce all dimensions to get a scalar + reduce_dims = list(range(op_strategy.ndim)) + output_strategy = common_reduction_strategy( + op_strategy, + reduce_dims, + reduction_linear=True, + reduction_op="max", + ) + output_tuple_strategy_children.append(output_strategy) + return TupleStrategy(output_tuple_strategy_children) + + +@register_op_strategy( + [ + aten._linalg_svd.default, + aten.linalg_qr.default, + # TODO: The diagonal ops can have an improved sharding strategy for + # shard placements that does not require redistributing to replicate. + aten.diagonal_copy.default, + aten.diag_embed.default, + aten.diag.default, + aten.diagonal.default, + aten.tril.default, + aten.triu.default, + aten._linalg_eigh.default, + aten.upsample_bicubic2d.default, + aten.upsample_bilinear2d.default, + aten.upsample_linear1d.default, + aten.upsample_nearest2d.default, + aten.upsample_trilinear3d.default, + # TODO: support the full F.interpolate set of options. + ], + schema_info=RuntimeSchemaInfo(1), +) +def linalg_replicate_strategy(op_schema: OpSchema) -> OpStrategy: + """ + Since we do not have a simple way to compute some linear algebra operations + like SVD or QR decomposition, always fall back to replicate. + """ + args_schema = op_schema.args_schema + input_strategy = args_schema[0] + if not isinstance(input_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}") + mesh = input_strategy.mesh + + output_strategies: list[OpSpec] = [] + for placement_strategy in input_strategy.strategies: + replicate_placements = tuple(Replicate() for _ in range(mesh.ndim)) + replicate_spec = DTensorSpec( + mesh=mesh, + placements=replicate_placements, + tensor_meta=placement_strategy.output_spec.tensor_meta, + ) + redistribute_cost = [ + generate_redistribute_costs(input_strategy, replicate_spec) + ] + replicate_strategy = OpSpec( + output_specs=replicate_spec, + input_specs=(replicate_spec,), + redistribute_cost=redistribute_cost, + ) + output_strategies.append(replicate_strategy) + return OpStrategy(output_strategies) + + +@register_op_strategy( + [aten._log_softmax.default, aten._softmax.default, aten._safe_softmax.default], + schema_info=RuntimeSchemaInfo(1), +) +def softmax_strategy(op_schema: OpSchema) -> OpStrategy: + input_strategy, softmax_dim, *_ = op_schema.args_schema + input_strategy = cast(OpStrategy, input_strategy) + + softmax_dim = cast(int, softmax_dim) + softmax_dim = normalize_dim(softmax_dim, input_strategy.ndim) + + output_strategy = OpStrategy([]) + for input_placement_strategy in input_strategy.strategies: + redistribute_costs = [] + input_src_spec = input_placement_strategy.output_spec + + # make sure input is replicated along the softmax dim + input_target_spec = DTensorSpec( + mesh=input_strategy.mesh, + placements=replicate_reduction_dims( + input_src_spec.placements, [softmax_dim] + ), + tensor_meta=input_src_spec.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_target_spec) + ) + output_target_spec = input_target_spec + output_strategy.strategies.append( + OpSpec( + output_specs=output_target_spec, + input_specs=[input_target_spec], + redistribute_cost=redistribute_costs, + ) + ) + + return output_strategy + + +@register_op_strategy( + [ + aten._log_softmax_backward_data.default, + aten._softmax_backward_data.default, + ], + schema_info=RuntimeSchemaInfo(2), +) +def softmax_backward_strategy(op_schema: OpSchema) -> OpStrategy: + grad_out_strategy, out_strategy, softmax_dim, _ = op_schema.args_schema + grad_out_strategy = cast(OpStrategy, grad_out_strategy) + out_strategy = cast(OpStrategy, out_strategy) + softmax_dim = cast(int, softmax_dim) + softmax_dim = normalize_dim(softmax_dim, grad_out_strategy.ndim) + + grad_in_strategy = OpStrategy([]) + for grad_out_placement_strat, out_placement_strat in zip( + grad_out_strategy.strategies, out_strategy.strategies + ): + # follow the sharding of the grad_out or out depending on which has more shards + grad_out_src_spec = grad_out_placement_strat.output_spec + out_src_spec = out_placement_strat.output_spec + src_spec = ( + grad_out_src_spec + if grad_out_src_spec.num_shards >= out_src_spec.num_shards + else out_src_spec + ) + + # make sure inputs are replicated along the softmax dim + tgt_spec = DTensorSpec( + mesh=grad_out_strategy.mesh, + placements=replicate_reduction_dims(src_spec.placements, [softmax_dim]), + ) + new_grad_out_spec = DTensorSpec( + mesh=tgt_spec.mesh, + placements=tgt_spec.placements, + tensor_meta=grad_out_src_spec.tensor_meta, + ) + new_out_spec = DTensorSpec( + mesh=tgt_spec.mesh, + placements=tgt_spec.placements, + tensor_meta=out_src_spec.tensor_meta, + ) + redist_grad_out_cost = generate_redistribute_costs(grad_out_strategy, tgt_spec) + redist_out_cost = generate_redistribute_costs(out_strategy, tgt_spec) + grad_in_strategy.strategies.append( + OpSpec( + output_specs=tgt_spec, + input_specs=(new_grad_out_spec, new_out_spec), + redistribute_cost=[redist_grad_out_cost, redist_out_cost], + ) + ) + + return grad_in_strategy + + +@register_op_strategy( + [aten.nll_loss_forward.default, aten.nll_loss2d_forward.default], + schema_info=RuntimeSchemaInfo(3), +) +def nll_loss_forward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + if not len(op_schema.args_schema) == 5: + raise AssertionError(f"Expected 5 args, got {len(op_schema.args_schema)}") + + ( + input_strategy, + target_strategy, + weight_strategy, + reduction, + _, + ) = op_schema.args_schema + input_strategy = cast(OpStrategy, input_strategy) + target_strategy = cast(OpStrategy, target_strategy) + reduction = cast(int, reduction) + + input_shape = input_strategy.shape + channel_dim = 1 if len(input_shape) >= 2 else 0 + + output_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + op_args_target_specs = [] + redistribute_costs = [] + + # make sure input is replicated along the channel dim + input_src_spec = input_placement_strategy.output_spec + input_expected_spec = DTensorSpec( + mesh=mesh, + placements=replicate_reduction_dims( + input_src_spec.placements, [channel_dim] + ), + tensor_meta=input_src_spec.tensor_meta, + ) + op_args_target_specs.append(input_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_expected_spec) + ) + + # target doesn't have channel dim, and it follows input on other dims + target_src_spec = target_strategy.strategies[idx].output_spec + target_expected_spec = DTensorSpec( + mesh=mesh, + placements=_skip_dim(input_expected_spec.placements, channel_dim), + tensor_meta=target_src_spec.tensor_meta, + ) + op_args_target_specs.append(target_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(target_strategy, target_expected_spec) + ) + + # weight tensor, if given, has to be a Tensor of size input_shape[channel_dim] + # make sure it is replicated + if weight_strategy is not None: + if not isinstance(weight_strategy, OpStrategy): + raise AssertionError( + f"Expected OpStrategy, got {type(weight_strategy)}" + ) + weight_src_spec = weight_strategy.strategies[idx].output_spec + weight_expected_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(weight_src_spec.placements), + tensor_meta=weight_src_spec.tensor_meta, + ) + op_args_target_specs.append(weight_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_expected_spec) + ) + + if reduction == Reduction.NONE.value: + output_expected_spec = target_expected_spec + total_weight_expected_spec = DTensorSpec( + mesh=mesh, placements=tuple([Replicate()] * mesh.ndim) + ) + else: + if reduction == Reduction.MEAN.value: + reduction_op = "avg" + if not is_tensor_evenly_shardable( + target_expected_spec.shape, target_expected_spec + ): + raise ValueError( + "The intermediate results of nll_loss cannot be evenly sharded, \ + resulting in biased mean result." + ) + else: # reduction == Reduction.SUM.value: + reduction_op = "sum" + reduce_dims = list(range(target_expected_spec.ndim)) + reduce_dims_map = _infer_reduce_dims_map( + reduce_dims, target_expected_spec.ndim, keep_dim=False + ) + out_placements = map_placements_after_reduction( + target_expected_spec.placements, + reduce_dims, + reduce_dims_map, + reduction_op, + ) + output_expected_spec = DTensorSpec( + mesh=mesh, + placements=out_placements, + ) + + # whether reduction is sum or mean, the total weight has to be summed up if not replicated + total_weight_placements = map_placements_after_reduction( + target_expected_spec.placements, + reduce_dims, + reduce_dims_map, + "sum", + ) + total_weight_expected_spec = DTensorSpec( + mesh=mesh, + placements=total_weight_placements, + ) + + output_strategy.strategies.append( + OpSpec( + output_specs=(output_expected_spec, total_weight_expected_spec), + input_specs=op_args_target_specs, + redistribute_cost=redistribute_costs, + ) + ) + + return output_strategy + + +@register_op_strategy( + [aten.nll_loss_backward.default, aten.nll_loss2d_backward.default], + schema_info=RuntimeSchemaInfo(4), +) +def nll_loss_backward_strategy(op_schema: OpSchema) -> OpStrategy: + # backward op does not need to validate the mesh since forward op has already done it + mesh = op_schema.get_mesh_from_args(validate=False) + + if not len(op_schema.args_schema) == 7: + raise AssertionError(f"Expected 7 args, got {len(op_schema.args_schema)}") + ( + grad_out_strategy, + input_strategy, + target_strategy, + weight_strategy, + reduction, + _, + total_weight_strategy, + ) = op_schema.args_schema + grad_out_strategy = cast(OpStrategy, grad_out_strategy) + input_strategy = cast(OpStrategy, input_strategy) + target_strategy = cast(OpStrategy, target_strategy) + reduction = cast(int, reduction) + total_weight_strategy = cast(OpStrategy, total_weight_strategy) + + input_shape = input_strategy.shape + channel_dim = 1 if len(input_shape) >= 2 else 0 + + grad_in_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + op_args_target_specs = [] + redistribute_costs = [] + + # make sure input is replicated along the channel dim + input_src_spec = input_placement_strategy.output_spec + input_expected_spec = DTensorSpec( + mesh=mesh, + placements=replicate_reduction_dims( + input_src_spec.placements, [channel_dim] + ), + tensor_meta=input_src_spec.tensor_meta, + ) + op_args_target_specs.append(input_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_expected_spec) + ) + + # target doesn't have channel dim, and it follows input on other dims + target_src_spec = target_strategy.strategies[idx].output_spec + target_expected_spec = DTensorSpec( + mesh=mesh, + placements=_skip_dim(input_expected_spec.placements, channel_dim), + tensor_meta=target_src_spec.tensor_meta, + ) + op_args_target_specs.append(target_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(target_strategy, target_expected_spec) + ) + + # grad_out follows target if there is no reduction; + # otherwise, it should be a replicated scalar. + grad_out_src_spec = grad_out_strategy.strategies[idx].output_spec + if reduction == Reduction.NONE.value: + grad_out_expected_spec = target_expected_spec + else: + grad_out_expected_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(grad_out_src_spec.placements), + tensor_meta=grad_out_src_spec.tensor_meta, + ) + op_args_target_specs.insert(0, grad_out_expected_spec) + redistribute_costs.insert( + 0, generate_redistribute_costs(grad_out_strategy, grad_out_expected_spec) + ) + + # weight tensor, if given, has to be a Tensor of size input_shape[channel_dim] + # make sure it is replicated + if weight_strategy is not None: + if not isinstance(weight_strategy, OpStrategy): + raise AssertionError( + f"Expected OpStrategy, got {type(weight_strategy)}" + ) + weight_src_spec = weight_strategy.strategies[idx].output_spec + weight_expected_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(weight_src_spec.placements), + tensor_meta=weight_src_spec.tensor_meta, + ) + op_args_target_specs.append(weight_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_expected_spec) + ) + + # total_weight should always be replicated + total_weight_src_spec = total_weight_strategy.strategies[idx].output_spec + total_weight_expected_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(total_weight_src_spec.placements), + tensor_meta=total_weight_src_spec.tensor_meta, + ) + op_args_target_specs.append(total_weight_expected_spec) + redistribute_costs.append( + generate_redistribute_costs( + total_weight_strategy, total_weight_expected_spec + ) + ) + + grad_in_expected_spec = input_expected_spec + grad_in_strategy.strategies.append( + OpSpec( + output_specs=grad_in_expected_spec, + input_specs=op_args_target_specs, + redistribute_cost=redistribute_costs, + ) + ) + + return grad_in_strategy + + +def _common_norm_forward_strategy( + op_schema: OpSchema, + rms_norm: bool = False, +) -> OpStrategy: + """Common forward strategy logic for layer_norm and rms_norm.""" + mesh = op_schema.get_mesh_from_args() + + if not rms_norm: + # layer_norm args: input, normalized_shape, weight, bias, eps + # for None weight and bias, their corresponding objects will + # be None as well. layer_norm_strategy returns one OpStrategy + # for the triple return values (out, mean, rstd). + if not len(op_schema.args_schema) == 5: + raise AssertionError(f"Expected 5 args, got {len(op_schema.args_schema)}") + ( + input_strategy, + normalized_shape, + weight_strategy, + bias_strategy, + _, + ) = op_schema.args_schema + else: + # rms_norm args: input, normalized_shape, weight, eps + if not len(op_schema.args_schema) == 4: + raise AssertionError(f"Expected 4 args, got {len(op_schema.args_schema)}") + ( + input_strategy, + normalized_shape, + weight_strategy, + _, + ) = op_schema.args_schema + bias_strategy = None + + # the current norm implementation requires that all + # input DTensor's sharding must be in form of OpStrategy + if not isinstance(input_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}") + if not isinstance(normalized_shape, (int, Sequence, torch.Size)): + raise AssertionError( + f"Expected int, Sequence, or torch.Size, got {type(normalized_shape)}" + ) + normalized_size = normalize_to_torch_size(normalized_shape) + + input_ndim = input_strategy.ndim + axis = input_ndim - len(normalized_size) + + # we use OpStrategy because the output values (out, mean, rstd) + # should have the same placements + output_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + op_args_target_specs = [] + redistribute_costs = [] + input_src_spec = input_placement_strategy.output_spec + + # for the input tensor, we replicate it on the inner dims if necessary + # TODO: we can avoid forcing the redistribution once we figure out + # how to decompose layer norm + input_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src_spec.placements, axis), + tensor_meta=input_src_spec.tensor_meta, + ) + op_args_target_specs.append(input_target_spec) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_target_spec) + ) + + if weight_strategy is not None: + if not isinstance(weight_strategy, OpStrategy): + raise AssertionError( + f"Expected OpStrategy, got {type(weight_strategy)}" + ) + weight_src_spec = weight_strategy.strategies[idx].output_spec + + # for the weight tensor, we replicate it on all dims if necessary + # TODO: we can avoid forcing the redistribution once we figure out + # how to decompose layer norm + weight_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(weight_src_spec.placements), + tensor_meta=weight_src_spec.tensor_meta, + ) + op_args_target_specs.append(weight_target_spec) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_target_spec) + ) + + if bias_strategy is not None: + if not isinstance(bias_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(bias_strategy)}") + bias_src_spec = bias_strategy.strategies[idx].output_spec + + # for the bias tensor, we replicate it on all dims if necessary + # TODO: we can avoid forcing the redistribution once we figure out + # how to decompose layer norm + bias_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(bias_src_spec.placements), + tensor_meta=bias_src_spec.tensor_meta, + ) + op_args_target_specs.append(bias_target_spec) + redistribute_costs.append( + generate_redistribute_costs(bias_strategy, bias_target_spec) + ) + + # the output spec is the same as input spec + output_target_spec = input_target_spec + output_strategy.strategies.append( + OpSpec( + output_specs=output_target_spec, + input_specs=op_args_target_specs, + redistribute_cost=redistribute_costs, + ) + ) + + return output_strategy + + +@register_op_strategy( + [aten.native_layer_norm.default], + schema_info=RuntimeSchemaInfo(1), +) +def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy: + return _common_norm_forward_strategy(op_schema) + + +@register_op_strategy( + [aten._fused_rms_norm.default], + schema_info=RuntimeSchemaInfo(1), +) +def fused_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + return _common_norm_forward_strategy(op_schema, rms_norm=True) + + +def _common_norm_backward_strategy( + op_schema: OpSchema, + rms_norm: bool = False, +) -> OpStrategy: + """Common backward strategy logic for layer_norm and rms_norm.""" + # backward op does not need to validate the mesh since forward op has already done it + mesh = op_schema.get_mesh_from_args(validate=False) + + if not rms_norm: + # layer_norm args: grad_out, input, normalized_shape, mean, rstd, + # weight, bias, output_mask. For None weight and bias, their + # corresponding objects will be None as well. + if not len(op_schema.args_schema) == 8: + raise AssertionError(f"Expected 8 args, got {len(op_schema.args_schema)}") + ( + grad_out_strategy, + input_strategy, + normalized_shape, + mean_strategy, + rstd_strategy, + weight_strategy, + bias_strategy, + output_mask, + ) = op_schema.args_schema + else: + # rms_norm args: grad_out, input, normalized_shape, rstd, + if not len(op_schema.args_schema) == 6: + raise AssertionError(f"Expected 6 args, got {len(op_schema.args_schema)}") + ( + grad_out_strategy, + input_strategy, + normalized_shape, + rstd_strategy, + weight_strategy, + output_mask, + ) = op_schema.args_schema + mean_strategy = None + bias_strategy = None + + if not isinstance(grad_out_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(grad_out_strategy)}") + if not isinstance(input_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}") + if not isinstance(rstd_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(rstd_strategy)}") + if mean_strategy is not None: + if not isinstance(mean_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(mean_strategy)}") + + if not isinstance(normalized_shape, (int, Sequence, torch.Size)): + raise AssertionError( + f"Expected int, Sequence, or torch.Size, got {type(normalized_shape)}" + ) + normalized_size = normalize_to_torch_size(normalized_shape) + input_ndim = input_strategy.ndim + axis = input_ndim - len(normalized_size) + outer_dims = list(range(axis)) + + if not rms_norm: + if not (isinstance(output_mask, list) and len(output_mask) == 3): + raise AssertionError( + f"Expected output_mask to be list of length 3, got {type(output_mask)} " + f"of length {len(output_mask) if isinstance(output_mask, list) else 'N/A'}" + ) + else: + if not (isinstance(output_mask, list) and len(output_mask) == 2): + raise AssertionError( + f"Expected output_mask to be list of length 2, got {type(output_mask)} " + f"of length {len(output_mask) if isinstance(output_mask, list) else 'N/A'}" + ) + + # output tuple: (d_input, d_weight[, d_bias]) + out_tuple_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + # args for OpSpec + output_specs_list: list[DTensorSpec | None] = [] + input_specs_list: list[DTensorSpec] = [] + redistribute_costs = [] + + input_src_spec = input_placement_strategy.output_spec + # arg: grad_out + # TODO: change the strategy to the following rule. + # d_input is basically a product of element-wise mul of + # grad_out, rstd, and normalized input, among which rstd + # and normalized input (x_hat) should have the same sharding + # placements, and grad_out's sharding is determined by the + # pointwise result of x_hat and weight/bias. + # TODO: now grad_out spec follows input spec. we may need + # to change it to apply a pointwise rule over grad_out, + # input, and weight. + grad_out_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src_spec.placements, axis), + tensor_meta=input_src_spec.tensor_meta, + ) + input_specs_list.append(grad_out_target_spec) + redistribute_costs.append( + generate_redistribute_costs(grad_out_strategy, grad_out_target_spec) + ) + output_specs_list.append(grad_out_target_spec if output_mask[0] else None) + + # arg: input + input_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src_spec.placements, axis), + tensor_meta=input_src_spec.tensor_meta, + ) + input_specs_list.append(input_target_spec) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_target_spec) + ) + + # arg: mean + if not rms_norm: + if mean_strategy is None: + raise AssertionError("Expected mean_strategy to not be None") + mean_src_spec = mean_strategy.strategies[idx].output_spec + input_specs_list.append(mean_src_spec) + redistribute_costs.append([0.0 for _ in mean_strategy.strategies]) + + # arg: rstd + rstd_src_spec = rstd_strategy.strategies[idx].output_spec + input_specs_list.append(rstd_src_spec) + redistribute_costs.append([0.0 for _ in rstd_strategy.strategies]) + + def _add_target_input_spec(strategy) -> DTensorSpec: + # shared logic for setting the weight and bias target input specs + if not isinstance(strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(strategy)}") + src_spec = strategy.strategies[idx].output_spec + # no need to redistribute since they should be replicated in forward pass + input_specs_list.append(src_spec) + redistribute_costs.append([0.0 for _ in strategy.strategies]) + return src_spec + + # arg: weight + # d_weight = sum(grad_out * (input - mean) / rstd, outer_dim, keepdim=False) + # For RMS norm, mean is 0, so it's just: sum(grad_out * input / rstd, outer_dim, keepdim=False) + if weight_strategy is not None: + weight_src_spec = _add_target_input_spec(weight_strategy) + # TODO: now d_weight spec follows input spec w/ a reduction. + # we may need to change to a pointwise rule over grad_out and + # input, then apply a reduction. + inp_placements = _replicate_dims_start_at(input_src_spec.placements, axis) + reduce_dims_map = _infer_reduce_dims_map( + outer_dims, input_src_spec.ndim, False + ) + out_placements = map_placements_after_reduction( + inp_placements, outer_dims, reduce_dims_map, "sum" + ) + weight_out_spec = DTensorSpec( + mesh=mesh, + placements=out_placements, + tensor_meta=weight_src_spec.tensor_meta, + ) + output_specs_list.append(weight_out_spec if output_mask[1] else None) + else: + if not rms_norm: + error_msg = "output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward." + else: + error_msg = "output_mask[1] should not be `True` while weight argument is `None` in _fused_rms_norm_backward." + if output_mask[1] is not False: + raise AssertionError(error_msg) + output_specs_list.append(None) + + # arg: bias + # d_bias = sum(grad_out, outer_dim, keepdim=False) + if not rms_norm: + if bias_strategy is not None: + bias_src_spec = _add_target_input_spec(bias_strategy) + # d_bias spec follows a reduction over grad_out + inp_placements = _replicate_dims_start_at( + grad_out_target_spec.placements, axis + ) + reduce_dims_map = _infer_reduce_dims_map( + outer_dims, grad_out_target_spec.ndim, False + ) + out_placements = map_placements_after_reduction( + inp_placements, outer_dims, reduce_dims_map, "sum" + ) + bias_out_spec = DTensorSpec( + mesh=mesh, + placements=out_placements, + tensor_meta=bias_src_spec.tensor_meta, + ) + output_specs_list.append(bias_out_spec if output_mask[2] else None) + else: + if output_mask[2] is not False: + raise AssertionError( + "output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward." + ) + output_specs_list.append(None) + + out_tuple_strategy.strategies.append( + OpSpec( + output_specs=tuple(output_specs_list), + input_specs=input_specs_list, + redistribute_cost=redistribute_costs, + ) + ) + + return out_tuple_strategy + + +@register_op_strategy( + [aten.native_layer_norm_backward.default], + schema_info=RuntimeSchemaInfo(2), +) +def layer_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy: + return _common_norm_backward_strategy(op_schema) + + +@register_op_strategy( + [aten._fused_rms_norm_backward.default], + schema_info=RuntimeSchemaInfo(2), +) +def fused_rms_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy: + return _common_norm_backward_strategy(op_schema, rms_norm=True) + + +def sort_strategy(op_schema: OpSchema, sort_dim: int) -> OpStrategy: + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + sort_dim = normalize_dim(sort_dim, input_strategy.ndim) + single_mesh_dim_strategies = [] + all_replicate: PlacementList = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + for dim in range(input_strategy.ndim): + if dim != sort_dim: + dim_shardings: PlacementList = [Shard(dim)] * 3 + single_mesh_dim_strategies.append(dim_shardings) + return expand_to_full_mesh_op_strategy( + input_strategy.mesh, op_schema, single_mesh_dim_strategies, input_index=2 + ) + + +@register_op_strategy( + [aten.topk.default], + schema_info=RuntimeSchemaInfo(2), +) +def topk_strategy(op_schema: OpSchema) -> OpStrategy: + topk_dim = ( + cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else -1 + ) + return sort_strategy(op_schema, topk_dim) + + +@register_op_strategy( + aten.sort.default, + schema_info=RuntimeSchemaInfo( + 1, + ), +) +def sort_default_strategy(op_schema: OpSchema) -> OpStrategy: + # mostly copy paste from topk_strategy + input_strategy = op_schema.args_schema[0] + if not isinstance(input_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}") + sort_dim = -1 + if len(op_schema.args_schema) > 1: + sort_dim = cast(int, op_schema.args_schema[1]) + return sort_strategy(op_schema, sort_dim) + + +@register_op_strategy( + aten.sort.stable, + schema_info=RuntimeSchemaInfo( + 1, + static_kwargkey=["dim", "descending", "stable"], + ), +) +def sort_stable_strategy(op_schema: OpSchema) -> OpStrategy: + # mostly copy paste from topk_strategy + input_strategy = op_schema.args_schema[0] + if not isinstance(input_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}") + sort_dim = -1 + if "dim" in op_schema.kwargs_schema: + sort_dim = cast(int, op_schema.kwargs_schema["dim"]) + return sort_strategy(op_schema, sort_dim) + + +@register_op_strategy( + [aten.histc.default], + # strategy choice depends on the value of 'min' and 'max' kwargs, which are position 2 and 3 + schema_info=RuntimeSchemaInfo(2), +) +def histc_strategy(op_schema: OpSchema) -> OpStrategy: + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + single_mesh_dim_strategies: list[PlacementList] = [] + single_mesh_dim_strategies.append([Replicate(), Replicate()]) + + # histc can support sharded input and partial output on any input dim, provided the min and max + # values are user-specified. If not user-specified, the true min and max of the data in each local + # tensor will be used to compute bin boundaries, which will not be the same across ranks, leading to + # an incorrect final result + if len(op_schema.args_schema) == 4: + for dim in range(input_strategy.ndim): + dim_shardings: PlacementList = [Partial(), Shard(dim)] + single_mesh_dim_strategies.append(dim_shardings) + + return expand_to_full_mesh_op_strategy( + input_strategy.mesh, op_schema, single_mesh_dim_strategies + ) + + +@register_op_strategy( + [aten.logsumexp.default], + schema_info=RuntimeSchemaInfo( + # static_argnum is the position where non-Tensor args beings. + static_argnum=1, + # static_kwargkey is the name of kwargs to hash (which determines + # whether sharding prop can be cached). + static_kwargkey=["keepdim"], + ), +) +def logsumexp_strategy(op_schema: OpSchema) -> OpStrategy: + """Implements the sharding propagation strategy for logsumexp.""" + + # args_schema contains all but the DTensor args (e.g., dim, keepdim). + args_schema = op_schema.args_schema + if not len(args_schema) > 1: + raise AssertionError( + f"Expected more than 1 arg (input and dim are required), got {len(args_schema)}" + ) + + input_strategy = args_schema[0] + if not isinstance(input_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}") + + dims_arg = args_schema[1] + reduce_dims = _infer_reduction_dims(dims_arg, input_strategy.ndim) + if reduce_dims is None: + raise AssertionError("Expected reduce_dims to not be None") + + keep_dim = cast(bool, op_schema.kwargs_schema.get("keepdim", False)) + return common_reduction_strategy( + input_strategy, + reduce_dims, + keep_dim=keep_dim, + reduction_linear=False, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_matrix_ops.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_matrix_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c00a44ef8f4f41730bdb4ca0550ffa1808a8fffe --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_matrix_ops.py @@ -0,0 +1,1087 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor + + +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpSpec, + OpStrategy, + PlacementList, + RuntimeSchemaInfo, +) +from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies +from torch.distributed.tensor._ops.registration import register_op_strategy +from torch.distributed.tensor._ops.utils import ( + expand_to_full_mesh_op_strategy, + generate_redistribute_costs, + infer_broadcast_dims_map, + is_tensor_shardable, + map_placements_after_broadcast, + prod, +) +from torch.distributed.tensor._utils import ( + compute_local_shape_and_global_offset, + compute_local_stride, +) +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +aten = torch.ops.aten + + +@register_op_strategy(aten.t.default) +def transpose_strategy(op_schema: OpSchema) -> OpStrategy: + self_strategy = op_schema.args_schema[0] + if not isinstance(self_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(self_strategy)}") + + transpose_strategies = [] + for input_strategy in self_strategy.strategies: + input_spec = input_strategy.output_spec + # follow the input spec but transpose the Shard placements + output_placements = [ + Shard(1 - p.dim) if isinstance(p, Shard) else p + for p in input_spec.placements + ] + transpose_strategy = OpSpec( + output_specs=DTensorSpec( + mesh=input_strategy.mesh, + placements=tuple(output_placements), + ), + input_specs=(input_strategy.output_spec,), + ) + transpose_strategies.append(transpose_strategy) + + return OpStrategy(strategies=transpose_strategies) + + +def _mm_like_strategy( + mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema +) -> OpStrategy: + self_strategy, mat2_strategy = op_schema.args_schema + if not isinstance(self_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(self_strategy)}") + if not isinstance(mat2_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(mat2_strategy)}") + # generate all possible strategies for mm + mm_strategy = gen_einsum_strategies(mm_equation, mesh) + # filter out invalid strategies and associate costs + strategies = mm_strategy.strategies + filtered_strategies = [] + for strtg in strategies: + if strtg.input_specs is None: + raise AssertionError( + f"Expected input_specs to be not None, got {strtg.input_specs}" + ) + self_spec = strtg.input_specs[0] + mat2_spec = strtg.input_specs[1] + if is_tensor_shardable( + self_strategy.shape, self_spec, allow_unbacked_sharding=True + ) and is_tensor_shardable( + mat2_strategy.shape, mat2_spec, allow_unbacked_sharding=True + ): + redistribute_cost = [ + generate_redistribute_costs(self_strategy, self_spec), + generate_redistribute_costs(mat2_strategy, mat2_spec), + ] + strtg.redistribute_cost = redistribute_cost + filtered_strategies.append(strtg) + + mm_strategy.strategies = filtered_strategies + + return mm_strategy + + +def _addmm_like_strategy( + mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema +) -> OpStrategy: + self_strategy, mat1_strategy, mat2_strategy = op_schema.args_schema + if not isinstance(self_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(self_strategy)}") + if not isinstance(mat1_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(mat1_strategy)}") + if not isinstance(mat2_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(mat2_strategy)}") + self_shape = self_strategy.shape + mm_out_shape = torch.Size( + [ + mat2_strategy.shape[-1] if i == len(mat1_strategy.shape) - 1 else dim_size + for i, dim_size in enumerate(mat1_strategy.shape) + ] + ) + # generate all possible strategies for mm + mm_strategy = gen_einsum_strategies(mm_equation, mesh) + # filter out invalid strategies and associate costs + strategies = mm_strategy.strategies + filtered_strategies = [] + for strtg in strategies: + # construct new strategy by consider the self arg + if strtg.input_specs is None: + raise AssertionError( + f"Expected input_specs to be not None, got {strtg.input_specs}" + ) + mat1_spec = strtg.input_specs[0] + mat2_spec = strtg.input_specs[1] + out_spec = strtg.output_spec + + # self arg's spec should follow the output of mm, but need + # to consider broadcast for the self arg + broadcast_dims_map = infer_broadcast_dims_map(mm_out_shape, self_shape) + self_placements = map_placements_after_broadcast( + out_spec.placements, mm_out_shape, broadcast_dims_map + ) + self_spec = DTensorSpec(mesh=mesh, placements=self_placements) + + if is_tensor_shardable( + mat1_strategy.shape, mat1_spec, allow_unbacked_sharding=True + ) and is_tensor_shardable( + mat2_strategy.shape, mat2_spec, allow_unbacked_sharding=True + ): + # update input specs with new self spec + strtg.input_specs = (self_spec, mat1_spec, mat2_spec) + + # associate costs + redistribute_cost = [ + generate_redistribute_costs(self_strategy, self_spec), + generate_redistribute_costs(mat1_strategy, mat1_spec), + generate_redistribute_costs(mat2_strategy, mat2_spec), + ] + strtg.redistribute_cost = redistribute_cost + filtered_strategies.append(strtg) + + mm_strategy.strategies = filtered_strategies + + return mm_strategy + + +def _scaled_mm_like_strategy( + mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema +) -> OpStrategy: + ( + self_strategy, + mat2_strategy, + scale_self_strategy, + scale_mat2_strategy, + bias_strategy, + scale_result_strategy, + *_, + ) = op_schema.args_schema + if not isinstance(self_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(self_strategy)}") + if not isinstance(mat2_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(mat2_strategy)}") + if not isinstance(scale_self_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(scale_self_strategy)}") + if not isinstance(scale_mat2_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(scale_mat2_strategy)}") + # TODO: add support for these later + if bias_strategy is not None: + raise AssertionError("_scaled_mm on DTensors doesn't support bias") + if scale_result_strategy is not None: + raise AssertionError("_scaled_mm on DTensors doesn't support scale_result") + # generate all possible strategies for mm + mm_strategy = gen_einsum_strategies(mm_equation, mesh) + # filter out invalid strategies and associate costs + strategies = mm_strategy.strategies + filtered_strategies = [] + for strtg in strategies: + if strtg.input_specs is None: + raise AssertionError( + f"Expected input_specs to be not None, got {strtg.input_specs}" + ) + self_spec = strtg.input_specs[0] + mat2_spec = strtg.input_specs[1] + # propagate the operands' specs to their scales, except for tensor-wise + # scaling which can have any numbers of dims (legacy...), hence sharding + # dims won't map. for tensor-wise, anyways, we can only do replication. + scale_self_spec = ( + DTensorSpec(self_spec.mesh, (Replicate(),)) + if prod(scale_self_strategy.shape) == 1 + else self_spec + ) + scale_mat2_spec = ( + DTensorSpec(mat2_spec.mesh, (Replicate(),)) + if prod(scale_mat2_strategy.shape) == 1 + else mat2_spec + ) + strtg.input_specs = list(strtg.input_specs) + [scale_self_spec, scale_mat2_spec] + if ( + is_tensor_shardable( + self_strategy.shape, self_spec, allow_unbacked_sharding=True + ) + and is_tensor_shardable( + mat2_strategy.shape, mat2_spec, allow_unbacked_sharding=True + ) + and is_tensor_shardable( + scale_self_strategy.shape, scale_self_spec, allow_unbacked_sharding=True + ) + and is_tensor_shardable( + scale_mat2_strategy.shape, scale_mat2_spec, allow_unbacked_sharding=True + ) + ): + redistribute_cost = [ + generate_redistribute_costs(self_strategy, self_spec), + generate_redistribute_costs(mat2_strategy, mat2_spec), + generate_redistribute_costs(scale_self_strategy, scale_self_spec), + generate_redistribute_costs(scale_mat2_strategy, scale_mat2_spec), + ] + strtg.redistribute_cost = redistribute_cost + filtered_strategies.append(strtg) + + mm_strategy.strategies = filtered_strategies + + return mm_strategy + + +@register_op_strategy(aten.dot.default) +def dot_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + return _mm_like_strategy("i,i->", mesh, op_schema) + + +@register_op_strategy(aten.mm.default) +def mm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + return _mm_like_strategy("mk,kn->mn", mesh, op_schema) + + +@register_op_strategy(aten.addmm.default) +def addmm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + return _addmm_like_strategy("mk,kn->mn", mesh, op_schema) + + +@register_op_strategy(aten.bmm.default) +def bmm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + return _mm_like_strategy("bmk,bkn->bmn", mesh, op_schema) + + +@register_op_strategy(aten.baddbmm.default) +def baddbmm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + return _addmm_like_strategy("bmk,bkn->bmn", mesh, op_schema) + + +@register_op_strategy(aten._scaled_mm.default) +def scaled_mm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + return _scaled_mm_like_strategy("mk,kn->mn", mesh, op_schema) + + +def _scaled_dot_product_flash_attention_base_strategies( + op_schema: OpSchema, +) -> list[PlacementList]: + """Helper that returns list of base placement strategies (without CP).""" + return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5] + q_input_strategy = op_schema.args_schema[0] + if not isinstance(q_input_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}") + # assuming q/k/v have the same shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [outputs, inputs] + # in the spda case, we have 3 valid tensor outputs and 3 tensor inputs + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [ + Replicate(), + Replicate(), + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + Replicate(), # rng_state + None, # unused + Replicate(), + Replicate(), + Replicate(), + Replicate(), + ] + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the num of head dim + qkv_sharding = Shard(1) # num head dim + output_sharding = Shard(1) # num head dim + logsumexp_sharding = Shard(1) # num head dim + if return_debug_mask: + debug_attn_mask_sharding: Placement = Shard(1) # num head dim + else: + # empty debug mask, replicated + debug_attn_mask_sharding = Replicate() + + num_heads_dim_sharding: PlacementList = [ + output_sharding, + logsumexp_sharding, + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + Replicate(), # rng_state + None, # unused + debug_attn_mask_sharding, + qkv_sharding, + qkv_sharding, + qkv_sharding, + ] + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + # Shard on the batch dimension + debug_attn_mask_sharding = Shard(0) if return_debug_mask else Replicate() + single_mesh_dim_strategies.append( + [ + Shard(0), # output + Shard(0), # logsumexp + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + Replicate(), # rng_state + None, # unused + debug_attn_mask_sharding, # debugattn + Shard(0), # q + Shard(0), # k + Shard(0), # v + ] + ) + return single_mesh_dim_strategies + + +@register_op_strategy( + aten._scaled_dot_product_flash_attention.default, schema_info=RuntimeSchemaInfo(5) +) +def scaled_dot_product_flash_attention_strategy(op_schema: OpSchema) -> OpStrategy: + # NOTE: currently we only support some simple strategies to support tensor parallelism + # TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation + # as it involves: matmul, pointwise, reduction ops together. + + mesh = op_schema.get_mesh_from_args() + single_mesh_dim_strategies = _scaled_dot_product_flash_attention_base_strategies( + op_schema + ) + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=9 + ) + + +def _scaled_dot_product_flash_attention_backward_base_strategies( + op_schema: OpSchema, +) -> list[PlacementList]: + """Helper that returns list of base placement strategies (without CP).""" + q_input_strategy = op_schema.args_schema[1] + if not isinstance(q_input_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}") + # assuming q/k/v have the same shape + + tensor_input_indices = [ + i + for i, arg_spec in enumerate(op_schema.args_schema) + if isinstance(arg_spec, OpStrategy) + ] + num_tensor_inputs = len(tensor_input_indices) + + single_mesh_dim_strategies = [] + + # placement list stores placements of [outputs, inputs] + # in the spda backward case, we have 3 tensor outputs and 6 to 10 tensor inputs + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [Replicate()] * (3 + num_tensor_inputs) + + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the num of head dim + grad_output_sharding = Shard(1) # num head dim + qkv_sharding = Shard(1) # num head dim + output_sharding = Shard(1) # num head dim + logsumexp_sharding = Shard(1) # num head dim + grad_qkv_sharding = Shard(1) # num head dim + + num_heads_dim_sharding: PlacementList = [ + grad_qkv_sharding, + grad_qkv_sharding, + grad_qkv_sharding, + grad_output_sharding, + qkv_sharding, + qkv_sharding, + qkv_sharding, + output_sharding, + logsumexp_sharding, + ] + # accept replicate on the rest tensor inputs, potentially + # cum_seq_q, cum_seq_k, philox_seed, philox_offset + # at indices 6, 7, 12, 13, respectively + num_heads_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6)) + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + # Batch sharding + batch_dim_sharding: PlacementList = [ + Shard(0), # grad_q + Shard(0), # grad_k + Shard(0), # grad_v + Shard(0), # grad_output + Shard(0), # q + Shard(0), # k + Shard(0), # v + Shard(0), # output + Shard(0), # logsumexp + ] + # accept replicate on the rest tensor inputs, potentially + # cum_seq_q, cum_seq_k, philox_seed, philox_offset + # at indices 6, 7, 12, 13, respectively + batch_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6)) + single_mesh_dim_strategies.append(batch_dim_sharding) + + return single_mesh_dim_strategies + + +@register_op_strategy(aten._scaled_dot_product_flash_attention_backward.default) +def scaled_dot_product_flash_attention_backward_strategy( + op_schema: OpSchema, +) -> OpStrategy: + # backward op does not need to validate the mesh since forward op has already done it + mesh = op_schema.get_mesh_from_args(validate=False) + single_mesh_dim_strategies = ( + _scaled_dot_product_flash_attention_backward_base_strategies(op_schema) + ) + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=3 + ) + + +@register_op_strategy(aten.constant_pad_nd.default) +def constant_pad_nd_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args(validate=False) + + # TODO(d4l3k); implement a more correct strategy for constant_pad_nd + return OpStrategy( + [ + OpSpec( + output_specs=DTensorSpec(mesh, (Replicate(),)), + input_specs=( + DTensorSpec(mesh, (Replicate(),)), + DTensorSpec(mesh, (Replicate(),)), + ), + redistribute_cost=[[1]], + ) + ] + ) + + +def _scaled_dot_product_efficient_attention_base_strategies( + op_schema: OpSchema, +) -> list[PlacementList]: + """Helper that returns list of base placement strategies (without CP).""" + q_input_strategy = op_schema.args_schema[0] + if not isinstance(q_input_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}") + # assuming q/k/v have the same shape + + has_attn_bias = op_schema.args_schema[3] is not None + compute_log_sumexp = op_schema.args_schema[4] + + single_mesh_dim_strategies: list[PlacementList] = [] + + # placement list stores placements of [outputs, inputs] + # in the spda case, we have 2 valid tensor outputs and 3 or 4 tensor inputs + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [ + Replicate(), + Replicate(), + None, + None, + Replicate(), + Replicate(), + Replicate(), + ] + if has_attn_bias: + all_replicate.append(Replicate()) # attn bias + + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the heads dimension + qkv_sharding = Shard(1) + output_sharding = Shard(1) + if compute_log_sumexp: + logsumexp_sharding: Placement = Shard(1) + else: + # empty logsumexp, replicated + logsumexp_sharding = Replicate() + + num_heads_dim_sharding = [ + output_sharding, + logsumexp_sharding, + None, + None, + qkv_sharding, + qkv_sharding, + qkv_sharding, + ] + if has_attn_bias: + num_heads_dim_sharding.append(Shard(1)) + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + # batch sharding + if compute_log_sumexp: + logsumexp_sharding_dp: Placement = Shard(0) + else: + # empty logsumexp, replicated + logsumexp_sharding_dp = Replicate() + batch_sharding = [ + Shard(0), # output + logsumexp_sharding_dp, # logsumexp + None, # philox_seed + None, # philox_offset + Shard(0), # q + Shard(0), # k + Shard(0), # v + ] + if has_attn_bias: + batch_sharding.append(Shard(0)) + + single_mesh_dim_strategies.append(batch_sharding) + + return single_mesh_dim_strategies + + +@register_op_strategy( + aten._scaled_dot_product_efficient_attention.default, + schema_info=RuntimeSchemaInfo(4), +) +def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpStrategy: + # NOTE: currently we only support some simple strategies to support tensor parallelism + mesh = op_schema.get_mesh_from_args() + single_mesh_dim_strategies = ( + _scaled_dot_product_efficient_attention_base_strategies(op_schema) + ) + return expand_to_full_mesh_op_strategy( + mesh, + op_schema, + single_mesh_dim_strategies, + input_index=4, + ) + + +def _scaled_dot_product_efficient_attention_backward_base_strategies( + op_schema: OpSchema, +) -> list[PlacementList]: + """Helper that returns list of base placement strategies (without CP).""" + q_input_strategy = op_schema.args_schema[1] + if not isinstance(q_input_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}") + # assuming q/k/v have the same shape + has_attn_bias = op_schema.args_schema[4] is not None + + single_mesh_dim_strategies = [] + + # placement list stores placements of [outputs, inputs] + # in the spda backward case, we have 4 tensor outputs and 8 or 9 tensor inputs + # NOTE: Output sharding of grad_bias on heads dim if attn_bias is present; + # otherwise grad_bias will be empty and its DTensorSpec will be removed. + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [Replicate()] * (12 + has_attn_bias) + + if not has_attn_bias: + all_replicate[3] = None # grad bias is None if attn_bias is not present + + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the heads dimension + grad_output_sharding = Shard(1) + qkv_sharding = Shard(1) + output_sharding = Shard(1) + logsumexp_sharding = Shard(1) + grad_qkv_sharding = Shard(1) + grad_bias_sharding = Shard(1) if has_attn_bias else None + + num_heads_dim_sharding: PlacementList = [ + grad_qkv_sharding, + grad_qkv_sharding, + grad_qkv_sharding, + grad_bias_sharding, + grad_output_sharding, + qkv_sharding, + qkv_sharding, + qkv_sharding, + # the place for optional input attn_bias, + output_sharding, + logsumexp_sharding, + ] + # input sharding of attn_bias on heads dim if present + if has_attn_bias: + num_heads_dim_sharding.insert(8, Shard(1)) + # accept replicate on the rest scalar tensor inputs + # namely philox_seed and philox_offset + num_heads_dim_sharding.extend([Replicate(), Replicate()]) + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + # Shards on batch dim + batch_dim_sharding: PlacementList = [ + Shard(0), # grad_q + Shard(0), # grad_k + Shard(0), # grad_v + Shard(0) if has_attn_bias else None, # grad_bias + Shard(0), # grad_output + Shard(0), # q + Shard(0), # k + Shard(0), # v + Shard(0), # output + Shard(0), # logsumexp + ] + # accept replicate on the rest tensor inputs, potentially + # cum_seq_q, cum_seq_k, philox_seed, philox_offset + # at indices 6, 7, 12, 13, respectively + if has_attn_bias: + batch_dim_sharding.insert(8, Shard(0)) + batch_dim_sharding.extend([Replicate(), Replicate()]) + single_mesh_dim_strategies.append(batch_dim_sharding) + + return single_mesh_dim_strategies + + +@register_op_strategy(aten._scaled_dot_product_efficient_attention_backward.default) +def scaled_dot_product_efficient_attention_backward_strategy( + op_schema: OpSchema, +) -> OpStrategy: + # backward op does not need to validate the mesh since forward op has already done it + mesh = op_schema.get_mesh_from_args(validate=False) + single_mesh_dim_strategies = ( + _scaled_dot_product_efficient_attention_backward_base_strategies(op_schema) + ) + return expand_to_full_mesh_op_strategy( + mesh, + op_schema, + single_mesh_dim_strategies, + input_index=4, + ) + + +def _scaled_dot_product_cudnn_attention_base_strategies( + op_schema: OpSchema, +) -> list[PlacementList]: + """Helper that returns list of base placement strategies (without CP).""" + ( + query_strategy, # query + _, # key + _, # value + attn_bias_strategy, + compute_log_sumexp, # compute_log_sumexp + *rest_args, # optional args: dropout_p, is_causal, return_debug_mask, scale + ) = op_schema.args_schema + return_debug_mask = len(op_schema.args_schema) >= 8 and rest_args[2] + has_attn_bias = attn_bias_strategy is not None + debug_attn_mask_sharding: Placement | None = ( + Replicate() if return_debug_mask else None + ) + + if not isinstance(query_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(query_strategy)}") + # assuming q/k/v have the same shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [outputs, inputs] + # in the spda case, we have 2 valid tensor outputs and 3 tensor inputs + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [ + Replicate(), # output + Replicate(), # logsumexp + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + # NOTE: debug_attn_mask is not supported by pytorch and is always an empty tensor + # https://github.com/pytorch/pytorch/blob/60205b0eb2602317856312a66d955c88334ade0b/aten/src/ATen/native/transformers/cuda/attention.cu#L839-L840 + debug_attn_mask_sharding, # debug_attn_mask + Replicate(), # q + Replicate(), # k + Replicate(), # v + ] + if has_attn_bias: + all_replicate.append(Replicate()) # attn bias + + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the num of head dim + tp_sharding = Shard(1) # num head dim + qkv_sharding = tp_sharding + output_sharding = tp_sharding + logsumexp_sharding = tp_sharding if compute_log_sumexp else Replicate() + debug_attn_mask_sharding = tp_sharding if return_debug_mask else None + + num_heads_dim_sharding: PlacementList = [ + output_sharding, + logsumexp_sharding, + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + debug_attn_mask_sharding, + qkv_sharding, + qkv_sharding, + qkv_sharding, + ] + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + # batch parallelism + logsumexp_sharding = Shard(0) if compute_log_sumexp else Replicate() + debug_attn_mask_sharding = Shard(0) if return_debug_mask else None + batch_dim_sharding: PlacementList = [ + Shard(0), # output + logsumexp_sharding, + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + debug_attn_mask_sharding, + Shard(0), # q + Shard(0), # k + Shard(0), # v + ] + single_mesh_dim_strategies.append(batch_dim_sharding) + + return single_mesh_dim_strategies + + +@register_op_strategy( + aten._scaled_dot_product_cudnn_attention.default, + schema_info=RuntimeSchemaInfo(4), +) +def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + single_mesh_dim_strategies = _scaled_dot_product_cudnn_attention_base_strategies( + op_schema + ) + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=9 + ) + + +def _scaled_dot_product_cudnn_attention_backward_base_strategies( + op_schema: OpSchema, +) -> list[PlacementList]: + """Helper that returns list of base placement strategies (without CP).""" + if len(op_schema.args_schema) < 15: + raise AssertionError( + f"Expected at least 15 args_schema, got {len(op_schema.args_schema)}" + ) + has_attn_bias = op_schema.args_schema[8] is not None + has_scale = len(op_schema.args_schema) >= 16 and False + + query_strategy = op_schema.args_schema[1] + if not isinstance(query_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(query_strategy)}") + # assuming q/k/v have the same shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [outputs, inputs] + # cudnn outputs: (Tensor dq, Tensor dk, Tensor dv) + # cudnn inputs: ( + # Tensor grad_out, + # Tensor query, + # Tensor key, + # Tensor value, + # Tensor out, + # Tensor logsumexp, + # Tensor philox_seed, + # Tensor philox_offset, + # Tensor attn_bias, + # Tensor cum_seq_q, + # Tensor cum_seq_k, + # SymInt max_q, + # SymInt max_k, + # float dropout_p, + # bool is_causal, + # int? scale, + # ) + + # case 1: we can always accept full replication for both inputs and outputs + all_replicate_out: PlacementList = [ + Replicate(), # dq + Replicate(), # dk + Replicate(), # dv + ] + all_replicate_inp: PlacementList = [Replicate()] * 6 + all_replicate_inp += [ + Replicate() + ] * 2 # philox_seed, philox_offset is casted to Replicate() in DTensor + all_replicate_inp += [Replicate() if has_attn_bias else None] + all_replicate_inp += [None] * 6 + if has_scale: + all_replicate_inp.append(None) + + all_replicate: PlacementList = all_replicate_out + all_replicate_inp + single_mesh_dim_strategies.append(all_replicate) + + # case 2: we can accept the sharding pattern of tensor parallelism, which + # shards on the num of head dim + qkv_sharding = Shard(1) # num head dim + output_sharding = Shard(1) # num head dim + logsumexp_sharding = Shard(1) # num head dim + + num_heads_dim_sharding_out: PlacementList = [qkv_sharding] * 3 + num_heads_dim_sharding_inp: PlacementList = [qkv_sharding] * 4 + num_heads_dim_sharding_inp += [output_sharding] + num_heads_dim_sharding_inp += [logsumexp_sharding] + num_heads_dim_sharding_inp += [ + Replicate() + ] * 2 # philox_seed, philox_offset is casted to Replicate() in DTensor + num_heads_dim_sharding_inp += [Shard(1) if has_attn_bias else None] + num_heads_dim_sharding_inp += [None] * 6 + if has_scale: + num_heads_dim_sharding_inp.append(None) + + num_heads_dim_sharding = num_heads_dim_sharding_out + num_heads_dim_sharding_inp + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + # case 3: we can accept the sharding pattern of batch parallelism, which + # shards on the batch dimension + qkv_sharding = Shard(0) + output_sharding = Shard(0) + logsumexp_sharding = Shard(0) + + batch_dim_sharding_out: PlacementList = [qkv_sharding] * 3 + batch_dim_sharding_inp: PlacementList = [qkv_sharding] * 4 + batch_dim_sharding_inp += [output_sharding] + batch_dim_sharding_inp += [logsumexp_sharding] + batch_dim_sharding_inp += [ + Replicate() + ] * 2 # philox_seed, philox_offset is casted to Replicate() in DTensor + batch_dim_sharding_inp += [Shard(0) if has_attn_bias else None] + batch_dim_sharding_inp += [None] * 6 + if has_scale: + batch_dim_sharding_inp.append(None) + + batch_dim_sharding = batch_dim_sharding_out + batch_dim_sharding_inp + single_mesh_dim_strategies.append(batch_dim_sharding) + + return single_mesh_dim_strategies + + +@register_op_strategy(aten._scaled_dot_product_cudnn_attention_backward.default) +def scaled_scaled_dot_product_cudnn_attention_backward_strategy( + op_schema: OpSchema, +) -> OpStrategy: + # backward op does not need to validate the mesh since forward op has already done it + mesh = op_schema.get_mesh_from_args(validate=False) + single_mesh_dim_strategies = ( + _scaled_dot_product_cudnn_attention_backward_base_strategies(op_schema) + ) + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=3 + ) + + +@register_op_strategy(aten._grouped_mm.default) +def grouped_mm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + mat1_strategy = op_schema.args_schema[0] + if not isinstance(mat1_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(mat1_strategy)}") + mat2_strategy = op_schema.args_schema[1] + if not isinstance(mat2_strategy, OpStrategy): + raise AssertionError(f"Expected OpStrategy, got {type(mat2_strategy)}") + if len(op_schema.args_schema) > 3: + bias_strategy = op_schema.args_schema[3] + if bias_strategy is not None: + raise AssertionError("grouped_mm doesn't support bias yet") + + single_mesh_dim_strategies = [] + + offs_placement = None + if len(op_schema.args_schema) > 2 and op_schema.args_schema[2] is not None: + offs_placement = Replicate() # offs should always be replicated + + all_replicate: PlacementList = [ + Replicate(), + Replicate(), # mat1 + Replicate(), # mat2 + offs_placement, # offs + None, # bias + ] + partial_replicate: PlacementList = [ + Partial(), + Partial(), # mat1 + Replicate(), # mat2 + offs_placement, # offs + None, # bias + ] + replicate_partial: PlacementList = [ + Partial(), + Replicate(), # mat1 + Partial(), # mat2 + offs_placement, # offs + None, # bias + ] + single_mesh_dim_strategies = [all_replicate, partial_replicate, replicate_partial] + + if mat1_strategy.ndim == 2 and mat2_strategy.ndim == 3: + # rowwise_replicate for 2dx3d not supported + replicate_colwise_2x3: PlacementList = [ + Shard(1), + Replicate(), # mat1 + Shard(2), # mat2 + offs_placement, # offs + None, # bias + ] + colwise_rowwise_2x3: PlacementList = [ + Partial(), + Shard(1), # mat1 + Shard(1), # mat2 + offs_placement, # offs + None, # bias + ] + single_mesh_dim_strategies.extend([replicate_colwise_2x3, colwise_rowwise_2x3]) + + if mat1_strategy.ndim == 3 and mat2_strategy.ndim == 2: + # replicate_colwise for 3dx2d not supported + colwise_rowwise_3x2: PlacementList = [ + Partial(), + Shard(2), # mat1 + Shard(0), # mat2 + offs_placement, # offs + None, # bias + ] + rowwise_replicate_3x2: PlacementList = [ + Shard(0), + Shard(1), # mat1 + Replicate(), # mat2 + offs_placement, # offs + None, # bias + ] + single_mesh_dim_strategies.extend([colwise_rowwise_3x2, rowwise_replicate_3x2]) + + if mat1_strategy.ndim == 2 and mat2_strategy.ndim == 2: + # colwise_rowwise for 2dx2d not supported + replicate_colwise_2x2: PlacementList = [ + Shard(2), + Replicate(), # mat1 + Shard(1), # mat2 + offs_placement, # offs + None, # bias + ] + rowwise_replicate_2x2: PlacementList = [ + Shard(1), + Shard(0), # mat1 + Replicate(), # mat2 + offs_placement, # offs + None, # bias + ] + single_mesh_dim_strategies.extend( + [replicate_colwise_2x2, rowwise_replicate_2x2] + ) + + if mat1_strategy.ndim == 3 and mat2_strategy.ndim == 3: + replicate_colwise_3x3: PlacementList = [ + Shard(2), + Replicate(), # mat1 + Shard(2), # mat2 + offs_placement, # offs + None, # bias + ] + rowwise_replicate_3x3: PlacementList = [ + Shard(1), + Shard(1), # mat1 + Replicate(), # mat2 + offs_placement, # offs + None, # bias + ] + colwise_rowwise_3x3: PlacementList = [ + Partial(), + Shard(2), # mat1 + Shard(1), # mat2 + offs_placement, # offs + None, # bias + ] + batch_dim_sharding: PlacementList = [ + Shard(0), + Shard(0), # mat1 + Shard(0), # mat2 + offs_placement, # offs + None, # bias + ] + single_mesh_dim_strategies.extend( + [ + replicate_colwise_3x3, + rowwise_replicate_3x3, + colwise_rowwise_3x3, + batch_dim_sharding, + ] + ) + + def valid_grouped_mm_strides( + input_specs: list[DTensorSpec], output_specs: tuple[DTensorSpec | None, ...] + ) -> bool: + # 1. compute the local-tensor shape/strides given this sharding proposal + # 2. apply the logic from the groped_mm meta function + # UGH the input DTensorSpecs are missing their tensormetas... so i can get them another way + def local_meta(spec: OpSpec, placements: tuple[Placement, ...]) -> TensorMeta: + if not isinstance(spec.output_specs, DTensorSpec): + raise AssertionError( + f"Expected DTensorSpec, got {type(spec.output_specs)}" + ) + if not isinstance(spec.output_specs.tensor_meta, TensorMeta): + raise AssertionError( + f"Expected TensorMeta, got {type(spec.output_specs.tensor_meta)}" + ) + meta: TensorMeta = spec.output_specs.tensor_meta + local_stride = compute_local_stride(meta.stride, mesh, placements) + local_shape, _ = compute_local_shape_and_global_offset( + meta.shape, mesh, placements, skip_offset=True + ) + return TensorMeta(torch.Size(local_shape), local_stride, meta.dtype) + + # pyrefly: ignore [missing-attribute] + mat1_meta = local_meta(mat1_strategy.strategies[0], input_specs[0].placements) + # pyrefly: ignore [missing-attribute] + mat2_meta = local_meta(mat2_strategy.strategies[0], input_specs[1].placements) + + def check_valid_strides(meta: TensorMeta) -> bool: + # copied from `_meta_grouped_mm_common` in meta_registrations.py + end_dim = len(meta.shape) - 1 + alignment = 16 // meta.dtype.itemsize + if meta.stride[end_dim - 1] == 1 and meta.stride[end_dim] >= max( + 1, meta.shape[end_dim - 1] + ): + if meta.stride[end_dim] % alignment != 0: + return False + elif meta.stride[end_dim] == 1 and meta.stride[end_dim - 1] >= max( + 1, meta.shape[end_dim] + ): + if meta.stride[end_dim - 1] % alignment != 0: + return False + else: + return False + return True + + mat1_valid = check_valid_strides(mat1_meta) + mat2_valid = check_valid_strides(mat2_meta) + return mat1_valid and mat2_valid + + return expand_to_full_mesh_op_strategy( + mesh, + op_schema, + single_mesh_dim_strategies, + input_index=1, + is_valid_strategy_cb=valid_grouped_mm_strides, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_random_ops.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_random_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..dd4cf8fec226aa2538205c9a82f68ad05dbabb18 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_random_ops.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import torch +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpSpec, + OpStrategy, + StrategyType, +) +from torch.distributed.tensor._ops.registration import register_op_strategy +from torch.distributed.tensor._ops.utils import is_tensor_partial + + +aten = torch.ops.aten + + +@register_op_strategy( + [ + aten.normal_.default, + aten.uniform_.default, + aten.native_dropout.default, + aten.bernoulli_.float, + aten.bernoulli.default, + ] +) +def random_op_strategy(op_schema: OpSchema) -> StrategyType: + self_strategy = op_schema.args_schema[0] + assert isinstance(self_strategy, OpStrategy) + + random_strategy = OpStrategy([]) + for arg_strategy in self_strategy.strategies: + arg_spec = arg_strategy.output_spec + if is_tensor_partial(arg_spec): + # TODO: figure out how inplace random op should behave when it's partial + raise RuntimeError(f"{op_schema.op} with Partial is not supported yet!") + random_strategy.strategies.append( + OpSpec( + output_specs=arg_spec, + input_specs=(arg_spec,), + redistribute_cost=[[0.0] * len(self_strategy.strategies)], + ) + ) + + return random_strategy diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/registration.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/registration.py new file mode 100644 index 0000000000000000000000000000000000000000..98ec79d101591864f34025c3249db8f060654154 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_ops/registration.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from collections.abc import Callable +from typing import TypeAlias, TypeVar + +import torch +from torch.distributed.tensor._api import DTensor +from torch.distributed.tensor._op_schema import ( + OpSchema, + OutputSharding, + RuntimeSchemaInfo, + StrategyType, +) + + +# convenient wrapper to register sharding propagation rules +def register_prop_rule( + op: torch._ops.OpOverload | list[torch._ops.OpOverload], + schema_info: RuntimeSchemaInfo | None = None, +) -> Callable[ + [Callable[[OpSchema], OutputSharding]], Callable[[OpSchema], OutputSharding] +]: + def wrapper( + impl: Callable[[OpSchema], OutputSharding], + ) -> Callable[[OpSchema], OutputSharding]: + overloads = op if isinstance(op, list) else [op] + for overload in overloads: + DTensor._op_dispatcher.sharding_propagator.register_sharding_prop_rule( + overload, impl, schema_info + ) + return impl + + return wrapper + + +# Note: +# using TypeVar here allows the registration decorator to preserve the specific type info of the wrapped strategy, +# while hardcoding the typing on the wrapper (e.g. Callable[[OpSchema], StrategyType]) would mean mypy would treat +# the return value of the wrapped strategy as always being a `StrategyType` even if it were a derived class like +# MyStrategyType(StrategyType). +_OpSchemaT = TypeVar("_OpSchemaT", bound=OpSchema) +_StrategyTypeT = TypeVar("_StrategyTypeT", bound=StrategyType) +_ShardingStrategyFunc: TypeAlias = Callable[[_OpSchemaT], _StrategyTypeT] + + +def register_op_strategy( + op: torch._ops.OpOverload | list[torch._ops.OpOverload], + schema_info: RuntimeSchemaInfo | None = None, +) -> Callable[[_ShardingStrategyFunc], _ShardingStrategyFunc]: + # For every ATen op that accepts any args in this list, + # the arg itself can impact the strides (and potentially the sharding strategy) + # of the output tensor. + # thus, we will detect ATen schemas with any of these args and ensure + # that they get specialized here. + arg_names_that_require_specializing_cache_strategy = [ + "memory_format", + ] + + def wrapper(impl: _ShardingStrategyFunc) -> _ShardingStrategyFunc: + if isinstance(op, list): + overloads = op + else: + overloads = [op] + + for overload in overloads: + curr_schema_info = None + if schema_info is None: + specialized_args = [ + a.name + for a in overload._schema.arguments + if a.name in arg_names_that_require_specializing_cache_strategy + ] + if any(specialized_args): + curr_schema_info = RuntimeSchemaInfo( + static_kwargkey=specialized_args + ) + else: + curr_schema_info = schema_info + DTensor._op_dispatcher.sharding_propagator.register_op_strategy( + overload, impl, curr_schema_info + ) + return impl + + return wrapper diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_random.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_random.py new file mode 100644 index 0000000000000000000000000000000000000000..995a057b0c7faa085ae94d67e11d7603d057e05a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_random.py @@ -0,0 +1,478 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import contextlib +import warnings +from logging import getLogger +from typing import Optional + +import torch +from torch.distributed._local_tensor import maybe_run_for_local_tensor +from torch.distributed.device_mesh import _get_device_handle, DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import _StridedShard, Shard + + +logger = getLogger(__name__) + +__all__ = [ + "is_rng_supported_mesh", + "manual_seed", + "OffsetBasedRNGTracker", +] + +_rng_tracker: Optional["_RNGStateTracker"] = None + + +def is_rng_supported_mesh(device_mesh: DeviceMesh) -> bool: + """Checks if the current device of ``device_mesh`` supports DTensor's random APIs. + Currently DTensor Random APIs only supports cuda/cuda-like devices. We suggest + users call this API to test the availability before using our random APIs. + + Args: + device_mesh (:class:`DeviceMesh`): The device mesh on which we check if the + random ops APIs are supported. + + Returns: + A bool value. True if ``device_mesh`` supports DTensor Random APIs; False otherwise. + + .. warning:: + Currently we only support correct RNG on cuda/cuda-like devices. + """ + device_handle = _get_device_handle(device_mesh.device_type) + if device_handle and hasattr(device_handle, "set_rng_state"): + return True + else: + # TODO: Logs way too much + warnings.warn( + f"DTensor random operators may not have complete support on {device_mesh.device_type} device mesh", + stacklevel=2, + ) + return False + + +def manual_seed(seed: int, device_mesh: DeviceMesh) -> None: + """Sets the seed for generating random numbers for the calling rank. + + Args: + seed (int): The desired seed. + device_mesh (:class:`DeviceMesh`): The device mesh to set the seed. It is + required that the ``device_mesh`` include the calling rank. This is + to ensure that the SPMD region maintains a synchronous RNG state, which + means no ranks should be initialized with values other than ``seed``. + + Returns: + None + + .. warning:: + :func:`manual_seed` does not check the ``seed`` value correctness. Users must + ensure on their own that the value passed in is the desired ``seed`` for ranks + within ``device_mesh``. + If ``device_mesh`` is a sub-mesh and the calling rank is not a part of it, + ``manual_seed`` will throw an error. + Current implementation only supports a GPU device mesh. + """ + if not is_rng_supported_mesh(device_mesh): + warnings.warn( + "DTensor manual_seed() may not have complete support " + f"on {device_mesh.device_type} device mesh", + stacklevel=2, + ) + return + + # TODO: deprecate this API, but also need to ensure we disable broadcast for PP case, and that's currently + # bundled together with this API. See torchtitan/distributed/utils.py:set_determinism + # warnings.warn( + # "DTensor manual_seed() is deprecated, since DTensor no longer maintains a separate copy of generator state. " + # "Use `torch.manual_seed` instead" + # ) + # Note: we still need to ensure setting `run_state_sync=False` to support the pp case + + # instantiate a RNG tracker if haven't. By default DTensor uses an + # OffsetBasedRNGTracker to perform random operators. + global _rng_tracker + if not _rng_tracker: + _rng_tracker = OffsetBasedRNGTracker(device_mesh, run_state_sync=False) + + if device_mesh.get_coordinate() is None: + raise RuntimeError( + "manual_seed requires the current rank to be a part of the device mesh " + "otherwise DTensor RNG state on the rank will not be initialized and " + "the behavior of DTensor random ops is undefined." + ) + + # DTensor no longer maintains a copy of rng state. manual seed on dtensor is the same thing + # as manual seed on torch. + # + # torch.manual_seed will handle LocalTensor mode correctly by + # iterating through all ranks if seed is a LocalIntNode. + torch.manual_seed(seed) + + +class _PhiloxState: + """ + Convenience accessor for interpreting the packed bits of (seed: uint64, offset: uint64) in the philox state, + which for some reason is actually exposed as a size-16 uint8 tensor. + + The state is always moved to .cpu since it is necessary for it to be on CPU before applying it back to a generator. + """ + + def __init__(self, state: torch.Tensor): + self._state = state.to("cpu") + + @property + def state(self): + return self._state + + @property + def offset(self) -> int: + return int(self._state[8:].view(dtype=torch.int64).item()) + + @offset.setter + def offset(self, offset: int) -> None: + offset_tensor = torch.tensor([offset], dtype=torch.uint64, device="cpu").view( + torch.uint8 + ) + self._state[8:] = offset_tensor + + @property + def seed(self) -> int: + return int(self._state[:8].view(dtype=torch.uint64).item()) + + @seed.setter + def seed(self, seed: int) -> None: + seed_tensor = torch.tensor([seed], dtype=torch.uint64, device="cpu").view( + torch.uint8 + ) + self._state[:8] = seed_tensor + + +class _RNGStateTracker: + """ + _RNGStateTracker stores Random Number Generator (RNG) state (a ByteTensor object) + in a dict, mapping from a corresponding tag to each state tensor. It also provides + a set of convenient utility methods to help access/modify the state tensors. The most + important interface is _distribute_region which will be used when DTensor executes + a random op (an operator that calls RNG). + """ + + def __init__(self, device: torch.device): + # pyrefly: ignore [read-only] + self._device = device + self._device_handle = _get_device_handle(self._device.type) + if not (self._device_handle and self._device_handle.is_available()): + raise RuntimeError( + f"{self.__class__.__name__} instantiation requires the presence of " + f"{device.type} device but couldn't find." + ) + self._use_distribute_region = True + + @property + def distribute_region_enabled(self) -> bool: + return self._use_distribute_region + + @distribute_region_enabled.setter + def distribute_region_enabled(self, value) -> None: + self._use_distribute_region = value + + def _distribute_region( + self, spec: DTensorSpec, generator: torch.Generator | None = None + ): + pass + + def _manual_seed(self, parallel_seed: int) -> None: + pass + + +class OffsetBasedRNGTracker(_RNGStateTracker): + """ + This subclass of ``_RNGStateTracker`` defines the default policy of how RNG states + should be shared and synchronized among all ranks to respect the semantics of DTensor + random operators. + + note: _RNGStateTracker only supports cuda/cuda-like device. + """ + + def __init__( + self, + device_mesh: DeviceMesh, + run_state_sync: bool = True, + ): + super().__init__(_resolve_device(device_mesh=device_mesh)) + assert self._device_handle is not None + # DTensor RNG tracker so far only supports CUDA/CUDA-like devices + if self._device.type == "cpu": + raise RuntimeError( + f"{self.__class__.__name__} instantiation requires the presence of " + f"CUDA/CUDA-like/XPU device. Got {self._device.type} instead." + ) + + rng_state = self._get_device_state() + if run_state_sync: + # synchronize RNG state using rank 0's current one + torch.distributed.broadcast(rng_state, 0) + my_rng_state = self._get_device_state() + if not all(my_rng_state == rng_state): + logger.warning( + "DTensor is synchronizing RNG states of every rank with the state from rank 0. " + "This behavior is deprecated. " + "Please call `torch.manual_seed()` on every rank that participates in SPMD DTensor Operations with " + "the same seed. If using Pipeline Parallelism, each pipeling state would use a different seed, " + "but all ranks belonging to one pipeline stage would use the same seed." + ) + self._set_device_state(rng_state) + + def _get_device_state(self) -> torch.Tensor: + if self._device.type == "hpu": + self._device_handle.set_rng_ctx("philox") + rng_state = self._device_handle.get_rng_state().to(self._device) + if self._device.type == "hpu": + self._device_handle.unset_rng_ctx("philox") + return rng_state + + def _set_device_state(self, state: torch.Tensor): + # It seems that the underlying generator wants a cpu tensor but the dtensor code expects `_get_device_state` + # to convert to a 'device' tensor, probably because we may use it with our backend comms for sync/debug + # for now, we just convert back to cpu here to make sure it always works. + if self._device.type == "hpu": + self._device_handle.set_rng_ctx("philox") + self._device_handle.set_rng_state(state.to("cpu")) + if self._device.type == "hpu": + self._device_handle.unset_rng_ctx("philox") + + @contextlib.contextmanager + def _distribute_region( + self, spec: DTensorSpec, generator: torch.Generator | None = None + ): + from torch.distributed._local_tensor import maybe_enable_local_tracker + + if local_tracker_context := maybe_enable_local_tracker( + self._device.type, self.distribute_region_enabled, spec, generator + ): + with local_tracker_context: + yield + return + + # regular (non-LocalTensor) mode + if generator is not None: + # This is a little hacky, but for any user-passed generator, we store its state under a unique key, + # not because we need to keep a copy of it but because its the easiest way to make it work with the + # existing set/get APIs. We also ensure we remove it from rng_states after each _distribute_region. + state = _PhiloxState(generator.get_state()) + else: + state = _PhiloxState(self._get_device_state()) + + if self.distribute_region_enabled: + if self._device.type == "hpu": + self._device_handle.set_rng_ctx("philox") + old_offset = state.offset + self._set_pre_op_offset(state, spec) + with torch.random.fork_rng( + devices=[self._device], device_type=self._device.type + ): + assert self._device_handle is not None + self._device_handle.set_rng_state(state.state) + try: + yield # execute the region code + finally: + # update offset to synchronize among ranks + self._set_post_op_offset(state, spec, old_offset) + if self._device.type == "hpu": + self._device_handle.unset_rng_ctx("philox") + else: + yield + + if generator is not None: + # ensure we (a) propagate the state advancement back to the user's RNG so its visible and impacts any future + # usage of that RNG (dtensor or non-dtensor), (b) drop it from our own cache so that if the user updates + # the seed value in their rng and uses it with DTensor again, we always use the latest value + generator.set_state(state.state) + else: + self._set_device_state(state.state) + + def _set_pre_op_offset(self, state: _PhiloxState, spec: DTensorSpec) -> None: + """Set the starting RNG offset for current device's local shard before actual + op execution. The pre_op_offset value should start from the current RNG offset + and increment by the size of local shard until it reaches the size of the whole + DTensor. For different ranks that hold the same DTensor shard, their pre_op_offset + will be the same. + + Args: + state (:class:`Tensor`): The generator state to modify + spec (:class:`DTensorSpec`): the spec of the DTensor object on which + we prepare the offset for running random ops. + + Returns: + None + + .. warning:: + Note that, current implementation does not consider DTensor's continguity. + + Example: + take a DTensor of shape [8, 16] as an example. Assume that the DTensor + is placed on a device mesh with placements ([Shard(1), Replicate(), Shard(0)]), + and the mesh is: + [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] + ``spec.mesh.get_coordinate()`` provides the coordinate of the current rank + in the mesh. For example, the coordinate of rank 5 is (1, 0, 1). + + Another concept to introduce besides rank coordinate is shard coordinate. + Each rank holds a local shard of the DTensor. In the example, the DTensor + is partitioned into 4 [4, 8] shards. The first shard has 2 replicas and + rank 0 (coord (0, 0, 0)) and rank 2 (coord (0, 1, 0)) have 1 replica each. + That being said, the local shard on rank 0 and rank 2 correspond to the same + shard of the DTensor. To denote each DTensor shard, we use a shard coordinate + (in the example, it will be a tuple (i, j) where shard (i, j) has the slice + DTensor[4 * i : 4 * (i + 1), 8 * j : 8 * (j + 1)], 0 <= i < 2, 0 <= j < 2). + + Once we have rank coordinate and shard coordinate, we can calculate on each rank + what shard of the DTensor the rank holds, with the help of dim_map. The dim_map + of the above DTensor is [2, 0] so the shard coordinate of a rank with rank coord + (x, y, z) is simply (z, x) by taking(rank_coord[dim_map[0]],rank_coord[dim_map[1]]). + Following this calculation, + rank 0 and rank 2 holds the shard of coord (0, 0); + rank 1 and rank 3 holds the shard of coord (0, 1); + rank 4 and rank 6 holds the shard of coord (1, 0); + rank 5 and rank 7 holds the shard of coord (1, 1); + + The last value to calculate before obtaining the starting offset is the shard linear index. + The starting offset for each rank will be its shard_linear_index * local_tensor_numel. + """ + mesh = spec.mesh + mesh_coordinate = mesh.get_coordinate() + assert mesh_coordinate is not None + + # Compute shard index and total number of shards on each tensor dim + shard_idx_by_dim, total_num_shards_by_dim = _calc_shard_info( + mesh_coordinate, spec + ) + + # compute shard linear index + shard_linear_idx = self._calc_shard_linear_idx( + shard_idx_by_dim, total_num_shards_by_dim + ) + + # compute starting offset using the first shard's size + local_size_on_rank_0 = _calc_first_shard_size(spec) + + from torch.distributed.tensor._ops.utils import prod + + local_size = prod(local_size_on_rank_0) + + # get current RNG offset + current_offset = state.offset + + # pytorch: offset must be multiple of 4 + # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp + offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4 + state.offset = current_offset + offset_incr + + def _set_post_op_offset( + self, state: _PhiloxState, spec: DTensorSpec, old_offset: int + ) -> None: + """Sets the RNG to a synchronized state after running the local random op. Every + rank should set its RNG offset to `old_offset + DTensor.numel()` where old_offset is + the offset before calling `set_pre_op_offset` i.e. the offset before running DTensor + random ops. + + Args: + state (:class:`Tensor`): The generator state to modify. + spec (:class:`DTensorSpec`): the spec of the DTensor object on which + we post-process the offset for running random ops. + + Returns: + None + """ + dtensor_shape = spec.shape + + from torch.distributed.tensor._ops.utils import prod + + numel = prod(dtensor_shape) + # pytorch: offset must be multiple of 4 + # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp + numel = (numel + 3) // 4 * 4 + state.offset = old_offset + numel + + def _calc_shard_linear_idx( + self, shard_coord: list[int], shard_size: list[int] + ) -> int: + return _calc_shard_linear_idx(shard_coord, shard_size) + + +def _calc_first_shard_size(spec: DTensorSpec) -> list[int]: + local_size_on_rank_0 = list(spec.shape) + for idx, placement in enumerate(spec.placements): + if isinstance(placement, Shard | _StridedShard): + mesh_dim_size = spec.mesh.size(idx) + shard_dim = placement.dim + local_size_on_rank_0[shard_dim], _ = placement._local_shard_size_and_offset( + spec.shape[shard_dim], + mesh_dim_size, + 0, + ) + return local_size_on_rank_0 + + +def _calc_shard_info( + mesh_coordinate: list[int], spec: DTensorSpec +) -> tuple[list[int], list[int]]: + mesh = spec.mesh + # note: dim_map does not allow double sharding which is the FSDP(fully_shard)+TP + # case. Replace the custom logic with dim_map once we support it. + dim_map: list[int | list[int]] = [-1] * spec.ndim + for i, placement in enumerate(spec.placements): + if isinstance(placement, Shard | _StridedShard): + shard_dim = placement.dim + if dim_map[shard_dim] == -1: + dim_map[shard_dim] = [i] + else: + mesh_dim_list = dim_map[shard_dim] + assert isinstance(mesh_dim_list, list) + mesh_dim_list.append(i) + + # Compute shard coordinate: + # The coordinate on each tensor dim is a tuple (idx, range) + # If a DTensor is partitioned on its dim i into n shards, and the current rank + # holds the j-th, then its shard coordinate will be (idx=j, range=n) on dim i + assert mesh_coordinate is not None + mesh_size = mesh.shape + shard_idx_by_dim = [] + total_num_shards_by_dim = [] # total number of shards on each tensor dim + for mesh_dim in dim_map: + shard_idx = 0 + total_num_shards = 1 + # the tensor dim is sharded on more than 1 mesh dim + if isinstance(mesh_dim, list): + rank_coord = [mesh_coordinate[d] for d in mesh_dim] + num_shards = [mesh_size[d] for d in mesh_dim] + # compute the shard idx and total number of shards + for idx, size in zip(rank_coord, num_shards): + shard_idx = shard_idx * size + idx + total_num_shards *= size + + shard_idx_by_dim.append(shard_idx) + total_num_shards_by_dim.append(total_num_shards) + return shard_idx_by_dim, total_num_shards_by_dim + + +def _calc_shard_linear_idx(shard_coord: list[int], shard_size: list[int]) -> int: + # compute shard linear index + shard_linear_idx = 0 + shard_coord_stride = 1 + for idx, size in zip(reversed(shard_coord), reversed(shard_size)): + shard_linear_idx += idx * shard_coord_stride + shard_coord_stride *= size + + return shard_linear_idx + + +def _resolve_device(device_mesh: DeviceMesh) -> torch.device: + device_type = device_mesh.device_type + device_handle = _get_device_handle(device_type) + assert device_handle is not None + device_idx = device_mesh.get_rank() % device_handle.device_count() + + @maybe_run_for_local_tensor + def get_device(device_idx): + return torch.device(f"{device_type}:{device_idx:d}") + + return get_device(device_idx) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_redistribute.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_redistribute.py new file mode 100644 index 0000000000000000000000000000000000000000..7119fd9ae6529c174f7a34f55145434a35070e2a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_redistribute.py @@ -0,0 +1,1067 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import contextlib +import dataclasses +import itertools +import logging +import weakref +from collections import defaultdict +from collections.abc import Sequence +from functools import cache +from typing import cast, NamedTuple, Optional + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.tensor._api as dtensor +from torch.distributed._functional_collectives import _are_we_tracing +from torch.distributed.tensor._dtensor_spec import ( + DTensorSpec, + ShardOrder, + ShardOrderEntry, + TensorMeta, +) +from torch.distributed.tensor.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import ( + _StridedShard, + Partial, + Placement, + Replicate, + Shard, +) +from torch.utils._debug_mode import get_active_debug_mode + + +logger = logging.getLogger(__name__) + +# Global configuration flag to control the redistribution planning strategy. +# When True, forces the graph-based algorithm using Dijkstra's shortest path. +# When False, prefers the greedy algorithm for faster planning. Uses the graph-based algorithm +# only when necessary to support strided-shard redistribution +_FORCE_MIN_COST_REDISTRIBUTION_PLAN: Optional[bool] = None + + +@contextlib.contextmanager +def use_min_cost_redistribution_plan(enabled: bool = True): + """ + Context manager to control the redistribution planning strategy for DTensor operations. + + This context manager allows you to choose between two algorithms for computing the + sequence of collective operations needed to redistribute a DTensor from one placement + to another: + + - **Graph-based**: Uses Dijkstra's algorithm to find the minimum-cost path + through all possible placement transformations. This approach considers the global + cost of all collective operations and finds the optimal sequence. Best for complex + redistribution patterns where reducing communication cost and memory overhead is critical. + + - **Greedy**: Uses a heuristic approach that makes locally optimal choices + at each step. This is faster to compute but may not produce the globally optimal + transformation sequence. Best for simple redistribution patterns or when planning + speed is more important than optimal communication. + + **Default Behavior (without this context manager):** + + When this context manager is NOT used, the algorithm selection follows this priority: + + 1. **Non-default shard orders** + → Always use graph-based algorithm (required for correctness) + + 2. **Explicit `use_graph_based_transform` parameter** to `_gen_transform_infos_non_cached` + → Use the specified algorithm (True = graph-based, False = greedy) + + 3. **No explicit parameter** (default case) + → Use greedy algorithm for faster planning + + **Behavior with this context manager:** + + This context manager overrides the default selection by setting the global flag + `_FORCE_MIN_COST_REDISTRIBUTION_PLAN`, which takes precedence over the explicit + `use_graph_based_transform` parameter (but not over non-default shard order requirements). + + **Cache Considerations:** + + The redistribution planner caches transform info for performance via the `@cache` + decorator on `_gen_transform_infos`. If you need to change the algorithm selection + for the same input specs, clear the cache using `_gen_transform_infos.cache_clear()` + to ensure the new setting takes effect and doesn't reuse cached results from a + previous run. + + Args: + enabled (bool): If True, forces the use of the graph-based algorithm. + If False, forces the use of the greedy algorithm. + Default: True + """ + global _FORCE_MIN_COST_REDISTRIBUTION_PLAN + old_value = _FORCE_MIN_COST_REDISTRIBUTION_PLAN + _FORCE_MIN_COST_REDISTRIBUTION_PLAN = enabled + try: + yield + finally: + _FORCE_MIN_COST_REDISTRIBUTION_PLAN = old_value + + +class _TransformInfo(NamedTuple): + mesh_dim: int + src_dst_placements: tuple[Placement, Placement] + # logical_shape on this mesh dimension + logical_shape: list[int] + + +# Global cache for DTensorRedistributePlanner instances +_planner_cache: dict[ + tuple[weakref.ReferenceType, int], "DTensorRedistributePlanner" +] = {} + + +def get_redistribute_planner( + device_mesh: DeviceMesh, tensor_dimension: int +) -> "DTensorRedistributePlanner": + """ + Factory function to get or create a DTensorRedistributePlanner instance. + This function provides transparent caching of planner instances based on + device_mesh and tensor_dimension. Multiple calls with the same parameters + will return the same cached instance for better performance. + Args: + device_mesh: The device mesh for the planner + tensor_dimension: Number of tensor dimensions + Returns: + A DTensorRedistributePlanner instance (potentially cached) + """ + cache_key = (weakref.ref(device_mesh), tensor_dimension) + + if cache_key not in _planner_cache: + planner = DTensorRedistributePlanner(device_mesh, tensor_dimension) + _planner_cache[cache_key] = planner + + return _planner_cache[cache_key] + + +def clear_redistribute_planner_cache() -> None: + """Clear the cache of DTensorRedistributePlanner instances.""" + _planner_cache.clear() + + +class DTensorRedistributePlanner: + """ + This class is used to plan the collective calls to transform the local shard + of the DTensor from its current spec to the target spec. + Suppose there are N tensor dimensions and M mesh dimensions, the total + possible state size will be (N+2)*M*M!. + Note: Use get_redistribute_planner() factory function instead of direct + instantiation for automatic caching. + """ + + @dataclasses.dataclass(frozen=True, slots=True) + class DistState: + placements: tuple[Placement, ...] + tensor_dim_to_mesh_dim: ShardOrder + _hash: int | None = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) + + def __str__(self): + return DTensorSpec.format_shard_order_str( + self.placements, + self.tensor_dim_to_mesh_dim, + ) + + def __repr__(self): + return self.__str__() + + def __post_init__(self): + # precompute hash after all attributes are set + object.__setattr__( + self, + "_hash", + self._compute_hash(), + ) + + def __hash__(self) -> int: + return self._hash if self._hash is not None else self._compute_hash() + + def _compute_hash(self) -> int: + return hash( + ( + self.placements, + self.tensor_dim_to_mesh_dim, + ) + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DTensorRedistributePlanner.DistState): + return False + if self._hash != other._hash: + return False + return ( + self.placements, + self.tensor_dim_to_mesh_dim, + ) == ( + other.placements, + other.tensor_dim_to_mesh_dim, + ) + + def _to_tuple(self, x): + """Convert a nested list structure to a nested tuple structure.""" + if isinstance(x, list | tuple): + return tuple(self._to_tuple(item) for item in x) + return x + + @staticmethod + def _dict_to_ShardOrder(x: dict[int, list[int]]) -> ShardOrder: + """Convert dict to ShardOrder""" + return tuple( + ShardOrderEntry(tensor_dim=key, mesh_dims=tuple(value)) + for key, value in sorted(x.items()) + if value + ) + + @staticmethod + def _ShardOrder_to_dict(x: ShardOrder) -> dict[int, list[int]]: + """Convert ShardOrder to dict with tensor dim as key""" + tensor_mesh_dim_dict = defaultdict(list) + for entry in x: + tensor_mesh_dim_dict[entry.tensor_dim] = list(entry.mesh_dims) + return tensor_mesh_dim_dict + + @staticmethod + def stringify_transform_infos( + mesh: DeviceMesh, + transform_infos: Sequence[_TransformInfo], + src_placement: tuple[Placement, ...], + src_shard_order: ShardOrder | None = None, + ) -> str: + """ + Generate a string representation of the sequence of state transitions + (placements and shard orders) as described by the given transform_info. + + Args: + mesh: The DeviceMesh used for the redistribution. + transform_infos: A sequence of _TransformInfo objects describing each + transformation step. + src_placement: The initial tuple of Placement objects. + src_shard_order: (Optional) The initial ShardOrder representing + the mapping of tensor dimensions to mesh dimensions. If None, + the default shard order is computed from src_placement and mesh. + + Returns: + A string showing the sequence of DistState transitions, separated by '->'. + """ + assert len(src_placement) == mesh.ndim + if src_shard_order is None: + src_shard_order = DTensorSpec.compute_default_shard_order(src_placement) + cur_placement = list(src_placement) + shard_order_dict = DTensorRedistributePlanner._ShardOrder_to_dict( + src_shard_order + ) + cur_state = DTensorRedistributePlanner.DistState( + tuple(cur_placement), src_shard_order + ) + state_list = [ + cur_state, + ] + for transform_info in transform_infos: + src_dim_placement, dst_dim_placement = transform_info.src_dst_placements + if src_dim_placement.is_shard(): + src_dim = src_dim_placement.dim # type: ignore[attr-defined] + assert ( + src_dim in shard_order_dict and len(shard_order_dict[src_dim]) > 0 + ) + shard_order_dict[src_dim].pop() + if dst_dim_placement.is_shard(): + dst_dim = dst_dim_placement.dim # type: ignore[attr-defined] + if dst_dim not in shard_order_dict: + shard_order_dict[dst_dim] = [] + shard_order_dict[dst_dim].append(transform_info.mesh_dim) + cur_placement[transform_info.mesh_dim] = dst_dim_placement + new_state = DTensorRedistributePlanner.DistState( + tuple(cur_placement), + DTensorRedistributePlanner._dict_to_ShardOrder(shard_order_dict), + ) + state_list.append(new_state) + return "->".join([str(s) for s in state_list]) + + def __init__( + self, + device_mesh: DeviceMesh, + tensor_dimension: int, + ) -> None: + """ + Initialize DTensorRedistributePlanner. + + Args: + device_mesh: The device mesh for this planner + tensor_dimension: Number of tensor dimensions + """ + self.device_mesh = device_mesh + self.coordinate = device_mesh.get_coordinate() + assert self.coordinate is not None + self.tensor_dimension = tensor_dimension + self.setup_collective_cost() + + def setup_collective_cost( + self, + all_reduce_cost: int = 4, + all_to_all_cost: int = 1, + all_gather_cost: int = 2, + reduce_scatter_cost: int = 2, + chunk_cost: int = 0, + ) -> None: + """ + Set up the cost weights for different collective operations. + """ + # those can be turned in a handler considering the tensor dim size + self.all_reduce_cost = all_reduce_cost + self.all_to_all_cost = all_to_all_cost + self.all_gather_cost = all_gather_cost + self.reduce_scatter = reduce_scatter_cost + self.chunk_cost = chunk_cost + + def get_next_state( + self, + placements: tuple[Placement, ...], + tensor_mesh_dim_tuple: ShardOrder, + ) -> dict["DTensorRedistributePlanner.DistState", int]: + # We map tensor dimensions to device mesh axes, similar to JAX-style + # sharding representation. Notation: + # S()[] means tensor dimension + # is sharded on the listed device mesh axes, where + # is sorted by device order. + # + # To generalize to arbitrary dimensionality, we use the following notation: + # S(a)[x, ...] : tensor dimension 'a' is sharded on device mesh axes x, ... (variadic, possibly empty) + # R[...] : replicated on the listed device mesh axes (possibly empty) + # P[...] : partial on the listed device mesh axes (possibly empty) + # The ellipsis '...' denotes a variadic wildcard, i.e., zero or more device mesh axes. + # + # Below are possible transitions from one sharding state to another. + # We use `S` for Shard, `R` for Replicate, and `P` for Partial. + # + # Case 1. Shard(a) -> Shard(b), use all-to-all (a2a), applies to: + # S(a)[..., x] -> S(b)[..., x] + # or + # S(a)[..., x, y]S(b)[..., z, k] -> S(a)[..., x]S(b)[..., z, k, y] + # where device order of 'y' > device order of 'z' and 'k' + # + # Case 2. Shard() -> Replicate(), use all-gather, applies to: + # S(a)[..., x, y, z] -> S(a)[..., x, y] + # + # Case 3. Partial() -> Replicate(), use all-reduce, applies to: + # P[..., x, y] -> P[..., y] or P[..., x] + # Note: this case can be disabled because all-reduce technically is not + # a primitive since it combines a reduce-scatter + all-gather. + # + # Case 4. Replicate() -> Shard(), use chunk, applies to: + # S(a)[..., z] -> S(a)[..., z, y] (`a` can be any tensor dim). Note that + # 'y' must be after 'z'. + # + # Case 5. Partial() -> Shard(), use reduce-scatter, applies to: + # P[..., x, y] -> P[..., x]S(a)[..., y] or P[..., x, y] -> P[..., y]S(a)[..., x] + # + # Case 6. Replicate() -> Partial(), local math op, applies to: + # R* -> P[..., x] + # + # NB: Device order in Partial placement doesn't take impact. We should be able + # to operate on any Partial mesh dim. + + # list of [DistState, cost] + all_next_state: dict[DTensorRedistributePlanner.DistState, int] = {} + + tensor_mesh_dim_dict = DTensorRedistributePlanner._ShardOrder_to_dict( + tensor_mesh_dim_tuple + ) + ###################################################################### + # handle case 1: Shard(a) -> Shard(b) + # For S(a), S(b), only the last device order of S(a) and S(b) can be a2a + # interchangeably. + + # convert sparse tuple + for entry in tensor_mesh_dim_tuple: + src_tensor_dim = entry.tensor_dim + for dst_tensor_dim in range(self.tensor_dimension): + if src_tensor_dim == dst_tensor_dim: + continue + # try move the last sharded device dim from + # Shard(src_tensor_dim) to Shard(dst_tensor_dim) + move_mesh_dim = tensor_mesh_dim_dict[src_tensor_dim].pop() + tensor_mesh_dim_dict[dst_tensor_dim].append(move_mesh_dim) + new_placements = list(placements) + new_placements[move_mesh_dim] = Shard(dst_tensor_dim) + dist_state = self.DistState( + self._to_tuple(new_placements), + DTensorRedistributePlanner._dict_to_ShardOrder( + tensor_mesh_dim_dict + ), + ) + all_next_state[dist_state] = self.all_to_all_cost + # reset content for next iteration + tensor_mesh_dim_dict[src_tensor_dim].append(move_mesh_dim) + tensor_mesh_dim_dict[dst_tensor_dim].pop() + # TODO(zpcore): support discovering submesh to prevent padding when + # tensor dim is not divisible by the mesh dim. + + ###################################################################### + # handle case 2: Shard() -> Replicate() + for entry in tensor_mesh_dim_tuple: + src_tensor_dim = entry.tensor_dim + move_mesh_dim = tensor_mesh_dim_dict[src_tensor_dim].pop() + new_placements = list(placements) + new_placements[move_mesh_dim] = Replicate() + dist_state = self.DistState( + self._to_tuple(new_placements), + DTensorRedistributePlanner._dict_to_ShardOrder(tensor_mesh_dim_dict), + ) + tensor_mesh_dim_dict[src_tensor_dim].append(move_mesh_dim) + all_next_state[dist_state] = self.all_gather_cost + + ###################################################################### + # handle case 3: Partial() -> Replicate() + for src_mesh_dim, placement in enumerate(placements): + if not isinstance(placement, Partial): + continue + new_placements = list(placements) + new_placements[src_mesh_dim] = Replicate() + dist_state = self.DistState( + self._to_tuple(new_placements), tensor_mesh_dim_tuple + ) + all_next_state[dist_state] = self.all_reduce_cost + + ###################################################################### + # handle case 4: Replicate() -> Shard() + for mesh_dim, placement in enumerate(placements): + if not isinstance(placement, Replicate): + continue + for dst_tensor_dim in range(self.tensor_dimension): + # try convert placement[mesh_dim] to Shard(dst_tensor_dim) + new_placements = list(placements) + new_placements[mesh_dim] = Shard(dst_tensor_dim) + tensor_mesh_dim_dict[dst_tensor_dim].append(mesh_dim) + dist_state = self.DistState( + self._to_tuple(new_placements), + DTensorRedistributePlanner._dict_to_ShardOrder( + tensor_mesh_dim_dict + ), + ) + all_next_state[dist_state] = self.chunk_cost + tensor_mesh_dim_dict[dst_tensor_dim].pop() + + ###################################################################### + # handle case 5: Partial() -> Shard() + for mesh_dim, placement in enumerate(placements): + if not isinstance(placement, Partial): + continue + for dst_tensor_dim in range(self.tensor_dimension): + # try convert placement[mesh_dim] to Shard(dst_tensor_dim) + new_placements = list(placements) + new_placements[mesh_dim] = Shard(dst_tensor_dim) + tensor_mesh_dim_dict[dst_tensor_dim].append(mesh_dim) + dist_state = self.DistState( + self._to_tuple(new_placements), + DTensorRedistributePlanner._dict_to_ShardOrder( + tensor_mesh_dim_dict + ), + ) + all_next_state[dist_state] = self.reduce_scatter + tensor_mesh_dim_dict[dst_tensor_dim].pop() + + ###################################################################### + # handle case 6: Replicate() -> Partial(), default to partial(sum) + for mesh_dim, placement in enumerate(placements): + if not isinstance(placement, Replicate): + continue + new_placements = list(placements) + new_placements[mesh_dim] = Partial() + dist_state = self.DistState( + self._to_tuple(new_placements), tensor_mesh_dim_tuple + ) + all_next_state[dist_state] = self.chunk_cost + + return all_next_state + + # TODO(zpcore): if the dst_state contains special placement like + # `_MaskPartial`, we will never reach that state. Need to support this case. + def find_min_cost_path( + self, src_state: DistState, dst_state: DistState + ) -> list["DTensorRedistributePlanner.DistState"]: + """ + Find the min cost path from src_state to dst_state using Dijkstra's + algorithm. + + Args: + src_state: The source state + dst_state: The destination state + + Returns: + A list of states representing the min cost path from src_state to + dst_state + """ + import heapq + + # priority queue (cost, counter, state, path) for Dijkstra's algorithm + # use counter to break ties and avoid comparing DistState objects + counter = 0 + pq: list[ + tuple[ + int, + int, + DTensorRedistributePlanner.DistState, + list[DTensorRedistributePlanner.DistState], + ] + ] = [(0, counter, src_state, [src_state])] + visited = set() + while pq: + cost, _, current_state, path = heapq.heappop(pq) + if current_state == dst_state: + return path + if current_state in visited: + continue + visited.add(current_state) + # get all possible next states and their costs + next_states = self.get_next_state( + current_state.placements, current_state.tensor_dim_to_mesh_dim + ) + for next_state, transition_cost in next_states.items(): + if next_state not in visited: + new_cost = cost + transition_cost + new_path = path + [next_state] + counter += 1 + heapq.heappush(pq, (new_cost, counter, next_state, new_path)) + raise AssertionError( + f"No path found from src_state {src_state} to dst_state {dst_state}" + ) + + def get_logical_shape( + self, + src_state: "DTensorRedistributePlanner.DistState", + mesh_dim: int, + full_tensor_shape: tuple[int, ...], + ) -> list[int]: + new_logical_shape = list(full_tensor_shape) + assert self.coordinate is not None + for entry in src_state.tensor_dim_to_mesh_dim: + tensor_dim = entry.tensor_dim + mesh_dims = entry.mesh_dims + assert len(mesh_dims) > 0 + for mdim in mesh_dims: + if mdim == mesh_dim: + continue + new_size = Shard.local_shard_size_and_offset( + new_logical_shape[tensor_dim], + self.device_mesh.size(mesh_dim=mdim), + self.coordinate[mdim], + )[0] + new_logical_shape[tensor_dim] = new_size + return new_logical_shape + + def generate_graph_based_transform_infos( + self, + src_spec: DTensorSpec, + dst_spec: DTensorSpec, + full_tensor_shape: tuple[int, ...], + ) -> list[_TransformInfo]: + # In case _StridedShard exists in placements, we let _StridedShard have + # higher priority to express shard_order. + if any( + isinstance(placement, _StridedShard) for placement in src_spec.placements + ): + src_placements, src_shard_order = ( + DTensorSpec._normalize_placements_into_shard_order( + src_spec.placements, src_spec.mesh + ) + ) + else: + src_placements = src_spec.placements + src_shard_order = src_spec.shard_order + if any( + isinstance(placement, _StridedShard) for placement in dst_spec.placements + ): + dst_placements, dst_shard_order = ( + DTensorSpec._normalize_placements_into_shard_order( + dst_spec.placements, dst_spec.mesh + ) + ) + else: + dst_placements = dst_spec.placements + dst_shard_order = dst_spec.shard_order + if src_shard_order is None or dst_shard_order is None: + raise NotImplementedError( + "Redistribution of _StridedShard placement is only supported for " + "_StridedShard that can be converted to ordered Shard placements. " + "Full _StridedShard redistribution support is not yet implemented." + ) + src_state = self.DistState(src_placements, src_shard_order) + dst_state = self.DistState(dst_placements, dst_shard_order) + transform_infos: list[_TransformInfo] = [] + state_path = self.find_min_cost_path(src_state, dst_state) + for cur_state, nxt_state in itertools.pairwise(state_path): + # find the mesh_dim that is different between cur_state and nxt_state + if cur_state.placements != nxt_state.placements: + update_mesh_dim = -1 + for mesh_dim, (cur_placement, nxt_placement) in enumerate( + zip(cur_state.placements, nxt_state.placements) + ): + if cur_placement != nxt_placement: + if update_mesh_dim != -1: + raise AssertionError( + "Multiple mesh_dims are different between cur_state and nxt_state" + ) + update_mesh_dim = mesh_dim + logical_shape = self.get_logical_shape( + cur_state, mesh_dim, full_tensor_shape + ) + transform_infos.append( + _TransformInfo( + mesh_dim=update_mesh_dim, + src_dst_placements=(cur_placement, nxt_placement), + logical_shape=logical_shape, + ) + ) + + return transform_infos + + def generate_greedy_transform_infos( + self, + src_spec: DTensorSpec, + dst_spec: DTensorSpec, + ) -> list[_TransformInfo]: + """ + Generate the transform infos from the source placements to the target placements. + + To transform from source to target placement it might have multiple steps, i.e. it + might decompose Si -> Sj into Si -> R -> Sj. + This would detect if there're mis-aligned/nested shardings between src/dst placements. + E.g. Suppose the redistribution to perform is (Shard(0), Shard(0)) -> (Replicate(), Shard(0)), + in this case Shard(0) -> Shard(0) for mesh dimension 1 actually needs resharding, because in + the former is a nested-sharding of a tensor already already sharded dimension 0, whereas + the latter is the first sharding on tensor dimension 0. + """ + # logical shape records the logic tensor shape on the mesh dimension + # this is useful to ensure uneven sharding gets correct output shape + assert self.coordinate is not None + initial_logical_shape = list(src_spec.shape) + mesh_dims_to_logical_shape = [initial_logical_shape] + transform_infos: list[_TransformInfo] = [] + if self.device_mesh.ndim == 1: + # if device_mesh is 1D, redistribute is a simple direct + # transformation + transform_infos.append( + _TransformInfo( + mesh_dim=0, + src_dst_placements=(src_spec.placements[0], dst_spec.placements[0]), + logical_shape=initial_logical_shape, + ) + ) + return transform_infos + + # Handle multi-dim device mesh placement redistribution First, we need + # to build the logical shape for each mesh dim for correct allgather + # uneven shards on each mesh dim (with dynamic padding) + for i, src in enumerate(src_spec.placements): + current_logical_shape = mesh_dims_to_logical_shape[i] + if isinstance(src, Shard): + if i < self.device_mesh.ndim - 1: + # calculate and save the logical shape for this sharding + mesh_dim_size = self.device_mesh.size(mesh_dim=i) + local_shard_size, _ = src._local_shard_size_and_offset( + current_logical_shape[src.dim], + mesh_dim_size, + self.coordinate[i], + ) + new_logical_shape = list(current_logical_shape) + new_logical_shape[src.dim] = local_shard_size + mesh_dims_to_logical_shape.append(new_logical_shape) + else: + mesh_dims_to_logical_shape.append(current_logical_shape) + + # Next, we need to derive the transform infos from src to dst + # placements, here we use a greedy search with step by step state + # transformations + current_placements = list(src_spec.placements) + target_placements = list(dst_spec.placements) + + if src_spec.num_shards > 1: + # If src_spec have sharding, it could potentially have sharding that + # is misaligned with dst_spec a common case of this is nested + # sharding (i.e. (S(0), S(0)) -> (R, S(0))). In those cases, we + # first traverse from inner placement to outer placement to detect + # misaligned shardings and properly replicate nested sharding first. + for mesh_dim in reversed(range(len(current_placements))): + current = current_placements[mesh_dim] + target = target_placements[mesh_dim] + # If target is not Shard, we can directly redistribute since we + # are traversing from inner to outer placements here + if isinstance(target, Shard): + # If target is Shard, check for nested sharding on the + # tensor dim BEFORE the current mesh_dim + shard_dim = target.dim + current_mesh_sharding, target_mesh_sharding = [], [] + for i, (s, p) in enumerate( + zip(current_placements, target_placements) + ): + if i >= mesh_dim: + break + if s.is_shard(shard_dim): + current_mesh_sharding.append(i) + if p.is_shard(shard_dim): + target_mesh_sharding.append(i) + + if current_mesh_sharding != target_mesh_sharding: + # if current/target_placements have misaligned sharding + # on the tensor dim BEFORE the current mesh_dim, we need + # to replicate the tensor on the mesh dim first to clear + # the nested sharding + target = Replicate() + + if current != target: + transform_infos.append( + _TransformInfo( + mesh_dim=mesh_dim, + src_dst_placements=(current, target), + logical_shape=mesh_dims_to_logical_shape[mesh_dim], + ) + ) + current_placements[mesh_dim] = target + + # We always traverse from outer placement to inner placement to collect + # the remaining needed transform infos (i.e. the replication from nested + # sharding might need to further perform resharding to Shard again) + for mesh_dim, (current, target) in enumerate( + zip(current_placements, target_placements) + ): + if current != target: + transform_infos.append( + _TransformInfo( + mesh_dim=mesh_dim, + src_dst_placements=(current, target), + logical_shape=mesh_dims_to_logical_shape[mesh_dim], + ) + ) + current_placements[mesh_dim] = target + return transform_infos + + +def _gen_transform_infos_non_cached( + src_spec: DTensorSpec, + dst_spec: DTensorSpec, + use_graph_based_transform: bool | None = None, +) -> list[_TransformInfo]: + device_mesh = src_spec.device_mesh + src_shard_order = src_spec.shard_order + dst_shard_order = dst_spec.shard_order + # DTensorSpec should automatically generate shard_order, and it can be () if + # no shard. + assert src_shard_order is not None and dst_shard_order is not None + # Determine which transform strategy to use: + # 1. Non-standard device order → always use graph-based + # 2. Global flag or explicit parameter True → use graph-based + # 3. Otherwise → use greedy + has_non_default_order = not all( + DTensorSpec.is_default_device_order(order) + for order in (src_shard_order, dst_shard_order) + ) + + if has_non_default_order is True: + use_graph_based_transform = True + elif _FORCE_MIN_COST_REDISTRIBUTION_PLAN is not None: + use_graph_based_transform = _FORCE_MIN_COST_REDISTRIBUTION_PLAN + elif use_graph_based_transform is None: + use_graph_based_transform = False + drp = get_redistribute_planner(device_mesh, len(src_spec.shape)) + if use_graph_based_transform: + transform_infos = drp.generate_graph_based_transform_infos( + src_spec, dst_spec, src_spec.shape + ) + else: + transform_infos = drp.generate_greedy_transform_infos(src_spec, dst_spec) + return transform_infos + + +@cache +def _gen_transform_infos( + src_spec: DTensorSpec, + dst_spec: DTensorSpec, + use_graph_based_transform: bool | None = None, +) -> list[_TransformInfo]: + return _gen_transform_infos_non_cached( + src_spec, dst_spec, use_graph_based_transform + ) + + +def redistribute_local_tensor( + local_tensor: torch.Tensor, + current_spec: DTensorSpec, + target_spec: DTensorSpec, + *, + async_op: bool = False, + is_backward: bool = False, + use_graph_based_transform: bool | None = None, +) -> torch.Tensor: + """ + This redistribute the local tensor (torch.Tensor) from the current DTensorSpec to + the target DTensorSpec, which involves the necessary collective calls to transform + the local shard of the DTensor from its current spec to the target spec. + """ + + if current_spec.mesh != target_spec.mesh: + # TODO: alltoall/permute reshuffling to change device_mesh if they are not the same + raise NotImplementedError("Cross device mesh comm not supported yet!") + + new_local_tensor = local_tensor + device_mesh = current_spec.mesh + + my_coordinate = device_mesh.get_coordinate() + + if my_coordinate is None: + # if rank is not part of mesh, we skip redistribute and simply return local_tensor, + # which should be an empty tensor + return local_tensor + + if _are_we_tracing(): + transform_infos = _gen_transform_infos_non_cached( + current_spec, target_spec, use_graph_based_transform + ) + else: + transform_infos = _gen_transform_infos( + current_spec, target_spec, use_graph_based_transform + ) + + debug_mode = get_active_debug_mode() + redistribute_context = ( + debug_mode.record_redistribute_calls( # type: ignore[union-attr] + local_tensor, + current_spec.placements, + target_spec.placements, + DTensorRedistributePlanner.stringify_transform_infos( + device_mesh, + transform_infos, + current_spec.placements, + current_spec.shard_order, + ), + ) + if debug_mode is not None + else contextlib.nullcontext() + ) + + with redistribute_context: + for transform_info in transform_infos: + i = transform_info.mesh_dim + current, target = transform_info.src_dst_placements + num_chunks = device_mesh.size(mesh_dim=i) + + if current == target: + # short cut, just use the original local tensor + new_local_tensor = local_tensor + continue + + if num_chunks == 1: + # short cut, if there's only one shard, we don't need to do any collective + # comm, just use the original local tensor + new_local_tensor = local_tensor + continue + + if target.is_replicate(): + # Case 1: target is Replicate + if current.is_partial(): + partial_spec = cast(Partial, current) + new_local_tensor = partial_spec._reduce_value( + local_tensor, device_mesh, i + ) + elif current.is_shard(): + current_placement = cast(Shard, current) + new_local_tensor = current_placement._to_replicate_tensor( + local_tensor, device_mesh, i, transform_info.logical_shape + ) + else: + raise RuntimeError( + f"redistribute from {current} to {target} not supported yet" + ) + + elif target.is_shard(): + # Case 2: target is Shard + target_placement = cast(Shard, target) + if current.is_partial(): + partial_spec = cast(Partial, current) + new_local_tensor = partial_spec._reduce_shard_value( + local_tensor, device_mesh, i, target_placement + ) + elif current.is_replicate(): + # split the tensor and return the corresponding cloned local shard + new_local_tensor = target_placement._replicate_to_shard( + local_tensor, device_mesh, i, my_coordinate[i] + ) + else: + assert current.is_shard(), ( + f"Current placement should be shard but found {current}" + ) + shard_spec = cast(Shard, current) + if shard_spec.dim != target_placement.dim: + new_local_tensor = shard_spec._to_new_shard_dim( + local_tensor, + device_mesh, + i, + transform_info.logical_shape, + target_placement.dim, + ) + elif target.is_partial(): + if current.is_replicate(): + partial_spec = cast(Partial, target) + # skip the replicate to partial transformation when we are in backward pass + # In this case we keep the grad as replicate, this is because we don't + # want to convert the replicated gradients back to partial, although + # that's logically conform with the same layout, converting the gradients + # back to partial is actually useless as you would have to do reduce later + # which would be more expensive than keeping it replicate! For this reason, + # we keep the replicate grad here. + new_local_tensor = ( + partial_spec._partition_value(local_tensor, device_mesh, i) + if not is_backward + else local_tensor + ) + elif current.is_shard(): + if not is_backward: + raise RuntimeError( + f"redistribute from {current} to {target} not supported yet" + ) + # for backward shard -> partial, we just need to convert the shard to replicate + current_placement = cast(Shard, current) + new_local_tensor = current_placement._to_replicate_tensor( + local_tensor, device_mesh, i, transform_info.logical_shape + ) + else: + # partial -> partial no op, should never hit + new_local_tensor = local_tensor + + if not async_op and isinstance( + new_local_tensor, funcol.AsyncCollectiveTensor + ): + new_local_tensor = new_local_tensor.wait() + local_tensor = new_local_tensor + return new_local_tensor + + +class Redistribute(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + # pyre-fixme[2]: Parameter must be annotated. + ctx, + input: "dtensor.DTensor", + device_mesh: DeviceMesh, + placements: tuple[Placement, ...], + async_op: bool = False, + forward_dtype: torch.dtype | None = None, + backward_dtype: torch.dtype | None = None, + ): + ctx.async_op = async_op + ctx.backward_dtype = backward_dtype + ctx.original_dtype = input._local_tensor.dtype + + if forward_dtype is not None and forward_dtype != input._local_tensor.dtype: + local_tensor = input._local_tensor.to(dtype=forward_dtype) + current_spec = DTensorSpec( + mesh=device_mesh, + placements=input._spec.placements, + tensor_meta=TensorMeta( + shape=input.shape, + stride=input.stride(), + dtype=forward_dtype, + ), + ) + else: + local_tensor = input._local_tensor + current_spec = input._spec + + ctx.current_spec = current_spec + + if current_spec.placements != placements: + target_spec = DTensorSpec( + device_mesh, placements, tensor_meta=current_spec.tensor_meta + ) + + output = redistribute_local_tensor( + local_tensor, current_spec, target_spec, async_op=async_op + ) + else: + # use the same local tensor if placements are the same. + output = local_tensor + target_spec = current_spec + + # pyrefly: ignore [bad-argument-type] + return dtensor.DTensor( + # pyrefly: ignore [bad-argument-count] + output, + target_spec, + # pyrefly: ignore [unexpected-keyword] + requires_grad=input.requires_grad, + ) + + @staticmethod + def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override] + previous_spec = ctx.current_spec + async_op = ctx.async_op + backward_dtype = ctx.backward_dtype or ctx.original_dtype + + if backward_dtype != grad_output._local_tensor.dtype: + local_tensor = grad_output._local_tensor.to(dtype=backward_dtype) + current_spec = DTensorSpec( + mesh=grad_output._spec.device_mesh, + placements=grad_output._spec.placements, + tensor_meta=TensorMeta( + shape=grad_output.shape, + stride=grad_output.stride(), + dtype=backward_dtype, + ), + ) + previous_spec = DTensorSpec( + mesh=previous_spec.device_mesh, + placements=previous_spec.placements, + tensor_meta=current_spec.tensor_meta, + ) + else: + local_tensor = grad_output._local_tensor + current_spec = grad_output._spec + + output = redistribute_local_tensor( + local_tensor, + current_spec, + previous_spec, + async_op=async_op, + is_backward=True, + ) + + if output.dtype != ctx.original_dtype: + output = output.to(ctx.original_dtype) + + # normalize the target placement to replicate if it is partial + normalized_placements: list[Placement] = [] + for previous_placement in previous_spec.placements: + if previous_placement.is_partial(): + # keep target placement to replicate instead of partial in this case + normalized_placements.append(Replicate()) + else: + normalized_placements.append(previous_placement) + + spec = DTensorSpec( + previous_spec.device_mesh, + tuple(normalized_placements), + tensor_meta=TensorMeta( + shape=grad_output.shape, + stride=grad_output.stride(), + dtype=output.dtype, + ), + ) + # pyrefly: ignore [bad-argument-type] + output_dtensor = dtensor.DTensor( + # pyrefly: ignore [bad-argument-count] + output, + spec, + # pyrefly: ignore [unexpected-keyword] + requires_grad=grad_output.requires_grad, + ) + + return ( + output_dtensor, + None, + None, + None, + None, + None, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_sharding_prop.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_sharding_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..c1fddd05c9d6e7f38e637ea10a3bf2ffe0e16fe0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_sharding_prop.py @@ -0,0 +1,680 @@ +# mypy: allow-untyped-defs +import logging +import threading +from collections.abc import Callable, Sequence +from functools import lru_cache +from itertools import chain +from typing import cast + +import torch +from torch._guards import detect_fake_mode +from torch._ops import OpOverload +from torch._subclasses import FakeTensorMode +from torch.distributed._functional_collectives import _are_we_tracing +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( + OpInfo, + OpSchema, + OpSpec, + OpStrategy, + OutputSharding, + OutputSpecType, + RuntimeSchemaInfo, + StrategyType, + TupleStrategy, +) +from torch.distributed.tensor._utils import ( + compute_local_shape_and_global_offset, + compute_local_stride, +) +from torch.distributed.tensor.placement_types import _StridedShard, Shard + + +aten = torch.ops.aten + +log = logging.getLogger(__name__) + + +def _length(obj) -> int: + if obj is None: + return 0 + if not isinstance(obj, Sequence): + return 1 + return len(obj) + + +class LocalLRUCache(threading.local): + def __init__(self, user_function: Callable) -> None: + self.cache = lru_cache(None)(user_function) + + def __call__(self, *args, **kwargs) -> object: + return self.cache(*args, **kwargs) + + def cache_info(self): + return self.cache.cache_info() + + def cache_clear(self): + return self.cache.cache_clear() + + +class ShardingPropagator: + def __init__(self) -> None: + self.op_to_rules: dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {} + self.op_strategy_funcs: dict[ + OpOverload, + Callable[[OpSchema], StrategyType], + ] = {} + # op map to save static argnum to decide to reuse sharding prop cache or + # re-run sharding prop + self.op_to_schema_info: dict[OpOverload, RuntimeSchemaInfo] = {} + self.propagate_op_sharding = LocalLRUCache( + self.propagate_op_sharding_non_cached + ) + # op map to save indices of shape (and stride) args which may need to be + # modified in sharding prop + self.op_to_shape_and_stride_idx: dict[OpOverload, int | tuple[int, int]] = { + # new factory ops + aten.new_empty.default: 1, + aten.new_full.default: 1, + aten.new_ones.default: 1, + aten.new_zeros.default: 1, + aten.new_empty_strided.default: (1, 2), + # view ops + aten.expand.default: 1, + aten.reshape.default: 1, + aten.view.default: 1, + aten._unsafe_view.default: 1, + aten.select_backward.default: 1, + aten.slice_backward.default: 1, + } + + def register_sharding_prop_rule( + self, + op_overload: OpOverload, + rule_func: Callable[[OpSchema], OutputSharding], + schema_info: RuntimeSchemaInfo | None = None, + ): + """ + Register a sharding propagation rule for an operator. + """ + self.op_to_rules[op_overload] = rule_func + if schema_info is not None: + self.op_to_schema_info[op_overload] = schema_info + + def register_op_strategy( + self, + op_overload: OpOverload, + strategy_func: Callable[[OpSchema], StrategyType], + schema_info: RuntimeSchemaInfo | None = None, + ): + """ + Register a :class:`OpStrategy` generator for an operator. + + During the sharding propagation, DTensor wants to enumerate all + acceptable sharding specs (:class:`OpSpec`) for an operator, + and by "acceptable" we mean that the operator can be executed on + the ``_local_tensor`` of DTensor args/kwargs (with ``OpSpec.input_specs``) + and the output(s) constitute valid DTensor(s) (with ``OpSpec.output_specs``). + + ``strategy_func`` is the function that enumerates such acceptable specs + for the operator ``op_overload``. One general approach to write ``strategy_func`` + is, if the operator has simple arguments structure (e.g. mm, bmm), first enumerating + all sharding specs for the operands, and then filtering out the ones that + are not valid. For example, for ``mm``, the operands are two 2D tensors, and + if both ``input`` and ``mat2`` have sharding placements ``[Shard(0)]``, then this + is not an acceptable ``input_specs``. + + Once we have a way to enumerate all acceptable sharding specs, we can use each + of them to construct a :class:`OpSpec`. The ``OpSpec.input_specs`` directly comes + from the sharding spec, and the ``OpSpec.output_specs`` is therefore determined + (e.g. ``[Shard(1)]`` @ ``[Shard(0)]`` yields ``[Partial()]``). In addition, + :class:`OpSpec` also contains ``redistribute_cost`` which records the redistribution + cost from each :class:`OpSpec` in the source :class:`OpStrategy.strategies` to + the target sharding spec, for each operand. + + The ``strategy_func`` should return a :class:`OpStrategy` which contains a list of + all the :class:`OpSpec`s generated in the above. + + The optional ``schema_info`` tells which non-DTensor args/kwargs could affect the + cache and whether ``pytree`` is needed to flatten the nested args. ``static_argnum`` + marks the starting index of the non-DTensor args that should be hashed into the + sharding propagation hash key, and ``static_kwargkey`` marks the keys of the + non-DTensor kwargs that should be hashed. ``needs_pytree`` should be used when + the input arg has :class:`list` or :class:`dict` structure. + + For example, ``aten.cat.default`` op has a ``List[Tensor]`` argument ``tensors`` + and an ``int`` argument ``dim``. Because ``dim`` affects the sharding propagation + result, we want to pass ``RuntimeSchemaInfo(static_argnum=1)`` because the argument + index of ``dim`` is 1. Besides, we also want to set ``needs_pytree=True`` because + ``tensors`` needs be flattened in sharding propagation. Another example is + ``aten.histc.default``. ``histc`` has 4 arguments (self, bins, min, max) and the + last two would affect sharding propagation along with the :class:`DTensor` argument + ``self``. Since the argument index of ``min`` is 2, the `schema_info` should be + `RuntimeSchemaInfo(static_argnum=2)`. + """ + self.op_strategy_funcs[op_overload] = strategy_func + if schema_info is not None: + self.op_to_schema_info[op_overload] = schema_info + + def _propagate_tensor_meta_non_cached( + self, op_schema: OpSchema + ) -> None | TensorMeta | Sequence[TensorMeta | None]: + """ + Propagate the tensor metadata, it could either return a TensorMeta + or a list/tuple of TensorMetas + """ + if op_schema.op == aten.equal.default: + # data dependent ops can't be used for fake propagation + return None + + # NOTE: We must call the tracing in fake tensor mode so that it avoids + # materializing memory. + fake_mode = detect_fake_mode() or FakeTensorMode() + with fake_mode: + fake_args = op_schema.gen_fake_args() + fake_kwargs = op_schema.gen_fake_kwargs() + fake_out = op_schema.op(*fake_args, **fake_kwargs) + + if isinstance(fake_out, torch.Tensor): + return TensorMeta( + shape=fake_out.shape, stride=fake_out.stride(), dtype=fake_out.dtype + ) + + elif isinstance(fake_out, (tuple, list)): + tensor_meta_list: list[TensorMeta | None] = [] + for fake_out_item in fake_out: + if isinstance(fake_out_item, torch.Tensor): + tensor_meta_list.append( + TensorMeta( + shape=fake_out_item.shape, + stride=fake_out_item.stride(), + dtype=fake_out_item.dtype, + ) + ) + else: + tensor_meta_list.append(None) + return ( + tuple(tensor_meta_list) + if isinstance(fake_out, tuple) + else tensor_meta_list + ) + else: + # if fake is not a tensor or tuple of tensor, return as none + return None + + @lru_cache # noqa: B019 + def _propagate_tensor_meta( + self, op_schema: OpSchema + ) -> None | TensorMeta | Sequence[TensorMeta | None]: + """ + Cached version of _propagate_tensor_meta_non_cached + This is a private API. Use propagate_tensor_meta instead. + """ + return self._propagate_tensor_meta_non_cached(op_schema) + + def propagate_tensor_meta( + self, op_schema: OpSchema + ) -> None | TensorMeta | Sequence[TensorMeta | None]: + """ + Propagate the tensor metadata, it could either return a TensorMeta + or a list/tuple of TensorMetas. This is a public API that should be + used if cache should be used. + """ + if _are_we_tracing(): + return self._propagate_tensor_meta_non_cached(op_schema) + else: + return self._propagate_tensor_meta(op_schema) + + def _create_output_spec_with_new_tensor_meta( + self, + op: OpOverload, + output_specs: OutputSpecType, + output_tensor_meta: None | TensorMeta | Sequence[TensorMeta | None], + ) -> OutputSpecType: + """ + Wrap the output_specs with the tensor metadata from the output. + """ + + if isinstance(output_specs, DTensorSpec): + if not isinstance(output_tensor_meta, TensorMeta): + # Either error due to ShardingPropagator or due to incorrect OutputSpec + if not isinstance(output_tensor_meta, (tuple, list)): + raise ValueError( + "ShardingPropagator error: output does not have an associated " + "TensorMeta" + ) + raise ValueError( + f"For the op {op.name()}, `output_specs` has 1 output which does " + "not equal the " + f"number of op outputs: {len(output_tensor_meta)}." + ) + return output_specs.shallow_copy_with_tensor_meta(output_tensor_meta) + elif isinstance(output_specs, (tuple, list)): + new_specs: list[DTensorSpec | None] = [] + if not isinstance(output_tensor_meta, (tuple, list)) or len( + output_specs + ) != len(output_tensor_meta): + raise ValueError( + f"For the op {op.name()}, `output_specs` has {len(output_specs)} " + "outputs which does not equal the " + f"number of op outputs {_length(output_tensor_meta)}." + ) + + for i, spec in enumerate(output_specs): + if isinstance(spec, DTensorSpec): + output_tensor_meta_i = output_tensor_meta[i] + if not isinstance(output_tensor_meta_i, TensorMeta): + # NOTE: aten.convolution_backward.default is an exception and it + # needs extra handling because any Tensor in the output tuple + # can be `None` depending on the output_mask parameter. This can + # occur during double backpropagation or when certain gradients + # are not needed (e.g., grad_input when input has requires_grad=False, + # grad_weight/grad_bias when weight/bias have requires_grad=False, + # or grad_bias when bias is None). We explicitly allow the + # corresponding TensorMeta to be `None`. + if ( + op == aten.convolution_backward.default + and i in (0, 1, 2) + and output_tensor_meta_i is None + ): + assert isinstance(output_specs, list) + new_specs.append(None) + continue + else: + raise ValueError( + f"ShardingPropagator error: output {i} of {op.name()} " + "does not have an associated TensorMeta" + ) + + new_specs.append( + spec.shallow_copy_with_tensor_meta(output_tensor_meta_i) + ) + else: + new_specs.append(spec) + + return tuple(new_specs) + else: + assert output_specs is None + return output_specs + + def _wrap_with_op_strategy(self, op_schema: OpSchema) -> OpSchema: + """ + wrap a op_schema that contains DTensorSpec to another op_schema that contains + OpStrategy/TupleStrategy, the returned op_schema is then used for sharding + strategy propagation on pytorch operators. + """ + + def spec_to_strategy(spec: object) -> object: + if isinstance(spec, DTensorSpec): + return OpStrategy([OpSpec(spec)]) + elif ( + isinstance(spec, (list, tuple)) + and len(spec) > 0 + and isinstance(spec[0], DTensorSpec) + ): + # tensor list create tuple strategy + tuple_strategy = [spec_to_strategy(s) for s in spec] + tuple_strategy = cast(Sequence[StrategyType], tuple_strategy) + return TupleStrategy( + tuple(tuple_strategy) if isinstance(spec, tuple) else tuple_strategy + ) + else: + return spec + + args_op_strategy = [spec_to_strategy(i) for i in op_schema.args_schema] + + kwargs_op_strategy = { + k: spec_to_strategy(v) for k, v in op_schema.kwargs_schema.items() + } + + return OpSchema( + op=op_schema.op, + args_schema=tuple(args_op_strategy), + kwargs_schema=kwargs_op_strategy, + schema_info=op_schema.schema_info, + ) + + def propagate(self, op_info: OpInfo) -> None: + # NB: The logic here is duplicated in _propagate_op_sharding_dispatch_slow_path. + # Ideally, this function would be deleted, but there are a handful of + # one off call sites here that aren't cleaned up. + + # We cannot use an lru cache if we know that inputs will have dynamic shapes, + # because SymInts are not hashable. + # This is generally ok because this only happens during tracing in torch.compile, + # and tracing does not need to be as fast as eagermode DTensor usages. + if _are_we_tracing(): + output_sharding = self.propagate_op_sharding_non_cached(op_info.schema) + else: + output_sharding = cast( + OutputSharding, self.propagate_op_sharding(op_info.schema) + ) + op_info.output_sharding = output_sharding + + def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding: + """ + Propagate the sharding for an operator given the op_schema. + """ + # no-op in OSS, logs API usage metrics in meta-internal runs + torch._C._log_api_usage_once( + "torch.distributed.tensor._sharding_prop.ShardingPropagator.propogate_op_sharding_non_cached" + ) + # special case op, we don't need to propagate for local + # scalar. TODO: figure out a better way to handle this + if op_schema.op is aten._local_scalar_dense.default: + return OutputSharding(None, op_schema) + + out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema) + if op_schema.op in self.op_strategy_funcs: + # wrap the op_schema with op strategy for sharding strategy propagation + strategy_schema = self._wrap_with_op_strategy(op_schema) + + # run sharding strategy propagation/generation + op_strategy = self.op_strategy_funcs[op_schema.op](strategy_schema) + + if isinstance(op_strategy, OpStrategy): + # single Op strategy + output_strategy = self._select_strategy(op_strategy, op_schema) + + # check if we need to redistribute the input + needs_redistribute = False + # check if we want to use args value from redistribute_schema + use_val_from_redistribute_schema = False + expected_input_specs: list[DTensorSpec] = [] + + # in case where the op does not specify input_specs and output_specs + # is a DTensorSpec, we use output_specs as the spec for each DTensor + # input arg. + if output_strategy.input_specs is None: + assert isinstance(output_strategy.output_specs, DTensorSpec) + + for idx, input_spec in enumerate(op_schema.args_spec): + desired_spec = ( + output_strategy.output_spec + if output_strategy.input_specs is None + else output_strategy.input_specs[idx] + ) + expected_input_specs.append( + desired_spec.shallow_copy_with_tensor_meta( + input_spec.tensor_meta + ) + ) + if input_spec.placements != desired_spec.placements: + needs_redistribute = True + + suggestion_schema = None + if needs_redistribute: + suggestion_schema = OpSchema( + op_schema.op, tuple(expected_input_specs), {} + ) + suggestion_schema._inplace_rewrap_schema_suggestion(op_schema) + + # shape and stride args need to be modified for + # view ops and new factory ops, potentially + if op_schema.op in self.op_to_shape_and_stride_idx: + assert isinstance(output_strategy.output_spec, DTensorSpec) + # It happens when the output has the same shape as the input + # and the input placements are not all Replicate(). + if any( + isinstance(p, Shard | _StridedShard) + for p in output_strategy.output_spec.placements + ): + schema = suggestion_schema or op_schema + assert isinstance(out_tensor_meta, TensorMeta) + suggestion_schema = self._adjust_shape_and_stride_args( + out_tensor_meta, schema, output_strategy.output_spec + ) + needs_redistribute = True + use_val_from_redistribute_schema = True + + # construct output spec for the op + if op_schema.return_type_tuple_tensor_like(): + # for ops that return multiple tensors and the output_specs is not + # a tuple, we use a tuple of that single output spec as the new + # output_specs + output_specs: OutputSpecType = output_strategy.output_specs + if isinstance(output_specs, DTensorSpec): + output_specs = tuple( + # create a new DTensorSpec with the same placement as the + # output_specs in output_strategy + DTensorSpec( + mesh=output_specs.mesh, + placements=output_specs.placements, + tensor_meta=output_specs.tensor_meta, + ) + for _ in range(len(op_schema.op._schema.returns)) + ) + elif ( + op_schema.return_type_tensor() + or op_schema.return_type_list_tensor_like() + ): + output_specs = output_strategy.output_specs + else: + output_specs = None + + output_sharding = OutputSharding( + output_specs, + suggestion_schema, + needs_redistribute=needs_redistribute, + use_val_from_redistribute_schema=use_val_from_redistribute_schema, + ) + elif isinstance(op_strategy, TupleStrategy): + # tuple strategy output sharding processing + # runtime select OpSpec for each TupleStrategy input arg + selected_strategies: list[OpSpec] = [] + out_spec_list: list[DTensorSpec] = [] + for strategy in op_strategy.children: + assert isinstance(strategy, OpStrategy) + selected_strategy = self._select_strategy(strategy) + selected_strategies.append(selected_strategy) + out_spec_list.append(selected_strategy.output_spec) + + needs_redistribute = False + suggestion_args: list[object] = [] + tensor_or_list_tensor_arg_idx = 0 + + for arg in op_schema.args_schema: + if ( + arg + and isinstance(arg, (list, tuple)) + and isinstance(arg[0], DTensorSpec) + ): + expected_input_spec_list: list[DTensorSpec] = [] + for idx, arg_spec in enumerate(arg): + expected_input_spec = selected_strategies[idx].input_spec( + tensor_or_list_tensor_arg_idx + ) + expected_input_spec = ( + expected_input_spec.shallow_copy_with_tensor_meta( + arg_spec.tensor_meta + ) + ) + if arg_spec.placements != expected_input_spec.placements: + needs_redistribute = True + expected_input_spec_list.append(expected_input_spec) + suggestion_args.append( + tuple(expected_input_spec_list) + if isinstance(arg, tuple) + else expected_input_spec_list + ) + tensor_or_list_tensor_arg_idx += 1 + + elif isinstance(arg, DTensorSpec): + expected_input_spec = selected_strategies[0].input_spec( + tensor_or_list_tensor_arg_idx + ) + expected_input_spec = ( + expected_input_spec.shallow_copy_with_tensor_meta( + arg.tensor_meta + ) + ) + if arg.placements != expected_input_spec.placements: + needs_redistribute = True + suggestion_args.append(expected_input_spec) + tensor_or_list_tensor_arg_idx += 1 + else: + suggestion_args.append(arg) + + suggestion_schema = None + if needs_redistribute: + suggestion_schema = OpSchema( + op_schema.op, tuple(suggestion_args), op_schema.kwargs_schema + ) + + output_sharding = OutputSharding( + tuple(out_spec_list) if out_tensor_meta is not None else None, + suggestion_schema, + needs_redistribute=needs_redistribute, + use_val_from_redistribute_schema=False, + ) + else: + raise ValueError("Unsupported op strategy type") + + # associate the output sharding with the output tensor metadata + new_output_spec = self._create_output_spec_with_new_tensor_meta( + op_schema.op, output_sharding.output_spec, out_tensor_meta + ) + output_sharding.output_spec = new_output_spec + return output_sharding + elif op_schema.op in self.op_to_rules: + # propagate the sharding with rule + sharding_prop_func = self.op_to_rules[op_schema.op] + + # step 1. there's sharding propagation rule, run + # sharding propagation to get the output sharding + try: + output_sharding = sharding_prop_func(op_schema) + except NotImplementedError as e: + raise e + except Exception as e: + raise RuntimeError( + f"Sharding propagation failed on op {op_schema}.\nError: {e}" + ) from e + + # step 2. if can't get output_spec from sharding + # propagation (i.e. no rules apply for input + # placements), we return the output sharding + # with schema suggestions, which can be used to + # decide how to do redistribute on inputs + if output_sharding.output_spec is None: + if output_sharding.redistribute_schema is None: + raise RuntimeError( + f"Sharding propagation failed on op {op_schema}!" + ) + else: + # we do auto redistribute on inputs if necessary + # run sharding propagation again with suggested schema + propagation_res = sharding_prop_func( + output_sharding.redistribute_schema + ) + # we set the output sharding with the new propagation result + # so that dispatching know both output_spec and redistribute_schema + # exist, which indicates a reshard is needed + output_sharding.output_spec = propagation_res.output_spec + output_sharding.needs_redistribute = True + + # associate the output sharding with the output tensor metadata + new_output_spec = self._create_output_spec_with_new_tensor_meta( + op_schema.op, output_sharding.output_spec, out_tensor_meta + ) + output_sharding.output_spec = new_output_spec + + return output_sharding + else: + raise NotImplementedError( + f"Operator {op_schema.op} does not have a sharding strategy registered." + ) + + def _select_strategy( + self, strategy: OpStrategy, op_schema: OpSchema | None = None + ) -> OpSpec: + from torch.fx.experimental.symbolic_shapes import guard_or_false + + if len(strategy.strategies) == 1: + # short cut with only one possible OpSpec + return strategy.strategies[0] + + op_spec_costs: list[torch.types.FloatLikeType] = [] + no_redistribute_strategy_index: int = -1 + negative_cost_index: int = -1 + zero_cost_index: int = -1 + for strategy_idx, op_spec in enumerate(strategy.strategies): + assert op_spec.redistribute_cost is not None, ( + "must set redistribute cost each OpSpec!" + ) + redistribute_cost = sum(chain.from_iterable(op_spec.redistribute_cost)) + op_spec_costs.append(redistribute_cost) + + # If there are strategies with negative/zero/no redistribute cost, + # we record those indices. + # TODO: Currently this only applies to OpStrategy selection. Requires extra + # logic to make it work for TupleStrategy, if needed. + if op_schema is not None: + if guard_or_false(redistribute_cost < 0): + if ( + negative_cost_index == -1 + or redistribute_cost < op_spec_costs[negative_cost_index] + ): + negative_cost_index = strategy_idx + elif guard_or_false(redistribute_cost == 0): + needs_redistribute = False + for spec_idx, input_spec in enumerate(op_schema.args_spec): + desired_spec = ( + op_spec.output_spec + if op_spec.input_specs is None + else op_spec.input_specs[spec_idx] + ) + if input_spec.placements != desired_spec.placements: + needs_redistribute = True + break + + if not needs_redistribute: + no_redistribute_strategy_index = strategy_idx + elif zero_cost_index == -1: + zero_cost_index = strategy_idx + + # prioritize negative/zero/no redistribute cost strategies + if negative_cost_index != -1: + # If there's negative cost, we select the one with the minimal cost, + # even if this means we need to redistribute, e.g. via local chunking. + # E.g. this can happen for ops in self.op_to_shape_and_stride_idx + # when the inputs / outputs are sharded. + selected_strategy_index = negative_cost_index + elif no_redistribute_strategy_index != -1: + selected_strategy_index = no_redistribute_strategy_index + elif zero_cost_index != -1: + selected_strategy_index = zero_cost_index + else: + # default to choosing minimal redistribute cost + min_cost = min(op_spec_costs) + selected_strategy_index = op_spec_costs.index(min_cost) + + return strategy.strategies[selected_strategy_index] + + def _adjust_shape_and_stride_args( + self, + out_tensor_meta: TensorMeta, + schema: OpSchema, + spec: DTensorSpec, + ) -> OpSchema: + shape_stride_idx = self.op_to_shape_and_stride_idx[schema.op] + if isinstance(shape_stride_idx, tuple): + shape_idx, stride_idx = shape_stride_idx + else: + shape_idx = shape_stride_idx + stride_idx = None + + expected_input_schema = list(schema.args_schema) + # adjust shape to be the same as that of the _local_tensor + # of the DTensor input arg at index 0, which is inferred + expected_input_schema[shape_idx], _ = compute_local_shape_and_global_offset( + out_tensor_meta.shape, spec.mesh, spec.placements, skip_offset=True + ) + + # adjust the stride arg for aten.new_empty_strided.default + if stride_idx: + expected_input_schema[stride_idx] = compute_local_stride( + out_tensor_meta.stride, spec.mesh, spec.placements + ) + + return OpSchema(schema.op, tuple(expected_input_schema), schema.kwargs_schema) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_shards_wrapper.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_shards_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..1673dd7e34b994470386e1fb1a5079c302302393 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_shards_wrapper.py @@ -0,0 +1,359 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any + +import torch +from torch.distributed.checkpoint.metadata import ( + ChunkStorageMetadata, + MetadataIndex, + TensorProperties, + TensorStorageMetadata, +) +from torch.distributed.checkpoint.planner import ( + TensorWriteData, + WriteItem, + WriteItemType, +) + + +aten = torch.ops.aten + + +class LocalShardsWrapper(torch.Tensor): + """ + A wrapper class to hold local shards of a DTensor. + This class is used largely for checkpointing purposes and implicitly subtypes + the _Checkpointable protocol. + """ + + __slots__ = ["_local_shards", "_storage_meta"] + _local_shards: list[torch.Tensor] + _storage_meta: TensorStorageMetadata + + @staticmethod + def __new__( + cls, local_shards: list[torch.Tensor], local_offsets: list[tuple[int, ...]] + ) -> "LocalShardsWrapper": + assert all( + tensor.device == local_shards[0].device for tensor in local_shards[1:] + ) + + # if empty shard, we create a empty tensor + if len(local_shards) == 0: + r = torch.Tensor._make_wrapper_subclass( + cls, + torch.Size([0, 0]), + ) + r._local_shards = [] + r._storage_meta = TensorStorageMetadata( + properties=TensorProperties(), + size=torch.Size([0, 0]), + chunks=[ + ChunkStorageMetadata( + offsets=torch.Size([0, 0]), sizes=torch.Size([0, 0]) + ) + ], + ) + return r + + # we calculate the total tensor size by "concat" on second tensor dimension + cat_tensor_shape = list(local_shards[0].size()) + if len(local_shards) > 1 and local_shards[0].ndim == 2: # column-wise sharding + for shard in local_shards[1:]: + cat_tensor_shape[1] += shard.size()[1] + + # in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension + if len(local_shards) > 1 and local_shards[0].ndim == 1: # column-wise sharding + for shard in local_shards[1:]: + cat_tensor_shape[0] += shard.size()[0] + + wrapper_properties = TensorProperties.create_from_tensor(local_shards[0]) + wrapper_shape = torch.Size(cat_tensor_shape) + chunks_meta = [ + ChunkStorageMetadata( + offsets=torch.Size(offset), + sizes=shard.size(), + ) + for shard, offset in zip(local_shards, local_offsets) + ] + + r = torch.Tensor._make_wrapper_subclass( + cls, + torch.Size(cat_tensor_shape), + ) + r._local_shards = local_shards + r._storage_meta = TensorStorageMetadata( + properties=wrapper_properties, + size=wrapper_shape, + chunks=chunks_meta, + ) + + return r + + # necessary for ops dispatching from this subclass to its local shards + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] + kwargs = kwargs or {} + + dispatcher = { + torch.ops._c10d_functional.all_gather_into_tensor.default: cls.handle_all_gather_into_tensor, + torch.ops._c10d_functional.wait_tensor.default: cls.handle_wait_tensor, + aten._to_copy.default: cls.handle_to_copy, + aten.view.default: cls.handle_view, + aten.equal.default: cls.handle_equal, + aten.detach.default: cls.handle_detach, + aten.clone.default: cls.handle_clone, + aten.new_empty.default: cls.handle_new_empty, + } + + if func in dispatcher: + return dispatcher[func](args, kwargs) + else: + raise NotImplementedError( + f"{func} is not supported for LocalShardsWrapper!" + ) + + @staticmethod + def handle_all_gather_into_tensor(args, kwargs) -> torch.Tensor: + dim = args[0].local_sizes()[0][1] + cat_tensor = torch.cat( + [t.view(-1) for t in args[0].local_shards()], dim=0 + ).view(-1, dim) + return torch.ops._c10d_functional.all_gather_into_tensor.default( + cat_tensor, *args[1:], **kwargs + ) + + @staticmethod + def handle_wait_tensor(args, kwargs) -> torch.Tensor: + return torch.ops._c10d_functional.wait_tensor(args[0]) + + @staticmethod + def handle_to_copy(args, kwargs) -> torch.Tensor: + res_shards_list = [ + aten._to_copy.default(shard, *args[1:], **kwargs) + for shard in args[0].local_shards() + ] + return LocalShardsWrapper(res_shards_list, args[0].local_offsets()) + + @staticmethod + def handle_view(args, kwargs) -> "LocalShardsWrapper": + view_shape = args[1] + res_shards_list = [] + if len(args[0].local_shards()) > 1: + if args[0].local_shards()[0].ndim == 2: + assert ( + args[0].storage_metadata().size[0] == view_shape[0] + and args[0].storage_metadata().size[1] == view_shape[1] + ) + # This accounts for a DTensor quirk, when multiple shards are present on a rank, DTensor on + # init calls view_as() on the global tensor shape + # will fail because the view shape is not applicable to individual shards. + res_shards_list = [ + aten.view.default(shard, shard.shape, **kwargs) + for shard in args[0].local_shards() + ] + elif args[0].local_shards()[0].ndim == 1: + assert args[0].storage_metadata().size[0] == view_shape[0] + # This case is for optimizer sharding as regardless of sharding type, optimizer state is row wise sharded + res_shards_list = [ + aten.view.default(shard, shard.shape, **kwargs) + for shard in args[0].local_shards() + ] + else: + raise NotImplementedError("No support for view on tensors ndim > 2") + else: + # view is called per shard + res_shards_list = [ + aten.view.default(shard, args[1], **kwargs) + for shard in args[0].local_shards() + ] + return LocalShardsWrapper(res_shards_list, args[0].local_offsets()) + + @staticmethod + def handle_equal(args, kwargs) -> bool: + """ + LocalShardsWrapper equal impl also checks for equality of storage metadata + and the order of shards + """ + a, b = args[0], args[1] + if len(a.local_shards()) != len(b.local_shards()): + return False + if not all( + aten.equal.default(x, y) for x, y in zip(a.local_shards(), b.local_shards()) + ): + return False + if a.storage_metadata() != b.storage_metadata(): + return False + return True + + @staticmethod + def handle_detach(args, kwargs) -> "LocalShardsWrapper": + self_ls = args[0] + deatched_local_shards = [ + aten.detach.default(shard) for shard in self_ls.local_shards() + ] + self_ls._local_shards = deatched_local_shards + self_ls._storage_meta.properties.requires_grad = False + return self_ls + + @staticmethod + def handle_clone(args, kwargs) -> "LocalShardsWrapper": + self_ls = args[0] + desired_memory_format = kwargs.get("memory_format", None) + if desired_memory_format and desired_memory_format != torch.preserve_format: + raise NotImplementedError( + f"{desired_memory_format} is not supported for LocalShardsWrapper!" + ) + cloned_local_shards = [ + shard.clone(memory_format=desired_memory_format) + for shard in self_ls._local_shards + ] + return LocalShardsWrapper(cloned_local_shards, self_ls.local_offsets()) + + @staticmethod + def handle_new_empty(args, kwargs) -> "LocalShardsWrapper": + self_ls = args[0] + return LocalShardsWrapper( + [torch.empty_like(shard) for shard in self_ls._local_shards], + self_ls.local_offsets(), + ) + + @property + def device(self) -> torch._C.device: # type: ignore[override] + return ( + self._local_shards[0].device if self._local_shards else torch.device("meta") + ) + + @property + def is_meta(self) -> bool: # type: ignore[override] + return self._local_shards[0].is_meta if self._local_shards else True + + def is_pinned(self) -> bool: # type: ignore[override] + return self._storage_meta.properties.pin_memory + + def requires_grad_(self, requires_grad: bool = True) -> "LocalShardsWrapper": + self._storage_meta.properties.requires_grad = requires_grad + [shard.requires_grad_(requires_grad) for shard in self._local_shards] + return self + + def local_shards(self) -> list[torch.Tensor]: + """ + Returns a list of :class:`torch.Tensor' corresponding to the + local shards for this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return self._local_shards + + def local_sizes(self) -> list[torch.Size]: + """ + Returns a list of :class:`torch.Size' corresponding to the + local sizes for the shards on this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return [chunk.sizes for chunk in self._storage_meta.chunks] + + def local_offsets(self) -> list[torch.Size]: + """ + Returns a list of :class:`torch.Size' corresponding to the + local offsets for the shards on this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return [chunk.offsets for chunk in self._storage_meta.chunks] + + @property + def local_chunks(self) -> list[ChunkStorageMetadata]: + """ + Returns a :class:`list[ChunkStorageMetadata]` object corresponding to the + metadata for each tensor shard + """ + return self._storage_meta.chunks + + def storage_metadata(self) -> TensorStorageMetadata: + """ + Returns a :class:`TensorStorageMetadata` object corresponding to the + metadata for the local tensor on current rank + """ + return self._storage_meta + + def is_empty_shard(self) -> bool: + """ + Returns a :class:`bool` object indicating if the local tensor on current rank + is an empty tensor + """ + return self._storage_meta.size[0] == 0 and self._storage_meta.size[1] == 0 + + def __create_write_items__(self, fqn: str, object: Any) -> list[WriteItem]: + """ + For compatibility with DCP, we support creation of WriteItems + such that they can be saved properly. + """ + return [ + WriteItem( + index=MetadataIndex(fqn, chunks.offsets), + type=WriteItemType.SHARD, + tensor_data=TensorWriteData( + chunk=ChunkStorageMetadata( + offsets=chunks.offsets, + sizes=chunks.sizes, + ), + properties=self._storage_meta.properties, + size=object.size(), + ), + ) + for tensor, chunks in zip(self.local_shards(), self.local_chunks) + ] + + def __create_chunk_list__(self) -> list[ChunkStorageMetadata]: + """ + For compatibility with DCP, we support creation of chunk lists + such that they can be saved properly. + """ + return self._storage_meta.chunks + + def __get_tensor_shard__(self, index: MetadataIndex) -> torch.Tensor: + """ + For compatibility with DCP, we support finding shard based on index + Return a 'torch.Tensor' shard based on 'MetadataIndex'. + """ + # Fast lookup path + if index.index is not None: + if ( + len(self._local_shards) > index.index + and self._storage_meta.chunks[index.index].offsets == index.offset + ): + return self._local_shards[index.index] + + if index.offset is not None: + for shard, chunk in zip(self._local_shards, self._storage_meta.chunks): + if chunk.offsets == index.offset: + return shard + + # Empty shard case + if len(self._local_shards) == 0 and self._storage_meta.chunks[ + 0 + ].sizes == torch.Size([0, 0]): + return torch.empty(0) + + raise ValueError( + f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'" + ) + + def _get_tensor_size_bytes(self) -> int: + object_size = 0 + for shard in self.local_shards(): + object_size += shard.nelement() * shard.element_size() + return object_size + + def __hash__(self) -> int: + return id(self) + + def __repr__(self) -> str: # type: ignore[override] + return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}" + + def __str__(self) -> str: + return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_tp_conv.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_tp_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..275cb07934b5030bc9cd5bc71dc66f82e98eb3b5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_tp_conv.py @@ -0,0 +1,293 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor +from typing import cast + +import torch +import torch.distributed as dist +import torch.distributed.tensor._api as dtensor + + +aten = torch.ops.aten + + +def _requires_data_exchange(padding, dim_map) -> bool: + # Data exchange is not need if only sharded across batch dim + if all(x == -1 for x in dim_map[1:]): + return False + # TODO: whether there requires data exchange is currently determined by padding + return padding[-1] != 0 + + +def _is_supported(input_size, kernel_size, stride, padding, dilation): + if dilation[-1] != 1: + raise RuntimeError("Dilation must be 1 for tensor parallel convolution.") + if padding[-1] != 0: + if stride[-1] != 1: + raise RuntimeError( + "Stride must be 1 when there is padding for tensor parallel convolution." + ) + if kernel_size[-1] // 2 > input_size[-1]: + raise RuntimeError( + "kernel_size[-1] // 2 should be less than or equal to input_size[-1] for tensor parallel convolution." + ) + else: + if not (input_size[-1] % stride[-1] == 0 and stride[-1] == kernel_size[-1]): + raise RuntimeError( + "It requires that input_size[-1] is divisible by stride[-1] and stride[-1] equals kernel_size[-1] " + "when there is padding for tensor parallel convolution." + ) + return True + + +def _ring_send_recv_construct(in_tensor, d1, d2, left, right, rank, size): + # dist comms and reconstruct local input tensor + send_to_right = in_tensor[..., -d1:].contiguous() + send_to_left = in_tensor[..., :d2].contiguous() + recv_from_right = torch.zeros_like(send_to_left) + recv_from_left = torch.zeros_like(send_to_right) + + send_op_right = dist.P2POp(dist.isend, send_to_right, right) + send_op_left = dist.P2POp(dist.isend, send_to_left, left) + recv_op_right = dist.P2POp(dist.irecv, recv_from_right, right) + recv_op_left = dist.P2POp(dist.irecv, recv_from_left, left) + + reqs = dist.batch_isend_irecv( + [send_op_right, send_op_left, recv_op_left, recv_op_right] + ) + for req in reqs: + req.wait() + + if rank == 0: + in_tensor = torch.cat([in_tensor, recv_from_right], dim=-1) + elif rank == size - 1: + in_tensor = torch.cat([recv_from_left, in_tensor], dim=-1) + else: + in_tensor = torch.cat([recv_from_left, in_tensor, recv_from_right], dim=-1) + + return in_tensor + + +def _ring_send_recv_aggregate(grad_in_tensor, d1, d2, left, right, rank, size): + # dist comms and aggregate gradients for edge pixels + send_to_right = grad_in_tensor[:, :, :, -d2:].contiguous() + send_to_left = grad_in_tensor[:, :, :, :d1].contiguous() + recv_from_right = torch.zeros_like(send_to_left) + recv_from_left = torch.zeros_like(send_to_right) + + send_op_right = dist.P2POp(dist.isend, send_to_right, right) + send_op_left = dist.P2POp(dist.isend, send_to_left, left) + recv_op_right = dist.P2POp(dist.irecv, recv_from_right, right) + recv_op_left = dist.P2POp(dist.irecv, recv_from_left, left) + + reqs = dist.batch_isend_irecv( + [send_op_right, send_op_left, recv_op_left, recv_op_right] + ) + for req in reqs: + req.wait() + + if rank == 0: + grad_in_tensor = grad_in_tensor[:, :, :, :-d2] + grad_in_tensor[:, :, :, -d1:] = torch.add( + grad_in_tensor[:, :, :, -d1:], recv_from_right + ) + elif rank == size - 1: + grad_in_tensor = grad_in_tensor[:, :, :, d1:] + grad_in_tensor[:, :, :, :d2] = torch.add( + grad_in_tensor[:, :, :, :d2], recv_from_left + ) + else: + grad_in_tensor = grad_in_tensor[:, :, :, d1:-d2] + grad_in_tensor[:, :, :, -d1:] = torch.add( + grad_in_tensor[:, :, :, -d1:], recv_from_right + ) + grad_in_tensor[:, :, :, :d2] = torch.add( + grad_in_tensor[:, :, :, :d2], recv_from_left + ) + + +def tp_convolution( + op_call: torch._ops.OpOverload, + local_tensor_args: tuple[object, ...], + local_tensor_kwargs: dict[str, object], + dim_map: list[int], +) -> object: + assert op_call == aten.convolution.default + assert len(local_tensor_args) == 9 + + rank = dist.get_rank() + size = dist.get_world_size() + in_tensor = cast(torch.Tensor, local_tensor_args[0]) + weight = cast(torch.Tensor, local_tensor_args[1]) + stride, padding, dilation = local_tensor_args[3:6] + + assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation) + assert isinstance(padding, list) + + if not _requires_data_exchange(padding, dim_map): + local_results = op_call(*local_tensor_args, **local_tensor_kwargs) + return local_results + else: + # step 0 compute the overlap pixels of the input tensor + d = weight.shape[-1] - 1 + d1 = d // 2 + d2 = d - d1 + assert d1 + d2 == d + right = (rank + 1) % size + left = (rank - 1 + size) % size + + # step1 reconstruct local input tensor + in_tensor = _ring_send_recv_construct( + in_tensor, d1, d2, left, right, rank, size + ) + + # step2 feed local input tensor to op_call + local_tensor_args_list = list(local_tensor_args) + local_tensor_args_list[0] = in_tensor + local_tensor_args = cast(tuple[object, ...], local_tensor_args_list) + local_results = op_call(*local_tensor_args, **local_tensor_kwargs) + + # step3 remove extra outputs from the results + padding_w = padding[-1] + w = local_results.size(-1) + if rank == 0: + local_results = local_results[..., : w - padding_w] + elif rank == size - 1: + local_results = local_results[..., padding_w:] + else: + local_results = local_results[..., padding_w : w - padding_w] + + return local_results + + +def tp_convolution_backward( + op_call: torch._ops.OpOverload, + local_tensor_args: tuple[object, ...], + local_tensor_kwargs: dict[str, object], + dim_map: list[int], +) -> object: + assert op_call == aten.convolution_backward.default + assert len(local_tensor_args) == 11 + + rank = dist.get_rank() + size = dist.get_world_size() + grad_out_tensor = cast(torch.Tensor, local_tensor_args[0]) + in_tensor = cast(torch.Tensor, local_tensor_args[1]) + weight = cast(torch.Tensor, local_tensor_args[2]) + stride, padding, dilation = local_tensor_args[4:7] + + assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation) + assert isinstance(padding, list) + + if not _requires_data_exchange(padding, dim_map): + local_results = op_call(*local_tensor_args, **local_tensor_kwargs) + return local_results + else: + # step 0 compute the overlap pixels of the input tensor + d = weight.shape[3] - 1 + d1 = d // 2 + d2 = d - d1 + assert d1 + d2 == d + right = (rank + 1) % size + left = (rank - 1 + size) % size + + # step1 reconstruct local input tensor + in_tensor = _ring_send_recv_construct( + in_tensor, d1, d2, left, right, rank, size + ) + + # step2 reconstruct local gradient output tensor + padding_w = padding[1] + if rank == 0: + grad_out_tensor = torch.nn.functional.pad( + grad_out_tensor, (0, padding_w), "constant", 0 + ) + elif rank == size - 1: + grad_out_tensor = torch.nn.functional.pad( + grad_out_tensor, (padding_w, 0), "constant", 0 + ) + else: + grad_out_tensor = torch.nn.functional.pad( + grad_out_tensor, (padding_w, padding_w), "constant", 0 + ) + + # step3 feed local input tensor to op_call + local_tensor_args_list = list(local_tensor_args) + local_tensor_args_list[0] = grad_out_tensor + local_tensor_args_list[1] = in_tensor + local_tensor_args = cast(tuple[object, ...], local_tensor_args_list) + local_results = op_call(*local_tensor_args, **local_tensor_kwargs) + + # step4 aggregate gradients for edge pixels + grad_in_tensor = local_results[0] + if grad_in_tensor is not None: + grad_in_tensor = _ring_send_recv_aggregate( + grad_in_tensor, d1, d2, left, right, rank, size + ) + local_results = list(local_results) + local_results[0] = grad_in_tensor + + local_results = cast(tuple[object, ...], local_results) + + return local_results + + +def convolution_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + # extract local tensor and sharding infos to a OpInfo + op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + + # sharding propagation + dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + assert output_sharding is not None, "output sharding should not be None" + output_spec = output_sharding.output_spec + assert isinstance(output_spec, dtensor.DTensorSpec) + + # local propagation + local_results = tp_convolution( + op_call, + tuple(op_info.local_args), + op_info.local_kwargs, + output_spec.dim_map, + ) + + return dtensor.DTensor._op_dispatcher.wrap(local_results, output_spec) + + +def convolution_backward_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + # Redistribute grad_output tensor to the same placement as input tensor + # pyrefly: ignore [bad-assignment] + args = list(args) + assert isinstance(args[0], dtensor.DTensor) and isinstance(args[1], dtensor.DTensor) + # pyrefly: ignore [unsupported-operation] + args[0] = args[0].redistribute(args[1].device_mesh, args[1].placements) + args = tuple(args) + + # extract local tensor and sharding infos to a OpInfo + op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + + # sharding propagation + dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + assert output_sharding is not None, "output sharding should not be None" + assert isinstance(op_info.flat_args_schema[0], dtensor.DTensorSpec) + + # local propagation + local_results = tp_convolution_backward( + op_call, + tuple(op_info.local_args), + op_info.local_kwargs, + op_info.flat_args_schema[0].dim_map, + ) + + return dtensor.DTensor._op_dispatcher.wrap( + local_results, output_sharding.output_spec + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f085b681f94911521683c7d566dc60124e1c9047 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/_utils.py @@ -0,0 +1,461 @@ +import logging +import threading +from collections.abc import Sequence +from typing import Any, cast, Optional + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.tensor._api as dtensor +from torch._prims_common import ShapeType +from torch.distributed._local_tensor import maybe_run_for_local_tensor +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._collective_utils import redistribute_cost +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import ( + _StridedShard, + Partial, + Placement, + Replicate, + Shard, +) + + +logger = logging.getLogger(__name__) + + +class ExplicitRedistributionContext: + """ + Within this context manager, DTensor will refuse to perform implicit redistribution, + instead raising an error. Manual calls to ``redistribute()`` are required wherever a redistribution + must occur to avoid erroring. This can be used to ensure that the user is aware of all redistribution. + + Note: it is easier to use this mode on just the forward pass of a typical DTensor program, as the backwards pass + may contain implicit redistribution calls that are not visible to the user and difficult to replace with manual + calls. Redistribution during backward can be made explicit by writing `autograd.Function`s that are no-op + during forward and perform a manual redistribution during backwards. + + enable (bool) if False, disables the context manager. Can be used nested inside an enabled region. + + strict (bool) if True, triggers on any redistribution. If False, only triggers on redistributions that perform + communication. + + mode (str) Determines what happens when ExplicitRedistributionContext triggers: + "raise": raises an exceptoin, "warn" issues a warning + """ + + _local = threading.local() + + def __init__(self, enable: bool = True, strict: bool = False, mode="raise"): + self._enable = enable + self._strict = strict + if mode not in ("raise", "warn"): + raise RuntimeError(f"Invalid mode {mode}") + self._raise_on_redistribution = mode == "raise" + + @classmethod + def observe_redistribution( + cls, src_spec: DTensorSpec, dst_spec: DTensorSpec, message: str + ): + if instance := getattr(cls._local, "_active", None): + allowed = True + if instance._enable: + if instance._strict: + allowed = False + else: + allowed = redistribute_cost(src_spec, dst_spec) <= 0 + if not allowed: + if instance._raise_on_redistribution: + raise RuntimeError(message) + else: + logger.warning(message) + + def __enter__(self): + self._prev = getattr(ExplicitRedistributionContext._local, "_active", None) + ExplicitRedistributionContext._local._active = self + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + ExplicitRedistributionContext._local._active = self._prev + + +def compute_local_shape_and_global_offset( + global_shape: ShapeType, + mesh: DeviceMesh, + placements: Sequence[Placement], + skip_offset: bool = False, +) -> tuple[tuple[int, ...], tuple[int, ...]]: + """ + Compute the local tensor shape and the global offsets into the original tensor + of a DTensor on its current global rank. This is useful for checkpointing purpose. + + Example: + global_tensor = [[0, 1, 2, 3, 4], sharded on mesh (DP=2, TP=2) with (Shard(1), Shard(1)) + [10, 11, 12, 13, 14]] + + This table shows the return value of local_shape and global_offset for each rank. + (`local_tensor` is for illustration only). + + Note how the first coordinate of global_offset is always 0, corresponding to tensor dim 0 being replicated. + + Rank local_tensor local_shape global_offset + ------------------------------------------------------------- + 0 [[0, 1], (2, 2) (0, 0) + [10, 11]] + + 1 [[2], (2, 1) (0, 2) + [12]] + + 2 [[3], (2, 1) (0, 3) + [13]] + + 3 [[4], (2, 1) (0, 4) + [14]] + + Args: + global_shape (ShapeType): The global shape of the DTensor. + mesh (:class:`DeviceMesh`): The device mesh this DTensor is distributed on. + placements (Sequence[:class:`Placement`]]): The placements of the DTensor. + skip_offset (bool): If True, skip computing the global offsets and return an empty + tuple for global_offset. This can improve performance when only the local shape + is needed. Defaults to False. + + Return: + local_shape: the shape of the DTensor's _local_tensor on the current rank. + global_offset: a tuple of offsets for each dimension of the global tensor shape, + identifying how this shard fits into the global tensor in each dimension. If + skip_offset is True, this will be an empty tuple. + + """ + return _compute_local_shape_and_global_offset( + global_shape, mesh.shape, mesh.get_coordinate(), placements, skip_offset + ) + + +@maybe_run_for_local_tensor +def _get_shard_size_and_offsets( + curr_local_size: int, + mesh_dim_size: int, + rank: int, + placement: Shard | _StridedShard, + previous_offsets, + zero_global_offset: int, + skip_offset: bool, +) -> tuple[int, Optional[torch.Tensor]]: + kwargs: dict[str, Any] = { + "curr_local_size": curr_local_size, + "num_chunks": mesh_dim_size, + "rank": rank, + } + if isinstance(placement, _StridedShard): + kwargs["return_first_offset"] = False + shard_size, shard_offsets = placement._local_shard_size_and_offset(**kwargs) + if skip_offset: + return shard_size, None + if shard_size == 0: + return shard_size, torch.arange(zero_global_offset, zero_global_offset + 1) + if isinstance(placement, Shard) and not isinstance(placement, _StridedShard): + assert isinstance(shard_offsets, int) + index = torch.arange(shard_offsets, shard_offsets + shard_size) + else: + assert isinstance(shard_offsets, list) + index = torch.tensor(shard_offsets) + if previous_offsets is None: + return shard_size, index + else: + return shard_size, previous_offsets[index] + + +@maybe_run_for_local_tensor +def _get_first_offset(offsets: torch.Tensor) -> int: + return int(offsets[0]) + + +# accept 'plain data types' to enable simpler unit testing without creating device mesh +def _compute_local_shape_and_global_offset( + global_shape: ShapeType, + mesh_shape: ShapeType, + my_coordinate: list[int] | None, + placements: Sequence[Placement], + skip_offset: bool = False, +) -> tuple[tuple[int, ...], tuple[int, ...]]: + """ + Suppose you have a full tensor with size global_shape, and you have sharded + it according to placements for mesh_shape. This function returns, for a + specific coordinate my_coordinate in the device mesh: + + - The size of your local shard WITHOUT padding (i.e., if you have + an uneven split, your size might be smaller than the other entries + in your dim), and + + - Where the data for your shard begins, in the full tensor. + + This function is fairly simple if your tensor is evenly sharded; the complication + is around uneven splits. There is also some complication for handling StridedShard, + which changes the order you should apply sharding. + + Args: + global_shape (ShapeType): The global shape of the tensor. + mesh_shape (ShapeType): The shape of the device mesh. + my_coordinate (Optional[list[int]]): The coordinate of the current rank in the device mesh. + placements (Sequence[Placement]): The placements of the DTensor. + skip_offset (bool): If True, skip computing the global offsets and return an empty + tuple for global_offset. This can improve performance when only the local shape + is needed. Defaults to False. + + Returns: + tuple: A tuple containing: + - local_shape (tuple[int, ...]): The shape of the local shard on the current rank. + - global_offset (tuple[int, ...]): The offsets for each dimension identifying where + this shard begins in the global tensor. If skip_offset is True, this will be an + empty tuple. + """ + + empty_offset = () + if my_coordinate is None: + # if rank not in the mesh, return empty offset + return ((0,), empty_offset) + + local_shape = list(global_shape) + # Perform shard from left to right. For example, + # global tensor: [0, 1, 2, 3, 4, 5, 6, 7] + # placements: S(0), SS(0, split_factor=2) + # mesh_shape: (2, 2) + # After S(0), shard_dim_to_global_offsets are + # {0: [0, 1, 2, 3]} on my_coordinate [0, 0] [0, 1] + # {0: [4, 5, 6, 7]} on my_coordinate [1, 0] [1, 1] + # After SS(0, split_factor=2), shard_dim_to_global_offsets are + # {0: [0, 2]} on my_coordinate [0, 0] + # {0: [1, 3]} on my_coordinate [0, 1] + # {0: [4, 6]} on my_coordinate [1, 0] + # {0: [5, 7]} on my_coordinate [1, 1] + shard_dim_to_global_offsets = {} + for mesh_dim, placement in enumerate(placements): + if not isinstance(placement, (Shard, _StridedShard)): + continue + shard_dim = placement.dim + zero_global_offset = global_shape[shard_dim] + assert shard_dim < len(local_shape), ( + f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + ) + previous_offsets = shard_dim_to_global_offsets.get(shard_dim) + shard_size, shard_offsets = _get_shard_size_and_offsets( + local_shape[shard_dim], + mesh_shape[mesh_dim], + my_coordinate[mesh_dim], + placement, + previous_offsets, + zero_global_offset, + skip_offset, + ) + local_shape[shard_dim] = shard_size + shard_dim_to_global_offsets[shard_dim] = shard_offsets + if skip_offset: + return tuple(local_shape), empty_offset + global_offset = [0] * len(global_shape) + for shard_dim, global_offsets in shard_dim_to_global_offsets.items(): + global_offset[shard_dim] = _get_first_offset(global_offsets) + return tuple(local_shape), tuple(global_offset) + + +compute_global_tensor_info = torch._C._DTensor_compute_global_tensor_info + + +def compute_local_tensor_info( + global_tensor: torch.Tensor, + mesh: DeviceMesh, + placements: Sequence[Placement], +) -> tuple[list[int], list[int]]: + """ + Compute the local size and stride of a DTensor from the given global tensor info. + + For example, if we have a global tensor with size (4, 8, 4) and stride (32, 1, 8). + If the DTensor placements are [Shard(2)] and world_size is 2; + then the local size is (4, 8, 2) and stride is (16, 1, 8). + + Args: + tensor (:class:`torch.Tensor`): + Global tensor which DTensor will distribute + mesh (:class:`DeviceMesh`): + Object which describes the mesh topology + of devices for the DTensor. + placements (Sequence[:class:`Placement`]): + The attribute of the DTensor that describes its layout + on the mesh topology. + + Returns: + local_shape: A List of int which specifies the size of the local tensor. + local_stride: A List of int which specifies the stride of the local tensor. + """ + local_shape = list(global_tensor.size()) + local_stride = list(global_tensor.stride()) + + for idx, placement in enumerate(placements): + mesh_dim_size = mesh.size(idx) + if placement.is_shard(): + shard_placement = cast(Shard, placement) + if shard_placement.dim < 0: + raise AssertionError( + "Shard placements should have negative dims normalized in " + f"the user-facing APIs: {shard_placement}" + ) + shard_dim = shard_placement.dim + assert shard_dim < len(local_shape), ( + f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)} " + f"for placement number {idx}." + ) + + global_dim_size = local_shape[shard_dim] + assert global_dim_size % mesh_dim_size == 0, ( + f"Global dim {global_dim_size} not divisible by mesh size {mesh_dim_size}" + ) + local_shape[shard_dim] = global_dim_size // mesh_dim_size + + # shrink strides that were scaled up globally + for i in range(len(local_stride)): + if ( + i != shard_dim + and local_stride[i] >= local_stride[shard_dim] * mesh_dim_size + ): + local_stride[i] = local_stride[i] // mesh_dim_size + + elif not isinstance(placement, (Replicate, Partial)): + raise RuntimeError(f"placement type {type(placement)} not supported!") + + return local_shape, local_stride + + +def compute_global_tensor_shape( + shape: torch.Size, mesh: DeviceMesh, placements: Sequence[Placement] +) -> torch.Size: + """ + Compute the global size of a DTensor from the given local tensor shape, + the mesh and placements. Different from `compute_global_tensor_info`, + which assumes sharding is even, this util allgathers local shards' shapes + from all ranks and thus can support uneven sharding. + NOTE: Currently this function only supports 1D mesh. + + Args: + shape (:class:`torch.Size`): + Shape of the local tensor + mesh (:class:`DeviceMesh`): + Object which describes the mesh topology + of devices for the DTensor. + placements (Sequence[:class:`Placement`]]): + The attribute of the DTensor that describes its layout + on the mesh topology. + + Return: + tensor_shape: Shape of the global DTensor. + """ + if len(placements) != 1: + raise NotImplementedError( + "compute_global_tensor_shape only supports 1 placement for now." + ) + + if len(placements) != mesh.ndim: + raise RuntimeError( + "Expected one placement per mesh dim, " + f"but found {len(placements)} placements and {mesh.ndim} mesh dims." + ) + + if isinstance(placements[0], Replicate): + return shape + elif isinstance(placements[0], Shard): + + @maybe_run_for_local_tensor + def _create_local_shape_tensor(shape): + return torch.tensor(list(shape), device=mesh.device_type) + + local_shape = _create_local_shape_tensor(shape) + gathered_shaped_tensors = [ + torch.empty_like(local_shape, device=local_shape.device) + for _ in range(mesh.size()) + ] + funcol.all_gather_inplace(gathered_shaped_tensors, local_shape, mesh) + + @maybe_run_for_local_tensor + def _validate_and_compute_global_shape(local_shape, gathered_shaped_tensors): + sharded_dim_sum = 0 + shard_dim = placements[0].dim # type: ignore[union-attr] + other_dims = [d for d in range(len(shape)) if d != shard_dim] + for shape_tensor in gathered_shaped_tensors: + if not torch.equal(local_shape[other_dims], shape_tensor[other_dims]): + raise RuntimeError( + "Non-sharded dimensions should have identical size across ranks." + ) + shape_tensor_list = shape_tensor.tolist() + sharded_dim_sum += shape_tensor_list[shard_dim] + return sharded_dim_sum + + sharded_dim_sum = _validate_and_compute_global_shape( + local_shape, gathered_shaped_tensors + ) + global_shape = list(shape) + global_shape[placements[0].dim] = sharded_dim_sum + return torch.Size(global_shape) + else: + raise NotImplementedError( + f"Placement type {type(placements[0])} not supported." + ) + + +def try_find_mesh_from_args( + op_call: torch._ops.OpOverload, args: Sequence[object] +) -> DeviceMesh: + """ + Find the device mesh object from args. + It returns None if no mesh is found. + NOTE: we can optimize this search if needed + """ + for arg in args: + if isinstance(arg, (dtensor.DTensor, DTensorSpec)): + return arg.device_mesh + elif ( + isinstance(arg, (list, tuple)) + and len(arg) > 0 + and isinstance(arg[0], (dtensor.DTensor, DTensorSpec)) + ): + return arg[0].device_mesh + + raise ValueError(f"Cannot find device mesh from args for op : {op_call}.") + + +def compute_local_stride( + global_stride: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] +) -> tuple[int, ...]: + """ + Compute the stride of a local tensor shard, given the global stride of the DTensor. + NOTE: Currently this function is assuming the DTensor is evenly shardable. + """ + stride_divisors = [1] * len(global_stride) + for mesh_idx, p in enumerate(placements): + if p.is_shard(): + i = cast(Shard, p).dim + # tensor dimension i is sharded on mesh dimension mesh_idx, + # so we need to divide all the strides larger than stride[i] + # (by the submesh size) + for j in range(len(global_stride)): + if global_stride[j] > global_stride[i]: + stride_divisors[j] *= mesh.size(mesh_idx) + return tuple( + global_stride[i] // stride_divisors[i] for i in range(len(global_stride)) + ) + + +def normalize_to_torch_size(size) -> torch.Size: # type: ignore[no-untyped-def] + """ + Unify variable types of size argument to torch.Size + Acceptable types include: + int, Sequence[int], Tuple[int], Tuple[Sequence[int]], + or torch.Size + """ + if isinstance(size, torch.Size): + return size + + if isinstance(size, int): + torch_size = [size] + elif len(size) == 1 and isinstance(size[0], Sequence): + torch_size = list(size[0]) + else: + torch_size = list(size) + return torch.Size(torch_size) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/device_mesh.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/device_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..ca59ded5eb52bc0a3878e76077ad2879df4bf499 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/device_mesh.py @@ -0,0 +1,9 @@ +from torch.distributed.device_mesh import ( # noqa: F401 + _get_device_handle, + _mesh_resources, + DeviceMesh, + init_device_mesh, +) + + +__all__ = ["init_device_mesh", "DeviceMesh"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0012040d74a3e0caaf23a71c138681b9c372e591 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from collections.abc import Iterator +from contextlib import contextmanager + +from torch.distributed.tensor._api import DTensor +from torch.distributed.tensor.experimental._attention import context_parallel +from torch.distributed.tensor.experimental._func_map import local_map +from torch.distributed.tensor.experimental._register_sharding import register_sharding + + +__all__ = ["context_parallel", "implicit_replication", "local_map", "register_sharding"] + + +@contextmanager +def implicit_replication() -> Iterator[None]: + """ + This context manager allows :class:`DTensor` to implicitly treat all non-DTensors (``torch.Tensor``) + in the program be replicate :class:`DTensor` s during the operator computation. + + .. warning:: This might possible lead to incorrect results if ``torch.Tensor`` s are not replicated + in practice, please use it at your discretion. + """ + try: + DTensor._op_dispatcher._allow_implicit_replication = True + yield + finally: + DTensor._op_dispatcher._allow_implicit_replication = False + + +# Set namespace for exposed private names +context_parallel.__module__ = "torch.distributed.tensor.experimental" +implicit_replication.__module__ = "torch.distributed.tensor.experimental" +local_map.__module__ = "torch.distributed.tensor.experimental" +register_sharding.__module__ = "torch.distributed.tensor.experimental" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_attention.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..f238739ddd5cf4f8e120f1e6a0337f0cfc8cc58d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_attention.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Backward compatibility stub - this module has been moved to _context_parallel/_attention.py + +from ._context_parallel._attention import ( + _CausalBehavior, + _context_parallel_shard, + _ContextParallel, + _cp_options, + _disable_context_parallel_dispatcher, + _enable_context_parallel_dispatcher, + _is_causal_behavior, + _RotateMethod, + _templated_ring_attention, + context_parallel, + context_parallel_unshard, + set_rotate_method, +) +from ._context_parallel._load_balancer import ( + _HeadTailLoadBalancer, + _LoadBalancer, + _PerDocumentHeadTailLoadBalancer, + _PTRRLoadBalancer, +) + + +# TODO(fegin): add deprecation message once the final interfaces are concluded. +__all__ = [ + "_CausalBehavior", + "_context_parallel_shard", + "_ContextParallel", + "_cp_options", + "_disable_context_parallel_dispatcher", + "_enable_context_parallel_dispatcher", + "_is_causal_behavior", + "_RotateMethod", + "_templated_ring_attention", + "context_parallel", + "context_parallel_unshard", + "set_rotate_method", + "_HeadTailLoadBalancer", + "_LoadBalancer", + "_PerDocumentHeadTailLoadBalancer", + "_PTRRLoadBalancer", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_func_map.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_func_map.py new file mode 100644 index 0000000000000000000000000000000000000000..759841a40aaa14b3f985dc7bce730198617ada5b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_func_map.py @@ -0,0 +1,278 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import functools +from collections.abc import Callable, Sequence +from typing import Optional, Union + +import torch +from torch.distributed._functional_collectives import AsyncCollectiveTensor +from torch.distributed.tensor import DeviceMesh, DTensor +from torch.distributed.tensor.placement_types import Placement + + +try: + from torch.utils import _cxx_pytree as pytree +except ImportError: + from torch.utils import _pytree as pytree # type: ignore[no-redef] + + +__all__ = ["local_map"] + +PlacementType = Optional[Sequence[Placement]] +InputPlacements = Optional[tuple[PlacementType, ...]] +OutputPlacements = Union[PlacementType, tuple[PlacementType, ...]] + + +def local_map( + func: Callable | None = None, + out_placements: OutputPlacements = None, + in_placements: InputPlacements = None, + in_grad_placements: InputPlacements = None, + device_mesh: DeviceMesh | None = None, + *, + redistribute_inputs: bool = False, +): + """ + :meth:`local_map` is an experimental API that allows users to pass :class:`DTensor` s + to a function that is written to be applied on ``torch.Tensor`` s. It is done by extracting + the local components of :class:`DTensor`, call the function, and wrap the outputs to + :class:`DTensor` according to the ``out_placements``. + + Args: + func (Callable): the function to be applied on each local shard of + :class:`DTensor` s. + out_placements (Union[`PlacementType`, Tuple[`PlacementType`, ...]]): + the desired placements of the :class:`DTensor` s in ``func``'s flattened output. + If the flattened ``output`` is a single value, the ``out_placements`` should be + of type `PlacementType`. Otherwise if the flattened ``output`` has multiple + values, the ``out_placements`` should be a tuple of `PlacementType` values 1:1 + mapping to the flattened ``output``. + Besides, for :class:`Tensor` output, we use `PlacementType` as its + placements (a `Tuple[Placement]` value). For non-Tensor output, the `PlacementType` + should be `None`. + Note that the only exception is when no :class:`DTensor` argument is passed + in. In this case, even if `out_placements` is not `None`, the result function + should ignore the desired placements because the function is not running with + :class:`DTensor` s. + in_placements (Tuple[`PlacementType`, ...], optional): + the required placements of the :class:`DTensor` s in the flattened inputs of ``func``. + If ``in_placements`` is specified, :meth:`local_map` would examine whether the + placements of each :class:`DTensor` argument is the same as the required + placements or not. If the placements are not the same and + ``redistribute_inputs`` is ``False``, an exception will be raised. Otherwise if + ``redistribute_inputs`` is ``True``, the argument will be first redistributed to + the required sharding placements before passing its local tensor to ``func``. + The only exception is when required placements are not ``None`` and the + argument is a :class:`torch.Tensor`. In this case, the placements examination + will be skipped and the argument will be directly passed to ``func``. + If ``in_placements`` is ``None``, no placements examination will be performed. + Default: None + in_grad_placements (Tuple[`PlacementType`, ...], optional): + the placements hint of the :class:`DTensor` s gradient corresponds + to the flattened input DTensor. This argument is the hint that user + can give to :meth:`to_local` in case the gradient layout of the + local tensor input does not match its :class:`DTensor` input layout. + If not specified, we will assume the gradient layout of the local + tensor input remains the same as the original :class:`DTensor` input + and use that for gradient computation. Default: None. + device_mesh (:class:`DeviceMesh`, optional): + the device mesh that the output :class:`DTensor` s are placed on. If not + specified, this will be inferred from the first input :class:`DTensor`'s device + mesh. Default: None. + + Keyword Args: + redistribute_inputs (bool, optional): + the bool value indicating whether to reshard the input :class:`DTensor` s when + their placements are different from the required input placements. If this + value is ``False`` and some :class:`DTensor` input has a different placement, + an exception will be raised. Default: False. + + Returns: + A ``Callable`` that applies ``func`` to each local shard of the input :class:`DTensor` + and returns a :class:`DTensor` constructed from the return value of ``func``. + + Raises: + AssertionError: For any non-DTensor output, we require its corresponding + output placement in ``out_placements`` be None. An AssertionError will be raised + if this is not the case. + + ValueError: If ``redistribute_inputs=False`` but the input :class:`DTensor` needs + a redistribution according to ``in_placements``. + + Example: + >>> # xdoctest: +SKIP("distributed") + >>> def mm_allreduce_forward(device_mesh, W, X): + >>> partial_sum_tensor = torch.mm(W, X) + >>> reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh) + >>> return reduced_tensor + >>> + >>> W = torch.randn(12, 8, requires_grad=False) + >>> X = torch.randn(8, 16, requires_grad=False) + >>> Y = torch.mm(W, X) + >>> row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh + >>> col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh + >>> + >>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor conversion + >>> local_mm_allreduce_forward = local_map( + >>> mm_allreduce_forward, + >>> out_placements=[Replicate()], + >>> in_placements=[col_wise, row_wise], + >>> device_mesh=device_mesh, + >>> ) + >>> + >>> W_dt = distribute_tensor( + ... W, device_mesh, (col_wise) + ... ) # col-wisely sharded W tensor + >>> X_dt = distribute_tensor( + ... X, device_mesh, (row_wise) + ... ) # row-wisely sharded X tensor + >>> Y_dt = local_mm_allreduce_forward( + ... device_mesh, W_dt, X_dt + ... ) # apply local_mm_allreduce_forward to DTensors + + .. note:: This API is currently experimental and subject to change + """ + + if func is None: + # decorator mode + def decorated(func): + return local_map( + func=func, + out_placements=out_placements, + in_placements=in_placements, + in_grad_placements=in_grad_placements, + device_mesh=device_mesh, + redistribute_inputs=redistribute_inputs, + ) + + return decorated + + return functools.partial( + _local_map_wrapped, + func, + out_placements, + in_placements, + in_grad_placements, + device_mesh, + redistribute_inputs, + ) + + +def _local_map_wrapped( + func: Callable, + out_placements: OutputPlacements, + in_placements: InputPlacements, + in_grad_placements: InputPlacements, + device_mesh: DeviceMesh | None, + redistribute_inputs: bool, + *args, + **kwargs, +): + # process input args + flat_args, args_spec = pytree.tree_flatten(args) + if in_placements is not None: + assert len(in_placements) == len(flat_args), ( + f"in_placements length {len(in_placements)} does not match the number " + f"of input args {len(flat_args)}!" + ) + + # we assume every DTensor object is placed on the same device mesh + flat_local_args = [] + seen_dtensor_arg = False + for idx, arg in enumerate(flat_args): + if isinstance(arg, DTensor): + # TODO: the current code doesn't consider the uneven sharding case + # Need to think about what the consequence is when the input DTensor + # is uneven sharded. + if device_mesh is None: # infer device mesh from the DTensor arg + device_mesh = arg.device_mesh + + # this function is applied to at least one DTensor argument + seen_dtensor_arg = True + + if in_placements is not None: + spec = in_placements[idx] + assert spec is not None, ( + f"DTensor input {arg} expects placements but received {spec}!" + ) + + if not isinstance(spec, tuple): + spec = tuple(spec) + + if arg.placements != spec: + if redistribute_inputs: + # redistribute to input placements + arg = arg.redistribute(placements=spec) + else: + raise ValueError( + f"arg {arg} in local_map has a mismatched placements: " + f"arg placements is {arg.placements} but the input " + f"placements is {spec}! " + "If redistribute_inputs is wanted, set " + "redistribute_inputs=True to local_map." + ) + + if in_grad_placements is not None: + spec = in_grad_placements[idx] + assert spec is not None, ( + f"DTensor input {arg} expects in grad placements but received {spec}!" + ) + if not isinstance(spec, tuple): + spec = tuple(spec) + local_arg = arg.to_local(grad_placements=spec) + else: + local_arg = arg.to_local() + + if isinstance(local_arg, AsyncCollectiveTensor): + local_arg = local_arg.wait() + + flat_local_args.append(local_arg) + else: + # Non-Tensor input must have None in `in_placements` + if in_placements is not None and not isinstance(arg, torch.Tensor): + spec = in_placements[idx] + assert spec is None, ( + f"Non-Tensor input {arg} expects None placements " + f"but received {spec}!" + ) + + flat_local_args.append(arg) + + # pyrefly: ignore [bad-argument-type] + local_args = pytree.tree_unflatten(flat_local_args, args_spec) + + out = func(*local_args, **kwargs) + + if seen_dtensor_arg: + # process output to be DTensor if we've seen DTensor inputs + flat_out, out_spec = pytree.tree_flatten(out) + + flat_dist_out = [] + out_placements_tuple = ( + out_placements if isinstance(out_placements, tuple) else (out_placements,) + ) + assert len(flat_out) == len(out_placements_tuple), ( + "local_map requires one PlacementType be provided for each output value," + f" received {len(out_placements_tuple)} out_placements but" + f" {len(flat_out)} is expected!" + ) + for out, spec in zip(flat_out, out_placements_tuple): + if isinstance(out, torch.Tensor): + assert not isinstance(out, DTensor), ( + f"torch.Tensor output expected but received {type(out)}: {out}" + ) + + flat_dist_out.append( + DTensor.from_local(out, device_mesh, spec, run_check=False) + ) + else: + assert spec is None, ( + f"Non-tensor output {out} expects None placements but received {spec}!" + ) + + flat_dist_out.append(out) + + # pyrefly: ignore [bad-argument-type] + return pytree.tree_unflatten(flat_dist_out, out_spec) + else: + return out diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_register_sharding.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_register_sharding.py new file mode 100644 index 0000000000000000000000000000000000000000..7b365dcf286d03be9628c5f909682bcd0a818f7e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_register_sharding.py @@ -0,0 +1,136 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from collections.abc import Callable, Sequence +from functools import partial + +import torch +from torch._ops import OpOverload +from torch.distributed.tensor import DTensor +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementList, + RuntimeSchemaInfo, + StrategyType, + TupleStrategy, +) +from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy + + +__all__ = ["register_sharding"] + + +def register_sharding(op: OpOverload | list[OpOverload]): + """ + :meth:`register_sharding` is an experimental API that allows users to register sharding + strategies for an operator when the tensor inputs and outputs are DTensor. + It can be useful when: (1) there doesn't exist a default sharding strategy for ``op``, + e.g. when ``op`` is a custom operator that is not supported by :class:`DTensor`; (2) + when users would like to overwrite default sharding strategies of existing operators. + + Args: + op (Union[OpOverload, List[OpOverload]]): + An op or a list of ops to register the customized sharding function. + + Returns: + A function decorator which can be used to wrap a function that defines the sharding + strategy for the operator specified in ``op``. The defined sharding strategy will be + registered to DTensor and will override the default sharding strategy if DTensor has + already implemented the operator. The customized sharding function takes the same inputs + as the original op (except that if an arg is a :class:`torch.Tensor`, it will be + replaced by a tensor-like object that DTensor uses internally). The function should + return a sequence of 2-tuples, each specifying acceptable output placements and its + corresponding input placements. + + Example: + >>> # xdoctest: +SKIP("distributed") + >>> @register_sharding(aten._softmax.default) + >>> def custom_softmax_sharding(x, dim, half_to_float): + >>> softmax_dim = dim if dim >= 0 else dim + x.ndim + >>> acceptable_shardings = [] + >>> + >>> all_replicate = ([Replicate()], [Replicate(), None, None]) + >>> acceptable_shardings.append(all_replicate) + >>> + >>> for sharding_dim in range(x.ndim): + >>> if sharding_dim != softmax_dim: + >>> all_sharded = ( + >>> [Shard(sharding_dim)], + >>> [Shard(sharding_dim), None, None], + >>> ) + >>> acceptable_shardings.append(all_sharded) + >>> + >>> return acceptable_shardings + + .. note:: This API is currently experimental and subject to change + """ + + def custom_strategy( + custom_sharding_fn: Callable[ + ..., Sequence[tuple[PlacementList, PlacementList]] + ], + op_schema: OpSchema, + ) -> StrategyType: + def strategy_to_spec(strategy: object) -> object: + if isinstance(strategy, OpStrategy): + # take the output spec from the first strategy + return strategy.strategies[0].output_spec + elif isinstance(strategy, TupleStrategy): + return tuple(strategy_to_spec(s) for s in strategy.children) + else: + return strategy + + mesh = op_schema.get_mesh_from_args() + + args_schema = tuple(strategy_to_spec(i) for i in op_schema.args_schema) + kwargs_schema = { + k: strategy_to_spec(v) for k, v in op_schema.kwargs_schema.items() + } + + acceptable_shardings = custom_sharding_fn(*args_schema, **kwargs_schema) + + single_mesh_dim_strategies: list[PlacementList] = [] + for output_specs, input_specs in acceptable_shardings: + single_mesh_dim_strategies.append(output_specs + input_specs) + + # TODO: handle out variant ops + return expand_to_full_mesh_op_strategy( + mesh, + op_schema, + single_mesh_dim_strategies, + input_index=len(op_schema.op._schema.returns), + inplace_op=op_schema.is_inplace_op(), + ) + + def wrapper(custom_sharding_fn): + def derive_schema_info(op): + # NOTE: without user directly providing RuntimeSchemaInfo, for now + # we create it in a conservative fashion as follows: + # 1. let static_argnum be the first int argument + # 2. let static_kwargkey include all the int type kwargs + # 3. always set needs_pytree=True + static_argnum = 100 + static_kwargkey: list[str] = [] + for i, arg in enumerate(op._schema.arguments): + if isinstance(arg.type, torch.IntType) or ( + isinstance(arg.type, torch.OptionalType) + and isinstance(arg.type.getElementType(), torch.IntType) + ): + static_argnum = min(i, static_argnum) + if arg.kwarg_only: + static_kwargkey.append(arg.name) + return RuntimeSchemaInfo( + static_argnum, static_kwargkey or None, needs_pytree=True + ) + + overloads = op if isinstance(op, list) else [op] + for overload in overloads: + DTensor._op_dispatcher.sharding_propagator.register_op_strategy( + overload, + partial(custom_strategy, custom_sharding_fn), + derive_schema_info(overload), + ) + + return custom_sharding_fn + + return wrapper diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_tp_transform.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_tp_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..1075df79f33956d710348330b38f56228ebc871b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_tp_transform.py @@ -0,0 +1,557 @@ +# mypy: allow-untyped-defs +import copy +import operator +from collections.abc import Sequence +from typing import Any, cast + +import torch +from torch._subclasses.fake_tensor import FakeTensor +from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpSpec, + OutputSharding, + OutputSpecType, +) +from torch.distributed.tensor._redistribute import redistribute_local_tensor +from torch.distributed.tensor.parallel.style import ColwiseParallel, ParallelStyle +from torch.distributed.tensor.placement_types import Placement, Replicate, Shard +from torch.export import ExportedProgram +from torch.export.exported_program import ExportGraphSignature +from torch.fx import GraphModule +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.node import Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.utils import _pytree as pytree + + +__all__ = ["tensor_parallel_transformation"] + +aten = torch.ops.aten + + +def tensor_parallel_transformation( + exported_program: ExportedProgram, + rank: int, + world_size: int, + device_type: str, + parallel_strategies: dict[str, ParallelStyle], +) -> ExportedProgram: + """ + The entry point function to perform graph transformations on an exported program + to transform a single-device graph into a tensor parallel graph. + + .. warning:: + This API is experimental and subject to change. + """ + + gm = exported_program.graph_module + sig = copy.deepcopy(exported_program.graph_signature) + state_dict = copy.copy(exported_program.state_dict) + + with gm._set_replace_hook(sig.get_replace_hook()): + res = _TensorParallelTransformPass( + rank, + world_size, + device_type, + state_dict, + exported_program.graph_signature, + parallel_strategies, + )(gm) + assert res is not None + gm = res.graph_module + + return exported_program._update(gm, sig, state_dict=state_dict) + + +class _TensorParallelTransformPass(PassBase): + """ + This pass is responsible for transforming a single-device graph into a tensor parallel + graph. It will mark the OpSpec of each node in the graph, partition the graph into + distributed graph, then shard the parameters/buffers accordingly. + """ + + def __init__( + self, + rank: int, + world_size: int, + device_type: str, + state_dict: dict[str, torch.Tensor], + graph_signature: ExportGraphSignature, + parallel_strategies: dict[str, ParallelStyle], + ) -> None: + super().__init__() + self.rank = rank + self.mesh = DeviceMesh(device_type, torch.arange(world_size)) + self.state_dict: dict[str, torch.Tensor] = state_dict + self.graph_signature = graph_signature + self.parallel_strategies = parallel_strategies + + def call(self, graph_module) -> PassResult: + gm = copy.deepcopy(graph_module) + + parameter_placements = _generate_parameter_and_buffer_placements( + list(self.state_dict.keys()), self.parallel_strategies + ) + placement_strategies = _mark_sharding( + gm, self.graph_signature, self.mesh, parameter_placements + ) + _partitioner(gm) + _shard_state_dict( + self.state_dict, placement_strategies, self.graph_signature, self.mesh + ) + return PassResult(gm, True) + + +def _generate_parameter_and_buffer_placements( + params_and_buffers: list[str], + parallel_strategies: dict[str, ParallelStyle], +) -> dict[str, Placement]: + """ + Build parameter placements based on the give parallel style of linear layers. + """ + parameter_placements: dict[str, Placement] = {} + for linear_fqn, parallel_style in parallel_strategies.items(): + weight_fqn = f"{linear_fqn}.weight" + bias_fqn = f"{linear_fqn}.bias" + assert weight_fqn in params_and_buffers + parameter_placements[weight_fqn] = ( + Shard(0) if parallel_style == ColwiseParallel else Shard(1) + ) + if bias_fqn in params_and_buffers: + parameter_placements[bias_fqn] = ( + Shard(0) if parallel_style == ColwiseParallel else Replicate() + ) + return parameter_placements + + +def _mark_tensor_parallel_shardings( + gm: GraphModule, + graph_signature: ExportGraphSignature, + mesh: DeviceMesh, + parameter_placements: dict[str, Placement], +) -> dict[Node, OpSpec]: + """ + Mark the placement strategies of the parameter and buffer placeholder nodes. + """ + placement_strategies: dict[Node, OpSpec] = {} + num_params_and_buffers = len(graph_signature.inputs_to_parameters) + len( + graph_signature.inputs_to_buffers + ) + placeholder_idx: int = 0 + for node in gm.graph.nodes: + if node.op == "placeholder": + if placeholder_idx < num_params_and_buffers: + fqn: str = _get_input_node_fqn(node.name, graph_signature) + placement: Placement = ( + parameter_placements[fqn] + if fqn in parameter_placements + else Replicate() + ) + placement_strategies[node] = _create_placement_strategy( + node, + mesh, + placements=(placement,), + ) + placeholder_idx += 1 + else: + placement_strategies[node] = _create_placement_strategy( + node, + mesh, + placements=(Replicate(),), + ) + return placement_strategies + + +def _get_input_node_fqn(input_name: str, graph_signature: ExportGraphSignature) -> str: + """ + Return the FQN of an input node. + """ + if input_name in graph_signature.inputs_to_parameters: + return graph_signature.inputs_to_parameters[input_name] + elif input_name in graph_signature.inputs_to_buffers: + return graph_signature.inputs_to_buffers[input_name] + else: + raise ValueError( + f"{input_name} not found in inputs_to_parameters or inputs_to_buffers" + ) + + +def _mark_sharding( + gm: GraphModule, + graph_signature: ExportGraphSignature, + mesh: DeviceMesh, + parameter_placements: dict[str, Placement], +) -> dict[Node, OpSpec]: + """ + Mark the sharding strategy for each node in the graph module. + """ + placement_strategies: dict[Node, OpSpec] = _mark_tensor_parallel_shardings( + gm, + graph_signature, + mesh, + parameter_placements, + ) + + for node in gm.graph.nodes: + if node.op == "placeholder": + if node not in placement_strategies: + placement_strategies[node] = _create_placement_strategy( + node, mesh, placements=(Replicate(),) + ) + node.meta["sharding"] = placement_strategies[node] + elif node.op == "call_function": + if node.target is operator.getitem: + input_nodes = node.all_input_nodes + assert len(input_nodes) == 1, ( + f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}" + ) + arg_strategy = placement_strategies[input_nodes[0]] + placement_strategies[node] = _create_placement_strategy( + node, + mesh, + placements=arg_strategy.output_spec.placements, + input_specs=_get_input_node_specs(node, placement_strategies), + ) + node.meta["sharding"] = placement_strategies[node] + else: + op_schema = _get_op_schema(node, placement_strategies) + + # get DTensor specs for inputs and outputs + if ( + op_schema.op + not in DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs + and op_schema.op + not in DTensor._op_dispatcher.sharding_propagator.op_to_rules + ): + # Mark all as replicated + output_sharding = _generate_default_output_sharding( + node, + mesh, + op_schema, + ) + else: + output_sharding = DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding( # type: ignore[assignment] + op_schema, + ) + placement_strategies[node] = OpSpec( + # pyrefly: ignore [bad-argument-type] + output_specs=_get_output_spec_from_output_sharding(output_sharding), + # pyrefly: ignore [missing-attribute] + input_specs=output_sharding.redistribute_schema.args_spec + # pyrefly: ignore [missing-attribute] + if output_sharding.redistribute_schema is not None + else _get_input_node_specs(node, placement_strategies), + ) + node.meta["sharding"] = placement_strategies[node] + elif node.op == "output": + node.meta["sharding"] = None + else: + raise RuntimeError(f"op code {node.op} not supported") + return placement_strategies + + +def _get_output_spec_from_output_sharding( + output_sharding: OutputSharding, +) -> DTensorSpec: + """ + Util function to extract output spec from output sharding. + """ + if isinstance(output_sharding.output_spec, DTensorSpec): + return output_sharding.output_spec + else: + # For ops that return multiple outputs, the outputs should have the same output spec + assert isinstance(output_sharding.output_spec, Sequence) + assert output_sharding.output_spec[0] is not None + output_sharding.output_spec[0].tensor_meta = None + return output_sharding.output_spec[0] + + +def _create_placement_strategy( + node: Node, + mesh: DeviceMesh, + placements: tuple[Placement, ...], + input_specs: Sequence[DTensorSpec] | None = None, +) -> OpSpec: + """ + Util function to construct an OpSpec for a given node. + """ + placement = OpSpec( + input_specs=input_specs, + output_specs=DTensorSpec( + mesh=mesh, + placements=placements, + ), + ) + _populate_tensor_meta(node, placement.output_specs) + return placement + + +def _populate_tensor_meta(node: Node, output_spec: OutputSpecType) -> None: + """ + Util function to populate tensor meta of output_spec based on node metadata. + """ + if isinstance(node.meta["val"], Sequence): + assert isinstance(output_spec, Sequence) + for spec, fake_tensor in zip(output_spec, node.meta["val"]): + assert spec is not None + spec.tensor_meta = TensorMeta( + shape=fake_tensor.shape, + stride=fake_tensor.stride(), + dtype=fake_tensor.dtype, + ) + else: + assert isinstance(output_spec, DTensorSpec) + output_spec.tensor_meta = TensorMeta( + shape=node.meta["val"].shape, + stride=node.meta["val"].stride(), + dtype=node.meta["val"].dtype, + ) + + +def _generate_default_output_sharding( + node: Node, + mesh: DeviceMesh, + op_schema: OpSchema, +) -> OutputSharding: + """ + Util function to create a default output sharding that suggests Replicate placement for both args and outputs. + """ + + def update_arg_spec(arg_spec: DTensorSpec) -> DTensorSpec: + return DTensorSpec( + mesh=arg_spec.mesh, + placements=(Replicate(),), + tensor_meta=arg_spec.tensor_meta, + ) + + new_op_schema = OpSchema( + op=op_schema.op, + args_schema=pytree.tree_map_only( + DTensorSpec, update_arg_spec, op_schema.args_schema + ), + kwargs_schema=op_schema.kwargs_schema, + ) + + def create_output_spec(tensor: FakeTensor) -> DTensorSpec: + return DTensorSpec( + mesh=mesh, + placements=(Replicate(),), + tensor_meta=TensorMeta( + shape=tensor.shape, + stride=tensor.stride(), + dtype=tensor.dtype, + ), + ) + + return OutputSharding( + output_spec=pytree.tree_map_only( + FakeTensor, create_output_spec, node.meta["val"] + ), + redistribute_schema=new_op_schema, + needs_redistribute=True, + ) + + +def _partitioner(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Graph partitioner that partitions the single device graph + to distributed graph + """ + for node in gm.graph.nodes: + node_sharding = node.meta["sharding"] + if node.op == "placeholder": + out_spec = node_sharding.output_spec + local_val = _partition_val(node.meta["val"], out_spec) + # update node value + node.meta["val"] = local_val + elif node.op == "call_function": + out_spec = node_sharding.output_spec + # check if there's misaligned sharding, insert reshard if there is + expected_input_specs = node_sharding.input_specs + for idx, input_arg in enumerate(node.all_input_nodes): + input_arg_sharding = input_arg.meta["sharding"] + input_arg_spec = input_arg_sharding.output_spec + desired_spec = ( + out_spec + if expected_input_specs is None + else expected_input_specs[idx] + ) + if input_arg_spec != desired_spec: + _insert_reshard_gm( + gm, node, input_arg, input_arg_spec, desired_spec + ) + # convert output val to its local component + output_val = node.meta["val"] + node.meta["val"] = _partition_val(output_val, out_spec) + elif node.op == "output": + for input_arg in node.all_input_nodes: + # input args of output should be Replicate, otherwise redistribution is needed. + input_args_to_check: Sequence[Node] = ( + input_arg if isinstance(input_arg, Sequence) else [input_arg] + ) + for arg in input_args_to_check: + arg_sharding = arg.meta["sharding"] + arg_spec = arg_sharding.output_spec + desired_spec = copy.copy(arg_spec) + desired_spec.placements = (Replicate(),) + if arg_spec != desired_spec: + _insert_reshard_gm(gm, node, arg, arg_spec, desired_spec) + else: + raise RuntimeError(f"op code {node} not supported") + + _clean_up_graph_metadata(gm) + gm.graph.lint() + gm.recompile() + return gm + + +def _partition_val(val: Any, spec: DTensorSpec) -> Any: + """ + util function to convert a full tensor val to its local component + """ + if isinstance(val, torch.Tensor): + local_shard = val + if val.ndim == 0: + # If it's already a scalar tensor, it is already local, we don't + # need to do anything + return local_shard + + for idx, placement in enumerate(spec.placements): + if placement.is_shard(): + placement = cast(Shard, placement) + num_chunks = spec.mesh.size(mesh_dim=idx) + my_coord = spec.mesh.get_coordinate() + assert my_coord is not None, "current rank not in mesh!" + my_coord_on_mesh_dim = my_coord[idx] + local_shard = placement._split_tensor( + local_shard, num_chunks, with_padding=False, contiguous=True + )[0][my_coord_on_mesh_dim] + return local_shard + elif isinstance(val, (list, tuple)): + return val.__class__(_partition_val(v, spec) for v in val) + else: + raise RuntimeError(f"val type {type(val)} not supported") + + +def _insert_reshard_gm( + gm: torch.fx.GraphModule, + node: Node, + input_arg: Node, + input_arg_spec: DTensorSpec, + desired_spec: DTensorSpec, +) -> None: + """ + Transform the graph for tensor redistribution. + """ + input_arg_spec.tensor_meta = input_arg.meta["tensor_meta"] + desired_spec.tensor_meta = input_arg.meta["tensor_meta"] + input_arg_tensor = input_arg.meta["val"] + + # insert reshard operation + def reshard_fn(local_tensor: torch.Tensor) -> torch.Tensor: + return redistribute_local_tensor( + local_tensor, + input_arg_spec, + desired_spec, + ) + + reshard_gm = make_fx(reshard_fn)(input_arg_tensor) + reshard_gm_nodes = list(reshard_gm.graph.nodes) + input_node = reshard_gm_nodes[0] + with gm.graph.inserting_before(node): + # copy nn_module_stack metadata for output, all-reduce nodes + for reshard_node in reshard_gm.graph.nodes: + if reshard_node.op not in ["placeholder", "output"]: + reshard_node.meta["nn_module_stack"] = ( + copy.copy(input_arg.meta["nn_module_stack"]) + if input_arg.op != "placeholder" + else copy.copy(node.meta["nn_module_stack"]) + ) + output_node = gm.graph.graph_copy( + reshard_gm.graph, + val_map={ + input_node: input_arg, + }, + ) + node.replace_input_with(input_arg, output_node) # type: ignore[arg-type] + + +def _clean_up_graph_metadata(gm: torch.fx.GraphModule) -> None: + """ + Clean up the graph by removing sharding and partitioning related metadata + """ + for node in gm.graph.nodes: + if "sharding" in node.meta: + del node.meta["sharding"] + if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor): + local_tensor_meta = _extract_tensor_metadata(node.meta["val"]) + node.meta["tensor_meta"] = local_tensor_meta + + +def _get_input_node_specs( + node: Node, placement_strategies: dict[Node, OpSpec] +) -> tuple[DTensorSpec, ...]: + """ + Get the input specs of a node. + """ + input_specs_list: list[DTensorSpec] = [] + for input_arg in node.all_input_nodes: + if input_arg in placement_strategies: + output_spec = placement_strategies[input_arg].output_specs + assert isinstance(output_spec, DTensorSpec) + input_specs_list.append(output_spec) + else: + raise ValueError(f"{input_arg} does not have output_spec populated.") + return tuple(input_specs_list) + + +def _get_op_schema(node: Node, placement_strategies: dict[Node, OpSpec]) -> OpSchema: + """ + Util function to construct the operator schema of a node. + """ + args_schema_list = pytree.tree_map_only( + Node, lambda arg: placement_strategies[arg].output_specs, node.args + ) + op_schema = OpSchema( + op=cast(torch._ops.OpOverload, node.target), + args_schema=tuple(args_schema_list), + kwargs_schema=cast(dict[str, object], node.kwargs), + ) + return op_schema + + +def _shard_state_dict( + state_dict: dict[str, torch.Tensor], + placement_strategies: dict[Node, OpSpec], + graph_signature: ExportGraphSignature, + mesh: DeviceMesh, +) -> None: + """ + Inplace partition the weights based on the OpSpec + """ + for node, op_spec in placement_strategies.items(): + if node.op != "placeholder": + continue + if node.name in graph_signature.inputs_to_parameters: + fqn = graph_signature.inputs_to_parameters[node.name] + elif node.name in graph_signature.inputs_to_buffers: + fqn = graph_signature.inputs_to_buffers[node.name] + else: + continue + assert fqn in state_dict, f"{fqn} not found in state dict: {state_dict.keys()}" + + original_param = state_dict[fqn] + dtensor_param = distribute_tensor( + original_param, + mesh, + op_spec.output_spec.placements, + ) + local_param = dtensor_param.to_local() + state_dict[fqn] = ( + torch.nn.Parameter(local_param) + if isinstance(original_param, torch.nn.Parameter) + else local_param + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/placement_types.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/placement_types.py new file mode 100644 index 0000000000000000000000000000000000000000..cdeaf359bc2f9e2d273ae40a5a122eea376e07c9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/distributed/tensor/placement_types.py @@ -0,0 +1,1114 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +from dataclasses import dataclass, field +from typing import cast, Optional + +import torch +import torch._C +import torch.distributed._functional_collectives as funcol +from torch._C._distributed import Placement +from torch.distributed._local_tensor import maybe_run_for_local_tensor +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._collective_utils import ( + fill_empty_tensor_to_shards, + mesh_broadcast, + mesh_scatter, + pad_tensor, + shard_dim_alltoall, + unpad_tensor, +) +from torch.distributed.tensor._ops._mask_buffer import MaskBuffer + + +__all__ = ["Placement", "Shard", "Replicate", "Partial", "MaskPartial"] + + +# Appease TestPublicBindings.test_correct_module_names +Placement.__module__ = "torch.distributed.tensor.placement_types" + + +class Shard(torch._C._distributed.Shard): + """ + The ``Shard(dim)`` placement describes the DTensor sharding on tensor dimension + ``dim`` over a corresponding ``DeviceMesh`` dimension, where each rank on the + DeviceMesh dimension only holds a shard/piece of the global Tensor. The + ``Shard(dim)`` placement follows the ``torch.chunk(dim)`` semantic, where the + last few shards on the DeviceMesh dimension might be empty when the tensor dimension + is not evenly divisible on the DeviceMesh dimension. The ``Shard`` placement can be + used by all DTensor APIs (i.e. distribute_tensor, from_local, etc.) + + Args: + dim (int): The tensor dimension that describes the DTensor is sharded over its + corresponding DeviceMesh dimension. + + .. warning:: sharding on a tensor dimension where the tensor dimension size is not + evenly divisible on a DeviceMesh dimension is currently experimental and subject to change. + """ + + def _split_tensor( + self, + tensor: torch.Tensor, + num_chunks: int, + *, + with_padding: bool = True, + contiguous: bool = True, + ) -> tuple[list[torch.Tensor], list[int]]: + """ + This function uses torch.chunk to split a tensor into num_chunks shards along + the Shard placement dimension, and return a list of shards with their pad sizes. + + Keyword args: + with_padding (bool, optional): when True, we pad the tensor on the last + few ranks before calling the collectives (i.e. scatter/all_gather, etc.). + This is because collectives usually require equal size tensor inputs + """ + assert self.dim <= tensor.ndim, ( + f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" + ) + + # chunk tensor over dimension `dim` into n slices + tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim)) + tensor_list = fill_empty_tensor_to_shards( + tensor_list, self.dim, num_chunks - len(tensor_list) + ) + + # compute the chunk size inline with ``torch.chunk`` to calculate padding + full_chunk_size = (tensor.size(self.dim) + num_chunks - 1) // num_chunks + + shard_list: list[torch.Tensor] = [] + pad_sizes: list[int] = [] + for shard in tensor_list: + if with_padding: + pad_size = Shard._get_shard_pad_size(full_chunk_size, shard, self.dim) + shard = pad_tensor(shard, self.dim, pad_size) + pad_sizes.append(pad_size) + if contiguous: + shard = shard.contiguous() + shard_list.append(shard) + return shard_list, pad_sizes + + @staticmethod + @maybe_run_for_local_tensor + def local_shard_size_and_offset( + curr_local_size: int, + num_chunks: int, + rank: int, + ) -> tuple[int, int]: + """ + Given the size of the current local tensor (which may already be sharded on some dimensions), + computes the new local shard size and offset given the desired number of chunks + (num_chunks is generally equal to the size of the current sharding dim). + + Note: new local shard offset is relative to the current sharded tensor, not the global tensor. + See `_utils.compute_local_shape_and_global_offset` for computing global offset. + + Returns (new local shard size, offset) + + """ + # Compute the chunk size inline with ``torch.chunk`` + if curr_local_size % num_chunks == 0: + full_chunk_size = curr_local_size // num_chunks + return full_chunk_size, full_chunk_size * rank + + # uneven sharding case + full_chunk_size = (curr_local_size + num_chunks - 1) // num_chunks + shard_starting_idx = full_chunk_size * rank + + if curr_local_size < shard_starting_idx: + return 0, curr_local_size + else: + local_shard_size = ( + min(curr_local_size, shard_starting_idx + full_chunk_size) + - shard_starting_idx + ) + return local_shard_size, shard_starting_idx + + def _local_shard_size_and_offset( + self, + curr_local_size: int, + num_chunks: int, + rank: int, + ) -> tuple[int, int | None]: + return Shard.local_shard_size_and_offset(curr_local_size, num_chunks, rank) + + @staticmethod + @maybe_run_for_local_tensor + def _maybe_unpad_tensor_with_sizes( + dim, local_tensor, pad_sizes, mesh_dim_local_rank, make_contiguous + ) -> torch.Tensor: + # Only unpad if the local_tensor was padded on the dimension. + if pad_sizes[mesh_dim_local_rank] > 0: + local_tensor = unpad_tensor( + local_tensor, dim, pad_sizes[mesh_dim_local_rank] + ) + if make_contiguous: + local_tensor = local_tensor.contiguous() + return local_tensor + + def _shard_tensor( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + src_data_rank: int | None = 0, + ) -> torch.Tensor: + """ + Shard and scatter a tensor on a mesh dimension (use coordinate 0 on the + mesh dimension as source of truth). + + Create the local tensor for this rank following the given Shard + placement. If src_data_rank is None, perform only local splitting. + Otherwise, additionally scatter data from src_data_rank. Unlike + ``_split_tensor``, which supports uneven sharding via padding, this + method requires the tensor dimension to be evenly divisible by the + number of chunks (mesh dimension size). + """ + my_coordinate = mesh.get_coordinate() + num_chunks = mesh.size(mesh_dim=mesh_dim) + + if my_coordinate is None: + # if rank is not part of mesh, we simply return an empty tensor + return tensor.new_empty(0, requires_grad=tensor.requires_grad) + + mesh_dim_local_rank = my_coordinate[mesh_dim] + + if src_data_rank is None: + # src_data_rank specified as None explicitly means to skip the + # communications, simply split + scatter_list, _ = self._split_tensor( + tensor, num_chunks, with_padding=False, contiguous=True + ) + + return self._select_shard(scatter_list, mesh_dim_local_rank) + + scatter_list, pad_sizes = self._split_tensor( + tensor, num_chunks, with_padding=True, contiguous=True + ) + + it = iter(scatter_list) + first = next(it) + # Tensors in the scatter list are expected to have the same shape because + # split is requested with padding. + assert all(first.shape == v.shape for v in it) + + output = torch.empty_like(first) + + # perform scatter from the src_data_rank as data source when it is not None + mesh_scatter( + output, scatter_list, mesh, mesh_dim=mesh_dim, group_src=src_data_rank + ) + + return Shard._maybe_unpad_tensor_with_sizes( + self.dim, output, pad_sizes, mesh_dim_local_rank, True + ) + + @classmethod + def _make_shard_tensor( + cls, + dim: int, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + src_data_rank: int | None = 0, + ) -> torch.Tensor: + shard_placement = cls(dim) + return shard_placement._shard_tensor(tensor, mesh, mesh_dim, src_data_rank) + + def _reduce_shard_tensor( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + reduce_op: str, + mesh_dim: int, + ) -> torch.Tensor: + """ + reduce and scatter a tensor on a mesh dimension + """ + my_coordinate = mesh.get_coordinate() + num_chunks = mesh.size(mesh_dim=mesh_dim) + + if my_coordinate is None: + # if rank is not part of mesh, we simply return local_tensor, + # which should be an empty tensor + return tensor + + is_padded = tensor.size(self.dim) % num_chunks != 0 + pad_sizes = None + if is_padded: + scattered_list, pad_sizes = self._split_tensor( + tensor, num_chunks, with_padding=True, contiguous=True + ) + tensor = torch.cat(scattered_list, dim=self.dim) + elif not tensor.is_contiguous(): + tensor = tensor.contiguous() + + output = funcol.reduce_scatter_tensor( + tensor, reduce_op, scatter_dim=self.dim, group=(mesh, mesh_dim) + ) + + if is_padded: + assert pad_sizes is not None + output = Shard._maybe_unpad_tensor_with_sizes( + self.dim, output, pad_sizes, my_coordinate[mesh_dim], False + ) + return output + + @maybe_run_for_local_tensor + def _maybe_pad_tensor( + self, + local_tensor: torch.Tensor, + logical_dim_size: int, + num_chunks: int, + ) -> torch.Tensor: + is_padded = logical_dim_size % num_chunks != 0 + + if is_padded: + full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks + pad_size = full_chunk_size - local_tensor.size(self.dim) + local_tensor = pad_tensor(local_tensor, self.dim, pad_size) + + if not local_tensor.is_contiguous(): + local_tensor = local_tensor.contiguous() + + return local_tensor + + @maybe_run_for_local_tensor + def _maybe_unpad_tensor( + self, + local_tensor: torch.Tensor, + logical_dim_size: int, + num_chunks: int, + ) -> torch.Tensor: + is_padded = logical_dim_size % num_chunks != 0 + + if is_padded: + full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks + unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined] + local_tensor = unpad_tensor(local_tensor, self.dim, unpad_size) + + return local_tensor + + def _to_replicate_tensor( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + current_logical_shape: list[int], + ) -> torch.Tensor: + """ + This function all_gather all shards and return a tensor that + is replicated on the previously sharded mesh dimension + """ + num_chunks = mesh.size(mesh_dim=mesh_dim) + logical_dim_size = current_logical_shape[self.dim] + + local_tensor = self._maybe_pad_tensor( + local_tensor, logical_dim_size, num_chunks + ) + + result = funcol.all_gather_tensor( + local_tensor, + gather_dim=self.dim, + group=(mesh, mesh_dim), + ) + + result = self._maybe_unpad_tensor(result, logical_dim_size, num_chunks) + + return result + + @staticmethod + @maybe_run_for_local_tensor + def _select_shard(shards: list[torch.Tensor], shard_index) -> torch.Tensor: + return shards[shard_index].clone() + + def _replicate_to_shard( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_index: int, + ) -> torch.Tensor: + """ + transform from replicated tensor to a sharded tensor on + the current rank, which would perform a local chunk + """ + num_chunks = mesh.size(mesh_dim=mesh_dim) + shards, _ = self._split_tensor( + local_tensor, + num_chunks, + with_padding=False, + contiguous=False, + ) + + return Shard._select_shard(shards, shard_index) + + @staticmethod + @maybe_run_for_local_tensor + def _get_shard_pad_size( + full_size: int, local_tensor: torch.Tensor, dim: int + ) -> int: + """ + Get the padding size of the local tensor on the shard dimension. + """ + return full_size - local_tensor.size(dim) + + @staticmethod + def _compute_padding_info( + current_logical_shape: list[int], + num_chunks: int, + old_shard_dim: int, + new_shard_dim: int, + ) -> tuple[bool, int, int, bool, int, int]: + results = [] + for shard_dim in [old_shard_dim, new_shard_dim]: + dim_logical_size = current_logical_shape[shard_dim] + dim_padding = dim_logical_size % num_chunks != 0 + dim_full_chunk_size = (dim_logical_size + num_chunks - 1) // num_chunks + results.append((dim_padding, dim_logical_size, dim_full_chunk_size)) + + return results[0] + results[1] + + @staticmethod + @maybe_run_for_local_tensor + def _pad_for_new_shard_dim( + current_logical_shape: list[int], + local_tensor: torch.Tensor, + num_chunks: int, + old_shard_dim: int, + new_shard_dim: int, + ) -> torch.Tensor: + ( + old_dim_padding, + _, + old_dim_full_chunk_size, + new_dim_padding, + _, + new_dim_full_chunk_size, + ) = Shard._compute_padding_info( + current_logical_shape, num_chunks, old_shard_dim, new_shard_dim + ) + + if old_dim_padding: + old_dim_pad_size = Shard._get_shard_pad_size( + old_dim_full_chunk_size, local_tensor, old_shard_dim + ) + local_tensor = pad_tensor(local_tensor, old_shard_dim, old_dim_pad_size) + if new_dim_padding: + new_dim_pad_size = Shard._get_shard_pad_size( + new_dim_full_chunk_size * num_chunks, local_tensor, new_shard_dim + ) + local_tensor = pad_tensor(local_tensor, new_shard_dim, new_dim_pad_size) + + if not local_tensor.is_contiguous(): + local_tensor = local_tensor.contiguous() + return local_tensor + + @staticmethod + @maybe_run_for_local_tensor + def _unpad_for_new_shard_dim( + current_logical_shape: list[int], + local_tensor: torch.Tensor, + num_chunks: int, + old_shard_dim: int, + new_shard_dim: int, + local_rank: int, + ) -> torch.Tensor: + ( + old_dim_padding, + _, + old_dim_full_chunk_size, + new_dim_padding, + new_dim_logical_size, + new_dim_full_chunk_size, + ) = Shard._compute_padding_info( + current_logical_shape, num_chunks, old_shard_dim, new_shard_dim + ) + + if old_dim_padding: + old_dim_unpad_size = ( + old_dim_full_chunk_size * num_chunks + - current_logical_shape[old_shard_dim] # type: ignore[possibly-undefined] + ) + local_tensor = unpad_tensor(local_tensor, old_shard_dim, old_dim_unpad_size) # type: ignore[possibly-undefined] + + if new_dim_padding: + local_shard_size_on_new_dim = Shard.local_shard_size_and_offset( + new_dim_logical_size, num_chunks, local_rank + )[0] + new_dim_unpad_size = new_dim_full_chunk_size - local_shard_size_on_new_dim # type: ignore[possibly-undefined] + local_tensor = unpad_tensor(local_tensor, new_shard_dim, new_dim_unpad_size) # type: ignore[possibly-undefined] + + return local_tensor + + def _to_new_shard_dim( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + current_logical_shape: list[int], + new_shard_dim: int, + ) -> torch.Tensor: + """ + transform from existing sharded tensor to a new sharded tensor on + that shard on a new dimension, which performs an alltoall + """ + my_coordinate = mesh.get_coordinate() + if my_coordinate is None: + # if rank is not part of mesh, we simply return local_tensor, + # which should be an empty tensor + return local_tensor + + num_chunks = mesh.size(mesh_dim=mesh_dim) + + local_tensor = Shard._pad_for_new_shard_dim( + current_logical_shape, local_tensor, num_chunks, self.dim, new_shard_dim + ) + + new_tensor = shard_dim_alltoall( + local_tensor, self.dim, new_shard_dim, mesh, mesh_dim + ) + + new_tensor = Shard._unpad_for_new_shard_dim( + current_logical_shape, + new_tensor, + num_chunks, + self.dim, + new_shard_dim, + my_coordinate[mesh_dim], + ) + + return new_tensor + + def __hash__(self) -> int: + return hash(self.dim) + + def __repr__(self) -> str: + """ + machine readable representation of the Shard placement + """ + return f"Shard(dim={self.dim})" + + def __str__(self) -> str: + """human readable representation of the Shard placement""" + return f"S({self.dim})" + + +class _StridedShard(torch._C._distributed.StridedShard): + """ + _StridedShard is only introduced to support 2D FSDP2 + TP sharding where the tensor + is sharded on the TP mesh dimension first, then sharded on the FSDP mesh dimension. + We call this right-to-left sharding which is the opposite of the default + left-to-right sharding. See the example below: + tensor shape: [8, 8] + mesh: [[0, 1], [2, 3]], names=("dp", "tp") + placements: [Shard(0), Shard(0)] + + The default sharding behavior shards the tensor on "dp" mesh dimension first then + "tp" dimension. The sharding result will be: + Rank | Mesh Coordinate | Shard Index + ------------------------------------------------ + 0 | (0, 0) | 0 (row 0-1) + 1 | (0, 1) | 1 (row 2-3) + 2 | (1, 0) | 2 (row 4-5) + 3 | (1, 1) | 3 (row 6-7) + + While the FSDP2 + TP sharding behavior does the opposite: it shards the tensor on + "tp" mesh dim first then "dp" dim. This right-to-left sharding will produce the + result: + Rank | Mesh Coordinate | Shard Index + ------------------------------------------------ + 0 | (0, 0) | 0 (row 0-1) + 1 | (0, 1) | 2 (row 4-5) + 2 | (1, 0) | 1 (row 2-3) + 3 | (1, 1) | 3 (row 6-7) + + The consequence is, any attempt to redistribute this DTensor to a full replica will + produce a wrong result because the shard-to-replicate redistribution always happens + right-to-left, regardless it's left-to-right sharding or right-to-left. To address + this, we use _StridedShard placement to make this right-to-left sharding compatible + with our left-to-right convention on both tensor distribution and redistribution. + + Now with _StridedShard, the right-to-left sharding above can be represented as: + tensor shape: [8, 8] + mesh: [[0, 1], [2, 3]], names=("dp", "tp") + placements: [_StridedShard(0, split_factor=2), Shard(0)] + + And a left-to-right processing of `placements` will produce the same result, which is + different from using the `Shard` placement: + Rank | Mesh Coordinate | Shard Index + ------------------------------------------------ + 0 | (0, 0) | 0 (row 0-1) + 1 | (0, 1) | 2 (row 4-5) + 2 | (1, 0) | 1 (row 2-3) + 3 | (1, 1) | 3 (row 6-7) + + The argument `split_factor` is the number of existing shards over the tensor sharding + dimension before processing the _StridedShard placement, as if the sharding happened + right-to-left. In the example above, the tensor should first be sharded on the "tp" + dimension into 2 shards before being sharded on the "dp" dimension. Therefore, the + `split_factor` of the _StridedShard placement on "dp" dim is 2. + + TODO: we should remove _StridedShard placement once we can unify it with Shard + """ + + def __hash__(self) -> int: + return hash((self.dim, self.split_factor)) + + def __repr__(self) -> str: + """ + machine readable representation of the _StridedShard placement + """ + return f"_StridedShard(dim={self.dim}, sf={self.split_factor})" + + def __str__(self) -> str: + """human readable representation of the _StridedShard placement""" + return f"_S({self.dim}, {self.split_factor})" + + @staticmethod + @maybe_run_for_local_tensor + def _select_shard(shards: list[torch.Tensor], shard_index) -> torch.Tensor: + return shards[shard_index].clone() + + @classmethod + def _make_shard_tensor( + cls, + dim: int, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + src_data_rank: int | None = 0, + split_factor: int = 1, + ) -> torch.Tensor: + strided_shard_placement = cls(dim=dim, split_factor=split_factor) + return strided_shard_placement._shard_tensor( + tensor, mesh, mesh_dim, src_data_rank + ) + + def _shard_tensor( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + src_data_rank: Optional[int] = 0, + ) -> torch.Tensor: + """ + Shard and scatter a tensor on a mesh dimension (use coordinate 0 on the + mesh dimension as source of truth). + + Create the local tensor for this rank following the given StridedShard + placement. If src_data_rank is None, perform only local splitting. + Otherwise, additionally scatter data from src_data_rank. Unlike + ``_split_tensor``, which supports uneven sharding via padding, this + method requires the tensor dimension to be evenly divisible by the + number of chunks (mesh dimension size). + """ + my_coordinate = mesh.get_coordinate() + num_chunks = mesh.size(mesh_dim=mesh_dim) + + if my_coordinate is None: + # if rank is not part of mesh, we simply return an empty tensor + return tensor.new_empty(0, requires_grad=tensor.requires_grad) + + mesh_dim_local_rank = my_coordinate[mesh_dim] + + if src_data_rank is None: + # src_data_rank specified as None explicitly means to skip the + # communications, simply split + scatter_list, _ = self._split_tensor( + tensor, num_chunks, with_padding=False, contiguous=True + ) + + return self._select_shard(scatter_list, mesh_dim_local_rank) + + scatter_list, pad_sizes = self._split_tensor( + tensor, num_chunks, with_padding=True, contiguous=True + ) + + it = iter(scatter_list) + first = next(it) + # Tensors in the scatter list are expected to have the same shape because + # split is requested with padding. + assert all(first.shape == v.shape for v in it) + + output = torch.empty_like(first) + + # perform scatter from the src_data_rank as data source when it is not None + mesh_scatter( + output, scatter_list, mesh, mesh_dim=mesh_dim, group_src=src_data_rank + ) + + return Shard._maybe_unpad_tensor_with_sizes( + self.dim, output, pad_sizes, mesh_dim_local_rank, True + ) + + def _split_tensor( + self, + tensor: torch.Tensor, + num_chunks: int, + *, + with_padding: bool = True, + contiguous: bool = True, + ) -> tuple[list[torch.Tensor], list[int]]: + assert self.dim <= tensor.ndim, ( + f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" + ) + + # Essentially _StridedShard express the right-to-left sharding in the + # reversed order. Here we perform first_split as the virtual "right" sharding, + # and then second_split as the virtual "left" sharding, and finally assemble + # results in the transposed left-first order. + + # First split: chunk into split_factor pieces + first_split = list(torch.chunk(tensor, self.split_factor, dim=self.dim)) + first_split = fill_empty_tensor_to_shards( + first_split, self.dim, self.split_factor - len(first_split) + ) + + # Second split: chunk each piece into num_chunks pieces + second_split = [] + for s in first_split: + chunks = list(torch.chunk(s, num_chunks, dim=self.dim)) + chunks = fill_empty_tensor_to_shards( + chunks, self.dim, num_chunks - len(chunks) + ) + second_split.append(chunks) + + shard_list: list[torch.Tensor] = [] + for i in range(num_chunks): + shard = torch.cat( + [second_split[j][i] for j in range(self.split_factor)], + dim=self.dim, + ) + if contiguous: + shard = shard.contiguous() + shard_list.append(shard) + + # The amount of padding is determined by the local chunk with the largest size. + pad_sizes: list[int] = [] + max_chunk_size = max([shard.size(self.dim) for shard in shard_list]) + if with_padding: + pad_sizes = [max_chunk_size - shard.size(self.dim) for shard in shard_list] + + return shard_list, pad_sizes + + def _to_replicate_tensor( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + current_logical_shape: list[int], + ) -> torch.Tensor: + """ + replay the replicate-to-shard process to understand how to stitch shards back + """ + num_chunks = mesh.size(mesh_dim=mesh_dim) + logical_dim_size = current_logical_shape[self.dim] + + # indices_tensor is 1D torch.arange(logical_dim_size) unsqueezed + # so that we can reuse self._split_tensor which splits on self.dim + shape = [1] * self.dim + [logical_dim_size] + indices_tensor = torch.arange( + logical_dim_size, device=local_tensor.device + ).view(shape) + + sharded_indices, _ = self._split_tensor( + indices_tensor, + num_chunks, + with_padding=False, + contiguous=False, + ) + # squeeze back to 1D indices tensor + sharded_indices = [shard.view(-1) for shard in sharded_indices] + + max_chunk_size = max([len(shard) for shard in sharded_indices]) + local_pad_size = max_chunk_size - local_tensor.size(self.dim) + local_tensor_padded = pad_tensor(local_tensor, self.dim, local_pad_size) + + if not local_tensor_padded.is_contiguous(): + local_tensor_padded = local_tensor_padded.contiguous() + + replicate_tensor_permuted_padded = funcol.all_gather_tensor( + local_tensor_padded, + gather_dim=self.dim, + group=(mesh, mesh_dim), + ) + if isinstance(replicate_tensor_permuted_padded, funcol.AsyncCollectiveTensor): + replicate_tensor_permuted_padded = replicate_tensor_permuted_padded.wait() + + if replicate_tensor_permuted_padded.shape[self.dim] > logical_dim_size: + replicate_tensor_permuted = unpad_tensor( + replicate_tensor_permuted_padded, + self.dim, + replicate_tensor_permuted_padded.shape[self.dim] - logical_dim_size, + ) + else: + replicate_tensor_permuted = replicate_tensor_permuted_padded + + permutation = torch.cat(sharded_indices) + inv_permutation = torch.argsort(permutation) + replicate_tensor = torch.index_select( + replicate_tensor_permuted, self.dim, inv_permutation + ) + + return replicate_tensor.contiguous() + + @staticmethod + @maybe_run_for_local_tensor + def _local_shard_size(sharded_indices: list[torch.Tensor], rank: int) -> int: + return len(sharded_indices[rank]) + + # delete pyre-ignore once separating _StridedShard from Shard + def _local_shard_size_and_offset( # pyre-ignore[bad-override] + self, + curr_local_size: int, + num_chunks: int, + rank: int, + return_first_offset: bool = True, + ) -> tuple[int, int | list[int]]: + return _StridedShard.local_shard_size_and_offset( + self, curr_local_size, num_chunks, rank, return_first_offset + ) + + @staticmethod + @maybe_run_for_local_tensor + def local_shard_size_and_offset( # pyre-ignore[bad-override] + self, + curr_local_size: int, + num_chunks: int, + rank: int, + return_first_offset: bool = True, + ) -> tuple[int, list[int] | int]: + """ + Compute the local shard size and offset(s) for a _StridedShard placement. + + Unlike the regular Shard placement which produces contiguous offsets, _StridedShard + produces non-contiguous (strided) offsets due to the right-to-left sharding semantics. + This method computes the actual indices that belong to the local shard. + + Args: + self (_StridedShard): The _StridedShard placement instance. + curr_local_size (int): The current size of the tensor dimension to be sharded. + num_chunks (int): Number of chunks to split the dimension into (typically the mesh dimension size). + rank (int): The rank index to compute the shard for. + return_first_offset (bool): If True, return only the first offset as an int. If False, + return all offsets as a list. Defaults to True. + + Returns: + tuple: A tuple containing: + - local_shard_size (int): The number of elements in the local shard for this rank. + - offset (int | list[int]): If return_first_offset is True, returns the first offset + as an int. If False or if the shard size is 0, returns a list of all offsets + (which may be empty for empty shards). + """ + # indices_tensor is 1D torch.arange(logical_dim_size) unsqueezed + # so that we can reuse self._split_tensor which splits on self.dim + shape = [1] * self.dim + [curr_local_size] + indices_tensor = torch.arange( + curr_local_size, + ).view(shape) + + sharded_indices, _ = self._split_tensor( + indices_tensor, + num_chunks, + with_padding=False, + contiguous=False, + ) + # squeeze back to 1D indices tensor + sharded_indices = [shard.view(-1) for shard in sharded_indices] + + local_shard_size = _StridedShard._local_shard_size(sharded_indices, rank) + if local_shard_size > 0: + offsets = sharded_indices[rank].tolist() + else: + offsets = [] + + if return_first_offset: + # Always return an int for consistency across ranks. + # For empty shards, return -1 as an invalid offset indicator. + offsets = offsets[0] if len(offsets) > 0 else -1 + + return local_shard_size, offsets + + +class Replicate(torch._C._distributed.Replicate): + """ + The ``Replicate()`` placement describes the DTensor replicating on a corresponding + ``DeviceMesh`` dimension, where each rank on the DeviceMesh dimension holds a + replica of the global Tensor. The ``Replicate`` placement can be used by all + DTensor APIs (i.e. ``distribute_tensor``, ``DTensor.from_local``, etc.) + """ + + def __hash__(self) -> int: + # every replicate placement is the same + return -1 + + def __repr__(self) -> str: + """ + machine readable representation of the Replicate placement + """ + return "Replicate()" + + def __str__(self) -> str: + """ + human readable representation of the Replicate placement + """ + return "R" + + @classmethod + def _make_replicate_tensor( + cls, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + src_data_rank: int | None = 0, + ) -> torch.Tensor: + """ + Replicate (broadcast) a torch.Tensor on a mesh dimension (use + the first coordinate on the mesh dimension as source of truth) + """ + my_coordinate = mesh.get_coordinate() + if my_coordinate is None: + # if rank is not part of mesh, we simply return an empty tensor + return tensor.new_empty(0, requires_grad=tensor.requires_grad) + + tensor = tensor.contiguous() + + if src_data_rank is not None: + # perform broadcast from the src_data_rank as data source when it is not None + mesh_broadcast(tensor, mesh, mesh_dim=mesh_dim, group_src=src_data_rank) + return tensor + + def _replicate_tensor( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + src_data_rank: int | None = 0, + ) -> torch.Tensor: + return Replicate._make_replicate_tensor(tensor, mesh, mesh_dim, src_data_rank) + + +class Partial(torch._C._distributed.Partial): + """ + The ``Partial(reduce_op)`` placement describes the DTensor that is pending + reduction on a specified ``DeviceMesh`` dimension, where each rank on the + DeviceMesh dimension holds the partial value of the global Tensor. User can + redistribute the ``Partial`` DTensor to a ``Replicate`` or ``Shard(dim)`` + placement on the specified ``DeviceMesh`` dimension using ``redistribute``, + which would trigger necessary communication operations under the hood (i.e. + ``allreduce``, ``reduce_scatter``). + + Args: + reduce_op (str, optional): The reduction op to be used for the partial DTensor + to produce Replicated/Sharded DTensor. Only element-wise reduction operations + are supported, including: "sum", "avg", "product", "max", "min", default: "sum". + + .. note:: The ``Partial`` placement can be generated as a result of the DTensor operators, + and can only be used by the ``DTensor.from_local`` API. + """ + + def _reduce_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # Partial placement contract #1: + # _reduce_value: reduce the value of the tensor on the mesh dimension + return funcol.all_reduce( + tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim) + ) + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + # Partial placement contract #2: + # _reduce_shard_value: reduce_scatter the value of the tensor over the mesh dimension + shard_spec = cast(Shard, shard_spec) + return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim) + + def _partition_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # Partial placement contract #3: + # _partition_value: partition the value of a replicated tensor on the mesh dimension + + # _partition_value is the conjugate operation of _reduce_value, e.g. + # - _partition_value on a sum reduce op is just a division operation + # - _reduce_value on a sum reduce op would just be a sum(allreduce) operation + num_chunks = mesh.size(mesh_dim=mesh_dim) + if self.reduce_op == "sum": + return tensor / num_chunks + elif self.reduce_op in ("avg", "min", "max"): + return tensor + else: + raise ValueError( + f"Replicate to Partial({self.reduce_op}) conversion is not supported." + ) + + def __hash__(self) -> int: + return 1 + hash(self.reduce_op) + + def __repr__(self) -> str: + """ + machine readable representation of the Partial placement + """ + return f"Partial({self.reduce_op})" + + def __str__(self) -> str: + """ + human readable representation of the Partial placement + """ + return f"P({self.reduce_op})" + + +# We keep the old _Partial name for a while for BC reason +_Partial = Partial + + +@dataclass(frozen=True) +class MaskPartial(Partial): + """ + A partial mask placement devised for rowwise sharded embedding op, where we need + to mask and adjust the indices to the local embedding shard, embedding masking + is a special type of the Partial placement + + NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor + lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor. + """ + + mask_buffer: MaskBuffer = field(default_factory=MaskBuffer) + + # required fields for computing the local offset and deriving the mask + offset_shape: torch.Size | None = None + offset_dim: int = 0 + + def __init__( + self, + reduce_op=None, + mask_buffer=None, + offset_shape=None, + offset_dim=0, + *args, + **kwargs, + ): + super().__init__(reduce_op) + if mask_buffer is None: + mask_buffer = MaskBuffer() + object.__setattr__(self, "mask_buffer", mask_buffer) + object.__setattr__(self, "offset_shape", offset_shape) + object.__setattr__(self, "offset_dim", offset_dim) + + @staticmethod + @maybe_run_for_local_tensor + def _mask_tensor( + tensor: torch.Tensor, local_offset_on_dim: int, local_shard_size: int + ) -> tuple[torch.Tensor, torch.Tensor]: + # Build the input mask and save it for the current partial placement + # this is so that the output of embedding op can reuse the same partial + # placement saved mask to perform mask + reduction + mask = (tensor < local_offset_on_dim) | ( + tensor >= local_offset_on_dim + local_shard_size + ) + # mask the input tensor + masked_tensor = tensor.clone() - local_offset_on_dim + masked_tensor[mask] = 0 + return mask, masked_tensor + + def _partition_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + my_coordinate = mesh.get_coordinate() + assert my_coordinate is not None, "my_coordinate should not be None" + # override parent logic to perform partial mask for embedding + num_chunks = mesh.size(mesh_dim) + # get local shard size and offset on the embedding_dim + assert self.offset_shape is not None, ( + "offset_shape needs to be set for MaskPartial" + ) + local_shard_size, local_offset_on_dim = Shard.local_shard_size_and_offset( + self.offset_shape[self.offset_dim], + num_chunks, + my_coordinate[mesh_dim], + ) + mask, masked_tensor = MaskPartial._mask_tensor( + tensor, local_offset_on_dim, local_shard_size + ) + # materialize the mask buffer to be used for reduction + self.mask_buffer.materialize_mask(mask) + return masked_tensor + + def _reduce_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # by the time we need reduction, we should have already saved the mask + assert self.mask_buffer.data is not None + + # apply the mask to the tensor that pending reduction + self.mask_buffer.apply_mask(tensor) + + # clear the mask buffer + self.mask_buffer.release_mask() + + # perform sum reduction + return funcol.all_reduce( + tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim) + ) + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + # by the time we need reduction, we should have already saved the mask + assert self.mask_buffer.data is not None + + # apply the mask to the tensor that pending reduction + self.mask_buffer.apply_mask(tensor) + + # clear the mask buffer + self.mask_buffer.release_mask() + + # call reduce_shard_tensor of the shard_spec. + shard_spec = cast(Shard, shard_spec) + return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, MaskPartial): + return False + + # if either data is not None, we invalidate the sharding cache, as this indicates + # the current MaskPartial placement is still in use and should not be used for cache hit. + if self.mask_buffer.data is not None or other.mask_buffer.data is not None: + return False + + return ( + self.reduce_op == other.reduce_op + and self.offset_shape == other.offset_shape + and self.offset_dim == other.offset_dim + ) + + def __hash__(self) -> int: + return 1 + hash( + ( + self.reduce_op, + self.offset_shape, + self.offset_dim, + ) + ) + + def __repr__(self) -> str: + """ + machine readable representation of the MaskPartial placement + """ + return f"MaskPartial(reduce_op={self.reduce_op}, offset_shape={self.offset_shape}, offset_dim={self.offset_dim})" + + def __str__(self) -> str: + """ + human readable representation of the MaskPartial placement + """ + return f"MaskP({self.reduce_op}, {self.offset_shape}, {self.offset_dim})" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74d1072ed9b53e5fae9289a1a9e6f2326eb1149b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_draft_export.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_draft_export.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a6ab0f741645dbe1a6b107382bad4263aa4ac01 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_draft_export.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_leakage_detection_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_leakage_detection_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..316263f16ec9b6b7344dc9d92b62e293f0e09c46 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_leakage_detection_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..534b5468d4799a90ff291d7e8ae48eda38f0123c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_remove_effect_tokens_pass.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_remove_effect_tokens_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..405cff31ddae53f3f439515098e1867db7d7c3d1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_remove_effect_tokens_pass.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_safeguard.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_safeguard.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..463fb1489b511a83f0b7fe70f63742f316f82e89 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_safeguard.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_swap.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_swap.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dc087d44a6f8d199f2656259c37105da2b7c6ab Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_swap.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_trace.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_trace.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0752bdb0cad63ab4420532fa9a650c5eb92c757e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_trace.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_tree_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_tree_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d3549f9e392611394cfae13f78e7f8f423b2326 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_tree_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_unlift.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_unlift.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f322d95144b97b7270be53b83ee7068dcafae9f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_unlift.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_wrapper_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_wrapper_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..645aa338b8033ceb15c4a8c2399f9641c50da890 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/_wrapper_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/custom_obj.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/custom_obj.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..720b457bf5576a672e4a1b78d6bfb6046fca4557 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/custom_obj.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/custom_ops.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/custom_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06ade5eb19bece7b9420daf1cac24c0db5999a9d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/custom_ops.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/decomp_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/decomp_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44a0f5fbe9b9d2846145d0b2893e7fb092005eb1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/decomp_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/dynamic_shapes.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/dynamic_shapes.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f18610614f35eb95c1fd3599d0cced44f5fd2ac Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/dynamic_shapes.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/exported_program.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/exported_program.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9542dcc5b1c82dcf4d46a41108004bedecc731e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/exported_program.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/graph_signature.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/graph_signature.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d11f93d0089ee27dfcfd61884bb3eaa7dc6ca6b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/graph_signature.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/unflatten.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/unflatten.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02f06dbb8944c09965f1d3f0d7a1a8be3104e117 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/__pycache__/unflatten.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/experimental/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..14399a7bfdadd7d7a35781892dd60e8809a6d5b7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/experimental/__init__.py @@ -0,0 +1,430 @@ +import copy +import dataclasses +import functools +import os +import types +import typing +import typing_extensions +import zipfile +from pathlib import Path + +import torch +from torch.export.experimental._utils import _get_main_cpp_file, _get_make_file +from torch.export.exported_program import _decompose_exported_program + + +_InputT = typing_extensions.ParamSpec("_InputT") +_RetT = typing.TypeVar("_RetT") + + +__all__ = [] # type: ignore[var-annotated] + + +def _copy_graph_module_and_signature( + ep: torch.export.ExportedProgram, +) -> tuple[torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature]: + # copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(), + # and this can break placeholder names in some particular cases. + # For example, node copying will avoid Python keywords like 'input', suffixing and renaming to 'input_1'. + # So we manually overwrite placeholder names by reading the old graph. + gm = copy.deepcopy(ep.graph_module) + new_graph_signature = copy.deepcopy(ep.graph_signature) + + # iterate over old/new graph modules + for old_gm, new_gm in zip(ep.graph_module.modules(), gm.modules()): # type: ignore[union-attr] + old_phs = [node for node in old_gm.graph.nodes if node.op == "placeholder"] + new_phs = [node for node in new_gm.graph.nodes if node.op == "placeholder"] + # iterate over placeholders + assert len(old_phs) == len(new_phs) + for old_node, new_node in zip(old_phs, new_phs): + new_node.name = old_node.name + + return gm, new_graph_signature + + +def _remove_detach_pass( + gm: torch.fx.GraphModule, sig: torch.export.graph_signature.ExportGraphSignature +) -> None: + with gm._set_replace_hook(sig.get_replace_hook()): + for node in list(reversed(gm.graph.nodes)): + if node.op != "call_function": + continue + if ( + node.target is torch.ops.aten.detach.default + and len(node.users) == 1 + and next(iter(node.users)).target is torch.ops.aten.detach.default + ): + next(iter(node.users)).replace_all_uses_with(node) + + gm.graph.eliminate_dead_code() + gm.recompile() + + +def _export_forward_backward( + ep: torch.export.ExportedProgram, joint_loss_index: int = 0 +) -> torch.export.ExportedProgram: + """ + WARNING: This API is highly unstable and will be subject to change in the future. + """ + from torch._decomp import core_aten_decompositions + + ep = _decompose_exported_program( + ep, + cia_to_decomp={}, + python_decomp_table=core_aten_decompositions(), + joint_loss_index=joint_loss_index, + # For serialization purpose, we don't want to decompose custom triton ops. + # If users would like to decompose custom triton ops, they could do it + # with run_decompositions() API. + decompose_custom_triton_ops=False, + ) + gm, new_graph_signature = _copy_graph_module_and_signature(ep) + _remove_detach_pass(gm, new_graph_signature) + + return ep._update(gm, new_graph_signature) + + +def _sticky_export( + forward_func: typing.Callable[_InputT, _RetT], + dynamic_shapes_callback: typing.Callable[ + _InputT, list[typing.Any] | dict[str, typing.Any] | tuple[typing.Any, ...] + ] + | None = None, +) -> typing.Callable[_InputT, _RetT]: + """ + Lazily export the model on first forward call. + Usage: + model.forward = _sticky_export(model.forward, dynamic_shapes_callback=callback) + """ + model = forward_func.__self__ # type: ignore[attr-defined] + original_forward = forward_func.__func__ # type: ignore[attr-defined] + + @functools.wraps(forward_func) + def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: + # Unpatch forward to avoid recursion during export + model.forward = types.MethodType(original_forward, model) + + dynamic_shapes_spec = None + if dynamic_shapes_callback: + dynamic_shapes_spec = dynamic_shapes_callback(*args, **kwargs) + + try: + exported = torch.export.export( + model, + args, + kwargs, + dynamic_shapes=dynamic_shapes_spec, + ).module() + wrapper._exported_artifact = exported # type: ignore[attr-defined] + finally: + # Restore the wrapper after export + model.forward = wrapper + + return exported(*args, **kwargs) + + return wrapper + + +@dataclasses.dataclass +class _ExportMethod: + overloads: dict[str, torch.export.ExportedProgram] + fallbacks: list[torch.export.ExportedProgram] + + +class _ExportPackage: + """ + An export package is a collection of torch.export()-ed PyTorch models consisting of + a list of exported methods and their corresponding overloads. ExportPackage is introduced + on top of torch.export() to support the following use cases: + - Exporting a model with multiple methods if a model has multiple independent parts. + - Exporting a function with multiple overloads based on tensor shapes or other metadata. + + ExportPackage is designed to contain multiple methods (associated with method names) and for + each method, it can have multiple overloads (associated with overload names). + + Here is an example of the data structure for an ExportPackage: + ``` + ExportPackage( + methods={ + "decoder": ExportMethod( + overloads={ + "prefill": ExportedProgram(...), + "decode": ExportedProgram(...), + }, + fallbacks=[], + ), + "encoder": ExportMethod(overloads={}, fallbacks=[ExportedProgram(...)]), + }, + ) + ``` + + To export a model into an ExportPackage, users can use the exporter API provided by ExportPackage. + Exporter is a decorator that takes a callable and returns a wrapper. The wrapper will export the + function into an ExportPackage, when it's invoked with some sample inputs (similar to how + torch.compile() works). For more details, please refer to the document on .exporter() method. + + This design allows users to decouple the exported callables from the actual sample inputs which can + be helpful for use cases where the exported callable is hidden behind helper functions or when sample + inpusts are hard to get. + + NOTE: This is an experimental API and anything can be changed in the future. + + Example usage: + ``` + def fn(x): + return x + 1 + + def main(f, x): + x += 1 + ret = f(x) + return ret + 1 + + package = ExportPackage() + main(package.exporter(fn), torch.randn(3, 2)) + ``` + + """ + + def __init__(self) -> None: + self.methods: dict[str, _ExportMethod] = {} + + def _exporter( + self, + method: str, + fn: typing.Callable[_InputT, _RetT], + *, + fallback: str = "once", + ) -> typing.Callable[_InputT, _RetT]: + """ + A function/module decorator that sets up a callable to be exported later invoked. + By default the exporter will only trigger torch.export for once and error on + later invocations. To customize this behavior, users have the following two options: + 1. Call .define_overload() method on the returned wrapper to define an overload. + 2. Adjust the fallback policy using `fallback` argument. + + An "overload" is a named branch for an ExportMethod with a user defined precondition, + typically based on input tensor shapes. It's up to a downstream backend implementation + of ExportMethod to respect the precondition later in inference. + + define_overload() takes arguments like the following: + - A name, for indexing purposes in a backend. + - A callable (spec) that: + - Has the same model input signature as the original model code. + - Returns an optional dynamic shape spec. + + Exporter will only export an overload when the spec callable successfully returns + a result without raising AssertionError. + + For example: + ``` + package = ExportPackage() + + + def prefill(x, xa, kv_cache): + assert x.shape[1] == 3 + assert kv_cache == {} + + + def decode(x, xa, kv_cache): + assert x.shape[1] > 1 + assert len(kv_cache) > 0 + return {...} # dynamic shape specs here + + + exporter = ( + package.exporter(decoder) + .define_overload("prefill", prefill) + .define_overload("decode", decode) + ) + ``` + + A "fallback" is exported when no overload precondition matches a given set of sample + inputs. Overloads should + Fallbacks don't have names and are ordered in a list. It's up to a backend to decide + which fallback is used amony multiple ones. + + A reference backend implementation of ExportMethod may look like the following: + ``` + def execute(method: ExportMethod, *args, **kwargs): + for overload in method.overloads: + if match_precondition(overload, *args, **kwargs): + return execute_overload(overload, *args, **kwargs) + for fallback in method.fallbacks: + if match_precondition(fallback, *args, **kwargs): + return execute_fallback(fallback, *args, **kwargs) + ``` + + Args: + method(str): The method name for an exported part of PyTorch model. This + will be saved together with the exported/compiled artifacts + in any serialization format and can be used as the key to + index ExportPackage methods later. + fn(callable): A PyTorch function/module to be exported. + fallback(str): The fallback policy to decide when to call torch.export + - "once" is the default policy. Under this policy a PyTorch program is assumed + to be only called once later and an error will be raised for subsequent + runs. + - "error" means the ExportMethod will never have any fallbacks, meaning + users should define all the possible overloads ahead of time. + + """ + + fallbacks: list[torch.export.ExportedProgram] = [] + specs: dict[str, typing.Callable[_InputT, typing.Any]] = {} + overloads: dict[str, torch.export.ExportedProgram] = {} + self.methods[method] = _ExportMethod(fallbacks=fallbacks, overloads=overloads) + + @functools.wraps(fn) + def _exporter_context(*args, **kwargs): # type: ignore[no-untyped-def] + import torch.export._wrapper_utils + + model: torch.nn.Module + if not isinstance(fn, torch.nn.Module): + model = torch.export._wrapper_utils._WrapperModule(fn) + else: + model = fn + + for k, v in specs.items(): + try: + if isinstance(fn, torch.nn.Module): + dynamic_shapes = v(fn, *args, **kwargs) # type: ignore[arg-type] + else: + # pyrefly: ignore [invalid-param-spec] + dynamic_shapes = v(*args, **kwargs) + except AssertionError: + continue + if k not in overloads: + ep = torch.export.export( + model, args, kwargs, dynamic_shapes=dynamic_shapes + ) + overloads[k] = ep + ep = overloads[k] + return ep.module()(*args, **kwargs) + + if fallback == "error": + raise RuntimeError( + f"Exporter: Cannot export fallback {fn} when fallback policy is set to 'error'," + + "please specify an overload or adjust the fallback policy." + ) + elif fallback == "once": + if len(fallbacks) > 0: + raise RuntimeError( + f"Exporter: Cannot export {fn} more than once, " + + "please specify an overload or adjust the fallback policy." + ) + else: + raise RuntimeError(f"Unknown fallback policy: {fallback}") + ep = torch.export.export(model, args, kwargs) + + fallbacks.append(ep) + return ep.module()(*args, **kwargs) + + if isinstance(fn, torch.nn.Module): + _exporter_context = torch._dynamo.eval_frame.OptimizedModule( # type: ignore[assignment] # noqa: F811 + fn, + lambda _: _exporter_context, # type: ignore[arg-type] + ) + + def _define_overload( + overload: str, spec: typing.Callable[_InputT, typing.Any] + ) -> typing.Any: + assert overload not in specs + assert callable(spec) + assert overload.isidentifier() + specs[overload] = spec + return _exporter_context + + assert not hasattr(fn, "_define_overload") + _exporter_context._define_overload = _define_overload # type: ignore[attr-defined] + + # pyrefly: ignore [bad-return] + return _exporter_context + + @property + def _method_overloads( + self, + ) -> typing.Iterator[tuple[str, torch.export.ExportedProgram]]: + for method, method_data in self.methods.items(): + for overload, ep in method_data.overloads.items(): + yield f"{method}:{overload}", ep + + def _compiled_and_package( + self, + f: torch.types.FileLike, + standalone: bool = False, + package_example_inputs: bool = False, + ) -> None: + options: dict[str, typing.Any] = { + "aot_inductor.package": True, + "aot_inductor.package_cpp_only": True, + "always_keep_tensor_constants": True, + # we'll change this back to False once we enable weight deduping for standalone mode + "aot_inductor.package_constants_in_so": standalone, + "aot_inductor_mode.compile_standalone": standalone, + } + aoti_files_map = {} + model_names = [] + for name, ep in self._method_overloads: + name = name.replace(":", "__") + model_names.append(name) + options["aot_inductor.model_name_for_generated_files"] = name + aoti_files = torch._inductor.aot_compile( + ep.module(), # type: ignore[arg-type] + ep.example_inputs[0], + kwargs=ep.example_inputs[1], + options=options, + ) + # pyrefly: ignore [unsupported-operation] + aoti_files_map[name] = aoti_files + + from torch._inductor.package import package + + pt2_path = package.package_aoti( + f, + aoti_files_map, # type: ignore[arg-type] + ) + + if not standalone: + return + + assert isinstance(pt2_path, str) + base_directory = os.path.dirname(pt2_path) + package_name = os.path.basename(pt2_path)[:-4] + with ( + zipfile.ZipFile(pt2_path, "r") as zip_ref, + ): + zip_ref.extractall(base_directory) + + example_inputs_map: dict[str, int] | None = ( + {} if package_example_inputs else None + ) + use_cuda = False + for name, ep in self._method_overloads: + name = name.replace(":", "__") + # TODO: also dump kwargs + # TODO: currently only support list of Tensors and they need to be on the same device + if not ep.example_inputs: + continue + for inp in ep.example_inputs[0]: + if isinstance(inp, torch.Tensor) and inp.device.type == "cuda": + # TODO: more carefully determine the device type + use_cuda = True + if package_example_inputs: + assert example_inputs_map is not None + example_inputs_map[name] = len(ep.example_inputs[0]) + for i, t in enumerate(ep.example_inputs[0]): + path = Path(base_directory) / f"{name}_input_{i}.pt" + torch.save(t, path) + + # Detect if ROCm is being used + is_hip = torch.version.hip is not None + cmake_file_str = _get_make_file(package_name, model_names, use_cuda, is_hip) + + with open(Path(base_directory) / "CMakeLists.txt", "w") as file: + file.write(cmake_file_str) + + main_file_str = _get_main_cpp_file( + package_name, model_names, use_cuda, example_inputs_map, is_hip + ) + with open(Path(base_directory) / "main.cpp", "w") as file: + file.write(main_file_str) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/experimental/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/experimental/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39f325bdf35d8753d72fc9d1bd534b6b8171161a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/experimental/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/experimental/__pycache__/_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/experimental/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5aaad6e3bba4900a03fe64992783e6415b774ee Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/experimental/__pycache__/_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/experimental/_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/experimental/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1005effe2f299a2bd33ac0517e24b46d840bf675 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/experimental/_utils.py @@ -0,0 +1,238 @@ +import logging + +from torch._inductor.utils import IndentedBuffer + + +__all__ = [] # type: ignore[var-annotated] +logger = logging.getLogger(__name__) + + +def _get_main_cpp_file( + package_name: str, + model_names: list[str], + cuda: bool, + example_inputs_map: dict[str, int] | None, + is_hip: bool, +) -> str: + """ + Generates a main.cpp file for AOTInductor standalone models in the specified package. + + Args: + package_name (str): Name of the package containing the models. + model_names (List[str]): List of model names to include in the generated main.cpp. + cuda (bool): Whether to generate code with CUDA support. + example_inputs_map (Optional[Dict[str, List[Tensor]]]): A mapping from model name to + its list of example input tensors. If provided, the generated main.cpp will + load and run these inputs. + + Returns: + str: The contents of the generated main.cpp file as a string. + """ + + ib = IndentedBuffer() + + ib.writelines( + [ + "#include ", + "#include ", + "#include ", + "#include ", + "#include ", + "#include ", + "#include ", + ] + ) + if cuda: + if is_hip: + ib.writelines( + [ + "#include ", + ] + ) + + else: + ib.writelines( + [ + "#include ", + "#include ", + ] + ) + + for model_name in model_names: + ib.writeline( + f'#include "{package_name}/data/aotinductor/{model_name}/{model_name}.h"' + ) + + ib.newline() + for model_name in model_names: + ib.writeline(f"using torch::aot_inductor::AOTInductorModel{model_name};") + + ib.writelines( + [ + "using torch::aot_inductor::ConstantHandle;", + "using torch::aot_inductor::ConstantMap;", + "", + "int main(int argc, char* argv[]) {", + ] + ) + + with ib.indent(): + ib.writeline(f'std::string device_str = "{"cuda" if cuda else "cpu"}";') + ib.writeline("try {") + + with ib.indent(): + ib.writeline("c10::Device device(device_str);") + + if example_inputs_map is not None: + # TODO: add device + for i, model_name in enumerate(model_names): + num_inputs = example_inputs_map[model_name] + + ib.writeline(f"// Load input tensors for model {model_name}") + ib.writeline(f"std::vector input_tensors{i + 1};") + ib.writeline(f"for (int j = 0; j < {num_inputs}; ++j) {{") + with ib.indent(): + ib.writeline( + f'std::string filename = "{model_name}_input_" + std::to_string(j) + ".pt";' + ) + ib.writeline("std::ifstream in(filename, std::ios::binary);") + ib.writeline("if (!in.is_open()) {") + with ib.indent(): + ib.writeline( + 'std::cerr << "Failed to open file: " << filename << std::endl;' + ) + ib.writeline("return 1;") + ib.writeline("}") + ib.writeline( + "std::vector buffer((std::istreambuf_iterator(in)), std::istreambuf_iterator());" + ) + ib.writeline( + "torch::IValue ivalue = torch::pickle_load(buffer);" + ) + ib.writeline( + f"input_tensors{i + 1}.push_back(ivalue.toTensor().to(device));" + ) + ib.writeline("}") + ib.newline() + + ib.newline() + ib.writeline("\n// Create array of input handles") + for i in range(len(model_names)): + ib.writelines( + [ + f"auto input_handles{i + 1} =", + f" torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(input_tensors{i + 1});", + ] + ) + + ib.writeline("\n// Create array for output handles") + for i in range(len(model_names)): + ib.writeline(f"AtenTensorHandle output_handle{i + 1};") + + ib.writeline("\n// Create and load models") + for i, model_name in enumerate(model_names): + ib.writelines( + [ + f"auto constants_map{i + 1} = std::make_shared();", + f"auto constants_array{i + 1} = std::make_shared>();", + f"auto model{i + 1} = std::make_unique(", + f" std::move(constants_map{i + 1}),", + f" std::move(constants_array{i + 1}),", + " device_str,", + f' "{package_name}/data/aotinductor/{model_name}/");', + f"model{i + 1}->load_constants();", + ] + ) + + if example_inputs_map is not None: + ib.writeline("\n// Run the models") + for i in range(len(model_names)): + ib.writeline( + f"torch::aot_inductor::DeviceStreamType stream{i + 1} = nullptr;" + ) + ib.writeline( + f"model{i + 1}->run(&input_handles{i + 1}[0], &output_handle{i + 1}, stream{i + 1}, nullptr);" + ) + + ib.writeline("\n// Convert output handles to tensors") + for i in range(len(model_names)): + ib.writelines( + [ + f"auto output_tensor{i + 1} =", + f" torch::aot_inductor::alloc_tensors_by_stealing_from_handles(&output_handle{i + 1}, 1);", + ] + ) + + ib.writeline("\n// Validate outputs") + for i in range(len(model_names)): + ib.writeline( + f"""std::cout << "output_tensor{i + 1}\\n" << output_tensor{i + 1} << std::endl;""" + ) + ib.writeline( + f"""torch::save(output_tensor{i + 1}, "output_tensor{i + 1}.pt");""" + ) + + ib.writeline("return 0;") + + ib.writelines( + [ + "} catch (const std::exception &e) {", + ] + ) + with ib.indent(): + ib.writeline('std::cerr << "Error: " << e.what() << std::endl;') + ib.writeline("return 1;") + + ib.writeline("}") + ib.writeline("}") + + return ib.getvalue() + + +def _get_make_file( + package_name: str, model_names: list[str], cuda: bool, is_hip: bool +) -> str: + ib = IndentedBuffer() + + ib.writelines( + [ + "cmake_minimum_required(VERSION 3.10)", + "project(TestProject)", + "", + "set(CMAKE_CXX_STANDARD 17)", + "", + ] + ) + + from torch._inductor.config import test_configs + + if test_configs.use_libtorch: + ib.writeline("find_package(Torch REQUIRED)") + + if cuda: + if is_hip: + ib.writeline("find_package(hip REQUIRED)") + else: + ib.writeline("find_package(CUDA REQUIRED)") + + ib.newline() + for model_name in model_names: + ib.writeline(f"add_subdirectory({package_name}/data/aotinductor/{model_name}/)") + + ib.writeline("\nadd_executable(main main.cpp)") + if cuda: + if is_hip: + ib.writeline("target_compile_definitions(main PRIVATE USE_HIP)") + else: + ib.writeline("target_compile_definitions(main PRIVATE USE_CUDA)") + + model_libs = " ".join(model_names) + ib.writeline(f"target_link_libraries(main PRIVATE torch {model_libs})") + + if cuda: + if is_hip: + ib.writeline("target_link_libraries(main PRIVATE hip::host)") + else: + ib.writeline("target_link_libraries(main PRIVATE cuda ${CUDA_LIBRARIES})") + + return ib.getvalue() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/passes/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5634bd4eadb7a80ddb7521ec0dae26fb2cfec5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/passes/__init__.py @@ -0,0 +1,97 @@ +from typing import Union + +import torch +import torch.utils._pytree as pytree +from torch.export.exported_program import ExportedProgram + + +__all__ = ["move_to_device_pass"] + + +def move_to_device_pass( + ep: ExportedProgram, location: torch.device | str | dict[str, str] +) -> ExportedProgram: + """ + Move the exported program to the given device. + + Args: + ep (ExportedProgram): The exported program to move. + location (Union[torch.device, str, Dict[str, str]]): The device to move the exported program to. + If a string, it is interpreted as a device name. + If a dict, it is interpreted as a mapping from + the existing device to the intended one + + Returns: + ExportedProgram: The moved exported program. + """ + + def _get_new_device( + curr_device: torch.device, + location: torch.device | str | dict[str, str], + ) -> str: + if isinstance(location, dict): + if str(curr_device) in location: + return location[str(curr_device)] + else: + return str(curr_device) + else: + return str(location) + + # move all the state_dict + for k, v in ep.state_dict.items(): + if isinstance(v, torch.nn.Parameter): + ep._state_dict[k] = torch.nn.Parameter( + v.to(_get_new_device(v.device, location)), + v.requires_grad, + ) + else: + ep._state_dict[k] = v.to(_get_new_device(v.device, location)) + + # move all the constants + for k, v in ep.constants.items(): + if isinstance(v, torch.Tensor): + ep._constants[k] = v.to(_get_new_device(v.device, location)) + + # move example_inputs if they exist + if ep.example_inputs is not None: + args, kwargs = ep.example_inputs + moved_args = pytree.tree_map_only( + torch.Tensor, + lambda tensor: tensor.to(_get_new_device(tensor.device, location)), + args, + ) + moved_kwargs = pytree.tree_map_only( + torch.Tensor, + lambda tensor: tensor.to(_get_new_device(tensor.device, location)), + kwargs, + ) + ep._example_inputs = (moved_args, moved_kwargs) + + for m in ep.graph_module.modules(): + if isinstance(m, torch.fx.GraphModule): + for node in m.graph.nodes: + # move all the nodes kwargs with burnt-in device + if "device" in node.kwargs: + kwargs = node.kwargs.copy() + kwargs["device"] = _get_new_device(kwargs["device"], location) + node.kwargs = kwargs + + if ( + node.op == "call_function" + and node.target is torch.ops.aten.to.device + ): + args = list(node.args) + # pyrefly: ignore [unsupported-operation] + args[1] = _get_new_device(args[1], location) + node.args = tuple(args) + + # move all the tensor metadata + node.meta["val"] = pytree.tree_map( + lambda v: v.to(_get_new_device(v.device, location)) + if isinstance(v, torch.Tensor) + else v, + node.meta.get("val"), + ) + + ep.validate() + return ep diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/passes/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/passes/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f97532144bbfcb42259cf2ceeebd2b398f893ae6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/passes/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b2bf26a275d9eef91f4b6807ac472b2cd0c30b0f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/__init__.py @@ -0,0 +1,4 @@ +from ._package import is_pt2_package, PT2ArchiveReader, PT2ArchiveWriter + + +__all__ = ["PT2ArchiveWriter", "PT2ArchiveReader", "is_pt2_package"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d600448fcf5a89c7478fc7c33ebfc2c5e5156114 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/_package.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/_package.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3058f446c894ef3edabb865b93f75a6e8cd6d5d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/_package.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/_package_weights.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/_package_weights.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d51e4937f092999e1372992f13ad158678664b84 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/_package_weights.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/constants.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/constants.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61035f516b6a81baee4162b625a7d73f8db4a59b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/constants.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/_package.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/_package.py new file mode 100644 index 0000000000000000000000000000000000000000..1b46db0958d28b37a602686b34e400c17cecacb3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/_package.py @@ -0,0 +1,1204 @@ +import glob +import io +import json +import logging +import os +import tempfile +import zipfile +from dataclasses import dataclass +from typing import Any, IO, TYPE_CHECKING, TypeAlias + +import torch +import torch.utils._pytree as pytree +from torch._export.serde import schema +from torch._export.serde.serialize import ( + _dataclass_to_dict, + _dict_to_dataclass, + deserialize_device, + deserialize_scalar_type, + deserialize_size, + deserialize_storage_offset, + deserialize_stride, + ExportedProgramDeserializer, + serialize, + serialize_tensor_meta, + SerializedArtifact, +) +from torch._inductor.cpp_builder import normalize_path_separator +from torch._subclasses.fake_tensor import FakeTensor +from torch.export import ExportedProgram +from torch.export._tree_utils import reorder_kwargs +from torch.export.pt2_archive._package_weights import ( + get_complete, + group_weights, + TensorProperties, + Weights, +) +from torch.export.pt2_archive.constants import ( + AOTINDUCTOR_DIR, + ARCHIVE_FORMAT_PATH, + ARCHIVE_FORMAT_VALUE, + ARCHIVE_VERSION_PATH, + ARCHIVE_VERSION_VALUE, + CONSTANTS_CONFIG_FILENAME_FORMAT, + CONSTANTS_DIR, + CUSTOM_OBJ_FILENAME_PREFIX, + EXECUTORCH_DIR, + EXTRA_DIR, + MODELS_DIR, + MODELS_FILENAME_FORMAT, + SAMPLE_INPUTS_FILENAME_FORMAT, + TENSOR_CONSTANT_FILENAME_PREFIX, + WEIGHT_FILENAME_PREFIX, + WEIGHTS_CONFIG_FILENAME_FORMAT, + WEIGHTS_DIR, +) +from torch.types import FileLike + + +if TYPE_CHECKING: + from torch.utils._ordered_set import OrderedSet + + +DEFAULT_PICKLE_PROTOCOL = 2 +AOTI_FILES: TypeAlias = list[str | Weights] | dict[str, list[str | Weights]] + + +logger: logging.Logger = logging.getLogger(__name__) + + +def is_pt2_package(serialized_model: bytes | str) -> bool: + """ + Check if the serialized model is a PT2 Archive package. + """ + try: + with zipfile.ZipFile( + io.BytesIO(serialized_model) + if isinstance(serialized_model, bytes) + else serialized_model + ) as zip_reader: + root_folder = zip_reader.namelist()[0].split(os.path.sep)[0] + archive_format_path = f"{root_folder}/{ARCHIVE_FORMAT_PATH}" + if archive_format_path in zip_reader.namelist(): + return zip_reader.read(archive_format_path) == b"pt2" + except Exception: + logger.info("Model is not a PT2 package") + return False + + +class PT2ArchiveWriter: + """ + Context manager for writing a PT2 archive. + """ + + def __init__(self, archive_path_or_buffer: FileLike): + if isinstance(archive_path_or_buffer, str): + archive_path_or_buffer = normalize_path_separator(archive_path_or_buffer) + self.archive_file = torch._C.PyTorchFileWriter(archive_path_or_buffer) # type: ignore[arg-type] + # NOTICE: version here is different from the archive_version + # this is the version of zip file format, which is used by PyTorchFileWriter, which write to /.data/version + # archive_version is the version of the PT2 archive spec, which write to /archive_version + self.archive_file.set_min_version(6) + + def __enter__(self) -> "PT2ArchiveWriter": + return self + + def __exit__(self, *args: Any) -> None: + if not self.has_record(ARCHIVE_FORMAT_PATH): + self.write_string(ARCHIVE_FORMAT_PATH, ARCHIVE_FORMAT_VALUE) + + if not self.has_record(ARCHIVE_VERSION_PATH): + self.write_string(ARCHIVE_VERSION_PATH, ARCHIVE_VERSION_VALUE) + + self.close() + + def has_record(self, name: str) -> bool: + """ + Check if a record exists in the archive. + """ + return name in self.archive_file.get_all_written_records() + + def count_prefix(self, prefix: str) -> int: + """ + Count the number of records that start with a given prefix. + """ + return sum( + 1 + for record in self.archive_file.get_all_written_records() + if record.startswith(prefix) + ) + + def write_bytes(self, name: str, data: bytes) -> None: + """ + Write a bytes object to the archive. + name: The destination file inside the archive. + data: The bytes object to write. + """ + assert isinstance(data, bytes), f"Expected bytes but got {type(data)}" + self.archive_file.write_record(name, data, len(data)) + + def write_string(self, name: str, data: str) -> None: + """ + Write a string object to the archive. + name: The destination file inside the archive. + data: The string object to write. + """ + assert isinstance(data, str), f"Expected string but got {type(data)}" + data_bytes = data.encode() + self.write_bytes(name, data_bytes) + + def write_file(self, name: str, file_path: str) -> None: + """ + Copy a file into the archive. + name: The destination file inside the archive. + file_path: The source file on disk. + """ + assert os.path.isfile(file_path), f"{file_path} is not a valid file path" + + with open(file_path, "rb") as f: + file_bytes = f.read() + self.write_bytes(name, file_bytes) + + def write_folder(self, archive_dir: str, folder_dir: str) -> None: + """ + Copy a folder into the archive. + archive_dir: The destination folder inside the archive. + folder_dir: The source folder on disk. + """ + assert os.path.isdir(folder_dir), f"{folder_dir} is not a valid directory path" + + file_paths = filter( + os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True) + ) + for file_path in file_paths: + # pyrefly: ignore [no-matching-overload] + filename = os.path.relpath(file_path, folder_dir) + archive_path = os.path.join(archive_dir, filename) + # pyrefly: ignore [bad-argument-type] + self.write_file(archive_path, file_path) + + def close(self) -> None: + """ + Close the archive. + """ + self.archive_file.write_end_of_file() + + +class PT2ArchiveReader: + """ + Context manager for reading a PT2 archive. + """ + + def __init__(self, archive_path_or_buffer: FileLike): + if isinstance(archive_path_or_buffer, str): + archive_path_or_buffer = normalize_path_separator(archive_path_or_buffer) + self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type] + assert self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE, ( + "Invalid archive format" + ) + + def __enter__(self) -> "PT2ArchiveReader": + return self + + def __exit__(self, *args: Any) -> None: + # torch._C.PyTorchFileReader doesn't have a close method + pass + + def read_bytes(self, name: str) -> bytes: + """ + Read a bytes object from the archive. + name: The source file inside the archive. + """ + return self.archive_file.get_record(name) + + def read_string(self, name: str) -> str: + """ + Read a string object from the archive. + name: The source file inside the archive. + """ + data = self.read_bytes(name) + return data.decode() + + def archive_version(self) -> int: + """ + Get the archive version. + """ + try: + archive_version = self.read_string(ARCHIVE_VERSION_PATH) + except Exception: + # if archive_version is not found, it means the archive is older than version 0. + # In this case, we assume the archive is version 0. + archive_version = "0" + + return int(archive_version) + + def get_file_names(self) -> list[str]: + """ + Get the file names in the archive. + """ + return self.archive_file.get_all_records() + + +is_pt2_package.__module__ = "torch.export.pt2_archive" +PT2ArchiveWriter.__module__ = "torch.export.pt2_archive" +PT2ArchiveReader.__module__ = "torch.export.pt2_archive" + + +def _package_aoti_files( + archive_writer: PT2ArchiveWriter, + aoti_files: AOTI_FILES | None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> None: + if aoti_files is None: + return + + if isinstance(aoti_files, list): + aoti_files = {"model": aoti_files} + + assert isinstance(aoti_files, dict) + + all_weights: dict[str, Weights] = {} # model_name -> weight + weights_configs: dict[ + str, dict[str, Any] + ] = {} # model_name -> (weight_name -> (filename, shape, stride, offset)) + + for model_name, files in aoti_files.items(): + num_so_files = 0 + weights_configs[model_name] = {} + + for file in files: + if file == "": + continue + + if isinstance(file, Weights): + all_weights[model_name] = file + continue + + if file.endswith(".so"): + num_so_files += 1 + if num_so_files > 1: + raise RuntimeError( + f"Multiple .so files found in {files}. " + "You might need to clear your cache " + "directory before calling aoti_compile again." + ) + + filename = os.path.basename(file) + if filename.startswith(CUSTOM_OBJ_FILENAME_PREFIX): + new_filepath = os.path.join(CONSTANTS_DIR, filename) + else: + new_filepath = os.path.join(AOTINDUCTOR_DIR, model_name, filename) + logger.debug( + "Saving AOTI generated file %s to archive in %s", file, new_filepath + ) + archive_writer.write_file( + str(new_filepath), + file, + ) + + if len(all_weights) > 0: + # Dedup weights + grouped_tensors: list[OrderedSet[tuple[str, str]]] = group_weights(all_weights) + for idx, group in enumerate(grouped_tensors): + filename = f"{WEIGHT_FILENAME_PREFIX}{idx}" + model_name, weight_name = get_complete(group, all_weights) + complete_tensor, _ = all_weights[model_name].get_weight(weight_name) + buffer = io.BytesIO() + torch.save(complete_tensor, buffer, pickle_protocol=pickle_protocol) + archive_writer.write_bytes( + os.path.join(WEIGHTS_DIR, filename), buffer.getvalue() + ) + for model_name, weight_name in group: + _, w_property = all_weights[model_name].get_weight(weight_name) + weights_configs[model_name][weight_name] = ( + filename, + w_property.shape, + w_property.stride, + w_property.offset, + ) + + for model_name, weights_config in weights_configs.items(): + archive_writer.write_string( + os.path.join(AOTINDUCTOR_DIR, model_name, "weights_config.json"), + json.dumps(weights_config), + ) + logger.debug("packaging weights_config for model %s", model_name) + logger.debug(weights_config) + + +def _is_fake_tensor(t: torch.Tensor) -> bool: + return isinstance(t, FakeTensor) + + +def _is_tensor_subclass(t: torch.Tensor) -> bool: + return isinstance(t, torch.Tensor) and type(t.data) is not torch.Tensor + + +def _get_raw_tensor_bytes(value: torch.Tensor) -> bytes: + """ + Get the raw bytes of a tensor. This is used to save the tensor in pt2 archive. + """ + # NOTE: don't chain .cpu() with .data_ptr(). If an HtoD copy needs to be + # performed, the CPU copy needs to be kept alive when its underlying + # memory is accessed. + import ctypes + + if _is_fake_tensor(value): + value_bytes = b"" + elif value.data_ptr(): + cpu_tensor = value.cpu() + value_untyped_storage = cpu_tensor.untyped_storage() + # we store the raw bytes the untyped storage. Tensor metadata is stored separately + value_bytes = bytes( + ctypes.cast( + value_untyped_storage.data_ptr(), + ctypes.POINTER(ctypes.c_ubyte * value_untyped_storage.size()), + ).contents + ) + else: + # for empty tensor + value_bytes = b"" + return value_bytes + + +def _should_use_pickle(t: torch.Tensor) -> bool: + return _is_tensor_subclass(t) and not _is_fake_tensor(t) + + +def _save_pickled_tensors( + pickled_items: list[tuple[str, torch.Tensor]], + archive_writer: PT2ArchiveWriter, + config: dict[str, schema.PayloadMeta], + directory: str, + filename_prefix: str, + idx: int, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> int: + """Save pickled tensors and update config. Returns updated index.""" + for item_fqn, tensor in pickled_items: + path_name = f"{filename_prefix}{idx}" + archive_path = os.path.join(directory, path_name) + buffer = io.BytesIO() + torch.save(tensor, buffer, pickle_protocol=pickle_protocol) + archive_writer.write_bytes(archive_path, buffer.getvalue()) + + config[item_fqn] = schema.PayloadMeta( + path_name=path_name, + is_param=isinstance(tensor, torch.nn.Parameter), + use_pickle=True, + tensor_meta=serialize_tensor_meta(tensor), + ) + idx += 1 + return idx + + +def _save_raw_tensors( + raw_items: dict[str, tuple[torch.Tensor, TensorProperties]], + model_name: str, + archive_writer: PT2ArchiveWriter, + config: dict[str, schema.PayloadMeta], + directory: str, + filename_prefix: str, + idx: int, +) -> int: + """Save deduplicated raw tensor bytes and update config. Returns updated index.""" + if not raw_items: + return idx + + weights_dict = {model_name: Weights(raw_items)} + storage_groups = group_weights(weights_dict) + + for group in storage_groups: + # Find the complete tensor that covers all others in this storage group + model_name, complete_item_name = get_complete(group, weights_dict) + complete_tensor, _ = weights_dict[model_name].get_weight(complete_item_name) + + path_name = f"{filename_prefix}{idx}" + archive_path = os.path.join(directory, path_name) + tensor_bytes = _get_raw_tensor_bytes(complete_tensor) + archive_writer.write_bytes(archive_path, tensor_bytes) + idx += 1 + + for _, item_fqn in group: + tensor, _ = weights_dict[model_name].get_weight(item_fqn) + config[item_fqn] = schema.PayloadMeta( + path_name=path_name, + is_param=isinstance(tensor, torch.nn.Parameter), + use_pickle=False, + tensor_meta=serialize_tensor_meta(tensor), + ) + + return idx + + +def _package_state_dict( + model_name: str, + exported_program: ExportedProgram, + archive_writer: PT2ArchiveWriter, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> schema.PayloadConfig: + weights_config: dict[str, schema.PayloadMeta] = {} + + pickled_weights: list[tuple[str, torch.Tensor]] = [] + raw_weights: dict[str, tuple[torch.Tensor, TensorProperties]] = {} + + # Categorize weights + for weight_fqn, weight_tensor in exported_program.state_dict.items(): + assert isinstance(weight_tensor, torch.Tensor), ( + "only torch.Tensor is allowed in state_dict" + ) + if _should_use_pickle(weight_tensor): + pickled_weights.append((weight_fqn, weight_tensor)) + else: + raw_weights[weight_fqn] = (weight_tensor, TensorProperties(weight_tensor)) + + idx = archive_writer.count_prefix(os.path.join(WEIGHTS_DIR, WEIGHT_FILENAME_PREFIX)) + + # Save weights in pickle format + idx = _save_pickled_tensors( + pickled_weights, + archive_writer, + weights_config, + WEIGHTS_DIR, + WEIGHT_FILENAME_PREFIX, + idx, + pickle_protocol, + ) + + # Save weights in raw bytes format + _save_raw_tensors( + raw_weights, + model_name, + archive_writer, + weights_config, + WEIGHTS_DIR, + WEIGHT_FILENAME_PREFIX, + idx, + ) + + return schema.PayloadConfig(config=weights_config) + + +def _package_constants( + model_name: str, + exported_program: ExportedProgram, + archive_writer: PT2ArchiveWriter, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> schema.PayloadConfig: + constants_config: dict[str, schema.PayloadMeta] = {} + + pickled_constants: list[tuple[str, torch.Tensor]] = [] + raw_constants: dict[str, tuple[torch.Tensor, TensorProperties]] = {} + custom_objects: list[tuple[str, torch._C.ScriptObject]] = [] + + # Categorize constants + for constant_fqn, constant in exported_program.constants.items(): + if isinstance(constant, torch.Tensor): + if _should_use_pickle(constant): + pickled_constants.append((constant_fqn, constant)) + else: + raw_constants[constant_fqn] = (constant, TensorProperties(constant)) + + elif isinstance(constant, torch._C.ScriptObject): + custom_objects.append((constant_fqn, constant)) + + else: + raise RuntimeError(f"Unsupported constant type: {type(constant)}") + + tensor_idx = archive_writer.count_prefix( + os.path.join(CONSTANTS_DIR, TENSOR_CONSTANT_FILENAME_PREFIX) + ) + custom_obj_idx = archive_writer.count_prefix( + os.path.join(CONSTANTS_DIR, CUSTOM_OBJ_FILENAME_PREFIX) + ) + + # Save constants in pickle format + tensor_idx = _save_pickled_tensors( + pickled_constants, + archive_writer, + constants_config, + CONSTANTS_DIR, + TENSOR_CONSTANT_FILENAME_PREFIX, + tensor_idx, + pickle_protocol, + ) + + # Save constants in raw bytes format + _save_raw_tensors( + raw_constants, + model_name, + archive_writer, + constants_config, + CONSTANTS_DIR, + TENSOR_CONSTANT_FILENAME_PREFIX, + tensor_idx, + ) + + # Handle custom objects + for constant_fqn, constant in custom_objects: + path_name = f"{CUSTOM_OBJ_FILENAME_PREFIX}{custom_obj_idx}" + archive_path = os.path.join(CONSTANTS_DIR, path_name) + custom_obj_bytes = torch._C._pickle_save(constant) + archive_writer.write_bytes(archive_path, custom_obj_bytes) + + constants_config[constant_fqn] = schema.PayloadMeta( + path_name=path_name, + is_param=False, + use_pickle=True, + tensor_meta=None, + ) + custom_obj_idx += 1 + + return schema.PayloadConfig(config=constants_config) + + +def _package_payload_config( + archive_writer: PT2ArchiveWriter, + payload_config: schema.PayloadConfig, + config_file: str, +) -> None: + """ + Save the payload config as json file in the archive. + """ + archive_writer.write_string( + config_file, json.dumps(_dataclass_to_dict(payload_config)) + ) + + +def _package_exported_programs( + archive_writer: PT2ArchiveWriter, + exported_programs: ExportedProgram | dict[str, ExportedProgram] | None, + opset_version: dict[str, int] | None = None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> None: + if exported_programs is None: + return + + if isinstance(exported_programs, ExportedProgram): + exported_programs = {"model": exported_programs} + + assert isinstance(exported_programs, dict) + + for model_name, ep in exported_programs.items(): + weights_config = _package_state_dict( + model_name, ep, archive_writer, pickle_protocol + ) + weights_config_file = WEIGHTS_CONFIG_FILENAME_FORMAT.format(model_name) + _package_payload_config(archive_writer, weights_config, weights_config_file) + + constants_config = _package_constants( + model_name, ep, archive_writer, pickle_protocol + ) + constants_config_file = CONSTANTS_CONFIG_FILENAME_FORMAT.format(model_name) + _package_payload_config(archive_writer, constants_config, constants_config_file) + + artifact: SerializedArtifact = serialize( + ep, + opset_version, + pickle_protocol, + ) + + archive_writer.write_bytes( + MODELS_FILENAME_FORMAT.format(model_name), artifact.exported_program + ) + archive_writer.write_bytes( + SAMPLE_INPUTS_FILENAME_FORMAT.format(model_name), + artifact.example_inputs, + ) + + +def _package_extra_files( + archive_writer: PT2ArchiveWriter, extra_files: dict[str, Any] | None +) -> None: + if extra_files is None: + return + + for extra_file_name, content in extra_files.items(): + archive_writer.write_string(f"{EXTRA_DIR}{extra_file_name}", content) + + +def _package_executorch_files( + archive_writer: PT2ArchiveWriter, executorch_files: dict[str, bytes] | None +) -> None: + if executorch_files is None: + return + + for file_name, content in executorch_files.items(): + archive_writer.write_bytes(f"{EXECUTORCH_DIR}{file_name}", content) + + +def package_pt2( + f: FileLike, + *, + exported_programs: ExportedProgram | dict[str, ExportedProgram] | None = None, + aoti_files: AOTI_FILES | None = None, + extra_files: dict[str, Any] | None = None, + opset_version: dict[str, int] | None = None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, + executorch_files: dict[str, bytes] | None = None, +) -> FileLike: + r""" + Saves the artifacts to a PT2Archive format. The artifact can then be loaded + using ``load_pt2``. + + Args: + f (str | os.PathLike[str] | IO[bytes]): A file-like object (has to + implement write and flush) or a string containing a file name. + + exported_programs (Union[ExportedProgram, dict[str, ExportedProgram]]): + The exported program to save, or a dictionary mapping model name to an + exported program to save. The exported program will be saved under + models/\*.json. If only one ExportedProgram is specified, this will + automatically be named "model". + + aoti_files (Union[list[str], dict[str, list[str]]]): A list of files + generated by AOTInductor via + ``torch._inductor.aot_compile(..., {"aot_inductor.package": True})``, + or a dictionary mapping model name to its AOTInductor generated files. + If only one set of files is specified, this will automatically be named + "model". + + extra_files (Optional[Dict[str, Any]]): Map from filename to contents + which will be stored as part of the pt2. + + opset_version (Optional[Dict[str, int]]): A map of opset names + to the version of this opset + + pickle_protocol: can be specified to override the default protocol + + executorch_files (Optional[dict[str, bytes]]): Optional executorch + artifacts to save. + + """ + assert not ( + exported_programs is None and aoti_files is None and extra_files is None + ), ( + "No value passed in for `exported_programs`, `aoti_files`, and " + "`extra_files`, implying that you do not plan on saving anything." + ) + + if not ( + (isinstance(f, (io.IOBase, IO)) and f.writable() and f.seekable()) + or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2")) + or (isinstance(f, tempfile._TemporaryFileWrapper) and f.name.endswith(".pt2")) + ): + # TODO: turn this into an error + logger.warning( + "Expect archive file to be a file ending in .pt2, or is a buffer. " + "Instead got {%s}", + f, + ) + + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + # pyrefly: ignore [bad-argument-type] + with PT2ArchiveWriter(f) as archive_writer: + _package_exported_programs( + archive_writer, exported_programs, pickle_protocol=pickle_protocol + ) + _package_aoti_files( + archive_writer, + aoti_files, + pickle_protocol=pickle_protocol, + ) + _package_extra_files(archive_writer, extra_files) + _package_executorch_files(archive_writer, executorch_files) + + if isinstance(f, (io.IOBase, IO)): + f.seek(0) + # pyrefly: ignore [bad-return] + return f + + +class AOTICompiledModel: + """ + Callable AOT Inductor loaded model from a .pt2 + """ + + def __init__(self, loader: torch._C._aoti.AOTIModelPackageLoader) -> None: + self.loader = loader + + def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] + call_spec = self.loader.get_call_spec() + in_spec = pytree.treespec_loads(call_spec[0]) + out_spec = pytree.treespec_loads(call_spec[1]) + flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] + flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)] + flat_outputs = self.loader.boxed_run(flat_inputs) + return pytree.tree_unflatten(flat_outputs, out_spec) + + def get_metadata(self) -> dict[str, str]: + return self.loader.get_metadata() + + def load_constants( + self, + constants_map: dict[str, torch.Tensor], + *, + check_full_update: bool, + user_managed: bool = False, + ) -> None: + """ + Given a mapping of constant fqns to tensors, load the constants into the model. + You can use ``get_constant_fqns`` to get the list of constant fqns that + are needed in the compiled model. + + Args: + constants_map: A mapping of constant fqns to tensors. + check_full_update: Whether to add check to see if all the constants + are updated and have values. + """ + self.loader.load_constants( + constants_map, False, check_full_update, user_managed + ) + + def get_constant_fqns(self) -> list[str]: + return self.loader.get_constant_fqns() + + def __deepcopy__(self, memo: dict[Any, Any] | None) -> "AOTICompiledModel": + logger.warning( + "AOTICompiledModel deepcopy warning: AOTICompiledModel.loader is not deepcopied." + ) + return AOTICompiledModel(self.loader) + + +@dataclass +class PT2ArchiveContents: + exported_programs: dict[str, ExportedProgram] + aoti_runners: dict[str, AOTICompiledModel] + extra_files: dict[str, Any] + + +def _create_flat_tensor_from_bytes( + tensor_bytes: bytes, + tensor_meta: schema.TensorMeta, +) -> torch.Tensor: + """ + Create a flat tensor from raw bytes with dtype, device and requires_grad. + It will be re-strided based on size, stride, and storage_offset later. + """ + dtype = deserialize_scalar_type(tensor_meta.dtype) + size = deserialize_size(tensor_meta.sizes) + device = deserialize_device(tensor_meta.device) + + if len(tensor_bytes) != 0: + tensor = torch.frombuffer( + tensor_bytes, dtype=dtype, requires_grad=tensor_meta.requires_grad + ).to(device) + else: + # cannot call torch.frombuffer() on empty bytes + logger.warning( + "Cannot call torch.frombuffer() on empty bytes. " + "Creating a tensor with zeros as workaround." + ) + tensor = torch.zeros(size, dtype=dtype, device=device) + + return tensor + + +def _build_file_map( + archive_reader: PT2ArchiveReader, + config: schema.PayloadConfig, + base_dir: str, +) -> dict[str, torch.Tensor]: + """ + Build a map from file path to the payload in flat tensor format. + """ + file_map: dict[str, torch.Tensor] = {} + for payload_meta in config.config.values(): + # skip pickled objects + if payload_meta.use_pickle: + continue + # skip files that already exist in the map + if payload_meta.path_name in file_map: + continue + + tensor_bytes = archive_reader.read_bytes( + os.path.join(base_dir, payload_meta.path_name) + ) + assert payload_meta.tensor_meta is not None + tensor = _create_flat_tensor_from_bytes(tensor_bytes, payload_meta.tensor_meta) + file_map[payload_meta.path_name] = tensor + + return file_map + + +def _load_payload_config( + archive_reader: PT2ArchiveReader, + config_file: str, +) -> schema.PayloadConfig: + """ + Load and parse a payload config from the archive. + """ + return _dict_to_dataclass( + schema.PayloadConfig, + json.loads(archive_reader.read_string(config_file)), + ) + + +def _load_state_dict( + archive_reader: PT2ArchiveReader, + model_name: str, +) -> dict[str, torch.Tensor] | bytes: + # Make it BC compatible with legacy weight files + legacy_weights_file = f"{WEIGHTS_DIR}{model_name}.pt" + if legacy_weights_file in archive_reader.get_file_names(): + logger.warning( + "You are loading weight from the legacy format. " + "Please generate a new pt2 file using torch.export.save()." + ) + return archive_reader.read_bytes(legacy_weights_file) + else: + weights_config_file = WEIGHTS_CONFIG_FILENAME_FORMAT.format(model_name) + assert weights_config_file in archive_reader.get_file_names(), ( + f"{weights_config_file} not found in PT2 archive" + ) + weights_config = _load_payload_config(archive_reader, weights_config_file) + # construct the mapping from file name (e.g. weight_0) to flat weight payload + state_dict_file_map = _build_file_map( + archive_reader, weights_config, WEIGHTS_DIR + ) + # chain the mapping weight FQN -> weight file name -> strided weight payload + # so that the aliasing of weights is preserved + state_dict: dict[str, torch.Tensor] = {} + for weight_fqn, payload_meta in weights_config.config.items(): + if payload_meta.use_pickle: + weight_bytes = archive_reader.read_bytes( + os.path.join(WEIGHTS_DIR, payload_meta.path_name) + ) + state_dict[weight_fqn] = torch.load( + io.BytesIO(weight_bytes), weights_only=False + ) + else: + tensor_meta = payload_meta.tensor_meta + assert tensor_meta is not None + weight_tensor = torch.as_strided( + input=state_dict_file_map[payload_meta.path_name], + size=deserialize_size(tensor_meta.sizes), + stride=deserialize_stride(tensor_meta.strides), + storage_offset=deserialize_storage_offset( + tensor_meta.storage_offset + ), + ) + if payload_meta.is_param: + state_dict[weight_fqn] = torch.nn.Parameter( + weight_tensor, requires_grad=tensor_meta.requires_grad + ) + else: + state_dict[weight_fqn] = weight_tensor + + return state_dict + + +def _load_constants( + archive_reader: PT2ArchiveReader, + model_name: str, +) -> dict[str, torch.Tensor] | bytes: + # Make it BC compatible with legacy constant files + legacy_constants_file = f"{CONSTANTS_DIR}{model_name}.pt" + if legacy_constants_file in archive_reader.get_file_names(): + logger.warning( + "You are loading constant from the legacy format. " + "Please generate a new pt2 file using torch.export.save()." + ) + return archive_reader.read_bytes(legacy_constants_file) + else: + constants_config_file = CONSTANTS_CONFIG_FILENAME_FORMAT.format(model_name) + assert constants_config_file in archive_reader.get_file_names(), ( + f"{constants_config_file} not found in PT2 archive" + ) + constants_config = _load_payload_config(archive_reader, constants_config_file) + # construct the mapping from file name (e.g. constant_0) to constant payload + constant_file_map = _build_file_map( + archive_reader, constants_config, CONSTANTS_DIR + ) + # chain the mapping constant FQN -> constant file name -> strided constant payload + # so that the aliasing of constants is preserved + constants: dict[str, torch.Tensor] = {} + for constant_fqn, payload_meta in constants_config.config.items(): + path_name = payload_meta.path_name + if path_name.startswith(TENSOR_CONSTANT_FILENAME_PREFIX): + if payload_meta.use_pickle: + constant_bytes = archive_reader.read_bytes( + os.path.join(CONSTANTS_DIR, path_name) + ) + constants[constant_fqn] = torch.load( + io.BytesIO(constant_bytes), weights_only=False + ) + else: + tensor_meta = payload_meta.tensor_meta + assert tensor_meta is not None + constant_tensor = torch.as_strided( + input=constant_file_map[path_name], + size=deserialize_size(tensor_meta.sizes), + stride=deserialize_stride(tensor_meta.strides), + storage_offset=deserialize_storage_offset( + tensor_meta.storage_offset + ), + ) + constants[constant_fqn] = constant_tensor + + elif path_name.startswith(CUSTOM_OBJ_FILENAME_PREFIX): + constant_bytes = archive_reader.read_bytes( + os.path.join(CONSTANTS_DIR, path_name) + ) + constants[constant_fqn] = torch._C._pickle_load_obj(constant_bytes) + + else: + raise RuntimeError(f"Unsupported constant type: {path_name}") + + return constants + + +def _load_exported_programs( + archive_reader: PT2ArchiveReader, + file_names: list[str], + expected_opset_version: dict[str, int] | None, +) -> dict[str, ExportedProgram]: + exported_program_files = [ + file for file in file_names if file.startswith(MODELS_DIR) + ] + exported_programs = {} + for file in exported_program_files: + prefix, suffix = MODELS_FILENAME_FORMAT.split( + "{}" + ) # split "models/{}.json" into "models/" and "json" + model_name = file[ + len(prefix) : -len(suffix) + ] # given "models/foo.json" we can now get "foo" + + sample_inputs_file = SAMPLE_INPUTS_FILENAME_FORMAT.format(model_name) + serialized_sample_inputs = archive_reader.read_bytes(sample_inputs_file) + + from torch._export.serde.serialize import _bytes_to_dataclass + + exported_program_bytes = archive_reader.read_bytes(file) + serialized_exported_program = _bytes_to_dataclass( + schema.ExportedProgram, exported_program_bytes + ) + state_dict = _load_state_dict(archive_reader, model_name) + constants = _load_constants(archive_reader, model_name) + + ep = ExportedProgramDeserializer(expected_opset_version).deserialize( + serialized_exported_program, + state_dict, + constants, + serialized_sample_inputs, + ) + + exported_programs[model_name] = ep + + return exported_programs + + +def _load_extra_files( + archive_reader: PT2ArchiveReader, file_names: list[str] +) -> dict[str, Any]: + extra_files = [file for file in file_names if file.startswith(EXTRA_DIR)] + + extra_file_contents: dict[str, Any] = {} + for file in extra_files: + contents = archive_reader.read_string(file) + extra_file_contents[file[len(EXTRA_DIR) :]] = contents + + return extra_file_contents + + +def _load_aoti( + file: str, + model_name: str, + run_single_threaded: bool, + num_runners: int, + device_idx: int, +) -> AOTICompiledModel: + loaded_metadata = torch._C._aoti.AOTIModelPackageLoader.load_metadata_from_package( # type: ignore[attr-defined] + file, model_name + ) + + device = loaded_metadata["AOTI_DEVICE_KEY"] + current_device_info = torch._inductor.codecache.get_device_information(device) + + for k, v in current_device_info.items(): + if k in loaded_metadata: + if v != loaded_metadata[k]: + logger.warning( + "Device information mismatch for %s: %s vs %s. " + "This could cause some issues when loading the AOTInductor compiled artifacts.", + k, + v, + loaded_metadata[k], + ) + + aoti_compiled_model = AOTICompiledModel( + torch._C._aoti.AOTIModelPackageLoader( + file, + model_name, + run_single_threaded, + num_runners, + device_idx, + ) + ) + + return aoti_compiled_model + + +def load_pt2( + f: FileLike, + *, + expected_opset_version: dict[str, int] | None = None, + run_single_threaded: bool = False, + num_runners: int = 1, + device_index: int = -1, + load_weights_from_disk: bool = False, +) -> PT2ArchiveContents: # type: ignore[type-arg] + """ + Loads all the artifacts previously saved with ``package_pt2``. + + Args: + f (str | os.PathLike[str] | IO[bytes]): A file-like object (has to + implement write and flush) or a string containing a file name. + + expected_opset_version (Optional[Dict[str, int]]): A map of opset names + to expected opset versions + + num_runners (int): Number of runners to load AOTInductor artifacts + + run_single_threaded (bool): Whether the model should be run without + thread synchronization logic. This is useful to avoid conflicts with + CUDAGraphs. + + device_index (int): The index of the device to which the PT2 package is + to be loaded. By default, `device_index=-1` is used, which corresponds + to the device `cuda` when using CUDA. Passing `device_index=1` would + load the package to `cuda:1`, for example. + + Returns: + A ``PT2ArchiveContents`` object which contains all the objects in the PT2. + """ + + from torch._inductor.cpp_builder import normalize_path_separator + + if not ( + (isinstance(f, (io.IOBase, IO)) and f.readable() and f.seekable()) + or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2")) + ): + # TODO: turn this into an error in 2.9 + logger.warning( + "Unable to load package. f must be a buffer or a file ending in " + ".pt2. Instead got {%s}", + f, + ) + + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + weights = {} + weight_maps = {} + # pyrefly: ignore [bad-argument-type] + with PT2ArchiveReader(f) as archive_reader: + version = archive_reader.read_string(ARCHIVE_VERSION_PATH) + if version != ARCHIVE_VERSION_VALUE: + raise ValueError( + f"Saved archive version {version} does not match our current " + f"archive version {ARCHIVE_VERSION_VALUE}." + ) + + file_names = archive_reader.get_file_names() + + exported_programs = _load_exported_programs( + archive_reader, file_names, expected_opset_version + ) + extra_files = _load_extra_files(archive_reader, file_names) + + # Get a list of AOTI model names + aoti_model_names: set[str] = set() + for file in file_names: + if file.startswith(AOTINDUCTOR_DIR): + file_end = file[ + len(AOTINDUCTOR_DIR) : + ] # remove data/aotinductor/ prefix + file_end = normalize_path_separator( + file_end + ) # Win32 need normalize path before split. + model_name = file_end.split("/")[ + 0 + ] # split "model_name/...cpp" into "model_name" + aoti_model_names.add(model_name) + if load_weights_from_disk and file.endswith("weights_config.json"): + weight_map = json.loads(archive_reader.read_string(file)) + weight_maps[model_name] = weight_map + elif load_weights_from_disk and file.startswith(WEIGHTS_DIR): + weight_file_name = file[ + len(WEIGHTS_DIR) : + ] # remove data/weights/ prefix + weight_bytes = archive_reader.read_bytes(file) + loaded_weight = torch.load(io.BytesIO(weight_bytes)) + weights[weight_file_name] = loaded_weight + + if isinstance(f, (io.IOBase, IO)): + if len(aoti_model_names) > 0: + # Workaround for AOTIModelPackageLoader not reading buffers + with tempfile.NamedTemporaryFile(suffix=".pt2") as tf: + f.seek(0) + tf.write(f.read()) + f.seek(0) + logger.debug("Writing buffer to tmp file located at %s.", tf.name) + + aoti_runners = { + model_name: _load_aoti( + tf.name, + model_name, + run_single_threaded, + num_runners, + device_index, + ) + for model_name in aoti_model_names + } + else: + aoti_runners = {} + else: + aoti_runners = { + model_name: _load_aoti( + f, + model_name, + run_single_threaded, + num_runners, + device_index, + ) + for model_name in aoti_model_names + } + + if weight_maps: + for model_name in aoti_model_names: + model_weights = {} + for weight_name, (file, shape, stride, storage_offset) in weight_maps[ + model_name + ].items(): + weight = weights[file] + model_weights[weight_name] = weight.as_strided( + shape, stride, storage_offset + ) + + # user_managed=True ensures the weights updates are shared by all runners. + aoti_runners[model_name].load_constants( + model_weights, check_full_update=True, user_managed=True + ) + + return PT2ArchiveContents(exported_programs, aoti_runners, extra_files) + + +def load_weights_to_pt2_contents( + pt2_contents: PT2ArchiveContents, weights_map: dict[str, Any] +) -> None: + """ + Load weights into the models in PT2 archive contents + + Args: + pt2_contents (PT2ArchiveContents): The contents of the PT2 archive. + """ + for model_name, weights in weights_map.items(): + if model_name not in pt2_contents.aoti_runners: + raise RuntimeError(f"Model {model_name} not found in PT2 archive contents.") + pt2_contents.aoti_runners[model_name].load_constants( + weights, check_full_update=True, user_managed=True + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/_package_weights.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/_package_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..5acd86feebf0a691d7e527e4ea382e7b4aaabf9c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/_package_weights.py @@ -0,0 +1,135 @@ +import collections +import warnings + +import torch +from torch._subclasses.fake_tensor import FakeTensor +from torch.utils._ordered_set import OrderedSet + + +def _end_ptr(tensor: torch.Tensor) -> int: + if tensor.nelement(): + stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size() + else: + stop = tensor.data_ptr() + return stop + + +class TensorProperties: + def __init__(self, tensor: torch.Tensor): + self.is_fake = isinstance(tensor, FakeTensor) + self.is_contiguous = tensor.is_contiguous() + self.storage_ptr = None + self.storage_size = None + self.start = None + self.end = None + + if not self.is_fake: + # only get the storage pointer for real tensors + # pyrefly: ignore [bad-assignment] + self.storage_ptr = tensor.untyped_storage().data_ptr() + if self.is_contiguous: + # only get storage size and start/end pointers for contiguous tensors + # pyrefly: ignore [bad-assignment] + self.storage_size = tensor.untyped_storage().nbytes() + # pyrefly: ignore [bad-assignment] + self.start = tensor.data_ptr() + # pyrefly: ignore [bad-assignment] + self.end = _end_ptr(tensor) + + # info to recover tensor + self.shape = tensor.shape + self.stride = tensor.stride() + self.offset = tensor.storage_offset() + + def is_complete(self) -> bool: + """ + Whether the tensor completely overlaps with its underlying storage + """ + if self.is_fake: + # Theoretically, fake tensors should not appear in weights + # But we handle this corner case to make it always complete + return True + if not self.is_contiguous: + return False + + assert self.storage_ptr is not None + assert self.storage_size is not None + assert self.start is not None + assert self.end is not None + return ( + self.start == self.storage_ptr + and self.end == self.storage_ptr + self.storage_size + ) + + +class Weights(dict): + """ + A dictionary mapping from weight name to a tuple of (tensor, TensorProperties). + tensor represents the actual initial value of the weight. + TensorProperties represents the properties of the weight that are needed to recover the weight. + + We use two separate entries because `tensor` could be a clone of the original weight tensor, + so it doesn't have the same property as the original weight (such as underlying storage pointer). + """ + + def __init__(self, weight_dict: dict[str, tuple[torch.Tensor, TensorProperties]]): + super().__init__(weight_dict) + + def get_weight(self, name: str) -> tuple[torch.Tensor, TensorProperties]: + return self[name] + + def get_weight_properties(self, name: str) -> TensorProperties: + return self[name][1] + + +def get_complete( + group: OrderedSet[tuple[str, str]], models_weights: dict[str, Weights] +) -> tuple[str, str]: + """ + `group` is a (model_name, weight_name) tuple. + `model_weights` is a dictionary mapping from model name to its Weights. + + One of the tensor in `group` must be complete and they must share the + same underlying storage. + + Returns the name of the complete tensor in the `group`. If multiple + tensors are complete, returns an arbitrary one. + """ + + def get_tensor_properties(name_tuple: tuple[str, str]) -> TensorProperties: + # returns the tensor properties + (model_name, weight_name) = name_tuple + return models_weights[model_name].get_weight_properties(weight_name) + + for name_tuple in group: + tensor_property = get_tensor_properties(name_tuple) + if tensor_property.is_complete(): + return name_tuple + + warnings.warn( + "No complete tensor found in the group! Returning the first one. " + "This may cause issues when your weights are not on CPU.", + stacklevel=2, + ) + assert len(group) > 0 + return next(iter(group)) + + +def group_weights(all_weights: dict[str, Weights]) -> list[OrderedSet[tuple[str, str]]]: + """ + Group weights that share the same underlying storage. + + Returns a list of sets, each set contains a tuple of (model_name, weight_name). + """ + + weights_dict: dict[tuple[int, torch.dtype], OrderedSet[tuple[str, str]]] = ( + collections.defaultdict(OrderedSet) + ) # (storage_key, dtype) -> set(weight) + + for model_name, weights in all_weights.items(): + for weight_name, (tensor, properties) in weights.items(): + weights_dict[(properties.storage_ptr, tensor.dtype)].add( + (model_name, weight_name) + ) + + return list(weights_dict.values()) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/constants.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..4b05e257b8f3dfc387b553f0aeecc7a0e1653528 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/export/pt2_archive/constants.py @@ -0,0 +1,35 @@ +# Defined in torch/csrc/export/pt2_archive_constants.h +from torch._C._export import pt2_archive_constants + + +AOTINDUCTOR_DIR: str = pt2_archive_constants.AOTINDUCTOR_DIR +ARCHIVE_FORMAT_PATH: str = pt2_archive_constants.ARCHIVE_FORMAT_PATH +ARCHIVE_FORMAT_VALUE: str = pt2_archive_constants.ARCHIVE_FORMAT_VALUE +ARCHIVE_ROOT_NAME: str = pt2_archive_constants.ARCHIVE_ROOT_NAME +ARCHIVE_VERSION_PATH: str = pt2_archive_constants.ARCHIVE_VERSION_PATH +ARCHIVE_VERSION_VALUE: str = pt2_archive_constants.ARCHIVE_VERSION_VALUE +CONSTANTS_DIR: str = pt2_archive_constants.CONSTANTS_DIR +CONSTANTS_CONFIG_FILENAME_FORMAT: str = ( + pt2_archive_constants.CONSTANTS_CONFIG_FILENAME_FORMAT +) +CUSTOM_OBJ_FILENAME_PREFIX: str = pt2_archive_constants.CUSTOM_OBJ_FILENAME_PREFIX +EXECUTORCH_DIR: str = pt2_archive_constants.EXECUTORCH_DIR +EXTRA_DIR: str = pt2_archive_constants.EXTRA_DIR +MODELS_DIR: str = pt2_archive_constants.MODELS_DIR +MODELS_FILENAME_FORMAT: str = pt2_archive_constants.MODELS_FILENAME_FORMAT +MODULE_INFO_PATH: str = pt2_archive_constants.MODULE_INFO_PATH +MTIA_DIR: str = pt2_archive_constants.MTIA_DIR +SAMPLE_INPUTS_DIR: str = pt2_archive_constants.SAMPLE_INPUTS_DIR +SAMPLE_INPUTS_FILENAME_FORMAT: str = pt2_archive_constants.SAMPLE_INPUTS_FILENAME_FORMAT +TENSOR_CONSTANT_FILENAME_PREFIX: str = ( + pt2_archive_constants.TENSOR_CONSTANT_FILENAME_PREFIX +) +WEIGHTS_CONFIG_FILENAME_FORMAT: str = ( + pt2_archive_constants.WEIGHTS_CONFIG_FILENAME_FORMAT +) +WEIGHT_FILENAME_PREFIX: str = pt2_archive_constants.WEIGHT_FILENAME_PREFIX +WEIGHTS_DIR: str = pt2_archive_constants.WEIGHTS_DIR +XL_MODEL_WEIGHTS_DIR: str = pt2_archive_constants.XL_MODEL_WEIGHTS_DIR +XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH: str = ( + pt2_archive_constants.XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/xpu/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/xpu/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a568c9cfc478b0f3d6f1133514255b400c596bb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/xpu/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/xpu/__pycache__/_gpu_trace.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/xpu/__pycache__/_gpu_trace.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2599a2c48dea85fd07dc7427a42f356c793ef1b4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/xpu/__pycache__/_gpu_trace.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/xpu/__pycache__/_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/xpu/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65d350fb2a2a12805f81dfe08288516605ea8084 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/xpu/__pycache__/_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/xpu/__pycache__/memory.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/xpu/__pycache__/memory.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..217aedb0d10785e29d7cd35c56862c2b2ae39cec Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/xpu/__pycache__/memory.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/xpu/__pycache__/random.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/xpu/__pycache__/random.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..805ced6a369e599c7f0dccfa3b3c0a3416fbf388 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/xpu/__pycache__/random.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/xpu/__pycache__/streams.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/xpu/__pycache__/streams.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e54ce77747e1f430ab79c464f4dfa9bd184e47e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/xpu/__pycache__/streams.cpython-312.pyc differ