import os
import sys
import subprocess
import traceback
import gradio as gr
import numpy as np
import librosa
import spaces
import torch
from pathlib import Path
from huggingface_hub import snapshot_download
REPO_URL = "https://github.com/fishaudio/fish-speech.git"
REPO_DIR = "fish-speech"
if not os.path.exists(REPO_DIR):
subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
os.chdir(REPO_DIR)
sys.path.insert(0, os.getcwd())
from fish_speech.models.text2semantic.inference import init_model, generate_long
device = "cuda" if torch.cuda.is_available() else "cpu"
precision = torch.bfloat16
checkpoint_dir = snapshot_download(repo_id="fishaudio/s2-pro")
llama_model, decode_one_token = init_model(
checkpoint_path=checkpoint_dir,
device=device,
precision=precision,
compile=False,
)
with torch.device(device):
llama_model.setup_caches(
max_batch_size=1,
max_seq_len=llama_model.config.max_seq_len,
dtype=next(llama_model.parameters()).dtype,
)
def load_codec(codec_checkpoint_path, target_device, target_precision):
from hydra.utils import instantiate
from omegaconf import OmegaConf
cfg = OmegaConf.load(Path("fish_speech/configs/modded_dac_vq.yaml"))
codec = instantiate(cfg)
state_dict = torch.load(codec_checkpoint_path, map_location="cpu")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
if any("generator" in k for k in state_dict):
state_dict = {
k.replace("generator.", ""): v
for k, v in state_dict.items()
if "generator." in k
}
codec.load_state_dict(state_dict, strict=False)
codec.eval()
codec.to(device=target_device, dtype=target_precision)
return codec
codec_model = load_codec(os.path.join(checkpoint_dir, "codec.pth"), device, precision)
@torch.no_grad()
def encode_reference_audio(audio_path):
wav_np, _ = librosa.load(audio_path, sr=codec_model.sample_rate, mono=True)
wav = torch.from_numpy(wav_np).to(device)
model_dtype = next(codec_model.parameters()).dtype
audios = wav[None, None, :].to(dtype=model_dtype)
audio_lengths = torch.tensor([wav.shape[0]], device=device, dtype=torch.long)
indices, feature_lengths = codec_model.encode(audios, audio_lengths)
return indices[0, :, : feature_lengths[0]]
@torch.no_grad()
def decode_codes_to_audio(merged_codes):
audio = codec_model.from_indices(merged_codes[None])
return audio[0, 0]
whisper_model = None
def get_whisper_model():
global whisper_model
if whisper_model is None:
from faster_whisper import WhisperModel
whisper_model = WhisperModel("large-v3", device="cuda", compute_type="int8")
return whisper_model
@spaces.GPU(duration=60)
def transcribe_audio(audio_path):
if audio_path is None:
raise gr.Error("Please upload a reference audio file first.")
try:
gr.Info("Transcribing audio with Whisper large-v3...")
model = get_whisper_model()
segments, info = model.transcribe(audio_path, beam_size=5, vad_filter=True)
text = " ".join(seg.text.strip() for seg in segments).strip()
if not text:
raise gr.Error("Whisper could not detect any speech in the audio.")
gr.Info(f"Detected language: {info.language} ({info.language_probability:.0%} confidence)")
return text
except gr.Error:
raise
except Exception as e:
traceback.print_exc()
raise gr.Error(f"Transcription error: {str(e)}")
def estimate_duration(text):
words = len(text.split())
seconds = max(5, int(words * 0.4))
return seconds
@spaces.GPU(duration=180)
def tts_inference(
text,
ref_audio,
ref_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
):
try:
if not text or not text.strip():
raise gr.Error("Please enter some text to synthesize.")
est = estimate_duration(text)
gr.Info(f"Generating audio... estimated ~{est}s depending on text length.")
prompt_tokens_list = None
if ref_audio is not None and ref_text and ref_text.strip():
prompt_tokens_list = [encode_reference_audio(ref_audio).cpu()]
generator = generate_long(
model=llama_model,
device=device,
decode_one_token=decode_one_token,
text=text,
num_samples=1,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=30,
temperature=temperature,
repetition_penalty=repetition_penalty,
compile=False,
iterative_prompt=True,
chunk_length=chunk_length,
prompt_text=[ref_text] if ref_text else None,
prompt_tokens=prompt_tokens_list,
)
codes = []
for response in generator:
if response.action == "sample":
codes.append(response.codes)
elif response.action == "next":
break
if not codes:
raise gr.Error("No audio was generated. Please check your input text.")
merged_codes = codes[0] if len(codes) == 1 else torch.cat(codes, dim=1)
merged_codes = merged_codes.to(device)
audio_waveform = decode_codes_to_audio(merged_codes)
audio_np = audio_waveform.cpu().float().numpy()
audio_np = (audio_np * 32767).clip(-32768, 32767).astype(np.int16)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return (codec_model.sample_rate, audio_np)
except gr.Error:
raise
except Exception as e:
traceback.print_exc()
raise gr.Error(f"Inference error: {str(e)}")
TAGS = [
"[pause]", "[emphasis]", "[laughing]", "[inhale]", "[chuckle]", "[tsk]",
"[singing]", "[excited]", "[laughing tone]", "[interrupting]", "[chuckling]",
"[excited tone]", "[volume up]", "[echo]", "[angry]", "[low volume]", "[sigh]",
"[low voice]", "[whisper]", "[screaming]", "[shouting]", "[loud]", "[surprised]",
"[short pause]", "[exhale]", "[delight]", "[panting]", "[audience laughter]",
"[with strong accent]", "[volume down]", "[clearing throat]", "[sad]",
"[moaning]", "[shocked]",
]
TAGS_HTML = " ".join(
f'{t}' for t in TAGS
)
with gr.Blocks(title="Fish Audio S2 Pro") as app:
gr.Markdown(
f"""
State-of-the-Art Dual-Autoregressive Text-to-Speech · Model Page ↗ · GitHub ↗
80+ languages supported · Zero-shot voice cloning · 15,000+ inline emotion tags
15,000+ unique tags supported. Use free-form descriptions like
[whisper in small voice] or [professional broadcast tone].
Common tags:
Tier 1: Japanese · English · Chinese |
Tier 2: Korean · Spanish · Portuguese · Arabic · Russian · French · German
Also supported: sv, it, tr, no, nl, cy, eu, ca, da, gl, ta, hu, fi, pl, et, hi,
la, ur, th, vi, jw, bn, yo, sl, cs, sw, nn, he, ms, uk, id, kk, bg, lv, my, tl, sk, ne, fa,
af, el, bo, hr, ro, sn, mi, yi, am, be, km, is, az, sd, br, sq, ps, mn, ht, ml, sr, sa, te,
ka, bs, pa, lt, kn, si, hy, mr, as, gu, fo, and more.
Language is auto-detected from the input text — no configuration needed.