Spaces:
Runtime error
Runtime error
Commit
·
13d3a21
1
Parent(s):
30b3bf7
update
Browse files- tokenflow_pnp.py +2 -1
tokenflow_pnp.py
CHANGED
|
@@ -28,6 +28,7 @@ elif torch.backends.mps.is_available():
|
|
| 28 |
device = "mps"
|
| 29 |
else:
|
| 30 |
device = "cpu"
|
|
|
|
| 31 |
|
| 32 |
class TokenFlow(nn.Module):
|
| 33 |
def __init__(self, config,
|
|
@@ -117,7 +118,7 @@ class TokenFlow(nn.Module):
|
|
| 117 |
depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
|
| 118 |
depth_maps.append(depth_map)
|
| 119 |
|
| 120 |
-
return torch.cat(depth_maps).to(
|
| 121 |
|
| 122 |
def get_pnp_inversion_prompt(self):
|
| 123 |
inv_prompts_path = os.path.join(str(Path(self.latents_path).parent), 'inversion_prompt.txt')
|
|
|
|
| 28 |
device = "mps"
|
| 29 |
else:
|
| 30 |
device = "cpu"
|
| 31 |
+
to = torch.float16 if self.device == 'cuda' else torch.float32
|
| 32 |
|
| 33 |
class TokenFlow(nn.Module):
|
| 34 |
def __init__(self, config,
|
|
|
|
| 118 |
depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
|
| 119 |
depth_maps.append(depth_map)
|
| 120 |
|
| 121 |
+
return torch.cat(depth_maps).to(to).to(self.device)
|
| 122 |
|
| 123 |
def get_pnp_inversion_prompt(self):
|
| 124 |
inv_prompts_path = os.path.join(str(Path(self.latents_path).parent), 'inversion_prompt.txt')
|