| from __future__ import annotations |
|
|
| import argparse |
| import importlib.util |
| from pathlib import Path |
|
|
| import torch |
| import torchaudio |
| from transformers import AutoModel, AutoProcessor, BitsAndBytesConfig |
|
|
| |
| torch.backends.cuda.enable_cudnn_sdp(False) |
| torch.backends.cuda.enable_flash_sdp(True) |
| torch.backends.cuda.enable_mem_efficient_sdp(True) |
| torch.backends.cuda.enable_math_sdp(True) |
|
|
| MODEL_ID = "OpenMOSS-Team/MOSS-TTSD-v1.0" |
| CODEC_ID = "OpenMOSS-Team/MOSS-Audio-Tokenizer" |
|
|
| PROMPT_TEXT_S1 = "[S1] In short, we embarked on a mission to make America great again for all Americans." |
| PROMPT_TEXT_S2 = ( |
| "[S2] NVIDIA reinvented computing for the first time after 60 years. In fact, Erwin at IBM knows quite " |
| "well that the computer has largely been the same since the 60s." |
| ) |
| TEXT_TO_GENERATE = """ |
| [S1] Listen, let's talk business. China. I'm hearing things. |
| People are saying they're catching up. Fast. What's the real scoop? |
| Their AI—is it a threat? |
| [S2] Well, the pace of innovation there is extraordinary, honestly. |
| They have the researchers, and they have the drive. |
| [S1] Extraordinary? I don't like that. I want us to be extraordinary. |
| Are they winning? |
| [S2] I wouldn't say winning, but their progress is very promising. |
| They are building massive clusters. They're very determined. |
| [S1] Promising. There it is. I hate that word. |
| When China is promising, it means we're losing. |
| It's a disaster, Jensen. A total disaster. |
| """.strip() |
|
|
| DEFAULT_AUDIO_TEMPERATURE = 1.1 |
| DEFAULT_AUDIO_TOP_P = 0.9 |
| DEFAULT_AUDIO_TOP_K = 50 |
| DEFAULT_AUDIO_REPETITION_PENALTY = 1.1 |
| DEFAULT_MAX_NEW_TOKENS = 2000 |
|
|
|
|
| def repo_root() -> Path: |
| return Path(__file__).resolve().parent |
|
|
|
|
| def resolve_attn_implementation(requested: str) -> str | None: |
| requested_norm = (requested or "").strip().lower() |
| if requested_norm == "none": |
| return None |
| if requested_norm and requested_norm != "auto": |
| return requested |
|
|
| if not torch.cuda.is_available(): |
| return "eager" |
|
|
| if importlib.util.find_spec("flash_attn") is not None: |
| major, _ = torch.cuda.get_device_capability(0) |
| if major >= 8: |
| return "flash_attention_2" |
| return "sdpa" |
|
|
|
|
| def load_wav_mono_resampled(path: Path, target_sr: int) -> torch.Tensor: |
| wav, sr = torchaudio.load(str(path)) |
| if wav.numel() == 0: |
| raise ValueError(f"Reference audio is empty: {path}") |
| if wav.shape[0] > 1: |
| wav = wav.mean(dim=0, keepdim=True) |
| if int(sr) != int(target_sr): |
| wav = torchaudio.functional.resample(wav, int(sr), int(target_sr)) |
| return wav.contiguous() |
|
|
|
|
| def resolve_codec_device(requested: str) -> torch.device: |
| requested_norm = (requested or "").strip().lower() |
| if requested_norm and requested_norm != "auto": |
| return torch.device(requested) |
| total_gib = torch.cuda.get_device_properties(0).total_memory / (1024**3) |
| return torch.device("cpu" if total_gib < 16 else "cuda:0") |
|
|
|
|
| def resolve_max_memory( |
| requested_limit_gib: int | None, |
| cpu_limit_gib: int, |
| reserve_gib: int, |
| ) -> dict[int | str, str] | None: |
| if not torch.cuda.is_available(): |
| return None |
|
|
| gpu_limit_gib = requested_limit_gib |
| if gpu_limit_gib is None: |
| total_gib = torch.cuda.get_device_properties(0).total_memory / (1024**3) |
| if total_gib < 16: |
| gpu_limit_gib = max(4, int(total_gib) - max(1, reserve_gib)) |
|
|
| if gpu_limit_gib is None: |
| return None |
|
|
| limits: dict[int | str, str] = {idx: f"{gpu_limit_gib}GiB" for idx in range(torch.cuda.device_count())} |
| limits["cpu"] = f"{cpu_limit_gib}GiB" |
| return limits |
|
|
|
|
| def build_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Run MOSS-TTSD with bitsandbytes NF4 quantization.") |
| parser.add_argument("--model-id", default=MODEL_ID) |
| parser.add_argument("--codec-id", default=CODEC_ID) |
| parser.add_argument( |
| "--prompt-audio-s1", |
| default=str(repo_root() / "assets" / "audio" / "reference_02_s1.wav"), |
| ) |
| parser.add_argument( |
| "--prompt-audio-s2", |
| default=str(repo_root() / "assets" / "audio" / "reference_02_s2.wav"), |
| ) |
| parser.add_argument("--prompt-text-s1", default=PROMPT_TEXT_S1) |
| parser.add_argument("--prompt-text-s2", default=PROMPT_TEXT_S2) |
| parser.add_argument("--text-to-generate", default=TEXT_TO_GENERATE) |
| parser.add_argument("--output-dir", default=str(repo_root() / "output_nf4")) |
| parser.add_argument("--attn-implementation", default="sdpa") |
| parser.add_argument("--codec-device", default="auto") |
| parser.add_argument("--gpu-memory-limit-gib", type=int) |
| parser.add_argument("--cpu-memory-limit-gib", type=int, default=96) |
| parser.add_argument("--gpu-memory-reserve-gib", type=int, default=1) |
| parser.add_argument("--offload-folder", default=str(repo_root() / ".offload_nf4")) |
| parser.add_argument("--max-new-tokens", type=int, default=DEFAULT_MAX_NEW_TOKENS) |
| parser.add_argument("--audio-temperature", type=float, default=DEFAULT_AUDIO_TEMPERATURE) |
| parser.add_argument("--audio-top-p", type=float, default=DEFAULT_AUDIO_TOP_P) |
| parser.add_argument("--audio-top-k", type=int, default=DEFAULT_AUDIO_TOP_K) |
| parser.add_argument("--audio-repetition-penalty", type=float, default=DEFAULT_AUDIO_REPETITION_PENALTY) |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = build_args() |
|
|
| if not torch.cuda.is_available(): |
| raise RuntimeError("The NF4 path requires CUDA. Use the normal fp32/bf16 MOSS-TTSD flow on CPU.") |
|
|
| torch.set_grad_enabled(False) |
| torch.backends.cuda.matmul.allow_tf32 = True |
|
|
| codec_device = resolve_codec_device(args.codec_device) |
| resolved_attn = resolve_attn_implementation(args.attn_implementation) |
| max_memory = resolve_max_memory( |
| requested_limit_gib=args.gpu_memory_limit_gib, |
| cpu_limit_gib=args.cpu_memory_limit_gib, |
| reserve_gib=args.gpu_memory_reserve_gib, |
| ) |
| print( |
| f"[INFO] visible_gpus={torch.cuda.device_count()} attn_implementation={resolved_attn} " |
| f"codec_device={codec_device} max_memory={max_memory}" |
| ) |
|
|
| processor = AutoProcessor.from_pretrained( |
| args.model_id, |
| trust_remote_code=True, |
| codec_path=args.codec_id, |
| ) |
| if hasattr(processor, "audio_tokenizer"): |
| processor.audio_tokenizer = processor.audio_tokenizer.to(codec_device) |
| processor.audio_tokenizer.eval() |
|
|
| quant_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) |
|
|
| model_kwargs = { |
| "trust_remote_code": True, |
| "quantization_config": quant_config, |
| "device_map": "auto", |
| "torch_dtype": torch.bfloat16, |
| "low_cpu_mem_usage": True, |
| } |
| if resolved_attn: |
| model_kwargs["attn_implementation"] = resolved_attn |
| if max_memory is not None: |
| offload_folder = Path(args.offload_folder) |
| offload_folder.mkdir(parents=True, exist_ok=True) |
| model_kwargs["max_memory"] = max_memory |
| model_kwargs["offload_folder"] = str(offload_folder) |
| model_kwargs["offload_state_dict"] = True |
|
|
| model = AutoModel.from_pretrained(args.model_id, **model_kwargs) |
| model.eval() |
| if hasattr(model, "hf_device_map"): |
| print(f"[INFO] hf_device_map={model.hf_device_map}") |
|
|
| try: |
| input_device = model.get_input_embeddings().weight.device |
| except Exception: |
| input_device = next(model.parameters()).device |
|
|
| target_sr = int(processor.model_config.sampling_rate) |
|
|
| wav1 = load_wav_mono_resampled(Path(args.prompt_audio_s1), target_sr).to(codec_device) |
| wav2 = load_wav_mono_resampled(Path(args.prompt_audio_s2), target_sr).to(codec_device) |
|
|
| reference_audio_codes = processor.encode_audios_from_wav( |
| [wav1, wav2], |
| sampling_rate=target_sr, |
| ) |
| concat_prompt_wav = torch.cat([wav1, wav2], dim=-1) |
| prompt_audio = processor.encode_audios_from_wav( |
| [concat_prompt_wav], |
| sampling_rate=target_sr, |
| )[0] |
|
|
| full_text = f"{args.prompt_text_s1} {args.prompt_text_s2} {args.text_to_generate}".strip() |
| conversations = [[ |
| processor.build_user_message( |
| text=full_text, |
| reference=reference_audio_codes, |
| ), |
| processor.build_assistant_message(audio_codes_list=[prompt_audio]), |
| ]] |
|
|
| batch = processor(conversations, mode="continuation") |
| input_ids = batch["input_ids"].to(input_device) |
| attention_mask = batch["attention_mask"].to(input_device) |
|
|
| outputs = model.generate( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| max_new_tokens=args.max_new_tokens, |
| audio_temperature=args.audio_temperature, |
| audio_top_p=args.audio_top_p, |
| audio_top_k=args.audio_top_k, |
| audio_repetition_penalty=args.audio_repetition_penalty, |
| ) |
|
|
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| for message_idx, message in enumerate(processor.decode(outputs)): |
| for seg_idx, audio in enumerate(message.audio_codes_list): |
| audio_tensor = audio.detach().to(torch.float32).cpu() if isinstance(audio, torch.Tensor) else torch.tensor(audio, dtype=torch.float32) |
| if audio_tensor.ndim > 1: |
| audio_tensor = audio_tensor.reshape(-1) |
| out_path = output_dir / f"{message_idx}_{seg_idx}.wav" |
| torchaudio.save(str(out_path), audio_tensor.unsqueeze(0), target_sr) |
| print(f"saved: {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|