import os, re, sys, platform, shutil, subprocess, importlib, json from functools import cached_property from typing import Union from glob import glob from importlib.metadata import version, PackageNotFoundError from lib.conf import * class DeviceInstaller(): def __init__(self): self.system = sys.platform self.arch = self.check_arch self.python_version = sys.version_info[:2] self.python_version_tuple = sys.version_info @cached_property def check_platform(self)->str: return self.detect_platform_tag() @cached_property def check_arch(self)->str: return self.detect_arch_tag() @cached_property def check_hardware(self)->tuple: return self.detect_device() @cached_property def cpu_baseline(self)->bool: machine = platform.machine().lower() if machine not in (archs['X86_64'], archs['AMD64']): return True cpuinfo_version = self.get_package_version('py-cpuinfo') if not cpuinfo_version: subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--upgrade', '--upgrade-strategy', 'only-if-needed', '--no-cache-dir', 'py-cpuinfo']) from cpuinfo import get_cpu_info flags = set(get_cpu_info().get('flags', [])) return {'sse4_2', 'popcnt', 'ssse3'}.issubset(flags) def check_device_info(self, mode:str)->str: if mode == NATIVE: name, tag, msg = self.check_hardware pyvenv = [3, 10] if tag in ['jetson51', 'jetson60', 'jetson61'] else list(max_python_version) arch = archs['AARCH64'] if name in [devices['JETSON']['proc']] else self.arch os_env = 'linux' if name == devices['JETSON']['proc'] else self.check_platform if all([name, tag, os_env, arch, pyvenv]): device_info = {"name": name, "os": os_env, "arch": arch, "pyvenv": pyvenv, "tag": tag, "note": msg} try: with open(device_info_json, 'w', encoding='utf-8') as f: json.dump(device_info, f) except OSError as e: error = f'warning: could not write .device_info.json: {e}' print(error, file=sys.stderr) return json.dumps(device_info) elif mode == FULL_DOCKER: device_info = None if os.path.isfile(device_info_json): try: with open(device_info_json, 'r', encoding='utf-8') as f: device_info = json.load(f) except (OSError, json.JSONDecodeError): pass if device_info is None: env_str = os.environ.get('DOCKER_DEVICE_STR', '') if env_str: try: device_info = json.loads(env_str) except json.JSONDecodeError: pass if device_info is not None: devices[device_info['name'].upper()]['found'] = True return json.dumps(device_info) elif mode == BUILD_DOCKER: name, tag, msg = self.check_hardware os_env = 'manylinux_2_28' pyvenv = [3, 10] if tag in ['jetson51', 'jetson60', 'jetson61'] else list(max_python_version) arch = archs['AARCH64'] if name in [devices['JETSON']['proc']] else self.arch if name in [devices['JETSON']['proc'], devices['MPS']['proc']]: name = tag = devices['CPU']['proc'] device_info = {"name": name, "os": os_env, "arch": arch, "pyvenv": pyvenv, "tag": tag, "note": msg.replace('!', '')} try: with open(device_info_json, 'w', encoding='utf-8') as f: json.dump(device_info, f) except OSError as e: error = f'warning: could not write .device_info.json: {e}' print(error, file=sys.stderr) return json.dumps(device_info) return '' def get_package_version(self, pkg:str)->Union[str, bool]: try: return version(pkg) except PackageNotFoundError: return False def detect_platform_tag(self)->str: if self.system == systems['WINDOWS']: return 'win' if self.system == systems['MACOS']: return 'macosx_11_0' if self.system == systems['LINUX']: return 'manylinux_2_28' return 'unknown' def detect_arch_tag(self)->str: m = platform.machine().upper() return archs.get(m, 'unknown') def detect_device(self)->str: def has_cmd(cmd:str)->bool: return shutil.which(cmd) is not None def try_cmd(cmd:str)->str: try: out = subprocess.check_output( cmd, shell = True, stderr = subprocess.DEVNULL ) return out.decode().lower() except Exception: return '' def lib_version_parse(text:str)->Union[str, None]: if not text: return None text = text.strip() if text.startswith('{'): try: obj = json.loads(text) if isinstance(obj, dict): if devices['CUDA']['proc'] in obj and isinstance(obj[devices['CUDA']['proc']], dict): v = obj[devices['CUDA']['proc']].get('version') if v: return str(v) v = obj.get('version') if v: return str(v) except Exception: pass m = re.search(r'cuda\s*version\s*([0-9]+(?:\.[0-9]+){1,2})', text, re.IGNORECASE) if m: return m.group(1) m = re.search(r'cuda\s*([0-9]+(?:\.[0-9]+)?)', text, re.IGNORECASE) if m: return m.group(1) m = re.search(r'rocm\s*version\s*([0-9]+(?:\.[0-9]+){0,2})', text, re.IGNORECASE) if m: return m.group(1) # CHANGED: keep full version, don't truncate to major.minor m = re.search(r'hip\s*version\s*([0-9]+(?:\.[0-9]+){0,2})', text, re.IGNORECASE) if m: return m.group(1) # CHANGED: keep full version m = re.search(r'(oneapi|xpu)\s*(toolkit\s*)?version\s*([0-9]+(?:\.[0-9]+)?)', text, re.IGNORECASE) if m: return m.group(3) return None def version_classify(version_str:Union[str, None], version_range:dict)->tuple: # Returns (cmp, current_tuple, min_tuple, max_tuple) # cmp: -1 = below min, 0 = in range, 1 = above max, None = parse fail / unranged # current_tuple is (major, minor, patch) — patch defaults to 0 if version_str is None: return (None, None, None, None) min_raw = tuple(version_range.get('min', (0, 0))) max_raw = tuple(version_range.get('max', (0, 0))) # Pad min/max to 3-tuples for consistent comparison min_tuple = min_raw + (0,) * (3 - len(min_raw)) if len(min_raw) < 3 else min_raw[:3] max_tuple = max_raw + (0,) * (3 - len(max_raw)) if len(max_raw) < 3 else max_raw[:3] try: parts = version_str.split('.') major = int(parts[0]) minor = int(parts[1]) if len(parts) > 1 else 0 patch = int(parts[2]) if len(parts) > 2 else 0 except (ValueError, IndexError): return (None, None, min_tuple, max_tuple) current = (major, minor, patch) if min_tuple == (0, 0, 0) and max_tuple == (0, 0, 0): return (0, current, min_tuple, max_tuple) # Compare on (major, minor) only for the range check — patch doesn't gate current_mm = (major, minor) min_mm = min_tuple[:2] max_mm = max_tuple[:2] if min_mm != (0, 0) and current_mm < min_mm: return (-1, current, min_tuple, max_tuple) if max_mm != (0, 0) and current_mm > max_mm: return (1, current, min_tuple, max_tuple) return (0, current, min_tuple, max_tuple) def tegra_version()->str: if os.path.exists('/etc/nv_tegra_release'): return try_cmd('cat /etc/nv_tegra_release') return '' def jetpack_version(text:str)->tuple: m1 = re.search(r'r(\d+)', text) m2 = re.search(r'revision:\s*([\d\.]+)', text) msg = '' if not m1 or not m2: msg = 'Unrecognized JetPack version. Falling back to CPU.' return ('unknown', msg) l4t_major = int(m1.group(1)) rev = m2.group(1) parts = rev.split('.') rev_major = int(parts[0]) if l4t_major < 35: msg = f'JetPack too old (L4T {l4t_major}). Please upgrade to JetPack 6+. Falling back to CPU.' return ('unsupported', msg) if l4t_major == 35: return ('51', msg) if rev_major == 2: return ('60', msg) return ('61', msg) def has_amd_gpu_pci(): if self.system == systems['MACOS']: return False if os.name == 'posix': sysfs = '/sys/bus/pci/devices' if os.path.isdir(sysfs): for d in os.listdir(sysfs): dev = os.path.join(sysfs, d) try: with open(f'{dev}/vendor') as f: if f.read().strip() not in ('0x1002', '0x1022'): continue with open(f'{dev}/class') as f: cls = f.read().strip() if cls.startswith('0x0300') or cls.startswith('0x0302'): return True except Exception: pass if has_cmd('lspci'): out = try_cmd('lspci -nn').lower() return ( ('1002:' in out or '1022:' in out) and (' vga ' in out or ' 3d ' in out) ) return False if os.name == 'nt': if has_cmd('wmic'): out = try_cmd('wmic path win32_VideoController get Name,PNPDeviceID').lower() return 'ven_1002' in out if has_cmd('powershell'): out = try_cmd('powershell -Command "Get-PnpDevice -Class Display | Select-Object -ExpandProperty InstanceId"').lower() return 'ven_1002' in out return False return False def has_rocm(): if self.system == systems['LINUX']: rocm_paths = ['/opt/rocm', '/opt/rocm/bin/rocminfo'] if any(os.path.exists(p) for p in rocm_paths): return True return has_cmd('rocminfo') elif self.system == systems['WINDOWS']: hip_path = os.environ.get('HIP_PATH') if hip_path and os.path.isdir(hip_path): return True program_files = os.environ.get('ProgramFiles', '') if program_files and glob(os.path.join(program_files, 'AMD', 'ROCm', '*')): return True return has_cmd('rocminfo') return False def has_nvidia_gpu_pci(): if self.system == systems['MACOS']: return False if os.name == 'posix': sysfs = '/sys/bus/pci/devices' if os.path.isdir(sysfs): for d in os.listdir(sysfs): dev = os.path.join(sysfs, d) try: with open(f'{dev}/vendor') as f: if f.read().strip() != '0x10de': continue with open(f'{dev}/class') as f: cls = f.read().strip() if cls.startswith('0x0300') or cls.startswith('0x0302'): return True except Exception: pass if has_cmd('lspci'): out = try_cmd('lspci -nn').lower() return '10de:' in out and (' vga ' in out or ' 3d ' in out) return False if os.name == 'nt': if has_cmd('nvidia-smi'): return True if has_cmd('wmic'): out = try_cmd('wmic path win32_VideoController get Name,PNPDeviceID').lower() return 'ven_10de' in out if has_cmd('powershell'): out = try_cmd( 'powershell -Command "Get-PnpDevice -Class Display | ' 'Select-Object -ExpandProperty InstanceId"' ).lower() return 'ven_10de' in out return False return False def is_wsl2(): if os.name != 'posix': return False try: with open('/proc/version', 'r', encoding='utf-8', errors='ignore') as f: return 'microsoft' in f.read().lower() except Exception: return False def has_cuda(): if self.system == systems['MACOS']: return False if not has_cmd('nvidia-smi'): return False out = try_cmd('nvidia-smi -L').lower() if not out: return False if 'failed' in out or 'error' in out or 'no devices were found' in out: return False return 'gpu' in out def has_intel_gpu_pci(): if self.system == systems['MACOS']: return False if os.name == 'posix': sysfs = '/sys/bus/pci/devices' if os.path.isdir(sysfs): for d in os.listdir(sysfs): dev = os.path.join(sysfs, d) try: with open(f'{dev}/vendor') as f: if f.read().strip() != '0x8086': continue with open(f'{dev}/class') as f: cls = f.read().strip() if cls.startswith('0x0300') or cls.startswith('0x0302'): return True except Exception: pass if has_cmd('lspci'): out = try_cmd('lspci -nn').lower() return '8086:' in out and (' vga ' in out or ' 3d ' in out) return False if os.name == 'nt': if has_cmd('wmic'): out = try_cmd('wmic path win32_VideoController get Name,PNPDeviceID').lower() return 'ven_8086' in out if has_cmd('powershell'): out = try_cmd( 'powershell -Command "Get-PnpDevice -Class Display | ' 'Select-Object -ExpandProperty InstanceId"' ).lower() return 'ven_8086' in out return False return False def has_xpu(): if self.system == systems['MACOS']: return False if os.name == 'posix': if not os.path.exists('/dev/dri/renderD128'): return False if has_cmd('sycl-ls'): out = try_cmd('sycl-ls').lower() if 'level-zero' in out and 'gpu' in out: return True if has_cmd('clinfo'): out = try_cmd('clinfo').lower() if 'intel' in out and 'gpu' in out: return True return False if os.name == 'nt': if has_cmd('sycl-ls'): out = try_cmd('sycl-ls').lower() return 'gpu' in out return False return False name = None tag = None msg = '' arch = platform.machine().lower() forced_tag = os.environ.get('DEVICE_TAG') if forced_tag: tag_letters = re.match(r'[a-zA-Z]+', forced_tag) if tag_letters: tag_letters = tag_letters.group(0).lower() name = devices['CUDA']['proc'] if tag_letters == 'cu' else devices['ROCM']['proc'] if tag_letters == devices['ROCM']['proc'] else devices['JETSON']['proc'] if tag_letters == devices['JETSON']['proc'] else devices['XPU']['proc'] if tag_letters == devices['XPU']['proc'] else devices['MPS']['proc'] if tag_letters == devices['MPS']['proc'] else devices['CPU']['proc'] devices[name.upper()]['found'] = True tag = forced_tag msg = f'Hardware forced from DEVICE_TAG={tag}' else: msg = f'DEVICE_TAG not valid' else: # ============================================================ # JETSON # ============================================================ if arch in (archs['AARCH64'],archs['ARM64']) and (os.path.exists('/etc/nv_tegra_release') or 'tegra' in try_cmd('cat /proc/device-tree/compatible')): raw = tegra_version() jp_code, msg = jetpack_version(raw) if jp_code not in ('unsupported', 'unknown'): if os.path.exists('/etc/nv_tegra_release'): devices['JETSON']['found'] = True name = devices['JETSON']['proc'] tag = f'jetson{jp_code}' elif os.path.exists('/proc/device-tree/compatible'): out = try_cmd('cat /proc/device-tree/compatible') if 'tegra' in out: devices['JETSON']['found'] = True name = devices['JETSON']['proc'] tag = f'jetson{jp_code}' else: out = try_cmd('uname -a') if 'tegra' in out: msg = 'Jetson GPU detected but not(?) compatible' if devices['JETSON']['found']: os.environ['CUDA_MODULE_LOADING'] = 'LAZY' os.environ['TORCH_CUDA_ENABLE_CUDA_GRAPH'] = '0' os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128,garbage_collection_threshold:0.6,expandable_segments:False' # ============================================================ # ROCm # ============================================================ elif has_rocm() and has_amd_gpu_pci(): def _normalize_version(v:str)->tuple: '''Parse version string into (major, minor, patch). Patch defaults to 0.''' m = re.search(r'(\d+)\.(\d+)(?:\.(\d+))?', v or '') if not m: return () major = int(m.group(1)) minor = int(m.group(2)) patch = int(m.group(3)) if m.group(3) else 0 return (major, minor, patch) os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:False' os.environ['PYTORCH_HIP_ALLOC_CONF'] = 'expandable_segments:False' version = () msg = '' hip_device_count = 0 # 1) HIP runtime detection via ctypes (primary) try: import ctypes libhip = None if os.name == 'nt': hip_path = os.environ.get('HIP_PATH', '') candidates = ['amdhip64.dll'] if hip_path: candidates.insert(0, os.path.join(hip_path, 'bin', 'amdhip64.dll')) for lib_name in candidates: try: libhip = ctypes.CDLL(lib_name) break except OSError: continue else: candidates = ['libamdhip64.so'] min_major, _ = rocm_version_range['min'] max_major, _ = rocm_version_range['max'] for major in range(max_major + 2, min_major - 1, -1): candidates.append(f'libamdhip64.so.{major}') hip_lib_dirs = [ '/opt/rocm/lib', '/opt/rocm/lib64', '/usr/lib/x86_64-linux-gnu', '/usr/lib64', ] for p in sorted(glob('/opt/rocm-*/lib'), reverse=True): hip_lib_dirs.append(p) for d in hip_lib_dirs: if os.path.isdir(d): try: for f in sorted(os.listdir(d), reverse=True): if f.startswith('libamdhip64.so.'): candidates.append(os.path.join(d, f)) except OSError: pass for lib_name in candidates: try: libhip = ctypes.CDLL(lib_name) break except OSError: continue if libhip: device_count = ctypes.c_int() if libhip.hipGetDeviceCount(ctypes.byref(device_count)) == 0: hip_device_count = device_count.value v_int = ctypes.c_int() if libhip.hipRuntimeGetVersion(ctypes.byref(v_int)) == 0: v = v_int.value if v >= 10000000: major = v // 10000000 minor = (v % 10000000) // 100000 patch = v % 100000 elif v >= 100: major = v // 100 minor = (v % 100) // 10 patch = 0 else: major, minor, patch = v, 0, 0 if hip_device_count > 0: version = (major, minor, patch) else: ver_disp = f'{major}.{minor}.{patch}' if patch else f'{major}.{minor}' msg = f'HIP runtime present ({ver_disp}) but no devices.' except (OSError, AttributeError): pass # 2) hipcc fallback if not version: if os.name == 'posix' and has_cmd('hipcc'): out = try_cmd('hipcc --version') if out: m = re.search(r'HIP version:\s*([\d.]+)', out, re.IGNORECASE) if m: version = _normalize_version(m.group(1)) elif os.name == 'nt': hip_path = os.environ.get('HIP_PATH', '') hipcc = os.path.join(hip_path, 'bin', 'hipcc') if hip_path else '' if hipcc and os.path.isfile(hipcc): out = try_cmd(f'"{hipcc}" --version') if out: m = re.search(r'HIP version:\s*([\d.]+)', out, re.IGNORECASE) if m: version = _normalize_version(m.group(1)) if not version and has_cmd('hipcc'): out = try_cmd('hipcc --version') if out: m = re.search(r'HIP version:\s*([\d.]+)', out, re.IGNORECASE) if m: version = _normalize_version(m.group(1)) # 3) torch.version.hip fallback if not version: try: import torch if getattr(torch.version, 'hip', None): version = _normalize_version(torch.version.hip) except Exception: pass # 4) ROCm install dir fallback if not version: if os.name == 'posix': for p in sorted(glob('/opt/rocm-*'), reverse=True): base = os.path.basename(p).replace('rocm-', '') v = _normalize_version(base) if v: version = v break if not version: for p in ('/opt/rocm/.info/version', '/opt/rocm/version'): if os.path.exists(p): try: with open(p, 'r', encoding='utf-8', errors='ignore') as f: version = _normalize_version(lib_version_parse(f.read())) break except Exception: pass elif os.name == 'nt': program_files = os.environ.get('ProgramFiles', '') if program_files: for p in sorted(glob(os.path.join(program_files, 'AMD', 'ROCm', '*')), reverse=True): v = _normalize_version(os.path.basename(p)) if v: version = v break if not version: for env in ('ROCM_PATH', 'HIP_PATH'): base = os.environ.get(env) if base: for p in (os.path.join(base, 'version'), os.path.join(base, '.info', 'version')): if os.path.exists(p): try: with open(p, 'r', encoding='utf-8', errors='ignore') as f: version = _normalize_version(lib_version_parse(f.read())) break except Exception: pass if version: break if version: version_str = '.'.join(str(p) for p in version) cmp, current, min_tuple, max_tuple = version_classify(version_str, rocm_version_range) # min_ver / max_ver: strip trailing .0 for display (range tuples are major.minor) min_ver = f'{min_tuple[0]}.{min_tuple[1]}' max_ver = f'{max_tuple[0]}.{max_tuple[1]}' if self.system == systems['WINDOWS'] and version < max_tuple: msg = f'ROCm {version_str} on Windows; needs to be upgraded to {max_ver}.x.' elif cmp == -1: msg = f'ROCm {version_str} < min {min_ver}. Please upgrade.' elif cmp is None: msg = 'ROCm GPU detected but version unparseable.' else: devices['ROCM']['found'] = True name = devices['ROCM']['proc'] compat_versions = [] for t, entry in torch_matrix.items(): if self.system not in entry['os'] or not t.startswith('rocm'): continue ver_str = t[len('rocm-rel-'):] if t.startswith('rocm-rel-') else t[len('rocm'):] tag_ver = _normalize_version(ver_str) if not tag_ver: continue compat_versions.append(tag_ver) tag = None if compat_versions: le_versions = [v for v in compat_versions if v <= version] if le_versions: matched = max(le_versions) if self.system == systems['WINDOWS']: tag = f'rocm-rel-{matched[0]}.{matched[1]}.{matched[2]}' if matched[2] else f'rocm-rel-{matched[0]}.{matched[1]}' else: tag = f'rocm{matched[0]}.{matched[1]}.{matched[2]}' if matched[2] else f'rocm{matched[0]}.{matched[1]}' if cmp == 1: msg = f'ROCm {version_str} > tested max {max_ver}; using {tag} torch build.' if tag else f'ROCm {version_str} detected but no compatible torch build for this OS.' elif not tag: msg = f'ROCm {version_str} detected but no compatible torch build for this OS.' else: msg = 'ROCm hardware detected but AMD ROCm base runtime not installed.' # 5) Last-resort torch fallback if not devices['ROCM']['found']: try: import torch if torch.cuda.is_available() and hasattr(torch.version, 'hip') and torch.version.hip: devices['ROCM']['found'] = True version = _normalize_version(torch.version.hip) if version: if self.system == systems['WINDOWS'] and version < tuple(rocm_version_range['max']): devices['ROCM']['found'] = False max_ver = f"{rocm_version_range['max'][0]}.{rocm_version_range['max'][1]}" msg = f'ROCm {".".join(str(p) for p in version)} on Windows; needs to be upgraded to {max_ver}.x.' else: compat_versions = [] for t, entry in torch_matrix.items(): if self.system not in entry['os'] or not t.startswith('rocm'): continue ver_str = t[len('rocm-rel-'):] if t.startswith('rocm-rel-') else t[len('rocm'):] tag_ver = _normalize_version(ver_str) if not tag_ver: continue compat_versions.append(tag_ver) tag = None if compat_versions: le_versions = [v for v in compat_versions if v <= version] if le_versions: matched = max(le_versions) if self.system == systems['WINDOWS']: tag = f'rocm-rel-{matched[0]}.{matched[1]}.{matched[2]}' if matched[2] else f'rocm-rel-{matched[0]}.{matched[1]}' else: tag = f'rocm{matched[0]}.{matched[1]}.{matched[2]}' if matched[2] else f'rocm{matched[0]}.{matched[1]}' msg = '' except Exception: pass # ============================================================ # CUDA # ============================================================ elif has_cuda() and (has_nvidia_gpu_pci() or is_wsl2()): version = '' msg = '' # 1) CUDA runtime detection via ctypes (primary) try: import ctypes libcudart = None if os.name == 'nt': # CUDA 12+ filename dropped the minor: 'cudart64_12.dll' # CUDA 11.x still has minor suffix: 'cudart64_11{minor}.dll' candidates = [] # Forward-compat for CUDA 13/14/15 (newest first) for major in range(15, 11, -1): candidates.append(f'cudart64_{major}.dll') # CUDA 11.x minors (newest first) for minor in range(9, -1, -1): candidates.append(f'cudart64_11{minor}.dll') for dll in candidates: try: libcudart = ctypes.CDLL(dll) break except OSError: continue else: # Linux / WSL2 — SONAME is major-only for CUDA 11+ candidates = ['libcudart.so'] min_major, _ = cuda_version_range['min'] max_major, _ = cuda_version_range['max'] # Extend upward past max for tolerance for major in range(max_major + 3, min_major - 1, -1): candidates.append(f'libcudart.so.{major}') cuda_lib_dirs = [ '/usr/local/cuda/lib64', '/usr/lib/x86_64-linux-gnu', '/usr/lib64', ] for d in cuda_lib_dirs: if os.path.isdir(d): try: for f in sorted(os.listdir(d), reverse=True): if f.startswith('libcudart.so.'): candidates.append(os.path.join(d, f)) except OSError: pass for lib_name in candidates: try: libcudart = ctypes.CDLL(lib_name) break except OSError: continue if libcudart: v_int = ctypes.c_int() if libcudart.cudaRuntimeGetVersion(ctypes.byref(v_int)) == 0: device_count = ctypes.c_int() if libcudart.cudaGetDeviceCount(ctypes.byref(device_count)) == 0: v = v_int.value major = v // 1000 minor = (v % 1000) // 10 if device_count.value > 0: version = f'{major}.{minor}' else: msg = f'CUDA runtime present ({major}.{minor}) but no devices.' else: v = v_int.value major = v // 1000 minor = (v % 1000) // 10 msg = f'CUDA runtime present ({major}.{minor}) but cudaGetDeviceCount failed.' except (OSError, AttributeError): pass # 2) CUDA toolkit version file (fallback) if not version: if os.name == 'posix': for p in ('/usr/local/cuda/version.json', '/usr/local/cuda/version.txt'): if os.path.exists(p): with open(p, 'r', encoding='utf-8', errors='ignore') as f: version = lib_version_parse(f.read()) or '' break elif os.name == 'nt': cuda_path = os.environ.get('CUDA_PATH') if cuda_path: for p in ( os.path.join(cuda_path, 'version.json'), os.path.join(cuda_path, 'version.txt'), ): if os.path.exists(p): with open(p, 'r', encoding='utf-8', errors='ignore') as f: version = lib_version_parse(f.read()) or '' break # 3) Version comparison + tag assignment # Tolerant: CUDA > max is accepted (driver is backward-compatible), # but torch build tag clamps at max (cu128) so we install a real wheel. if version: cmp, current, min_tuple, max_tuple = version_classify(version, cuda_version_range) min_ver = f'{min_tuple[0]}.{min_tuple[1]}' max_ver = f'{max_tuple[0]}.{max_tuple[1]}' if cmp == -1: msg = f'CUDA {version} < min {min_ver}. Please upgrade.' elif cmp is None: msg = f'CUDA version {version} unparseable.' else: devices['CUDA']['found'] = True name = devices['CUDA']['proc'] if cmp == 1: tag = f'cu{max_tuple[0]}{max_tuple[1]}' msg = f'CUDA {version} > tested max {max_ver}; using cu{max_tuple[0]}{max_tuple[1]} torch build.' else: tag = f'cu{current[0]}{current[1]}' # still index 0/1, ignore patch else: msg = 'CUDA Toolkit or Runtime not installed or hardware not detected.' # 4) PyTorch fallback (only helps if a CUDA-enabled torch is already installed) if not devices['CUDA']['found']: try: import torch if torch.cuda.is_available(): devices['CUDA']['found'] = True torch_cuda_ver = torch.version.cuda if torch_cuda_ver: cmp, current, min_tuple, max_tuple = version_classify(torch_cuda_ver, cuda_version_range) if cmp == 1: tag = f'cu{max_tuple[0]}{max_tuple[1]}' elif cmp == 0 and current is not None: tag = f'cu{current[0]}{current[1]}' else: tag = f'cu{max_tuple[0]}{max_tuple[1]}' name = devices['CUDA']['proc'] msg = '' except Exception: pass # 5) nvidia-smi header parsing — last-resort rescue # Works driver-only; useful on fresh installs with no toolkit # and CPU-only torch (where step 4 can't help). if not devices['CUDA']['found'] and has_cmd('nvidia-smi'): out = try_cmd('nvidia-smi') # Header line: '| NVIDIA-SMI ... Driver Version: ... CUDA Version: 12.4 |' m = re.search(r'cuda\s*version\s*:?\s*([0-9]+(?:\.[0-9]+)?)', out, re.IGNORECASE) if m: smi_version = m.group(1) cmp, current, min_tuple, max_tuple = version_classify(smi_version, cuda_version_range) max_ver = '.'.join(str(p) for p in max_tuple) if cmp == -1: msg = f'CUDA {smi_version} (from nvidia-smi) < min. Please upgrade.' elif cmp is not None: devices['CUDA']['found'] = True name = devices['CUDA']['proc'] if cmp == 1: tag = f'cu{max_tuple[0]}{max_tuple[1]}' msg = f'CUDA {smi_version} (from nvidia-smi) > tested max {max_ver}; using cu{max_tuple[0]}{max_tuple[1]} torch build.' else: tag = f'cu{current[0]}{current[1]}' msg = f'CUDA {smi_version} detected via nvidia-smi (driver-only).' # ============================================================ # INTEL XPU # ============================================================ elif has_xpu() and has_intel_gpu_pci(): version = '' msg = '' xpu_device_count = 0 # 1) Level Zero / SYCL runtime detection via ctypes (primary) try: import ctypes libze = None if os.name == 'nt': candidates = ['ze_loader.dll'] oneapi_root = os.environ.get('ONEAPI_ROOT', '') if oneapi_root: candidates.insert(0, os.path.join(oneapi_root, 'bin', 'ze_loader.dll')) for lib_name in candidates: try: libze = ctypes.CDLL(lib_name) break except OSError: continue else: candidates = ['libze_loader.so', 'libze_loader.so.1'] ze_lib_dirs = [ '/usr/lib/x86_64-linux-gnu', '/usr/lib64', '/opt/intel/oneapi/lib', ] for d in ze_lib_dirs: if os.path.isdir(d): try: for f in sorted(os.listdir(d), reverse=True): if f.startswith('libze_loader.so.'): candidates.append(os.path.join(d, f)) except OSError: pass for lib_name in candidates: try: libze = ctypes.CDLL(lib_name) break except OSError: continue if libze: if libze.zeInit(ctypes.c_uint(0)) == 0: driver_count = ctypes.c_uint(0) if libze.zeDriverGet(ctypes.byref(driver_count), None) == 0 and driver_count.value > 0: xpu_device_count = driver_count.value except (OSError, AttributeError): pass # 2) sycl-ls detection if not version: if os.name == 'posix' and has_cmd('sycl-ls'): out = try_cmd('sycl-ls') if out: gpu_lines = [l for l in out.splitlines() if 'gpu' in l.lower()] if gpu_lines and xpu_device_count == 0: xpu_device_count = len(gpu_lines) elif os.name == 'nt': oneapi_root = os.environ.get('ONEAPI_ROOT', '') sycl_ls = os.path.join(oneapi_root, 'bin', 'sycl-ls') if oneapi_root else '' if sycl_ls and os.path.isfile(sycl_ls): out = try_cmd(f'"{sycl_ls}"') if out: gpu_lines = [l for l in out.splitlines() if 'gpu' in l.lower()] if gpu_lines and xpu_device_count == 0: xpu_device_count = len(gpu_lines) if xpu_device_count == 0 and has_cmd('sycl-ls'): out = try_cmd('sycl-ls') if out: gpu_lines = [l for l in out.splitlines() if 'gpu' in l.lower()] if gpu_lines: xpu_device_count = len(gpu_lines) # 3) oneAPI version file if not version: if os.name == 'posix': for p in ( '/opt/intel/oneapi/version.txt', '/opt/intel/oneapi/compiler/latest/version.txt', '/opt/intel/oneapi/runtime/latest/version.txt', ): if os.path.exists(p): with open(p, 'r', encoding='utf-8', errors='ignore') as f: version = lib_version_parse(f.read()) or '' break elif os.name == 'nt': oneapi_root = os.environ.get('ONEAPI_ROOT') if oneapi_root: for p in ( os.path.join(oneapi_root, 'version.txt'), os.path.join(oneapi_root, 'compiler', 'latest', 'version.txt'), os.path.join(oneapi_root, 'runtime', 'latest', 'version.txt'), ): if os.path.exists(p): with open(p, 'r', encoding='utf-8', errors='ignore') as f: version = lib_version_parse(f.read()) or '' break # Version comparison + tag assignment (unranged by default: accepts anything) if version: cmp, current, min_tuple, max_tuple = version_classify(version, xpu_version_range) min_ver = '.'.join(str(p) for p in min_tuple) max_ver = '.'.join(str(p) for p in max_tuple) if cmp == -1: msg = f'XPU oneAPI {version} < min {min_ver}. Please upgrade.' elif cmp is None: msg = 'Intel GPU detected but oneAPI version unparseable.' else: devices['XPU']['found'] = True name = devices['XPU']['proc'] tag = devices['XPU']['proc'] if cmp == 1: msg = f'XPU oneAPI {version} > tested max {max_ver}; using default xpu torch build.' elif xpu_device_count > 0: msg = 'Intel GPU detected but oneAPI toolkit version file not found.' else: msg = 'Intel GPU detected but oneAPI Base Toolkit not installed.' # 4) PyTorch last-resort fallback if not devices['XPU']['found']: try: import torch if hasattr(torch, 'xpu') and torch.xpu.is_available(): devices['XPU']['found'] = True xpu_device_count = torch.xpu.device_count() name = devices['XPU']['proc'] tag = devices['XPU']['proc'] msg = 'XPU detected via PyTorch fallback.' except Exception: pass # ============================================================ # APPLE MPS # ============================================================ elif self.system == systems['MACOS'] and arch in (archs['ARM64'], archs['AARCH64']): devices['MPS']['found'] = True name = devices['MPS']['proc'] tag = devices['MPS']['proc'] # ============================================================ # CPU # ============================================================ if tag is None: name = devices['CPU']['proc'] tag = devices['CPU']['proc'] name, tag, msg = (v.strip() if isinstance(v, str) else v for v in (name, tag, msg)) return (name, tag, msg) def version_pkg(self, pkg_name:str, local_path:str|None=None)->str|None: if pkg_name: try: return version(pkg_name) except PackageNotFoundError: pass if not local_path or not os.path.isdir(local_path): return None version_file = os.path.join(local_path, 'version.txt') if os.path.exists(version_file): try: with open(version_file, 'r', encoding = 'utf-8') as f: return f.read().strip() except Exception: pass return None def version_tuple(self, v:str, max_parts:int=3)->tuple: m = re.search(r'\d+(?:\.\d+)*', v) if not m: return (0,) * max_parts nums = [int(n) for n in m.group(0).split('.')[:max_parts]] return tuple(nums + [0] * (max_parts - len(nums))) def eval_marker(self, marker_part:str)->tuple|bool: env = { 'python_version': '.'.join(map(str, sys.version_info[:2])), 'sys_platform': sys.platform, 'platform_system': platform.system(), 'platform_machine': platform.machine() } m = re.match(r'(\w+)\s*(==|!=|>=|<=|>|<)\s*["\']([^"\']+)["\']', marker_part) if not m: raise ValueError(f'Unsupported marker: {marker_part}') key, op, value = m.groups() if key not in env: raise ValueError(f'Unknown marker variable: {key}') def vt(v): return tuple(map(int, v.split('.'))) if v[0].isdigit() else v left = vt(env[key]) right = vt(value) if op == '==': return left == right if op == '!=': return left != right if op == '>=': return left >= right if op == '<=': return left <= right if op == '>': return left > right if op == '<': return left < right return False def install_python_packages(self)->int: if not os.path.exists(requirements_file): error = f'Warning: File {requirements_file} not found. Skipping package check.' print(error) return 1 overrides = {} if self.system == systems['MACOS']: overrides['onnxruntime-gpu'] = None try: with open(requirements_file, 'r') as f: contents = f.read().replace('\r', '\n') packages = [] for line in contents.splitlines(): pkg = line.strip() if not pkg or not re.search(r'[a-zA-Z0-9]', pkg): continue if '#' in pkg: pkg = pkg.split('#', 1)[0].strip() if not pkg: continue head = re.split(r'[<>=!\[;]', pkg, 1)[0].strip().lower() if head in {'torch', 'torchaudio'}: continue if head == 'onnxruntime-gpu' and self.system == systems['MACOS']: continue if head in overrides: pkg = overrides[head] packages.append(pkg) missing_packages = [] for package in packages: raw_pkg = package.strip() if ';' in raw_pkg: pkg_part, marker_part = raw_pkg.split(';', 1) marker_part = marker_part.strip() try: if not self.eval_marker(marker_part): continue except Exception as e: error = f'Warning: Could not evaluate marker {marker_part} for {pkg_part}: {e}' print(error) raw_pkg = pkg_part.strip() clean_pkg = re.sub(r'\[.*?\]', '', raw_pkg) local_path = None pkg_name = None if os.path.isdir(clean_pkg): local_path = os.path.abspath(clean_pkg) else: vcs_match = re.search(r'([\w\-]+)\s*@?\s*git\+', clean_pkg) if vcs_match: pkg_name = vcs_match.group(1) else: pkg_base = re.split(r'[<>=!]', clean_pkg, maxsplit=1)[0].strip() pkg_name = pkg_base if 'git+' in raw_pkg or '://' in raw_pkg: spec = importlib.util.find_spec(pkg_name) if spec is None: msg = f'{pkg_name} (git package) is missing.' print(msg) missing_packages.append(raw_pkg) continue if local_path: pkg_name = os.path.basename(local_path) vendor_version = self.version_pkg(None, local_path) if not vendor_version: msg = f'{local_path} has no detectable version.' print(msg) missing_packages.append(raw_pkg) continue try: installed_version = version(pkg_name) except PackageNotFoundError: error = f'{pkg_name} is not installed.' print(error) missing_packages.append(raw_pkg) continue if installed_version != vendor_version: msg = f'{pkg_name} version mismatch: installed {installed_version} != vendor {vendor_version}.' print(msg) missing_packages.append(raw_pkg) continue installed_version = self.version_pkg(pkg_name, None) if not installed_version: msg = f'{pkg_name} is not installed.' print(msg) missing_packages.append(raw_pkg) continue if '+' in installed_version: installed_version = installed_version.split('+', 1)[0] pkg_spec_part = re.split(r'[<>=!]', clean_pkg, maxsplit=1) spec_str = clean_pkg[len(pkg_spec_part[0]):].strip() if spec_str: req_match = re.search(r'(==|!=|>=|<=|>|<)\s*(\d+\.\d+(?:\.\d+)?)', spec_str) if req_match: op, req_ver = req_match.groups() req_v = self.version_tuple(req_ver, 3) norm_match = re.match(r'^(\d+\.\d+(?:\.\d+)?)', installed_version) short_version = norm_match.group(1) if norm_match else installed_version installed_v = self.version_tuple(short_version, 3) if op == '==' and installed_v != req_v: msg = f'{pkg_name} (installed {installed_version}) != required {req_ver}.' print(msg) missing_packages.append(raw_pkg) elif op == '>=' and installed_v < req_v: msg = f'{pkg_name} (installed {installed_version}) < required {req_ver}.' print(msg) missing_packages.append(raw_pkg) elif op == '<=' and installed_v > req_v: msg = f'{pkg_name} (installed {installed_version}) > allowed {req_ver}.' print(msg) missing_packages.append(raw_pkg) elif op == '>' and installed_v <= req_v: msg = f'{pkg_name} (installed {installed_version}) <= required {req_ver}.' print(msg) missing_packages.append(raw_pkg) elif op == '<' and installed_v >= req_v: msg = f'{pkg_name} (installed {installed_version}) >= restricted {req_ver}.' print(msg) missing_packages.append(raw_pkg) elif op == '!=' and installed_v == req_v: msg = f'{pkg_name} (installed {installed_version}) == excluded {req_ver}.' print(msg) missing_packages.append(raw_pkg) if missing_packages: msg = '\nInstalling missing or upgrade packages…\n' print(msg) subprocess.call([sys.executable, '-m', 'pip', 'cache', 'purge']) subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--upgrade', 'pip']) for raw_pkg in missing_packages: try: cmd = [sys.executable, '-m', 'pip', 'install', '--upgrade', '--upgrade-strategy', 'only-if-needed', '--no-cache-dir'] cmd.append(raw_pkg) subprocess.check_call(cmd) except subprocess.CalledProcessError as e: msg = f'Failed to install {raw_pkg}: {e}' print(msg) return 1 msg = '\nAll required packages are installed.' print(msg) return self.check_dictionary() except Exception as e: error = f'install_python_packages() error: {e}' print(error) return 1 def check_numpy(self)->bool: try: numpy_version = self.get_package_version('numpy') or None torch_version = self.get_package_version('torch') or None min_cpu_baseline = self.cpu_baseline numpy_pkg = None if torch_version is None: return False torch_version_base = self.version_tuple(torch_version) if numpy_version is None: if torch_version_base <= self.version_tuple('2.2.2'): numpy_pkg = 'numpy<2' elif not min_cpu_baseline: numpy_pkg = 'numpy<2.4.0' else: numpy_pkg = 'numpy' else: numpy_version_base = self.version_tuple(numpy_version) if torch_version_base <= self.version_tuple('2.2.2') and numpy_version_base >= self.version_tuple('2.0.0'): numpy_pkg = 'numpy<2' elif not min_cpu_baseline and numpy_version_base >= self.version_tuple('2.4.0'): numpy_pkg = 'numpy<2.4.0' if numpy_pkg is not None: subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--upgrade', '--upgrade-strategy', 'only-if-needed', '--no-cache-dir', '--force-reinstall', numpy_pkg]) return True except subprocess.CalledProcessError as e: error = f'Failed to install numpy package: {e}' print(error) return False except Exception as e: error = f'Error while installing numpy package: {e}' print(error) return False def check_dictionary(self)->bool: import unidic unidic_path = unidic.DICDIR dicrc = os.path.join(unidic_path, 'dicrc') if not os.path.exists(dicrc) or os.path.getsize(dicrc) == 0: try: error = 'UniDic dictionary not found or incomplete. Downloading now…' print(error) subprocess.run(['python', '-m', 'unidic', 'download'], check=True) except (subprocess.CalledProcessError, ConnectionError, OSError) as e: error = f'Failed to download UniDic dictionary. Error: {e}. Unable to continue without UniDic. Exiting…' raise SystemExit(error) return 1 return 0 def install_device_packages(self, device_info_str:str)->int: def _tag_ok(installed_tag): # CPU index: '/whl/cpu' -> bare on macOS, '+cpu' on linux/windows; both are fine if tag == devices['CPU']['proc']: return installed_tag is None or installed_tag == devices['CPU']['proc'] # MPS: installed from '/whl/cpu' on macOS arm64 -> bare wheels if device_info['name'] == devices['MPS']['proc']: return installed_tag is None # ROCm Windows (TheRock): matrix key is 'rocm-rel-X.Y.Z' (kept distinct from # the linux 'rocmX.Y' keys) but the wheel's local version drops '-rel-', # e.g. tag='rocm-rel-7.2.1' -> '+rocm7.2.1' (optionally with a build suffix). if device_info['name'] == devices['ROCM']['proc'] and self.system == systems['WINDOWS']: wheel_tag = tag.replace('-rel-', '') return installed_tag == wheel_tag or (installed_tag is not None and installed_tag.startswith(f'{wheel_tag}-')) # CUDA, XPU, ROCm Linux, Jetson: must be exactly '+' # (a pure hex local version means a custom/dev build -> reinstall) return installed_tag == tag def _needs_reinstall(): # torch: base version + local tag must match what we'd install for this device if not torch_version_current_full: return True if torch_version_current_base != torch_version_matrix: return True if not _tag_ok(current_tag): return True # torchaudio: base version + local tag must match what we'd install for this device torchaudio_full = self.get_package_version('torchaudio') if not torchaudio_full: return True torchaudio_base = torchaudio_full.split('+', 1)[0] if torchaudio_base != torch_version_matrix: return True m_ta = re.search(r'\+(.+)$', torchaudio_full) torchaudio_tag = m_ta.group(1) if m_ta else None if not _tag_ok(torchaudio_tag): return True # torchcodec: presence only (when torch >= 2.9 needs it) if self.version_tuple(torch_version_matrix, 2) >= (2, 9) and not self.get_package_version('torchcodec'): return True return False def _probe_gpus()->dict: script = os.path.abspath('./tools/detect_gpus.py') try: proc = subprocess.run( [sys.executable, script], capture_output=True, text=True, timeout=30, ) if proc.returncode != 0: return {'count': 0, 'backend': None, 'error': proc.stderr.strip() or 'non-zero exit'} return json.loads(proc.stdout.strip() or '{}') except (subprocess.TimeoutExpired, json.JSONDecodeError, OSError) as e: return {'count': 0, 'backend': None, 'error': str(e)} try: if device_info_str: device_info = json.loads(device_info_str) if device_info: msg = f'---> Hardware detected: {device_info}' print(msg) tag = device_info.get('tag') if tag in ['unknown','unsupported']: return 0 key = 'last' if self.python_version >= (3, 12) else 'base' torch_version_matrix = torch_matrix[tag].get(key) or torch_matrix[tag]['base'] torchcodec_version_matrix = torch_matrix[tag]['codec'] # macOS Intel was dropped from torch wheels after 2.2.2 — pin it before # any version comparison happens, otherwise _needs_reinstall() compares # against the matrix's 'last' and triggers an unnecessary reinstall. if device_info['os'] == 'macosx_11_0' and device_info['arch'] == archs['X86_64']: torch_version_matrix = '2.2.2' torchcodec_version_matrix = '' # 2.2.2 < 2.9, no torchcodec torch_version_current_full = self.get_package_version('torch') torch_version_current_base = None current_tag = None non_standard_tag = None if torch_version_current_full: m = re.search(r'\+(.+)$', torch_version_current_full) current_tag = m.group(1) if m else None non_standard_match = re.fullmatch(r'[0-9a-f]{7,40}', current_tag) if current_tag is not None else None non_standard_tag = non_standard_match.group(0) if non_standard_match else None torch_version_current_base = torch_version_current_full.split('+',1)[0] if _needs_reinstall(): try: msg = f"Installing the right library packages for {device_info['name']}…" print(msg) os_env = device_info['os'] arch = device_info['arch'] toolkit_version = ''.join(c for c in tag if c.isdigit()) tag_dir = tag py_major, py_minor = device_info['pyvenv'] tag_py = f'cp{py_major}{py_minor}' subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--force-reinstall', '--no-cache-dir', 'filelock', 'jinja2', 'fsspec', 'networkx', 'sympy']) if device_info['name'] == devices['JETSON']['proc']: url = default_jetson_url torch_pkg = f"{url}/torch-v{toolkit_version}/torch-{torch_version_matrix}%2B{tag}-{tag_py}-{tag_py}-{os_env}_{arch}.whl" torchaudio_pkg = f"{url}/torchaudio-v{toolkit_version}/torchaudio-{torch_version_matrix}%2B{tag}-{tag_py}-{tag_py}-{os_env}_{arch}.whl" subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--upgrade', '--upgrade-strategy', 'only-if-needed', '--no-cache-dir', torch_pkg]) subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--upgrade', '--upgrade-strategy', 'only-if-needed', '--no-cache-dir', torchaudio_pkg]) subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--force-reinstall', '--no-cache-dir', 'scikit-learn']) subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--force-reinstall', '--no-cache-dir', 'scipy']) elif device_info['name'] == devices['ROCM']['proc'] and self.system == systems['WINDOWS']: url = default_pytorch_amd_url norm_tag = tag.replace('-rel-', '') # rocm_sdk is required by torch ROCm wheels on Windows; install it first if missing import importlib.util if importlib.util.find_spec('rocm_sdk') is None: rocm_ver = tag[len('rocm-rel-'):] if tag.startswith('rocm-rel-') else tag sdk_pkgs = [ f'{url}/{tag}/rocm_sdk_core-{rocm_ver}-py3-none-{os_env}_{arch}.whl', f'{url}/{tag}/rocm_sdk_devel-{rocm_ver}-py3-none-{os_env}_{arch}.whl', f'{url}/{tag}/rocm_sdk_libraries_custom-{rocm_ver}-py3-none-{os_env}_{arch}.whl', f'{url}/{tag}/rocm-{rocm_ver}.tar.gz', ] msg = f'Installing ROCm SDK {rocm_ver}…' print(msg) subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--no-cache-dir', *sdk_pkgs]) torch_pkg = f'{url}/{tag}/torch-{torch_version_matrix}%2B{norm_tag}-{tag_py}-{tag_py}-{os_env}_{arch}.whl' torchaudio_pkg = f'{url}/{tag}/torchaudio-{torch_version_matrix}%2B{norm_tag}-{tag_py}-{tag_py}-{os_env}_{arch}.whl' subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--force-reinstall', '--no-cache-dir', '--no-deps', torch_pkg, torchaudio_pkg]) else: url = default_pytorch_url tag_dir = 'cpu' if device_info['name'] == devices['MPS']['proc'] else tag subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--upgrade', '--upgrade-strategy', 'only-if-needed', '--no-cache-dir', f'torch=={torch_version_matrix}', f'torchaudio=={torch_version_matrix}', '--force-reinstall', '--index-url', f'{url}/{tag_dir}']) if self.version_tuple(torch_version_matrix, 2) >= (2, 9): is_cpu_aarch64_linux = ( tag == devices['CPU']['proc'] and device_info['os'] in ('manylinux_2_28', 'linux') and device_info['arch'] == archs['AARCH64'] ) has_native_codec = ( device_info['name'] == devices['CUDA']['proc'] and self.system != systems['WINDOWS'] ) or tag == devices['CPU']['proc'] if is_cpu_aarch64_linux: torchcodec_index_url = f"{default_torchcodec_arm_url}/torchcodec-{arch}-{tag_py}/torchcodec-{torchcodec_version_matrix}%2B{tag}-{tag_py}-{tag_py}-{os_env}_{arch}.whl" subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--force-reinstall', '--no-cache-dir', '--no-deps', torchcodec_index_url]) else: if has_native_codec: torchcodec_index_url = f'{default_pytorch_url}/{tag_dir}' else: torchcodec_index_url = f'{default_pytorch_url}/cpu' subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--force-reinstall', '--no-cache-dir', '--no-deps', f'torchcodec=={torchcodec_version_matrix}', '--index-url', torchcodec_index_url]) except subprocess.CalledProcessError as e: error = f'Failed to install torch package: {e}' print(error) return 1 except Exception as e: error = f'Error while installing torch package: {e}' print(error) return 1 if device_info['os'] == 'linux' and ('jetpack' in device_info.get('note', '').lower() or device_info['name'] == devices['JETSON']['proc']): libgomp_src = '/usr/lib/aarch64-linux-gnu/libgomp.so' if os.path.exists(libgomp_src): libs_list = ['ctranslate2.libs', 'scikit_learn.libs'] libs_dir = os.path.join('python_env', 'lib', f'python{sys.version_info.major}.{sys.version_info.minor}', 'site-packages') for lib in libs_list: lib_path = os.path.join(libs_dir, lib) if os.path.isdir(lib_path): for libgomp_dst in glob(os.path.join(lib_path, 'libgomp*')): if os.path.islink(libgomp_dst): if os.path.realpath(libgomp_dst) == os.path.realpath(libgomp_src): continue os.unlink(libgomp_dst) else: os.unlink(libgomp_dst) msg = 'Create symlink to use OS libgomp.' print(msg) os.symlink(libgomp_src, libgomp_dst) if not self.check_numpy(): return 1 gpu_info = _probe_gpus() device_info_dict['gpu_count'] = gpu_info['count'] device_info_dict['gpu_backend'] = gpu_info['backend'] if gpu_info.get('error'): error = f'GPU detection warning: {gpu_info["error"]}' print(error) if gpu_info['count'] > 0: idx = ','.join(str(i) for i in range(gpu_info['count'])) if gpu_info['backend'] == 'cuda': os.environ['CUDA_VISIBLE_DEVICES'] = idx elif gpu_info['backend'] == 'rocm': os.environ['HIP_VISIBLE_DEVICES'] = idx elif gpu_info['backend'] == 'xpu': os.environ['ONEAPI_DEVICE_SELECTOR'] = f'level_zero:{idx}' os.environ['ZE_AFFINITY_MASK'] = idx return 0 else: error = 'install_device_packages() error: device_info_str is empty' print(error) else: error = f'install_device_packages() error: json.loads() could not decode device_info_str={device_info_str}' print(error) return 1 except Exception as e: error = f'install_device_packages() error: {e}' print(error) return 1