Spaces:
Running on Zero
Running on Zero
Commit ·
70c780d
1
Parent(s): 1daf6b4
refactor quantization method selection
Browse files- acestep/handler.py +16 -9
- test.py +7 -2
acestep/handler.py
CHANGED
|
@@ -247,15 +247,22 @@ class AceStepHandler:
|
|
| 247 |
if compile_model:
|
| 248 |
self.model = torch.compile(self.model)
|
| 249 |
|
| 250 |
-
if self.quantization
|
| 251 |
-
from torchao.quantization import quantize_
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
|
| 261 |
silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
|
|
|
|
| 247 |
if compile_model:
|
| 248 |
self.model = torch.compile(self.model)
|
| 249 |
|
| 250 |
+
if self.quantization is not None:
|
| 251 |
+
from torchao.quantization import quantize_
|
| 252 |
+
if self.quantization == "int8_weight_only":
|
| 253 |
+
from torchao.quantization import Int8WeightOnlyConfig
|
| 254 |
+
quant_config = Int8WeightOnlyConfig()
|
| 255 |
+
elif self.quantization == "fp8_weight_only":
|
| 256 |
+
from torchao.quantization import Float8WeightOnlyConfig
|
| 257 |
+
quant_config = Float8WeightOnlyConfig()
|
| 258 |
+
elif self.quantization == "w8a8_dynamic":
|
| 259 |
+
from torchao.quantization import Int8DynamicActivationInt8WeightConfig, MappingType
|
| 260 |
+
quant_config = Int8DynamicActivationInt8WeightConfig(act_mapping_type=MappingType.ASYMMETRIC)
|
| 261 |
+
else:
|
| 262 |
+
raise ValueError(f"Unsupported quantization type: {self.quantization}")
|
| 263 |
+
|
| 264 |
+
quantize_(self.model, quant_config)
|
| 265 |
+
logger.info("DiT quantized with:",self.quantization)
|
| 266 |
|
| 267 |
|
| 268 |
silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
|
test.py
CHANGED
|
@@ -46,7 +46,7 @@ def main():
|
|
| 46 |
compile_model=True,
|
| 47 |
offload_to_cpu=True,
|
| 48 |
offload_dit_to_cpu=False, # Keep DiT on GPU
|
| 49 |
-
quantization="
|
| 50 |
)
|
| 51 |
|
| 52 |
if not enabled:
|
|
@@ -108,7 +108,12 @@ def main():
|
|
| 108 |
print(f"Generated Audio Codes (first 50 chars): {audio_codes[:50]}...")
|
| 109 |
else:
|
| 110 |
print("Skipping 5Hz LLM generation...")
|
| 111 |
-
metadata = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
audio_codes = None
|
| 113 |
lm_status = "Skipped"
|
| 114 |
|
|
|
|
| 46 |
compile_model=True,
|
| 47 |
offload_to_cpu=True,
|
| 48 |
offload_dit_to_cpu=False, # Keep DiT on GPU
|
| 49 |
+
quantization="int8_weight_only", # Enable FP8 weight-only quantization
|
| 50 |
)
|
| 51 |
|
| 52 |
if not enabled:
|
|
|
|
| 108 |
print(f"Generated Audio Codes (first 50 chars): {audio_codes[:50]}...")
|
| 109 |
else:
|
| 110 |
print("Skipping 5Hz LLM generation...")
|
| 111 |
+
metadata = {
|
| 112 |
+
'bpm': 90,
|
| 113 |
+
'keyscale': 'A major',
|
| 114 |
+
'timesignature': '4',
|
| 115 |
+
'duration': 240,
|
| 116 |
+
}
|
| 117 |
audio_codes = None
|
| 118 |
lm_status = "Skipped"
|
| 119 |
|