Update 8bitapp.py
Browse files- 8bitapp.py +15 -8
8bitapp.py
CHANGED
|
@@ -13,7 +13,6 @@ from torch.cuda.amp import autocast
|
|
| 13 |
import warnings
|
| 14 |
import random
|
| 15 |
from bitsandbytes.nn import Linear8bitLt
|
| 16 |
-
from transformers import AutoModel
|
| 17 |
|
| 18 |
# Suppress warnings for cleaner output
|
| 19 |
warnings.filterwarnings("ignore")
|
|
@@ -52,24 +51,32 @@ try:
|
|
| 52 |
# Load MusicGen model in FP16
|
| 53 |
musicgen_model = MusicGen.get_pretrained(local_model_path, device=device)
|
| 54 |
|
| 55 |
-
# Apply 8-bit quantization to
|
| 56 |
def quantize_to_8bit(model):
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
if isinstance(module, torch.nn.Linear):
|
| 59 |
# Replace with 8-bit linear layer
|
| 60 |
-
parent =
|
| 61 |
-
|
|
|
|
| 62 |
parent = getattr(parent, part)
|
| 63 |
-
setattr(parent,
|
| 64 |
module.in_features,
|
| 65 |
module.out_features,
|
| 66 |
bias=module.bias is not None,
|
| 67 |
has_fp16_weights=False,
|
| 68 |
threshold=6.0
|
| 69 |
))
|
|
|
|
|
|
|
| 70 |
return model
|
| 71 |
|
| 72 |
-
# Quantize the model
|
| 73 |
musicgen_model = quantize_to_8bit(musicgen_model)
|
| 74 |
musicgen_model.to(device)
|
| 75 |
|
|
@@ -94,7 +101,7 @@ def print_resource_usage(stage: str):
|
|
| 94 |
print("---------------")
|
| 95 |
|
| 96 |
# Check available GPU memory
|
| 97 |
-
def check_vram_availability(required_gb=
|
| 98 |
"""Check if sufficient VRAM is available for audio generation."""
|
| 99 |
total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
| 100 |
allocated_vram = torch.cuda.memory_allocated() / (1024**3)
|
|
|
|
| 13 |
import warnings
|
| 14 |
import random
|
| 15 |
from bitsandbytes.nn import Linear8bitLt
|
|
|
|
| 16 |
|
| 17 |
# Suppress warnings for cleaner output
|
| 18 |
warnings.filterwarnings("ignore")
|
|
|
|
| 51 |
# Load MusicGen model in FP16
|
| 52 |
musicgen_model = MusicGen.get_pretrained(local_model_path, device=device)
|
| 53 |
|
| 54 |
+
# Apply 8-bit quantization to the language model (lm) component
|
| 55 |
def quantize_to_8bit(model):
|
| 56 |
+
# Target the lm (language model) attribute, which contains the transformer
|
| 57 |
+
if not hasattr(model, 'lm'):
|
| 58 |
+
raise AttributeError("MusicGen model does not have 'lm' attribute for quantization.")
|
| 59 |
+
lm = model.lm
|
| 60 |
+
quantized_layers = 0
|
| 61 |
+
for name, module in lm.named_modules():
|
| 62 |
if isinstance(module, torch.nn.Linear):
|
| 63 |
# Replace with 8-bit linear layer
|
| 64 |
+
parent = lm
|
| 65 |
+
name_parts = name.split('.')
|
| 66 |
+
for part in name_parts[:-1]:
|
| 67 |
parent = getattr(parent, part)
|
| 68 |
+
setattr(parent, name_parts[-1], Linear8bitLt(
|
| 69 |
module.in_features,
|
| 70 |
module.out_features,
|
| 71 |
bias=module.bias is not None,
|
| 72 |
has_fp16_weights=False,
|
| 73 |
threshold=6.0
|
| 74 |
))
|
| 75 |
+
quantized_layers += 1
|
| 76 |
+
print(f"Quantized {quantized_layers} linear layers to 8-bit.")
|
| 77 |
return model
|
| 78 |
|
| 79 |
+
# Quantize the model
|
| 80 |
musicgen_model = quantize_to_8bit(musicgen_model)
|
| 81 |
musicgen_model.to(device)
|
| 82 |
|
|
|
|
| 101 |
print("---------------")
|
| 102 |
|
| 103 |
# Check available GPU memory
|
| 104 |
+
def check_vram_availability(required_gb=4.0):
|
| 105 |
"""Check if sufficient VRAM is available for audio generation."""
|
| 106 |
total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
| 107 |
allocated_vram = torch.cuda.memory_allocated() / (1024**3)
|