MOSS-TTSD-NF4 / run_moss_ttsd_nf4.py
groxaxo's picture
Upload MOSS-TTSD NF4 quantized model
3afa0cd verified
from __future__ import annotations
import argparse
import importlib.util
from pathlib import Path
import torch
import torchaudio
from transformers import AutoModel, AutoProcessor, BitsAndBytesConfig
# Match the demo's SDPA backend settings.
torch.backends.cuda.enable_cudnn_sdp(False)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)
MODEL_ID = "OpenMOSS-Team/MOSS-TTSD-v1.0"
CODEC_ID = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
PROMPT_TEXT_S1 = "[S1] In short, we embarked on a mission to make America great again for all Americans."
PROMPT_TEXT_S2 = (
"[S2] NVIDIA reinvented computing for the first time after 60 years. In fact, Erwin at IBM knows quite "
"well that the computer has largely been the same since the 60s."
)
TEXT_TO_GENERATE = """
[S1] Listen, let's talk business. China. I'm hearing things.
People are saying they're catching up. Fast. What's the real scoop?
Their AI—is it a threat?
[S2] Well, the pace of innovation there is extraordinary, honestly.
They have the researchers, and they have the drive.
[S1] Extraordinary? I don't like that. I want us to be extraordinary.
Are they winning?
[S2] I wouldn't say winning, but their progress is very promising.
They are building massive clusters. They're very determined.
[S1] Promising. There it is. I hate that word.
When China is promising, it means we're losing.
It's a disaster, Jensen. A total disaster.
""".strip()
DEFAULT_AUDIO_TEMPERATURE = 1.1
DEFAULT_AUDIO_TOP_P = 0.9
DEFAULT_AUDIO_TOP_K = 50
DEFAULT_AUDIO_REPETITION_PENALTY = 1.1
DEFAULT_MAX_NEW_TOKENS = 2000
def repo_root() -> Path:
return Path(__file__).resolve().parent
def resolve_attn_implementation(requested: str) -> str | None:
requested_norm = (requested or "").strip().lower()
if requested_norm == "none":
return None
if requested_norm and requested_norm != "auto":
return requested
if not torch.cuda.is_available():
return "eager"
if importlib.util.find_spec("flash_attn") is not None:
major, _ = torch.cuda.get_device_capability(0)
if major >= 8:
return "flash_attention_2"
return "sdpa"
def load_wav_mono_resampled(path: Path, target_sr: int) -> torch.Tensor:
wav, sr = torchaudio.load(str(path))
if wav.numel() == 0:
raise ValueError(f"Reference audio is empty: {path}")
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
if int(sr) != int(target_sr):
wav = torchaudio.functional.resample(wav, int(sr), int(target_sr))
return wav.contiguous()
def resolve_codec_device(requested: str) -> torch.device:
requested_norm = (requested or "").strip().lower()
if requested_norm and requested_norm != "auto":
return torch.device(requested)
total_gib = torch.cuda.get_device_properties(0).total_memory / (1024**3)
return torch.device("cpu" if total_gib < 16 else "cuda:0")
def resolve_max_memory(
requested_limit_gib: int | None,
cpu_limit_gib: int,
reserve_gib: int,
) -> dict[int | str, str] | None:
if not torch.cuda.is_available():
return None
gpu_limit_gib = requested_limit_gib
if gpu_limit_gib is None:
total_gib = torch.cuda.get_device_properties(0).total_memory / (1024**3)
if total_gib < 16:
gpu_limit_gib = max(4, int(total_gib) - max(1, reserve_gib))
if gpu_limit_gib is None:
return None
limits: dict[int | str, str] = {idx: f"{gpu_limit_gib}GiB" for idx in range(torch.cuda.device_count())}
limits["cpu"] = f"{cpu_limit_gib}GiB"
return limits
def build_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run MOSS-TTSD with bitsandbytes NF4 quantization.")
parser.add_argument("--model-id", default=MODEL_ID)
parser.add_argument("--codec-id", default=CODEC_ID)
parser.add_argument(
"--prompt-audio-s1",
default=str(repo_root() / "assets" / "audio" / "reference_02_s1.wav"),
)
parser.add_argument(
"--prompt-audio-s2",
default=str(repo_root() / "assets" / "audio" / "reference_02_s2.wav"),
)
parser.add_argument("--prompt-text-s1", default=PROMPT_TEXT_S1)
parser.add_argument("--prompt-text-s2", default=PROMPT_TEXT_S2)
parser.add_argument("--text-to-generate", default=TEXT_TO_GENERATE)
parser.add_argument("--output-dir", default=str(repo_root() / "output_nf4"))
parser.add_argument("--attn-implementation", default="sdpa")
parser.add_argument("--codec-device", default="auto")
parser.add_argument("--gpu-memory-limit-gib", type=int)
parser.add_argument("--cpu-memory-limit-gib", type=int, default=96)
parser.add_argument("--gpu-memory-reserve-gib", type=int, default=1)
parser.add_argument("--offload-folder", default=str(repo_root() / ".offload_nf4"))
parser.add_argument("--max-new-tokens", type=int, default=DEFAULT_MAX_NEW_TOKENS)
parser.add_argument("--audio-temperature", type=float, default=DEFAULT_AUDIO_TEMPERATURE)
parser.add_argument("--audio-top-p", type=float, default=DEFAULT_AUDIO_TOP_P)
parser.add_argument("--audio-top-k", type=int, default=DEFAULT_AUDIO_TOP_K)
parser.add_argument("--audio-repetition-penalty", type=float, default=DEFAULT_AUDIO_REPETITION_PENALTY)
return parser.parse_args()
def main() -> None:
args = build_args()
if not torch.cuda.is_available():
raise RuntimeError("The NF4 path requires CUDA. Use the normal fp32/bf16 MOSS-TTSD flow on CPU.")
torch.set_grad_enabled(False)
torch.backends.cuda.matmul.allow_tf32 = True
codec_device = resolve_codec_device(args.codec_device)
resolved_attn = resolve_attn_implementation(args.attn_implementation)
max_memory = resolve_max_memory(
requested_limit_gib=args.gpu_memory_limit_gib,
cpu_limit_gib=args.cpu_memory_limit_gib,
reserve_gib=args.gpu_memory_reserve_gib,
)
print(
f"[INFO] visible_gpus={torch.cuda.device_count()} attn_implementation={resolved_attn} "
f"codec_device={codec_device} max_memory={max_memory}"
)
processor = AutoProcessor.from_pretrained(
args.model_id,
trust_remote_code=True,
codec_path=args.codec_id,
)
if hasattr(processor, "audio_tokenizer"):
processor.audio_tokenizer = processor.audio_tokenizer.to(codec_device)
processor.audio_tokenizer.eval()
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model_kwargs = {
"trust_remote_code": True,
"quantization_config": quant_config,
"device_map": "auto",
"torch_dtype": torch.bfloat16,
"low_cpu_mem_usage": True,
}
if resolved_attn:
model_kwargs["attn_implementation"] = resolved_attn
if max_memory is not None:
offload_folder = Path(args.offload_folder)
offload_folder.mkdir(parents=True, exist_ok=True)
model_kwargs["max_memory"] = max_memory
model_kwargs["offload_folder"] = str(offload_folder)
model_kwargs["offload_state_dict"] = True
model = AutoModel.from_pretrained(args.model_id, **model_kwargs)
model.eval()
if hasattr(model, "hf_device_map"):
print(f"[INFO] hf_device_map={model.hf_device_map}")
try:
input_device = model.get_input_embeddings().weight.device
except Exception:
input_device = next(model.parameters()).device
target_sr = int(processor.model_config.sampling_rate)
wav1 = load_wav_mono_resampled(Path(args.prompt_audio_s1), target_sr).to(codec_device)
wav2 = load_wav_mono_resampled(Path(args.prompt_audio_s2), target_sr).to(codec_device)
reference_audio_codes = processor.encode_audios_from_wav(
[wav1, wav2],
sampling_rate=target_sr,
)
concat_prompt_wav = torch.cat([wav1, wav2], dim=-1)
prompt_audio = processor.encode_audios_from_wav(
[concat_prompt_wav],
sampling_rate=target_sr,
)[0]
full_text = f"{args.prompt_text_s1} {args.prompt_text_s2} {args.text_to_generate}".strip()
conversations = [[
processor.build_user_message(
text=full_text,
reference=reference_audio_codes,
),
processor.build_assistant_message(audio_codes_list=[prompt_audio]),
]]
batch = processor(conversations, mode="continuation")
input_ids = batch["input_ids"].to(input_device)
attention_mask = batch["attention_mask"].to(input_device)
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=args.max_new_tokens,
audio_temperature=args.audio_temperature,
audio_top_p=args.audio_top_p,
audio_top_k=args.audio_top_k,
audio_repetition_penalty=args.audio_repetition_penalty,
)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for message_idx, message in enumerate(processor.decode(outputs)):
for seg_idx, audio in enumerate(message.audio_codes_list):
audio_tensor = audio.detach().to(torch.float32).cpu() if isinstance(audio, torch.Tensor) else torch.tensor(audio, dtype=torch.float32)
if audio_tensor.ndim > 1:
audio_tensor = audio_tensor.reshape(-1)
out_path = output_dir / f"{message_idx}_{seg_idx}.wav"
torchaudio.save(str(out_path), audio_tensor.unsqueeze(0), target_sr)
print(f"saved: {out_path}")
if __name__ == "__main__":
main()