Spaces:
Paused
Paused
Commit ·
54f9225
1
Parent(s): c493a61
Add device and dtype logging in app.py for better debugging
Browse files
app.py
CHANGED
|
@@ -188,7 +188,7 @@ def load_infinity(
|
|
| 188 |
model_path='',
|
| 189 |
scale_schedule=None,
|
| 190 |
vae=None,
|
| 191 |
-
device='cuda',
|
| 192 |
model_kwargs=None,
|
| 193 |
text_channels=2048,
|
| 194 |
apply_spatial_patchify=0,
|
|
@@ -295,6 +295,7 @@ def load_visual_tokenizer(args):
|
|
| 295 |
|
| 296 |
def load_transformer(vae, args):
|
| 297 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
| 298 |
model_path = args.model_path
|
| 299 |
if args.checkpoint_type == 'torch':
|
| 300 |
# copy large model to local; save slim to local; and copy slim to nas; load local slim model
|
|
@@ -420,8 +421,8 @@ weights_path.mkdir(exist_ok=True)
|
|
| 420 |
download_infinity_weights(weights_path)
|
| 421 |
|
| 422 |
# Device setup
|
| 423 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 424 |
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
|
|
|
|
| 425 |
|
| 426 |
# Define args
|
| 427 |
args = argparse.Namespace(
|
|
|
|
| 188 |
model_path='',
|
| 189 |
scale_schedule=None,
|
| 190 |
vae=None,
|
| 191 |
+
device='cuda',
|
| 192 |
model_kwargs=None,
|
| 193 |
text_channels=2048,
|
| 194 |
apply_spatial_patchify=0,
|
|
|
|
| 295 |
|
| 296 |
def load_transformer(vae, args):
|
| 297 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 298 |
+
print(f"Device: {device}")
|
| 299 |
model_path = args.model_path
|
| 300 |
if args.checkpoint_type == 'torch':
|
| 301 |
# copy large model to local; save slim to local; and copy slim to nas; load local slim model
|
|
|
|
| 421 |
download_infinity_weights(weights_path)
|
| 422 |
|
| 423 |
# Device setup
|
|
|
|
| 424 |
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
|
| 425 |
+
print(f"Using dtype: {dtype}")
|
| 426 |
|
| 427 |
# Define args
|
| 428 |
args = argparse.Namespace(
|