Spaces:
Paused
Paused
Commit
·
9914b63
1
Parent(s):
54f9225
Optimize Infinity model loading by clearing CUDA cache and adjusting device assignment; remove redundant calls
Browse files
app.py
CHANGED
|
@@ -197,7 +197,8 @@ def load_infinity(
|
|
| 197 |
):
|
| 198 |
print(f'[Loading Infinity]')
|
| 199 |
text_maxlen = 512
|
| 200 |
-
|
|
|
|
| 201 |
infinity_test: Infinity = Infinity(
|
| 202 |
vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
|
| 203 |
shared_aln=True, raw_scale_schedule=scale_schedule,
|
|
@@ -215,7 +216,7 @@ def load_infinity(
|
|
| 215 |
inference_mode=True,
|
| 216 |
train_h_div_w_list=[1.0],
|
| 217 |
**model_kwargs,
|
| 218 |
-
).to(device
|
| 219 |
print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}')
|
| 220 |
|
| 221 |
if bf16:
|
|
@@ -225,9 +226,6 @@ def load_infinity(
|
|
| 225 |
infinity_test.eval()
|
| 226 |
infinity_test.requires_grad_(False)
|
| 227 |
|
| 228 |
-
infinity_test.cuda()
|
| 229 |
-
torch.cuda.empty_cache()
|
| 230 |
-
|
| 231 |
print(f'[Load Infinity weights]')
|
| 232 |
state_dict = torch.load(model_path, map_location=device)
|
| 233 |
print(infinity_test.load_state_dict(state_dict))
|
|
@@ -529,7 +527,6 @@ with gr.Blocks() as demo:
|
|
| 529 |
# Output Section
|
| 530 |
gr.Markdown("### Generated Image")
|
| 531 |
output_image = gr.Image(label="Generated Image", type="pil")
|
| 532 |
-
gr.Markdown("**Tip:** Right-click the image to save it.")
|
| 533 |
|
| 534 |
# Error Handling
|
| 535 |
error_message = gr.Textbox(label="Error Message", visible=False)
|
|
|
|
| 197 |
):
|
| 198 |
print(f'[Loading Infinity]')
|
| 199 |
text_maxlen = 512
|
| 200 |
+
torch.cuda.empty_cache()
|
| 201 |
+
with torch.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():
|
| 202 |
infinity_test: Infinity = Infinity(
|
| 203 |
vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
|
| 204 |
shared_aln=True, raw_scale_schedule=scale_schedule,
|
|
|
|
| 216 |
inference_mode=True,
|
| 217 |
train_h_div_w_list=[1.0],
|
| 218 |
**model_kwargs,
|
| 219 |
+
).to(device)
|
| 220 |
print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}')
|
| 221 |
|
| 222 |
if bf16:
|
|
|
|
| 226 |
infinity_test.eval()
|
| 227 |
infinity_test.requires_grad_(False)
|
| 228 |
|
|
|
|
|
|
|
|
|
|
| 229 |
print(f'[Load Infinity weights]')
|
| 230 |
state_dict = torch.load(model_path, map_location=device)
|
| 231 |
print(infinity_test.load_state_dict(state_dict))
|
|
|
|
| 527 |
# Output Section
|
| 528 |
gr.Markdown("### Generated Image")
|
| 529 |
output_image = gr.Image(label="Generated Image", type="pil")
|
|
|
|
| 530 |
|
| 531 |
# Error Handling
|
| 532 |
error_message = gr.Textbox(label="Error Message", visible=False)
|