Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| check_dependencies.py | |
| Script to check and install required dependencies for MorphGuard training and inference. | |
| """ | |
| import subprocess | |
| import sys | |
| import os | |
| import importlib | |
| import argparse | |
| # Base requirements | |
| REQUIRED_PACKAGES = { | |
| "torch": "torch>=1.12.0", | |
| "torchvision": "torchvision>=0.13.0", | |
| "pytorch_lightning": "pytorch_lightning>=1.8.0", | |
| "timm": "timm>=0.6.11", | |
| "diffusers": "diffusers>=0.11.1", | |
| "transformers": "transformers>=4.21.0", | |
| "pillow": "pillow>=9.2.0", | |
| "numpy": "numpy>=1.22.0", | |
| "scipy": "scipy>=1.8.0", | |
| "tqdm": "tqdm>=4.64.0", | |
| "scikit-learn": "scikit-learn>=1.1.2", | |
| "torchmetrics": "torchmetrics>=0.9.3", | |
| "matplotlib": "matplotlib>=3.5.3", | |
| "opencv-python": "opencv-python>=4.6.0.66", | |
| "omegaconf": "omegaconf>=2.2.3", | |
| "gradio": "gradio>=3.0.0", | |
| "requests": "requests>=2.28.1" | |
| } | |
| # Optional requirements based on model type | |
| OPTIONAL_PACKAGES = { | |
| "freq": { | |
| "pywavelets": "pywavelets>=1.3.0" | |
| }, | |
| "gan": { | |
| "ninja": "ninja>=1.10.2" | |
| }, | |
| "diffusion": { | |
| "ftfy": "ftfy>=6.1.1", | |
| "accelerate": "accelerate>=0.12.0" | |
| } | |
| } | |
| def check_package(package_name): | |
| """Check if a package is installed.""" | |
| try: | |
| importlib.import_module(package_name) | |
| return True | |
| except ImportError: | |
| return False | |
| def install_package(package_spec): | |
| """Install a package using pip.""" | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", package_spec]) | |
| def check_and_install(packages_dict, model_types=None): | |
| """Check and install packages from a dictionary.""" | |
| missing_packages = [] | |
| for package_name, package_spec in packages_dict.items(): | |
| if not check_package(package_name): | |
| missing_packages.append(package_spec) | |
| # Install optional packages if model_types is specified | |
| if model_types: | |
| for model_type in model_types: | |
| if model_type in OPTIONAL_PACKAGES: | |
| for package_name, package_spec in OPTIONAL_PACKAGES[model_type].items(): | |
| if not check_package(package_name): | |
| missing_packages.append(package_spec) | |
| if missing_packages: | |
| print(f"Installing missing packages: {', '.join(missing_packages)}") | |
| try: | |
| install_package(" ".join(missing_packages)) | |
| return True | |
| except subprocess.CalledProcessError: | |
| print("Failed to install some packages. Try installing them manually:") | |
| for package in missing_packages: | |
| print(f" pip install {package}") | |
| return False | |
| else: | |
| print("All required packages are already installed.") | |
| return True | |
| def check_cuda(): | |
| """Check if CUDA is available and print version info.""" | |
| try: | |
| import torch | |
| cuda_available = torch.cuda.is_available() | |
| if cuda_available: | |
| cuda_version = torch.version.cuda | |
| cuda_device_count = torch.cuda.device_count() | |
| cuda_device_name = torch.cuda.get_device_name(0) if cuda_device_count > 0 else "N/A" | |
| print(f"CUDA is available: {cuda_available}") | |
| print(f"CUDA version: {cuda_version}") | |
| print(f"CUDA device count: {cuda_device_count}") | |
| print(f"CUDA device name: {cuda_device_name}") | |
| return True | |
| else: | |
| print("CUDA is not available. Training will be slow on CPU.") | |
| return False | |
| except Exception as e: | |
| print(f"Error checking CUDA: {e}") | |
| return False | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Check and install MorphGuard dependencies") | |
| parser.add_argument("--model-types", type=str, nargs="+", | |
| choices=["freq", "gan", "diffusion"], | |
| help="Optional model types to install dependencies for") | |
| parser.add_argument("--check-only", action="store_true", | |
| help="Only check for missing packages without installing") | |
| args = parser.parse_args() | |
| print("Checking Python packages...") | |
| if args.check_only: | |
| # Only check, don't install | |
| missing = [] | |
| for package_name, package_spec in REQUIRED_PACKAGES.items(): | |
| if not check_package(package_name): | |
| missing.append(package_name) | |
| if args.model_types: | |
| for model_type in args.model_types: | |
| if model_type in OPTIONAL_PACKAGES: | |
| for package_name in OPTIONAL_PACKAGES[model_type]: | |
| if not check_package(package_name): | |
| missing.append(f"{package_name} (required for {model_type})") | |
| if missing: | |
| print(f"Missing packages: {', '.join(missing)}") | |
| return False | |
| else: | |
| print("All required packages are installed.") | |
| else: | |
| # Check and install | |
| success = check_and_install(REQUIRED_PACKAGES, args.model_types) | |
| if not success: | |
| return False | |
| # Check CUDA availability | |
| print("\nChecking CUDA availability...") | |
| check_cuda() | |
| print("\nAll checks completed.") | |
| return True | |
| if __name__ == "__main__": | |
| success = main() | |
| sys.exit(0 if success else 1) |