Spaces:
Paused
Paused
Commit ·
66cfd91
1
Parent(s): 9d8246c
Make the generator CPU
Browse files
app.py
CHANGED
|
@@ -160,7 +160,7 @@ def save_slim_model(infinity_model_path, save_file=None, device='cpu', key='gpt_
|
|
| 160 |
return save_file
|
| 161 |
|
| 162 |
def load_tokenizer(t5_path =''):
|
| 163 |
-
print(
|
| 164 |
text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
|
| 165 |
text_tokenizer.model_max_length = 512
|
| 166 |
text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
|
|
@@ -186,7 +186,7 @@ def load_infinity(
|
|
| 186 |
use_flex_attn=False,
|
| 187 |
bf16=False,
|
| 188 |
):
|
| 189 |
-
print(
|
| 190 |
|
| 191 |
# Set device if not provided
|
| 192 |
if device is None:
|
|
@@ -232,13 +232,13 @@ def load_infinity(
|
|
| 232 |
infinity_test.eval()
|
| 233 |
infinity_test.requires_grad_(False)
|
| 234 |
|
| 235 |
-
print(
|
| 236 |
state_dict = torch.load(model_path, map_location=device)
|
| 237 |
print(infinity_test.load_state_dict(state_dict))
|
| 238 |
|
| 239 |
# Initialize random number generator on the correct device
|
| 240 |
-
infinity_test.rng = torch.Generator(device=
|
| 241 |
-
|
| 242 |
return infinity_test
|
| 243 |
|
| 244 |
def transform(pil_img, tgt_h, tgt_w):
|
|
|
|
| 160 |
return save_file
|
| 161 |
|
| 162 |
def load_tokenizer(t5_path =''):
|
| 163 |
+
print('[Loading tokenizer and text encoder]')
|
| 164 |
text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
|
| 165 |
text_tokenizer.model_max_length = 512
|
| 166 |
text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
|
|
|
|
| 186 |
use_flex_attn=False,
|
| 187 |
bf16=False,
|
| 188 |
):
|
| 189 |
+
print('[Loading Infinity]')
|
| 190 |
|
| 191 |
# Set device if not provided
|
| 192 |
if device is None:
|
|
|
|
| 232 |
infinity_test.eval()
|
| 233 |
infinity_test.requires_grad_(False)
|
| 234 |
|
| 235 |
+
print('[Load Infinity weights]')
|
| 236 |
state_dict = torch.load(model_path, map_location=device)
|
| 237 |
print(infinity_test.load_state_dict(state_dict))
|
| 238 |
|
| 239 |
# Initialize random number generator on the correct device
|
| 240 |
+
infinity_test.rng = torch.Generator(device="cpu")
|
| 241 |
+
|
| 242 |
return infinity_test
|
| 243 |
|
| 244 |
def transform(pil_img, tgt_h, tgt_w):
|