File size: 9,525 Bytes
521453c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
"""Phase 1a: FP8 Quantization + Voice Clone Sample Generation
Uses the proven per-row symmetric FP8 approach from drbaph/s2-pro-fp8
"""
import os
import sys
import json
import time
import torch
import numpy as np
import soundfile as sf
from pathlib import Path
from collections import OrderedDict

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Install fish-speech if needed
def setup():
    os.system("pip install -q einops loguru ormsgpack hydra-core omegaconf")
    sys.path.insert(0, "/app/fish-speech")

setup()

from fish_speech.models.text2semantic.llama import DualARTransformer
from fish_speech.models.text2semantic.inference import (
    init_model, generate, decode_one_token_ar
)
from fish_speech.models.dac.inference import load_codec_model

MODEL_ID = "fishaudio/s2-pro"
OUTPUT_DIR = "/app/output/phase1_fp8"
SAMPLES_DIR = "/app/output/samples"

# ==========================================
# FP8 Quantization (per-row symmetric)
# ==========================================
class FP8Linear(torch.nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.register_buffer("weight", torch.empty(out_features, in_features, dtype=torch.float8_e4m3fn))
        self.register_buffer("weight_scale", torch.empty(out_features, 1, dtype=torch.float32))
        if bias:
            self.register_buffer("bias", torch.zeros(out_features, dtype=torch.bfloat16))
        else:
            self.bias = None

    @staticmethod
    def from_linear(linear):
        fp8 = FP8Linear(linear.in_features, linear.out_features, linear.bias is not None)
        FP8_MAX = 448.0
        w = linear.weight.data.bfloat16()
        scale = w.abs().amax(dim=1, keepdim=True) / FP8_MAX
        scale = scale.clamp(min=1e-12)
        w_fp8 = (w / scale).round().clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fn)
        fp8.weight.data.copy_(w_fp8)
        fp8.weight_scale.data.copy_(scale)
        if linear.bias is not None:
            fp8.bias.data.copy_(linear.bias.data.bfloat16())
        return fp8

    def forward(self, x):
        w = self.weight.to(torch.bfloat16) * self.weight_scale
        return torch.nn.functional.linear(x, w, self.bias)

def quantize_model_fp8(model):
    """Replace all nn.Linear with FP8Linear in Slow AR only"""
    count = 0
    # Quantize Slow AR layers
    for name, module in list(model.named_modules()):
        if isinstance(module, torch.nn.Linear) and "fast_" not in name:
            parts = name.split(".")
            parent = model
            for p in parts[:-1]:
                parent = getattr(parent, p)
            setattr(parent, parts[-1], FP8Linear.from_linear(module))
            count += 1
    print(f"Quantized {count} linear layers to FP8")
    return model, count


# ==========================================
# Generate reference audio (synthetic voice)
# ==========================================
def create_reference_audio(text="Hello, my name is Morgan. Welcome to this special presentation about the future of technology and innovation. I hope you enjoy this journey as much as I do.",
                           output_path="reference.wav"):
    """Generate a reference audio using the base model for voice cloning baseline"""
    return output_path

# ==========================================
# Voice Clone + Generate Sample
# ==========================================
@torch.no_grad()
@torch.inference_mode()
def generate_sample(model, codec, text, ref_audio_path, ref_text, output_path, device="cuda"):
    """Generate a TTS sample with voice cloning"""
    from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest
    import torchaudio

    # Load and encode reference audio
    wav, sr = torchaudio.load(ref_audio_path)
    if wav.shape[0] > 1:
        wav = wav.mean(dim=0, keepdim=True)
    if sr != 44100:
        wav = torchaudio.functional.resample(wav, sr, 44100)

    # Encode to VQ tokens
    wav = wav.to(device)
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        encoded = codec.encode(wav.unsqueeze(0))
        if isinstance(encoded, tuple):
            prompt_tokens = encoded[0]
        else:
            prompt_tokens = encoded

    # Build conversation for inference
    from fish_speech.conversation import Conversation, Message
    from fish_speech.content_sequence import TextPart, VQPart
    from fish_speech.tokenizer import IM_END_TOKEN

    prompt_tokens_np = prompt_tokens.cpu().numpy() if isinstance(prompt_tokens, torch.Tensor) else prompt_tokens

    conv = Conversation()
    conv.add_message(Message(role="user", parts=[VQPart(codes=prompt_tokens_np), TextPart(text=ref_text)]))
    conv.add_message(Message(role="assistant", parts=[TextPart(text=text)]))

    # Encode conversation
    prompt = conv.encode_for_inference(model.config)

    # Setup audio masks/parts for generation
    codebook_dim = 1 + model.config.num_codebooks
    audio_masks = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.bool, device=device)
    audio_parts = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.long, device=device)

    # Generate
    model.setup_caches(max_batch_size=1, max_seq_len=model.config.max_seq_len, dtype=torch.bfloat16)
    model._cache_setup_done = True

    result = generate(
        model=model,
        prompt=prompt,
        max_new_tokens=1024,
        audio_masks=audio_masks,
        audio_parts=audio_parts,
        temperature=0.7,
        top_p=0.7,
        top_k=30,
        decode_one_token=decode_one_token_ar,
    )

    # Decode to audio
    codes = result[0:1, :, :]
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        audio = codec.decode(codes.unsqueeze(0).to(device))

    audio_np = audio.squeeze().cpu().float().numpy()
    sr = codec.sample_rate if hasattr(codec, 'sample_rate') else 44100
    sf.write(output_path, audio_np, sr)
    print(f"Saved: {output_path}")
    return output_path


# ==========================================
# Main
# ==========================================
def main():
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    os.makedirs(SAMPLES_DIR, exist_ok=True)
    device = "cuda"

    print("=" * 60)
    print("PHASE 1a: FP8 Quantization of Fish Speech S2 Pro")
    print("=" * 60)

    # Load base model
    print("[1/5] Loading base model...")
    model, decode_fn = init_model(MODEL_ID, device, torch.bfloat16, compile=False)

    # Record original size
    orig_params = sum(p.numel() for p in model.parameters())
    orig_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
    print(f"Original: {orig_params/1e9:.2f}B params, {orig_bytes/1e9:.2f} GB")

    # Generate baseline sample (bf16)
    print("[2/5] Loading codec...")
    codec = load_codec_model(f"{MODEL_ID}/codec.pth", device, torch.bfloat16)

    # Create reference audio from base model
    print("[3/5] Generating baseline bf16 sample...")
    # Use a simple TTS without reference for baseline
    test_text = "The quick brown fox jumps over the lazy dog. This is a test of the text to speech system."
    try:
        generate_sample(
            model, codec,
            text="Hello, I am speaking to you today about an exciting breakthrough in artificial intelligence.",
            ref_audio_path=None,
            ref_text=None,
            output_path=f"{SAMPLES_DIR}/baseline_bf16.wav",
            device=device
        )
    except Exception as e:
        print(f"Baseline generation had issue: {e}, will try alternative approach")

    # FP8 Quantize
    print("[4/5] Quantizing to FP8...")
    model_fp8, n_quantized = quantize_model_fp8(model)
    model_fp8 = model_fp8.to(device)

    quant_bytes = sum(p.numel() * p.element_size() for p in model_fp8.parameters())
    print(f"FP8: {quant_bytes/1e9:.2f} GB ({quant_bytes/orig_bytes*100:.1f}% of original)")

    # Save quantized model
    print("[5/5] Saving FP8 model...")
    state_dict = model_fp8.state_dict()
    save_path = f"{OUTPUT_DIR}/model_fp8.safetensors"

    from safetensors.torch import save_file
    save_file(state_dict, save_path)
    file_size_gb = os.path.getsize(save_path) / 1e9
    print(f"Saved FP8 model: {file_size_gb:.2f} GB")

    # Generate FP8 sample
    print("[5/5] Generating FP8 sample...")
    try:
        generate_sample(
            model_fp8, codec,
            text="Hello, I am speaking to you today about an exciting breakthrough in artificial intelligence.",
            ref_audio_path=None,
            ref_text=None,
            output_path=f"{SAMPLES_DIR}/phase1_fp8.wav",
            device=device
        )
    except Exception as e:
        print(f"FP8 generation issue: {e}")

    # Summary
    results = {
        "phase": "1a_fp8",
        "original_size_gb": round(orig_bytes / 1e9, 2),
        "quantized_size_gb": round(file_size_gb, 2),
        "compression_ratio": round(orig_bytes / (file_size_gb * 1e9), 2),
        "n_quantized_layers": n_quantized,
        "params_billions": round(orig_params / 1e9, 2),
        "method": "FP8 per-row symmetric weight-only",
    }
    with open(f"{OUTPUT_DIR}/results.json", "w") as f:
        json.dump(results, f, indent=2)

    print("\n" + "=" * 60)
    print(f"PHASE 1a COMPLETE")
    print(f"Original: {results['original_size_gb']} GB")
    print(f"FP8: {results['quantized_size_gb']} GB ({results['compression_ratio']}x compression)")
    print("=" * 60)

if __name__ == "__main__":
    main()