xushengyuan commited on
Commit
70c780d
·
1 Parent(s): 1daf6b4

refactor quantization method selection

Browse files
Files changed (2) hide show
  1. acestep/handler.py +16 -9
  2. test.py +7 -2
acestep/handler.py CHANGED
@@ -247,15 +247,22 @@ class AceStepHandler:
247
  if compile_model:
248
  self.model = torch.compile(self.model)
249
 
250
- if self.quantization == "int8_weight_only":
251
- from torchao.quantization import quantize_, Int8WeightOnlyConfig
252
- quantize_(self.model, Int8WeightOnlyConfig())
253
- logger.info("DiT quantized with Int8WeightOnlyConfig")
254
- elif self.quantization == "fp8_weight_only":
255
- from torchao.quantization import quantize_, Float8WeightOnlyConfig
256
- quantize_(self.model, Float8WeightOnlyConfig())
257
- elif self.quantization is not None:
258
- raise ValueError(f"Unsupported quantization type: {self.quantization}")
 
 
 
 
 
 
 
259
 
260
 
261
  silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
 
247
  if compile_model:
248
  self.model = torch.compile(self.model)
249
 
250
+ if self.quantization is not None:
251
+ from torchao.quantization import quantize_
252
+ if self.quantization == "int8_weight_only":
253
+ from torchao.quantization import Int8WeightOnlyConfig
254
+ quant_config = Int8WeightOnlyConfig()
255
+ elif self.quantization == "fp8_weight_only":
256
+ from torchao.quantization import Float8WeightOnlyConfig
257
+ quant_config = Float8WeightOnlyConfig()
258
+ elif self.quantization == "w8a8_dynamic":
259
+ from torchao.quantization import Int8DynamicActivationInt8WeightConfig, MappingType
260
+ quant_config = Int8DynamicActivationInt8WeightConfig(act_mapping_type=MappingType.ASYMMETRIC)
261
+ else:
262
+ raise ValueError(f"Unsupported quantization type: {self.quantization}")
263
+
264
+ quantize_(self.model, quant_config)
265
+ logger.info("DiT quantized with:",self.quantization)
266
 
267
 
268
  silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
test.py CHANGED
@@ -46,7 +46,7 @@ def main():
46
  compile_model=True,
47
  offload_to_cpu=True,
48
  offload_dit_to_cpu=False, # Keep DiT on GPU
49
- quantization="fp8_weight_only", # Enable FP8 weight-only quantization
50
  )
51
 
52
  if not enabled:
@@ -108,7 +108,12 @@ def main():
108
  print(f"Generated Audio Codes (first 50 chars): {audio_codes[:50]}...")
109
  else:
110
  print("Skipping 5Hz LLM generation...")
111
- metadata = {}
 
 
 
 
 
112
  audio_codes = None
113
  lm_status = "Skipped"
114
 
 
46
  compile_model=True,
47
  offload_to_cpu=True,
48
  offload_dit_to_cpu=False, # Keep DiT on GPU
49
+ quantization="int8_weight_only", # Enable FP8 weight-only quantization
50
  )
51
 
52
  if not enabled:
 
108
  print(f"Generated Audio Codes (first 50 chars): {audio_codes[:50]}...")
109
  else:
110
  print("Skipping 5Hz LLM generation...")
111
+ metadata = {
112
+ 'bpm': 90,
113
+ 'keyscale': 'A major',
114
+ 'timesignature': '4',
115
+ 'duration': 240,
116
+ }
117
  audio_codes = None
118
  lm_status = "Skipped"
119