| |
| """ |
| Fish Speech S2 Pro Quantization Toolkit |
| ======================================== |
| Quantizes the S2 Pro model at multiple precision levels and generates |
| voice-cloned TTS samples for quality comparison. |
| |
| Usage: |
| python quantize.py --phase all # Run all phases |
| python quantize.py --phase 1a # FP8 only |
| python quantize.py --phase 1b # INT4 only |
| python quantize.py --phase 2a # Hybrid INT4+FP8 |
| python quantize.py --phase 2b # INT8 |
| python quantize.py --phase 2c # INT3 |
| python quantize.py --phase 3a # INT2 |
| python quantize.py --phase 3b # INT2 all layers |
| |
| Requirements: |
| - CUDA GPU with >= 24GB VRAM (A100 40/80GB recommended) |
| - pip install torch einops loguru ormsgpack hydra-core omegaconf safetensors torchaudio soundfile |
| |
| Author: Fish Speech Quantization Experiment |
| """ |
| import os, sys, json, time, gc, traceback, argparse |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import soundfile as sf |
| from pathlib import Path |
| from collections import OrderedDict |
| from safetensors.torch import save_file |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| DTYPE = torch.bfloat16 |
| BASE_MODEL = "fishaudio/s2-pro" |
|
|
| |
| |
| |
|
|
| class FP8Linear(nn.Module): |
| """Per-row symmetric FP8 (float8_e4m3fn) weight-only quantization. |
| Proven zero-quality-loss approach from drbaph/s2-pro-fp8.""" |
| def __init__(self, in_f, out_f, bias=True): |
| super().__init__() |
| self.in_features = in_f |
| self.out_features = out_f |
| self.register_buffer("weight", torch.empty(out_f, in_f, dtype=torch.float8_e4m3fn)) |
| self.register_buffer("weight_scale", torch.empty(out_f, 1, dtype=torch.float32)) |
| self.has_bias = bias |
| if bias: |
| self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) |
| else: |
| self.bias = None |
|
|
| @staticmethod |
| def from_linear(linear): |
| fp8 = FP8Linear(linear.in_features, linear.out_features, linear.bias is not None) |
| FP8_MAX = 448.0 |
| w = linear.weight.data.detach().to(torch.bfloat16) |
| scale = w.abs().amax(dim=1, keepdim=True) / FP8_MAX |
| scale = scale.clamp(min=1e-12) |
| w_q = (w / scale).round().clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fn) |
| fp8.weight.data.copy_(w_q) |
| fp8.weight_scale.data.copy_(scale) |
| if linear.bias is not None: |
| fp8.bias.data.copy_(linear.bias.data.detach().to(torch.bfloat16)) |
| return fp8 |
|
|
| def forward(self, x): |
| w = self.weight.to(torch.bfloat16) * self.weight_scale |
| return nn.functional.linear(x, w, self.bias) |
|
|
|
|
| class INT8Linear(nn.Module): |
| """Per-row symmetric INT8 weight-only quantization.""" |
| def __init__(self, in_f, out_f, bias=True): |
| super().__init__() |
| self.in_features = in_f |
| self.out_features = out_f |
| self.register_buffer("weight", torch.empty(out_f, in_f, dtype=torch.int8)) |
| self.register_buffer("weight_scale", torch.empty(out_f, 1, dtype=torch.float32)) |
| self.has_bias = bias |
| if bias: |
| self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) |
| else: |
| self.bias = None |
|
|
| @staticmethod |
| def from_linear(linear): |
| q = INT8Linear(linear.in_features, linear.out_features, linear.bias is not None) |
| w = linear.weight.data.detach().to(torch.bfloat16) |
| scale = w.abs().amax(dim=1, keepdim=True) / 127.0 |
| scale = scale.clamp(min=1e-12) |
| w_q = (w / scale).round().clamp(-128, 127).to(torch.int8) |
| q.weight.data.copy_(w_q) |
| q.weight_scale.data.copy_(scale) |
| if linear.bias is not None: |
| q.bias.data.copy_(linear.bias.data.detach().to(torch.bfloat16)) |
| return q |
|
|
| def forward(self, x): |
| w = self.weight.to(torch.bfloat16) * self.weight_scale |
| return nn.functional.linear(x, w, self.bias) |
|
|
|
|
| class INT4Linear(nn.Module): |
| """Group-wise symmetric INT4 weight-only quantization (group_size=128). |
| Approximates GPTQ-style quantization without calibration data.""" |
| def __init__(self, in_f, out_f, group_size=128, bias=True): |
| super().__init__() |
| self.in_features = in_f |
| self.out_features = out_f |
| self.group_size = group_size |
| |
| self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8)) |
| self.register_buffer("weight_scale", torch.empty( |
| out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32)) |
| self.has_bias = bias |
| if bias: |
| self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) |
| else: |
| self.bias = None |
|
|
| @staticmethod |
| def from_linear(linear, group_size=128): |
| in_f = linear.in_features |
| out_f = linear.out_features |
| q = INT4Linear(in_f, out_f, group_size, linear.bias is not None) |
| w = linear.weight.data.detach().to(torch.bfloat16) |
| n_groups = (in_f + group_size - 1) // group_size |
| pad = n_groups * group_size - in_f |
| if pad > 0: |
| w = nn.functional.pad(w, (0, pad)) |
| w_g = w.reshape(out_f, n_groups, group_size) |
| scale = w_g.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) / 7.0 |
| w_q = (w_g / scale).round().clamp(-7, 7).to(torch.int8) |
| q.weight_q.data.copy_(w_q.reshape(out_f, -1)[:, :in_f]) |
| q.weight_scale.data.copy_(scale.squeeze(-1)[:, :n_groups]) |
| if linear.bias is not None: |
| q.bias.data.copy_(linear.bias.data.detach().to(torch.bfloat16)) |
| return q |
|
|
| def forward(self, x): |
| s = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features] |
| w = self.weight_q[:, :self.in_features].to(torch.bfloat16) * s |
| return nn.functional.linear(x, w, self.bias) |
|
|
|
|
| class INT3Linear(nn.Module): |
| """Group-wise symmetric INT3 weight-only quantization (group_size=128). |
| Values in range [-3, 3].""" |
| def __init__(self, in_f, out_f, group_size=128, bias=True): |
| super().__init__() |
| self.in_features = in_f |
| self.out_features = out_f |
| self.group_size = group_size |
| self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8)) |
| self.register_buffer("weight_scale", torch.empty( |
| out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32)) |
| self.has_bias = bias |
| if bias: |
| self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) |
| else: |
| self.bias = None |
|
|
| @staticmethod |
| def from_linear(linear, group_size=128): |
| in_f = linear.in_features |
| out_f = linear.out_features |
| q = INT3Linear(in_f, out_f, group_size, linear.bias is not None) |
| w = linear.weight.data.detach().to(torch.bfloat16) |
| n_groups = (in_f + group_size - 1) // group_size |
| pad = n_groups * group_size - in_f |
| if pad > 0: |
| w = nn.functional.pad(w, (0, pad)) |
| w_g = w.reshape(out_f, n_groups, group_size) |
| scale = w_g.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) / 3.0 |
| w_q = (w_g / scale).round().clamp(-3, 3).to(torch.int8) |
| q.weight_q.data.copy_(w_q.reshape(out_f, -1)[:, :in_f]) |
| q.weight_scale.data.copy_(scale.squeeze(-1)[:, :n_groups]) |
| if linear.bias is not None: |
| q.bias.data.copy_(linear.bias.data.detach().to(torch.bfloat16)) |
| return q |
|
|
| def forward(self, x): |
| s = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features] |
| w = self.weight_q[:, :self.in_features].to(torch.bfloat16) * s |
| return nn.functional.linear(x, w, self.bias) |
|
|
|
|
| class INT2Linear(nn.Module): |
| """Group-wise symmetric INT2 weight-only quantization (group_size=64). |
| Values in range [-1, 0, 1].""" |
| def __init__(self, in_f, out_f, group_size=64, bias=True): |
| super().__init__() |
| self.in_features = in_f |
| self.out_features = out_f |
| self.group_size = group_size |
| self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8)) |
| self.register_buffer("weight_scale", torch.empty( |
| out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32)) |
| self.has_bias = bias |
| if bias: |
| self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) |
| else: |
| self.bias = None |
|
|
| @staticmethod |
| def from_linear(linear, group_size=64): |
| in_f = linear.in_features |
| out_f = linear.out_features |
| q = INT2Linear(in_f, out_f, group_size, linear.bias is not None) |
| w = linear.weight.data.detach().to(torch.bfloat16) |
| n_groups = (in_f + group_size - 1) // group_size |
| pad = n_groups * group_size - in_f |
| if pad > 0: |
| w = nn.functional.pad(w, (0, pad)) |
| w_g = w.reshape(out_f, n_groups, group_size) |
| scale = w_g.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) / 1.0 |
| w_q = (w_g / scale).round().clamp(-1, 1).to(torch.int8) |
| q.weight_q.data.copy_(w_q.reshape(out_f, -1)[:, :in_f]) |
| q.weight_scale.data.copy_(scale.squeeze(-1)[:, :n_groups]) |
| if linear.bias is not None: |
| q.bias.data.copy_(linear.bias.data.detach().to(torch.bfloat16)) |
| return q |
|
|
| def forward(self, x): |
| s = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features] |
| w = self.weight_q[:, :self.in_features].to(torch.bfloat16) * s |
| return nn.functional.linear(x, w, self.bias) |
|
|
|
|
|
|
| |
| |
| |
|
|
| def apply_quantization(model, quant_class, target="slow_ar", skip_names=None, **kwargs): |
| """Replace nn.Linear layers with quantized versions. |
| |
| Args: |
| target: 'slow_ar' = only Slow AR (36 layers), 'all' = both Slow + Fast AR |
| skip_names: list of name substrings to skip (e.g., ['embed', 'norm']) |
| """ |
| if skip_names is None: |
| skip_names = ['embed', 'norm'] |
| count = 0 |
| for name, module in list(model.named_modules()): |
| if not isinstance(module, nn.Linear): |
| continue |
| if any(s in name for s in skip_names): |
| continue |
| is_fast = "fast_" in name |
| if target == "slow_ar" and is_fast: |
| continue |
| parts = name.split(".") |
| parent = model |
| for p in parts[:-1]: |
| parent = getattr(parent, p) |
| try: |
| quantized = quant_class.from_linear(module, **kwargs) |
| setattr(parent, parts[-1], quantized) |
| count += 1 |
| except Exception as e: |
| print(f" Skip {name}: {e}") |
| return model, count |
|
|
|
|
| def get_model_size_mb(model): |
| """Get total model size in MB""" |
| total = 0 |
| for p in model.parameters(): |
| total += p.numel() * p.element_size() |
| for b in model.buffers(): |
| total += b.numel() * b.element_size() |
| return total / (1024 * 1024) |
|
|
|
|
| |
| |
| |
|
|
| def generate_tts_simple(model, codec, text, output_path, device="cuda"): |
| """Generate TTS sample without reference audio (text-only).""" |
| import torchaudio |
| from fish_speech.tokenizer import IM_END_TOKEN |
| from fish_speech.models.text2semantic.inference import generate, decode_one_token_ar |
| from fish_speech.content_sequence import TextPart |
| from fish_speech.conversation import Conversation, Message |
|
|
| conv = Conversation() |
| conv.add_message(Message(role="user", parts=[TextPart(text="")])) |
| conv.add_message(Message(role="assistant", parts=[TextPart(text=text)])) |
| prompt = conv.encode_for_inference(model.config) |
| codebook_dim = 1 + model.config.num_codebooks |
| audio_masks = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.bool, device=device) |
| audio_parts = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.long, device=device) |
|
|
| if not getattr(model, '_cache_setup_done', False): |
| model.setup_caches(max_batch_size=1, max_seq_len=model.config.max_seq_len, dtype=DTYPE) |
| model._cache_setup_done = True |
|
|
| with torch.autocast(device_type="cuda", dtype=DTYPE): |
| result = generate( |
| model=model, prompt=prompt, max_new_tokens=512, |
| audio_masks=audio_masks, audio_parts=audio_parts, |
| temperature=0.7, top_p=0.7, top_k=30, |
| decode_one_token=decode_one_token_ar, |
| ) |
| codes = result[0:1, :, :].unsqueeze(0) |
| with torch.autocast(device_type="cuda", dtype=DTYPE): |
| audio = codec.decode(codes.to(device)) |
| audio_np = audio.squeeze().cpu().float().numpy() |
| sr = getattr(codec, 'sample_rate', 44100) |
| sf.write(output_path, audio_np, sr) |
| dur = len(audio_np) / sr |
| print(f" Saved: {output_path} ({dur:.1f}s)") |
| return True, dur |
|
|
|
|
| def generate_voice_clone(model, codec, text, ref_path, ref_text, output_path, device="cuda"): |
| """Generate voice-cloned TTS sample from reference audio.""" |
| import torchaudio |
| from fish_speech.models.text2semantic.inference import generate, decode_one_token_ar |
| from fish_speech.content_sequence import TextPart, VQPart |
| from fish_speech.conversation import Conversation, Message |
|
|
| wav, sr = torchaudio.load(ref_path) |
| if wav.shape[0] > 1: |
| wav = wav.mean(dim=0, keepdim=True) |
| if sr != 44100: |
| wav = torchaudio.functional.resample(wav, sr, 44100) |
| wav = wav.to(device) |
|
|
| with torch.autocast(device_type="cuda", dtype=DTYPE): |
| encoded = codec.encode(wav.unsqueeze(0)) |
| prompt_tokens = (encoded[0] if isinstance(encoded, tuple) else encoded).cpu().numpy() |
|
|
| conv = Conversation() |
| conv.add_message(Message(role="user", parts=[ |
| VQPart(codes=prompt_tokens), TextPart(text=ref_text)])) |
| conv.add_message(Message(role="assistant", parts=[TextPart(text=text)])) |
|
|
| prompt = conv.encode_for_inference(model.config) |
| codebook_dim = 1 + model.config.num_codebooks |
| audio_masks = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.bool, device=device) |
| audio_parts = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.long, device=device) |
|
|
| if not getattr(model, '_cache_setup_done', False): |
| model.setup_caches(max_batch_size=1, max_seq_len=model.config.max_seq_len, dtype=DTYPE) |
| model._cache_setup_done = True |
|
|
| with torch.autocast(device_type="cuda", dtype=DTYPE): |
| result = generate( |
| model=model, prompt=prompt, max_new_tokens=512, |
| audio_masks=audio_masks, audio_parts=audio_parts, |
| temperature=0.7, top_p=0.7, top_k=30, |
| decode_one_token=decode_one_token_ar, |
| ) |
| codes = result[0:1, :, :].unsqueeze(0) |
| with torch.autocast(device_type="cuda", dtype=DTYPE): |
| audio = codec.decode(codes.to(device)) |
| audio_np = audio.squeeze().cpu().float().numpy() |
| sr = getattr(codec, 'sample_rate', 44100) |
| sf.write(output_path, audio_np, sr) |
| dur = len(audio_np) / sr |
| print(f" Voice clone saved: {output_path} ({dur:.1f}s)") |
| return True, dur |
|
|
|
|
|
|
| |
| |
| |
|
|
| def run_phase(phase_id, quant_class, target, codec, ref_audio_path, ref_text, |
| test_text, clone_text, output_dir, **qkwargs): |
| """Run one quantization phase end-to-end.""" |
| from fish_speech.models.text2semantic.inference import init_model |
| |
| phase_dir = f"{output_dir}/{phase_id}" |
| samples_dir = f"{output_dir}/samples" |
| os.makedirs(phase_dir, exist_ok=True) |
| os.makedirs(samples_dir, exist_ok=True) |
|
|
| print(f"\n{'='*60}") |
| print(f" {phase_id.upper()}: {quant_class.__name__} ({target})") |
| print(f"{'='*60}") |
|
|
| |
| model, _ = init_model(BASE_MODEL, DEVICE, DTYPE, compile=False) |
| orig_size = get_model_size_mb(model) |
|
|
| |
| t0 = time.time() |
| model, n_layers = apply_quantization(model, quant_class, target=target, **qkwargs) |
| model = model.to(DEVICE) |
| t_quant = time.time() - t0 |
|
|
| quant_size = get_model_size_mb(model) |
| ratio = orig_size / quant_size if quant_size > 0 else 0 |
| print(f" {orig_size:.0f} MB -> {quant_size:.0f} MB ({ratio:.2f}x, {n_layers} layers, {t_quant:.1f}s)") |
|
|
| |
| save_path = f"{phase_dir}/model.safetensors" |
| save_file(model.state_dict(), save_path) |
| disk_mb = os.path.getsize(save_path) / (1024*1024) |
| print(f" Disk: {disk_mb:.0f} MB") |
|
|
| |
| tts_ok, tts_dur = False, 0 |
| try: |
| tts_ok, tts_dur = generate_tts_simple( |
| model, codec, test_text, f"{samples_dir}/{phase_id}_tts.wav") |
| except Exception as e: |
| print(f" TTS failed: {e}") |
|
|
| |
| clone_ok, clone_dur = False, 0 |
| if ref_audio_path and os.path.exists(ref_audio_path): |
| try: |
| clone_ok, clone_dur = generate_voice_clone( |
| model, codec, clone_text, ref_audio_path, ref_text, |
| f"{samples_dir}/{phase_id}_clone.wav") |
| except Exception as e: |
| print(f" Clone failed: {e}") |
|
|
| del model |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| result = { |
| "phase": phase_id, "method": quant_class.__name__, "target": target, |
| "original_mb": round(orig_size), "quantized_mb": round(quant_size), |
| "disk_mb": round(disk_mb), "compression": round(ratio, 3), |
| "n_layers": n_layers, "time_s": round(t_quant, 1), |
| "tts_ok": tts_ok, "tts_dur_s": round(tts_dur, 1), |
| "clone_ok": clone_ok, "clone_dur_s": round(clone_dur, 1), |
| } |
| with open(f"{phase_dir}/results.json", "w") as f: |
| json.dump(result, f, indent=2) |
| return result |
|
|
|
|
|
|
| |
| |
| |
|
|
| TEST_TEXT = ( |
| "The quick brown fox jumps over the lazy dog. " |
| "Artificial intelligence is transforming the way we communicate with machines." |
| ) |
| CLONE_TEXT = ( |
| "Hello everyone, welcome to this special presentation. " |
| "Today we explore the fascinating world of neural text to speech synthesis." |
| ) |
| REF_TEXT = "This is a reference voice recording used for demonstration purposes." |
| |
| CELEBRITY_REF_TEXT = ( |
| "Good morning. I want to tell you something about the universe. " |
| "Every atom in your body came from a star that exploded. " |
| "We are all made of star stuff." |
| ) |
|
|
| PHASES = { |
| "1a": {"cls": FP8Linear, "target": "slow_ar", "kwargs": {}}, |
| "1b": {"cls": INT4Linear, "target": "slow_ar", "kwargs": {"group_size": 128}}, |
| "2a": {"cls": INT4Linear, "target": "all", "kwargs": {"group_size": 128}}, |
| "2b": {"cls": INT8Linear, "target": "slow_ar", "kwargs": {}}, |
| "2c": {"cls": INT3Linear, "target": "slow_ar", "kwargs": {"group_size": 128}}, |
| "3a": {"cls": INT2Linear, "target": "slow_ar", "kwargs": {"group_size": 64}}, |
| "3b": {"cls": INT2Linear, "target": "all", "kwargs": {"group_size": 64}}, |
| } |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Fish Speech S2 Pro Quantization") |
| parser.add_argument("--phase", default="all", help="Phase to run (1a,1b,2a,2b,2c,3a,3b,all)") |
| parser.add_argument("--output", default="./output", help="Output directory") |
| parser.add_argument("--model", default=BASE_MODEL, help="Model ID or path") |
| parser.add_argument("--ref-audio", default=None, help="Reference audio for voice cloning") |
| args = parser.parse_args() |
|
|
| global BASE_MODEL |
| BASE_MODEL = args.model |
| output_dir = args.output |
| os.makedirs(f"{output_dir}/samples", exist_ok=True) |
|
|
| |
| if not os.path.exists("fish-speech"): |
| os.system("git clone --depth 1 https://github.com/fishaudio/fish-speech.git") |
| sys.path.insert(0, "fish-speech") |
|
|
| from fish_speech.models.text2semantic.inference import init_model |
| from fish_speech.models.dac.inference import load_codec_model |
|
|
| |
| print("Loading codec...") |
| codec = load_codec_model(f"{BASE_MODEL}/codec.pth", DEVICE, DTYPE) |
|
|
| |
| ref_path = args.ref_audio or f"{output_dir}/reference_celebrity.wav" |
| if not os.path.exists(ref_path): |
| print("Generating celebrity-style reference audio from base model...") |
| model_base, _ = init_model(BASE_MODEL, DEVICE, DTYPE, compile=False) |
| try: |
| generate_tts_simple(model_base, codec, CELEBRITY_REF_TEXT, ref_path) |
| print(f"Reference audio saved: {ref_path}") |
| except Exception as e: |
| print(f"Warning: Could not generate reference audio: {e}") |
| ref_path = None |
| |
| |
| try: |
| generate_tts_simple(model_base, codec, TEST_TEXT, f"{output_dir}/samples/baseline_bf16_tts.wav") |
| if ref_path: |
| generate_voice_clone(model_base, codec, CLONE_TEXT, ref_path, REF_TEXT, |
| f"{output_dir}/samples/baseline_bf16_clone.wav") |
| except Exception as e: |
| print(f"Warning: Baseline generation issue: {e}") |
| del model_base |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| |
| if args.phase == "all": |
| phases_to_run = list(PHASES.keys()) |
| else: |
| phases_to_run = [p.strip() for p in args.phase.split(",")] |
|
|
| all_results = [] |
| for pid in phases_to_run: |
| if pid not in PHASES: |
| print(f"Unknown phase: {pid}") |
| continue |
| cfg = PHASES[pid] |
| r = run_phase( |
| f"phase{pid}", cfg["cls"], cfg["target"], codec, |
| ref_path, REF_TEXT, TEST_TEXT, CLONE_TEXT, output_dir, |
| **cfg["kwargs"] |
| ) |
| all_results.append(r) |
|
|
| |
| print(f"\n{'='*70}") |
| print(" QUANTIZATION EXPERIMENT SUMMARY") |
| print(f"{'='*70}") |
| print(f"{'Phase':<12} {'Method':<12} {'Target':<10} {'Disk MB':<10} {'Ratio':<8} {'TTS':<5} {'Clone':<5}") |
| print("-" * 65) |
| for r in all_results: |
| print(f"{r['phase']:<12} {r['method']:<12} {r['target']:<10} {r['disk_mb']:<10} {r['compression']:<8.2f} " |
| f"{'OK' if r['tts_ok'] else 'FAIL':<5} {'OK' if r['clone_ok'] else 'FAIL':<5}") |
|
|
| with open(f"{output_dir}/all_results.json", "w") as f: |
| json.dump(all_results, f, indent=2) |
| print(f"\nAll results saved to {output_dir}/all_results.json") |
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|