Spaces:
Running
on
Zero
Running
on
Zero
Initialize single-process distributed group for single GPU inference compatibility
Browse files
sample.py
CHANGED
|
@@ -54,7 +54,29 @@ DEFAULT_CAPTIONS = {
|
|
| 54 |
|
| 55 |
def setup_model_and_components(args: argparse.Namespace) -> Tuple[torch.nn.Module, Optional[torch.nn.Module], tuple]:
|
| 56 |
"""Initialize and load the model, VAE, and text encoder."""
|
|
|
|
|
|
|
|
|
|
| 57 |
dist = utils.Distributed()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 59 |
|
| 60 |
# Set random seed
|
|
|
|
| 54 |
|
| 55 |
def setup_model_and_components(args: argparse.Namespace) -> Tuple[torch.nn.Module, Optional[torch.nn.Module], tuple]:
|
| 56 |
"""Initialize and load the model, VAE, and text encoder."""
|
| 57 |
+
# Initialize distributed training context
|
| 58 |
+
# For single GPU inference, we still need to initialize process group
|
| 59 |
+
# because the model code uses torch.distributed.all_reduce
|
| 60 |
dist = utils.Distributed()
|
| 61 |
+
|
| 62 |
+
# If not running with torchrun, initialize single-process group
|
| 63 |
+
if not dist.distributed and torch.cuda.is_available():
|
| 64 |
+
import os
|
| 65 |
+
# Initialize single-process process group for model compatibility
|
| 66 |
+
if not torch.distributed.is_initialized():
|
| 67 |
+
os.environ['MASTER_ADDR'] = 'localhost'
|
| 68 |
+
os.environ['MASTER_PORT'] = '12355'
|
| 69 |
+
os.environ['RANK'] = '0'
|
| 70 |
+
os.environ['LOCAL_RANK'] = '0'
|
| 71 |
+
os.environ['WORLD_SIZE'] = '1'
|
| 72 |
+
torch.distributed.init_process_group(
|
| 73 |
+
backend='nccl' if torch.cuda.is_available() else 'gloo',
|
| 74 |
+
init_method='env://',
|
| 75 |
+
world_size=1,
|
| 76 |
+
rank=0,
|
| 77 |
+
)
|
| 78 |
+
print("✅ Initialized single-process distributed group for model compatibility")
|
| 79 |
+
|
| 80 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 81 |
|
| 82 |
# Set random seed
|