Spaces:
Running
on
Zero
Running
on
Zero
Plat
commited on
Commit
·
cdb5002
1
Parent(s):
f56068c
chore: use bf16 with cuda
Browse files
app.py
CHANGED
|
@@ -28,6 +28,7 @@ DEVICE = (
|
|
| 28 |
if torch.backends.mps.is_available()
|
| 29 |
else torch.device("cpu")
|
| 30 |
)
|
|
|
|
| 31 |
MAX_TOKEN_LENGTH = 32
|
| 32 |
|
| 33 |
model_map: dict[str, JiTModel] = {} # {model_path: model}
|
|
@@ -64,6 +65,7 @@ def load_model(
|
|
| 64 |
label2id_path: str,
|
| 65 |
config_path: str,
|
| 66 |
device: torch.device,
|
|
|
|
| 67 |
) -> tuple[JiTModel, dict]:
|
| 68 |
"""モデルを読み込む"""
|
| 69 |
|
|
@@ -84,7 +86,7 @@ def load_model(
|
|
| 84 |
)
|
| 85 |
model.eval()
|
| 86 |
model.requires_grad_(False)
|
| 87 |
-
model.to(device=device)
|
| 88 |
model_map[model_path] = model # cache
|
| 89 |
|
| 90 |
label2id = load_label2id(get_file_path(MODEL_REPO, label2id_path))
|
|
@@ -113,9 +115,10 @@ def generate_images(
|
|
| 113 |
label2id_path=label2id_path,
|
| 114 |
config_path=config_path,
|
| 115 |
device=DEVICE,
|
|
|
|
| 116 |
)
|
| 117 |
|
| 118 |
-
with torch.inference_mode():
|
| 119 |
images = model.generate(
|
| 120 |
prompt=[prompt] * batch_size,
|
| 121 |
negative_prompt=negative_prompt,
|
|
@@ -127,7 +130,7 @@ def generate_images(
|
|
| 127 |
cfg_time_range=[0.1, 1.0],
|
| 128 |
seed=seed if seed >= 0 else None,
|
| 129 |
device=DEVICE,
|
| 130 |
-
execution_dtype=
|
| 131 |
)
|
| 132 |
|
| 133 |
return images
|
|
@@ -271,6 +274,7 @@ if __name__ == "__main__":
|
|
| 271 |
label2id_path=LABEL2ID_PATH,
|
| 272 |
config_path=CONFIG_PATH,
|
| 273 |
device=DEVICE,
|
|
|
|
| 274 |
)
|
| 275 |
|
| 276 |
demo().launch()
|
|
|
|
| 28 |
if torch.backends.mps.is_available()
|
| 29 |
else torch.device("cpu")
|
| 30 |
)
|
| 31 |
+
DTYPE = torch.bfloat16 if DEVICE.type in ["cuda"] else torch.float16
|
| 32 |
MAX_TOKEN_LENGTH = 32
|
| 33 |
|
| 34 |
model_map: dict[str, JiTModel] = {} # {model_path: model}
|
|
|
|
| 65 |
label2id_path: str,
|
| 66 |
config_path: str,
|
| 67 |
device: torch.device,
|
| 68 |
+
dtype: torch.dtype = DTYPE,
|
| 69 |
) -> tuple[JiTModel, dict]:
|
| 70 |
"""モデルを読み込む"""
|
| 71 |
|
|
|
|
| 86 |
)
|
| 87 |
model.eval()
|
| 88 |
model.requires_grad_(False)
|
| 89 |
+
model.to(device=device, dtype=dtype)
|
| 90 |
model_map[model_path] = model # cache
|
| 91 |
|
| 92 |
label2id = load_label2id(get_file_path(MODEL_REPO, label2id_path))
|
|
|
|
| 115 |
label2id_path=label2id_path,
|
| 116 |
config_path=config_path,
|
| 117 |
device=DEVICE,
|
| 118 |
+
dtype=DTYPE,
|
| 119 |
)
|
| 120 |
|
| 121 |
+
with torch.inference_mode(), torch.autocast(device_type=DEVICE.type, dtype=DTYPE):
|
| 122 |
images = model.generate(
|
| 123 |
prompt=[prompt] * batch_size,
|
| 124 |
negative_prompt=negative_prompt,
|
|
|
|
| 130 |
cfg_time_range=[0.1, 1.0],
|
| 131 |
seed=seed if seed >= 0 else None,
|
| 132 |
device=DEVICE,
|
| 133 |
+
execution_dtype=DTYPE,
|
| 134 |
)
|
| 135 |
|
| 136 |
return images
|
|
|
|
| 274 |
label2id_path=LABEL2ID_PATH,
|
| 275 |
config_path=CONFIG_PATH,
|
| 276 |
device=DEVICE,
|
| 277 |
+
dtype=DTYPE,
|
| 278 |
)
|
| 279 |
|
| 280 |
demo().launch()
|