Plat commited on
Commit
cdb5002
·
1 Parent(s): f56068c

chore: use bf16 with cuda

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -28,6 +28,7 @@ DEVICE = (
28
  if torch.backends.mps.is_available()
29
  else torch.device("cpu")
30
  )
 
31
  MAX_TOKEN_LENGTH = 32
32
 
33
  model_map: dict[str, JiTModel] = {} # {model_path: model}
@@ -64,6 +65,7 @@ def load_model(
64
  label2id_path: str,
65
  config_path: str,
66
  device: torch.device,
 
67
  ) -> tuple[JiTModel, dict]:
68
  """モデルを読み込む"""
69
 
@@ -84,7 +86,7 @@ def load_model(
84
  )
85
  model.eval()
86
  model.requires_grad_(False)
87
- model.to(device=device)
88
  model_map[model_path] = model # cache
89
 
90
  label2id = load_label2id(get_file_path(MODEL_REPO, label2id_path))
@@ -113,9 +115,10 @@ def generate_images(
113
  label2id_path=label2id_path,
114
  config_path=config_path,
115
  device=DEVICE,
 
116
  )
117
 
118
- with torch.inference_mode():
119
  images = model.generate(
120
  prompt=[prompt] * batch_size,
121
  negative_prompt=negative_prompt,
@@ -127,7 +130,7 @@ def generate_images(
127
  cfg_time_range=[0.1, 1.0],
128
  seed=seed if seed >= 0 else None,
129
  device=DEVICE,
130
- execution_dtype=model.config.torch_dtype,
131
  )
132
 
133
  return images
@@ -271,6 +274,7 @@ if __name__ == "__main__":
271
  label2id_path=LABEL2ID_PATH,
272
  config_path=CONFIG_PATH,
273
  device=DEVICE,
 
274
  )
275
 
276
  demo().launch()
 
28
  if torch.backends.mps.is_available()
29
  else torch.device("cpu")
30
  )
31
+ DTYPE = torch.bfloat16 if DEVICE.type in ["cuda"] else torch.float16
32
  MAX_TOKEN_LENGTH = 32
33
 
34
  model_map: dict[str, JiTModel] = {} # {model_path: model}
 
65
  label2id_path: str,
66
  config_path: str,
67
  device: torch.device,
68
+ dtype: torch.dtype = DTYPE,
69
  ) -> tuple[JiTModel, dict]:
70
  """モデルを読み込む"""
71
 
 
86
  )
87
  model.eval()
88
  model.requires_grad_(False)
89
+ model.to(device=device, dtype=dtype)
90
  model_map[model_path] = model # cache
91
 
92
  label2id = load_label2id(get_file_path(MODEL_REPO, label2id_path))
 
115
  label2id_path=label2id_path,
116
  config_path=config_path,
117
  device=DEVICE,
118
+ dtype=DTYPE,
119
  )
120
 
121
+ with torch.inference_mode(), torch.autocast(device_type=DEVICE.type, dtype=DTYPE):
122
  images = model.generate(
123
  prompt=[prompt] * batch_size,
124
  negative_prompt=negative_prompt,
 
130
  cfg_time_range=[0.1, 1.0],
131
  seed=seed if seed >= 0 else None,
132
  device=DEVICE,
133
+ execution_dtype=DTYPE,
134
  )
135
 
136
  return images
 
274
  label2id_path=LABEL2ID_PATH,
275
  config_path=CONFIG_PATH,
276
  device=DEVICE,
277
+ dtype=DTYPE,
278
  )
279
 
280
  demo().launch()