Spaces:
Paused
Paused
Commit
·
eecb045
1
Parent(s):
366fd1c
Refactor load_tokenizer function to include error handling and device optimizations; streamline model loading process and improve memory management
Browse files
app.py
CHANGED
|
@@ -57,7 +57,7 @@ def encode_prompt(text_tokenizer, text_encoder, prompt):
|
|
| 57 |
text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float()
|
| 58 |
lens: List[int] = mask.sum(dim=-1).tolist()
|
| 59 |
cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0))
|
| 60 |
-
Ltext = max(lens)
|
| 61 |
kv_compact = []
|
| 62 |
for len_i, feat_i in zip(lens, text_features.unbind(0)):
|
| 63 |
kv_compact.append(feat_i[:len_i])
|
|
@@ -77,15 +77,40 @@ def save_slim_model(infinity_model_path, save_file=None, device='cpu', key='gpt_
|
|
| 77 |
print('[Save slim model] done')
|
| 78 |
return save_file
|
| 79 |
|
| 80 |
-
def load_tokenizer(t5_path=''):
|
| 81 |
-
|
| 82 |
-
tokenizer
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
def load_infinity(
|
| 91 |
rope2d_each_sa_layer,
|
|
@@ -154,8 +179,8 @@ def load_infinity(
|
|
| 154 |
state_dict = torch.load(model_path, map_location=device)
|
| 155 |
print(infinity_test.load_state_dict(state_dict))
|
| 156 |
|
| 157 |
-
# Initialize random number generator on the correct device
|
| 158 |
-
infinity_test.rng = torch.Generator(device=device)
|
| 159 |
|
| 160 |
return infinity_test
|
| 161 |
|
|
@@ -315,29 +340,7 @@ def load_transformer(vae, args):
|
|
| 315 |
model_path = args.model_path
|
| 316 |
|
| 317 |
if args.checkpoint_type == 'torch':
|
| 318 |
-
|
| 319 |
-
local_model_path = osp.join(args.cache_dir, 'tmp', model_path.replace('/', '_'))
|
| 320 |
-
else:
|
| 321 |
-
local_model_path = model_path
|
| 322 |
-
|
| 323 |
-
if args.enable_model_cache:
|
| 324 |
-
slim_model_path = model_path.replace('ar-', 'slim-')
|
| 325 |
-
local_slim_model_path = local_model_path.replace('ar-', 'slim-')
|
| 326 |
-
os.makedirs(osp.dirname(local_slim_model_path), exist_ok=True)
|
| 327 |
-
if not osp.exists(local_slim_model_path):
|
| 328 |
-
if osp.exists(slim_model_path):
|
| 329 |
-
shutil.copyfile(slim_model_path, local_slim_model_path)
|
| 330 |
-
else:
|
| 331 |
-
if not osp.exists(local_model_path):
|
| 332 |
-
shutil.copyfile(model_path, local_model_path)
|
| 333 |
-
save_slim_model(local_model_path, save_file=local_slim_model_path, device=device)
|
| 334 |
-
if not osp.exists(slim_model_path):
|
| 335 |
-
shutil.copyfile(local_slim_model_path, slim_model_path)
|
| 336 |
-
os.remove(local_model_path)
|
| 337 |
-
os.remove(model_path)
|
| 338 |
-
slim_model_path = local_slim_model_path
|
| 339 |
-
else:
|
| 340 |
-
slim_model_path = model_path
|
| 341 |
print(f'Loading checkpoint from {slim_model_path}')
|
| 342 |
else:
|
| 343 |
raise ValueError(f"Unsupported checkpoint_type: {args.checkpoint_type}")
|
|
@@ -465,10 +468,13 @@ args = argparse.Namespace(
|
|
| 465 |
)
|
| 466 |
|
| 467 |
# Load models
|
|
|
|
| 468 |
text_tokenizer, text_encoder = load_tokenizer(t5_path="google/flan-t5-xl")
|
|
|
|
| 469 |
vae = load_visual_tokenizer(args)
|
|
|
|
| 470 |
infinity = load_transformer(vae, args)
|
| 471 |
-
|
| 472 |
|
| 473 |
# Define the image generation function
|
| 474 |
@spaces.GPU
|
|
|
|
| 57 |
text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float()
|
| 58 |
lens: List[int] = mask.sum(dim=-1).tolist()
|
| 59 |
cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0))
|
| 60 |
+
Ltext = max(lens)
|
| 61 |
kv_compact = []
|
| 62 |
for len_i, feat_i in zip(lens, text_features.unbind(0)):
|
| 63 |
kv_compact.append(feat_i[:len_i])
|
|
|
|
| 77 |
print('[Save slim model] done')
|
| 78 |
return save_file
|
| 79 |
|
| 80 |
+
def load_tokenizer(t5_path='google/flan-t5-xl'):
|
| 81 |
+
"""
|
| 82 |
+
Load and configure the T5 tokenizer and encoder with optimizations.
|
| 83 |
+
"""
|
| 84 |
+
try:
|
| 85 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 86 |
+
bf16_supported = device.type == 'cuda' and torch.cuda.is_bf16_supported()
|
| 87 |
+
dtype = torch.bfloat16 if bf16_supported else torch.float32
|
| 88 |
+
|
| 89 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 90 |
+
t5_path,
|
| 91 |
+
legacy=True,
|
| 92 |
+
model_max_length=512,
|
| 93 |
+
use_fast=True,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
if device.type == 'cuda':
|
| 97 |
+
torch.cuda.empty_cache()
|
| 98 |
+
|
| 99 |
+
encoder = T5EncoderModel.from_pretrained(
|
| 100 |
+
t5_path,
|
| 101 |
+
torch_dtype=dtype,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
encoder.eval().requires_grad_(False).to(device)
|
| 105 |
+
|
| 106 |
+
if device.type == 'cuda' and not bf16_supported:
|
| 107 |
+
encoder.half()
|
| 108 |
+
|
| 109 |
+
return tokenizer, encoder
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f"Error loading tokenizer/encoder: {str(e)}")
|
| 113 |
+
raise RuntimeError("Failed to initialize text models") from e
|
| 114 |
|
| 115 |
def load_infinity(
|
| 116 |
rope2d_each_sa_layer,
|
|
|
|
| 179 |
state_dict = torch.load(model_path, map_location=device)
|
| 180 |
print(infinity_test.load_state_dict(state_dict))
|
| 181 |
|
| 182 |
+
# # Initialize random number generator on the correct device
|
| 183 |
+
# infinity_test.rng = torch.Generator(device=device)
|
| 184 |
|
| 185 |
return infinity_test
|
| 186 |
|
|
|
|
| 340 |
model_path = args.model_path
|
| 341 |
|
| 342 |
if args.checkpoint_type == 'torch':
|
| 343 |
+
slim_model_path = model_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
print(f'Loading checkpoint from {slim_model_path}')
|
| 345 |
else:
|
| 346 |
raise ValueError(f"Unsupported checkpoint_type: {args.checkpoint_type}")
|
|
|
|
| 468 |
)
|
| 469 |
|
| 470 |
# Load models
|
| 471 |
+
print(f"VRAM before forward pass: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB")
|
| 472 |
text_tokenizer, text_encoder = load_tokenizer(t5_path="google/flan-t5-xl")
|
| 473 |
+
print(f"VRAM before forward pass: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB")
|
| 474 |
vae = load_visual_tokenizer(args)
|
| 475 |
+
print(f"VRAM before forward pass: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB")
|
| 476 |
infinity = load_transformer(vae, args)
|
| 477 |
+
print(f"VRAM before forward pass: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB")
|
| 478 |
|
| 479 |
# Define the image generation function
|
| 480 |
@spaces.GPU
|