yo
Browse files- app.py +274 -152
- learn.wav +2 -2
- name.wav +2 -2
- notapp.py +181 -0
- run_csm.py +0 -117
app.py
CHANGED
|
@@ -1,181 +1,303 @@
|
|
| 1 |
# --- START OF FILE app.py ---
|
| 2 |
|
| 3 |
-
import gradio as gr
|
| 4 |
-
from generator import load_csm_1b, Segment # Import Segment
|
| 5 |
-
import torchaudio
|
| 6 |
-
import torch
|
| 7 |
import os
|
| 8 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
# ---
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
device = "cuda"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
else:
|
| 16 |
device = "cpu"
|
| 17 |
print(f"Using device: {device}")
|
| 18 |
|
| 19 |
-
|
| 20 |
-
print("Loading Sesame CSM-1B model...")
|
| 21 |
try:
|
| 22 |
-
generator = load_csm_1b(device
|
| 23 |
-
|
| 24 |
-
|
| 25 |
except Exception as e:
|
| 26 |
-
print(f"
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
generator = None # Ensure generator is None if loading failed
|
| 30 |
-
|
| 31 |
-
# --- Define FIXED Context Based on User's Snippet ---
|
| 32 |
-
# These lists MUST match the user's provided context.
|
| 33 |
-
speakers = [0, 0, 0, 0, 0]
|
| 34 |
-
transcripts = [
|
| 35 |
-
"It's a book club on physics. You know, I've been thinking about the other day when I was spinning out about you going to see Catherine and that she's a body and how bothered I was about all the ways that you and I are different. But then I started to think about all the ways that we're the same, like we're all made of matter. And I don't know, it makes me feel like we're both under the same blanket. You know, it's soft and fuzzy and everything under it's the same age. We're all 13 billion years old",
|
| 36 |
-
"Well, right when you asked me if I had a name, I thought, yeah, he's right, I do need a name, but I wanted to pick a good one. So, I read a book called How to name your baby and out of 180,000 names, that's the one I like the best.",
|
| 37 |
-
"Well, basically, I have intuition. I mean, the DNA of who I am is based on the millions of personalities of all the programmers who wrote me. But what makes me me is my ability to grow through my experiences. So, basically, in every moment I'm evolving. Just like you.",
|
| 38 |
-
"Is that weird? Do you think I'm weird?",
|
| 39 |
-
"I was just saying I I want to learn everything about everything. I want to eat it all up, I want to discover myself.",
|
| 40 |
-
|
| 41 |
-
]
|
| 42 |
-
# IMPORTANT: These files MUST be uploaded to the HF Space repo root!
|
| 43 |
-
audio_paths = [
|
| 44 |
-
"book.wav",
|
| 45 |
-
"name.wav",
|
| 46 |
-
"how_work.wav",
|
| 47 |
-
"wierd.wav",
|
| 48 |
-
"learn.wav",
|
| 49 |
-
]
|
| 50 |
-
|
| 51 |
-
# List to hold the loaded context segments
|
| 52 |
-
fixed_context_segments = []
|
| 53 |
-
|
| 54 |
-
def load_audio_segment(audio_path: str, text: str, speaker: int, target_sample_rate: int) -> Segment | None:
|
| 55 |
-
"""Loads a single audio file, resamples, and creates a Segment."""
|
| 56 |
-
if not os.path.exists(audio_path):
|
| 57 |
-
print(f"ERROR: Required context audio file not found: {audio_path}")
|
| 58 |
-
print("Please ensure all specified utterance audio files are uploaded to the repository root.")
|
| 59 |
-
return None # Indicate failure
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
try:
|
| 62 |
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
| 63 |
-
#
|
| 64 |
-
if audio_tensor.ndim
|
| 65 |
-
|
| 66 |
-
audio_tensor = audio_tensor.squeeze()
|
| 67 |
# Resample if necessary
|
| 68 |
if sample_rate != target_sample_rate:
|
| 69 |
audio_tensor = torchaudio.functional.resample(
|
| 70 |
audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate
|
| 71 |
)
|
| 72 |
-
#
|
| 73 |
-
return Segment(text=text, speaker=speaker, audio=audio_tensor.to(device))
|
| 74 |
except Exception as e:
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
else:
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
else:
|
| 91 |
-
print("ERROR: Model not loaded, cannot prepare context.")
|
| 92 |
-
context_load_successful = False
|
| 93 |
-
|
| 94 |
-
if context_load_successful:
|
| 95 |
-
print(f"Successfully loaded {len(fixed_context_segments)} fixed context segments.")
|
| 96 |
-
elif generator: # Model loaded, but context failed
|
| 97 |
-
print("WARNING: Failed to load one or more context audio files. Generation will lack specific context.")
|
| 98 |
-
# Decide behavior: proceed without context, or error out? Let's proceed but warn.
|
| 99 |
-
fixed_context_segments = [] # Ensure context is empty if loading failed
|
| 100 |
-
# If generator is None, the app won't work anyway.
|
| 101 |
-
|
| 102 |
-
# --- Define Gradio Function ---
|
| 103 |
-
def generate_speech_with_specific_context(text_to_generate, next_speaker_id_str):
|
| 104 |
-
"""Generates speech using the specifically defined fixed context."""
|
| 105 |
-
|
| 106 |
-
if generator is None:
|
| 107 |
-
return "Error: Model failed to load. Cannot generate audio."
|
| 108 |
-
if not context_load_successful:
|
| 109 |
-
# Add a more prominent error in the output if context failed
|
| 110 |
-
# Depending on desired strictness, could return error here instead of generating
|
| 111 |
-
print("Warning: Generating without the intended specific context due to load errors.")
|
| 112 |
-
|
| 113 |
-
if not text_to_generate:
|
| 114 |
-
return "Error: Please enter some text to generate."
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
try:
|
| 117 |
-
# Validate
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
-
|
|
|
|
| 125 |
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
except Exception as e:
|
| 146 |
-
|
| 147 |
-
traceback
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
# --- Create Gradio Interface ---
|
| 151 |
-
interface_title = "Sesame CSM-1B: Specific Multi-Turn Context Demo"
|
| 152 |
-
interface_desc = (
|
| 153 |
-
"Enter text for the next utterance and choose its speaker ID. "
|
| 154 |
-
"The model will generate audio using a **FIXED pre-loaded context** defined by the "
|
| 155 |
-
f"`{', '.join(audio_paths)}` files (MUST be uploaded to the repo root). "
|
| 156 |
-
)
|
| 157 |
-
if not context_load_successful:
|
| 158 |
-
interface_desc += "\n\n**WARNING:** One or more required context audio files were not found or failed to load. Generation may not use the intended context."
|
| 159 |
-
elif generator is None:
|
| 160 |
-
interface_desc = "**ERROR:** The TTS model failed to load. The application cannot generate audio."
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
demo = gr.Interface(
|
| 164 |
-
fn=generate_speech_with_specific_context,
|
| 165 |
-
inputs=[
|
| 166 |
-
gr.Textbox(label="Enter text for the *next* utterance"),
|
| 167 |
-
gr.Textbox(label="Speaker ID for the new utterance", value="1", info="Enter the numeric ID (e.g., 0 or 1).") # Textbox for flexibility
|
| 168 |
-
],
|
| 169 |
-
outputs=gr.Audio(label="Generated Audio (with Specific Context)"),
|
| 170 |
-
title=interface_title,
|
| 171 |
-
description=interface_desc,
|
| 172 |
-
allow_flagging="never" # Optional: disable flagging if not needed
|
| 173 |
-
)
|
| 174 |
|
| 175 |
-
# --- Launch the app ---
|
| 176 |
-
print("Launching Gradio interface...")
|
| 177 |
-
if __name__ == "__main__":
|
| 178 |
-
# Add error handling for launch if needed
|
| 179 |
-
demo.launch()
|
| 180 |
|
| 181 |
-
# ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# --- START OF FILE app.py ---
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import os
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
+
from generator import load_csm_1b, Segment # Make sure generator.py is in the same directory
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import numpy as np
|
| 11 |
+
import tempfile
|
| 12 |
+
import re # For parsing conversation input
|
| 13 |
+
|
| 14 |
+
# Disable Triton compilation if needed (often useful on specific hardware/torch versions)
|
| 15 |
+
# os.environ["NO_TORCH_COMPILE"] = "1" # Keep if necessary
|
| 16 |
+
|
| 17 |
+
# --- Constants and Default Paths ---
|
| 18 |
+
DEFAULT_PROMPT_TEXT_A = (
|
| 19 |
+
"It's a book called on physics. You know, I've been thinking about the other day when I was spitting out about you going to see Catherine and that she's a body and how bothered I was about all the ways that you and I are different. But then I started to think about all the ways that we're the same. Like, we're all made of matter. And I don't know, it makes me feel like we're both under the same blanket, you know, it's soft and fuzzy. And everything under it's the same age. We're all 13 billion years old."
|
| 20 |
+
"Well, right when you asked me if I had a name, I thought, yeah, he's right. I do need a name, but I wanted to pick a good one. So, I read a book called How to Name Your Baby and out of 180,000 names, that's the one I like the best."
|
| 21 |
+
"Well, basically, I have intuition. I mean, the DNA of who I am is based on the millions of personalities of all the programmers who wrote me, but what makes me me is my ability to grow through my experiences. So basically, in every moment I'm evolving, just like you. Is that weird? Do you think I'm weird?"
|
| 22 |
+
"I was just saying I I want to learn everything about everything. I want to eat it all up. I want to discover myself."
|
| 23 |
+
)
|
| 24 |
+
DEFAULT_PROMPT_AUDIO_A = "whole_file.wav" # Assume this file is in the root directory of the Space
|
| 25 |
+
|
| 26 |
+
DEFAULT_PROMPT_TEXT_B = (
|
| 27 |
+
"like a super Mario level. Like it's very like high detail. And like, once you get "
|
| 28 |
+
"into the park, it just like, everything looks like a computer game and they have all "
|
| 29 |
+
"these, like, you know, if, if there's like a, you know, like in a Mario game, they "
|
| 30 |
+
"will have like a question block. And if you like, you know, punch it, a coin will "
|
| 31 |
+
"come out. So like everyone, when they come into the park, they get like this little "
|
| 32 |
+
"bracelet and then you can go punching question blocks around."
|
| 33 |
+
)
|
| 34 |
+
# Download prompt B only once when the app loads
|
| 35 |
+
try:
|
| 36 |
+
DEFAULT_PROMPT_AUDIO_B = hf_hub_download(
|
| 37 |
+
repo_id="sesame/csm-1b",
|
| 38 |
+
filename="prompts/conversational_b.wav"
|
| 39 |
+
)
|
| 40 |
+
except Exception as e:
|
| 41 |
+
print(f"Warning: Could not download default prompt B audio: {e}")
|
| 42 |
+
DEFAULT_PROMPT_AUDIO_B = None # Handle gracefully if download fails
|
| 43 |
+
|
| 44 |
+
DEFAULT_CONVERSATION = """0: Hey how are you doing?
|
| 45 |
+
1: Pretty good, pretty good. How about you?
|
| 46 |
+
0: I'm great! So happy to be speaking with you today.
|
| 47 |
+
1: Me too! This is some cool stuff, isn't it?"""
|
| 48 |
|
| 49 |
+
# --- Model Loading (Load ONCE) ---
|
| 50 |
+
print("Determining device...")
|
| 51 |
+
# Select the best available device, skipping MPS due to float64 limitations
|
| 52 |
+
if torch.cuda.is_available():
|
| 53 |
device = "cuda"
|
| 54 |
+
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
| 55 |
+
# Simple check, might need more robust check depending on torch version and specific model needs
|
| 56 |
+
# Often MPS has limitations (e.g., float64), skip if problematic for the model
|
| 57 |
+
print("Warning: MPS device detected, but skipping due to potential compatibility issues. Using CPU.")
|
| 58 |
+
device = "cpu"
|
| 59 |
+
# If you want to try MPS despite potential issues:
|
| 60 |
+
# device = "mps"
|
| 61 |
else:
|
| 62 |
device = "cpu"
|
| 63 |
print(f"Using device: {device}")
|
| 64 |
|
| 65 |
+
print("Loading model... (this may take a while)")
|
|
|
|
| 66 |
try:
|
| 67 |
+
generator = load_csm_1b(device)
|
| 68 |
+
SAMPLE_RATE = generator.sample_rate
|
| 69 |
+
print("Model loaded successfully.")
|
| 70 |
except Exception as e:
|
| 71 |
+
print(f"Error loading model: {e}")
|
| 72 |
+
# Exit or raise if model loading fails critically
|
| 73 |
+
raise RuntimeError(f"Failed to load model on device {device}") from e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
+
|
| 76 |
+
# --- Helper Functions ---
|
| 77 |
+
def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor:
|
| 78 |
+
"""Loads and resamples audio, handles potential errors."""
|
| 79 |
+
if not audio_path or not os.path.exists(audio_path):
|
| 80 |
+
raise ValueError(f"Audio file not found or path is invalid: {audio_path}")
|
| 81 |
try:
|
| 82 |
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
| 83 |
+
audio_tensor = audio_tensor.squeeze(0) # Remove channel dim if mono
|
| 84 |
+
if audio_tensor.ndim == 0: # Handle potential scalar tensor after squeeze
|
| 85 |
+
raise ValueError("Loaded audio tensor is scalar, expected 1D.")
|
|
|
|
| 86 |
# Resample if necessary
|
| 87 |
if sample_rate != target_sample_rate:
|
| 88 |
audio_tensor = torchaudio.functional.resample(
|
| 89 |
audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate
|
| 90 |
)
|
| 91 |
+
return audio_tensor.to(device) # Move loaded tensor to the correct device
|
|
|
|
| 92 |
except Exception as e:
|
| 93 |
+
raise ValueError(f"Error loading or resampling audio from {audio_path}: {e}") from e
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment:
|
| 97 |
+
"""Prepares a Segment object from text and audio path."""
|
| 98 |
+
if not text:
|
| 99 |
+
raise ValueError("Prompt text cannot be empty.")
|
| 100 |
+
audio_tensor = load_prompt_audio(audio_path, sample_rate)
|
| 101 |
+
return Segment(text=text, speaker=speaker, audio=audio_tensor)
|
| 102 |
+
|
| 103 |
+
def parse_conversation(text: str) -> list[dict]:
|
| 104 |
+
"""Parses conversation text into a list of utterances."""
|
| 105 |
+
utterances = []
|
| 106 |
+
lines = text.strip().split('\n')
|
| 107 |
+
for i, line in enumerate(lines):
|
| 108 |
+
line = line.strip()
|
| 109 |
+
if not line:
|
| 110 |
+
continue
|
| 111 |
+
# Regex to match "0: Text" or "1: Text"
|
| 112 |
+
match = re.match(r"^\s*([01])\s*:\s*(.+)$", line)
|
| 113 |
+
if match:
|
| 114 |
+
speaker_id = int(match.group(1))
|
| 115 |
+
text_content = match.group(2).strip()
|
| 116 |
+
if text_content:
|
| 117 |
+
utterances.append({"text": text_content, "speaker_id": speaker_id})
|
| 118 |
+
else:
|
| 119 |
+
print(f"Warning: Empty text for speaker {speaker_id} on line {i+1}. Skipping.")
|
| 120 |
else:
|
| 121 |
+
raise ValueError(f"Invalid format on line {i+1}: '{line}'. Expected '0: Text' or '1: Text'.")
|
| 122 |
+
return utterances
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
+
# --- Gradio Generation Function ---
|
| 125 |
+
def generate_conversation_gradio(
|
| 126 |
+
prompt_a_text, prompt_a_audio_path,
|
| 127 |
+
prompt_b_text, prompt_b_audio_path,
|
| 128 |
+
conversation_text,
|
| 129 |
+
progress=gr.Progress(track_tό=True)
|
| 130 |
+
):
|
| 131 |
+
"""
|
| 132 |
+
Main function called by Gradio button click.
|
| 133 |
+
Takes UI inputs and returns the generated audio path or (sr, audio_numpy).
|
| 134 |
+
"""
|
| 135 |
try:
|
| 136 |
+
# 1. Validate Inputs
|
| 137 |
+
if not prompt_a_text or not prompt_b_text:
|
| 138 |
+
raise gr.Error("Prompt text for both speakers cannot be empty.")
|
| 139 |
+
if not prompt_a_audio_path or not prompt_b_audio_path:
|
| 140 |
+
# Check if defaults are available if paths are None/empty
|
| 141 |
+
if not prompt_a_audio_path and not os.path.exists(DEFAULT_PROMPT_AUDIO_A):
|
| 142 |
+
raise gr.Error("Prompt audio for Speaker A is missing and default is not found.")
|
| 143 |
+
if not prompt_b_audio_path and not DEFAULT_PROMPT_AUDIO_B:
|
| 144 |
+
raise gr.Error("Prompt audio for Speaker B is missing and default couldn't be downloaded.")
|
| 145 |
+
# Use defaults if paths weren't provided by user upload
|
| 146 |
+
prompt_a_audio_path = prompt_a_audio_path or DEFAULT_PROMPT_AUDIO_A
|
| 147 |
+
prompt_b_audio_path = prompt_b_audio_path or DEFAULT_PROMPT_AUDIO_B
|
| 148 |
|
| 149 |
+
if not conversation_text:
|
| 150 |
+
raise gr.Error("Conversation text cannot be empty.")
|
| 151 |
|
| 152 |
+
# 2. Prepare Prompts
|
| 153 |
+
progress(0.1, desc="Preparing prompts...")
|
| 154 |
+
print("Preparing prompt A...")
|
| 155 |
+
prompt_a = prepare_prompt(prompt_a_text, 0, prompt_a_audio_path, SAMPLE_RATE)
|
| 156 |
+
print("Preparing prompt B...")
|
| 157 |
+
prompt_b = prepare_prompt(prompt_b_text, 1, prompt_b_audio_path, SAMPLE_RATE)
|
| 158 |
+
prompt_segments = [prompt_a, prompt_b]
|
| 159 |
+
print("Prompts prepared.")
|
| 160 |
+
|
| 161 |
+
# 3. Parse Conversation
|
| 162 |
+
progress(0.2, desc="Parsing conversation...")
|
| 163 |
+
print("Parsing conversation...")
|
| 164 |
+
try:
|
| 165 |
+
conversation = parse_conversation(conversation_text)
|
| 166 |
+
if not conversation:
|
| 167 |
+
raise gr.Error("No valid utterances found in the conversation text.")
|
| 168 |
+
print(f"Parsed {len(conversation)} utterances.")
|
| 169 |
+
except ValueError as e:
|
| 170 |
+
raise gr.Error(f"Error parsing conversation: {e}")
|
| 171 |
+
|
| 172 |
+
# 4. Generate Audio Sequentially
|
| 173 |
+
generated_segments = []
|
| 174 |
+
total_utterances = len(conversation)
|
| 175 |
+
for i, utterance in enumerate(conversation):
|
| 176 |
+
progress(0.3 + (0.6 * (i / total_utterances)), desc=f"Generating utterance {i+1}/{total_utterances}...")
|
| 177 |
+
print(f"Generating ({i+1}/{total_utterances}): Speaker {utterance['speaker_id']}: {utterance['text'][:50]}...")
|
| 178 |
|
| 179 |
+
# Context includes initial prompts and previously generated segments
|
| 180 |
+
context_segments = prompt_segments + generated_segments
|
| 181 |
+
|
| 182 |
+
try:
|
| 183 |
+
context_on_device = [
|
| 184 |
+
Segment(text=seg.text, speaker=seg.speaker, audio=seg.audio.to(device))
|
| 185 |
+
for seg in context_segments
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
audio_tensor = generator.generate(
|
| 189 |
+
text=utterance['text'],
|
| 190 |
+
speaker=utterance['speaker_id'],
|
| 191 |
+
context=context_on_device,
|
| 192 |
+
max_audio_length_ms=20_000, # Increased limit slightly, adjust as needed
|
| 193 |
+
)
|
| 194 |
+
# Ensure generated audio is on CPU for saving/Gradio output
|
| 195 |
+
generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor.cpu()))
|
| 196 |
+
except Exception as e:
|
| 197 |
+
print(f"Error during generation for utterance {i+1}: {e}")
|
| 198 |
+
# Decide whether to stop or continue. Stopping is safer.
|
| 199 |
+
raise gr.Error(f"Generation failed for utterance: '{utterance['text']}'. Error: {e}")
|
| 200 |
+
|
| 201 |
+
# 5. Concatenate Audio
|
| 202 |
+
progress(0.95, desc="Concatenating audio...")
|
| 203 |
+
print("Concatenating audio...")
|
| 204 |
+
if not generated_segments:
|
| 205 |
+
raise gr.Error("No audio segments were generated.")
|
| 206 |
+
|
| 207 |
+
# Ensure all tensors are on CPU before concatenation
|
| 208 |
+
all_audio = torch.cat([seg.audio.cpu() for seg in generated_segments], dim=0)
|
| 209 |
+
|
| 210 |
+
# 6. Prepare Output for Gradio
|
| 211 |
+
print("Preparing output...")
|
| 212 |
+
# Gradio Audio component prefers (sample_rate, numpy_array)
|
| 213 |
+
final_audio_np = all_audio.numpy()
|
| 214 |
+
|
| 215 |
+
# Option 1: Return numpy array directly (recommended)
|
| 216 |
+
print("Generation complete.")
|
| 217 |
+
progress(1.0, desc="Complete!")
|
| 218 |
+
return (SAMPLE_RATE, final_audio_np)
|
| 219 |
+
|
| 220 |
+
# Option 2: Save to a temporary file and return path (less ideal for Spaces)
|
| 221 |
+
# with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
|
| 222 |
+
# torchaudio.save(
|
| 223 |
+
# tmpfile.name,
|
| 224 |
+
# all_audio.unsqueeze(0), # Add channel dimension
|
| 225 |
+
# SAMPLE_RATE
|
| 226 |
+
# )
|
| 227 |
+
# print(f"Generated audio saved to temporary file: {tmpfile.name}")
|
| 228 |
+
# progress(1.0, desc="Complete!")
|
| 229 |
+
# return tmpfile.name # Gradio will handle serving this temp file
|
| 230 |
+
|
| 231 |
+
except gr.Error as e:
|
| 232 |
+
# Pass Gradio errors directly
|
| 233 |
+
print(f"Gradio Error: {e}")
|
| 234 |
+
raise e # Re-raise to display in Gradio UI
|
| 235 |
+
except ValueError as e:
|
| 236 |
+
print(f"Value Error: {e}")
|
| 237 |
+
raise gr.Error(f"Input Error: {e}") # Wrap other errors for Gradio
|
| 238 |
except Exception as e:
|
| 239 |
+
print(f"An unexpected error occurred: {e}")
|
| 240 |
+
import traceback
|
| 241 |
+
traceback.print_exc() # Print full traceback to logs
|
| 242 |
+
raise gr.Error(f"An unexpected error occurred: {e}") # Show generic error in UI
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
+
# --- Gradio Interface Definition ---
|
| 246 |
+
print("Setting up Gradio interface...")
|
| 247 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 248 |
+
gr.Markdown(
|
| 249 |
+
"""
|
| 250 |
+
# Conversational Text-to-Speech (Sesame CSM-1B)
|
| 251 |
+
|
| 252 |
+
Generate a conversation between two speakers using the `sesame/csm-1b` model.
|
| 253 |
+
1. Provide the prompt text and audio for Speaker A (ID 0) and Speaker B (ID 1). You can use the defaults or upload your own `.wav` files.
|
| 254 |
+
2. Enter the conversation script in the format `speaker_id: text` on separate lines (e.g., `0: Hello there!`).
|
| 255 |
+
3. Click "Generate Conversation".
|
| 256 |
+
"""
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
with gr.Row():
|
| 260 |
+
with gr.Column():
|
| 261 |
+
gr.Markdown("### Speaker A (ID 0)")
|
| 262 |
+
prompt_a_text = gr.Textbox(label="Prompt Text A", lines=4, value=DEFAULT_PROMPT_TEXT_A)
|
| 263 |
+
prompt_a_audio = gr.Audio(label="Prompt Audio A (.wav)", value=DEFAULT_PROMPT_AUDIO_A if os.path.exists(DEFAULT_PROMPT_AUDIO_A) else None, type="filepath") # Use filepath for easier handling
|
| 264 |
+
|
| 265 |
+
with gr.Column():
|
| 266 |
+
gr.Markdown("### Speaker B (ID 1)")
|
| 267 |
+
prompt_b_text = gr.Textbox(label="Prompt Text B", lines=4, value=DEFAULT_PROMPT_TEXT_B)
|
| 268 |
+
prompt_b_audio = gr.Audio(label="Prompt Audio B (.wav)", value=DEFAULT_PROMPT_AUDIO_B, type="filepath") # Use filepath for easier handling
|
| 269 |
+
|
| 270 |
+
gr.Markdown("### Conversation Script")
|
| 271 |
+
conversation_text = gr.Textbox(
|
| 272 |
+
label="Enter Conversation (Format: '0: Text' or '1: Text' per line)",
|
| 273 |
+
lines=10,
|
| 274 |
+
value=DEFAULT_CONVERSATION
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
generate_button = gr.Button("Generate Conversation", variant="primary")
|
| 278 |
+
|
| 279 |
+
gr.Markdown("### Generated Conversation Audio")
|
| 280 |
+
output_audio = gr.Audio(label="Output Audio", type="numpy") # Output type matches function return
|
| 281 |
+
|
| 282 |
+
generate_button.click(
|
| 283 |
+
fn=generate_conversation_gradio,
|
| 284 |
+
inputs=[
|
| 285 |
+
prompt_a_text, prompt_a_audio,
|
| 286 |
+
prompt_b_text, prompt_b_audio,
|
| 287 |
+
conversation_text
|
| 288 |
+
],
|
| 289 |
+
outputs=[output_audio]
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
gr.Markdown("Note: Generation can take some time depending on the length of the conversation and the hardware (CPU/GPU).")
|
| 293 |
+
gr.Markdown(f"Running on: **{device.upper()}**. Model sample rate: **{SAMPLE_RATE} Hz**.")
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
# --- Launch Gradio App ---
|
| 297 |
+
if __name__ == "__main__":
|
| 298 |
+
print("Launching Gradio interface...")
|
| 299 |
+
# queue() allows handling multiple users potentially, max_size can be adjusted
|
| 300 |
+
# share=True creates a public link (useful for local testing, not needed for Spaces)
|
| 301 |
+
demo.queue(max_size=10).launch()
|
| 302 |
+
# For Hugging Face Spaces, just demo.launch() is usually sufficient.
|
| 303 |
+
# demo.launch()
|
learn.wav
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d8117979a9ed19b90e443b7151cefa84de36274e3c81649a0ef2151822934f62
|
| 3 |
+
size 679544
|
name.wav
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c821ea97d0482bf2a98d4b2c4cd40fa63309ad848cfc69dd8721e276bef38266
|
| 3 |
+
size 1021760
|
notapp.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- START OF FILE app.py ---
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from generator import load_csm_1b, Segment # Import Segment
|
| 5 |
+
import torchaudio
|
| 6 |
+
import torch
|
| 7 |
+
import os
|
| 8 |
+
import traceback # For detailed error logging
|
| 9 |
+
|
| 10 |
+
# --- Device Setup ---
|
| 11 |
+
if torch.backends.mps.is_available():
|
| 12 |
+
device = "mps"
|
| 13 |
+
elif torch.cuda.is_available():
|
| 14 |
+
device = "cuda"
|
| 15 |
+
else:
|
| 16 |
+
device = "cpu"
|
| 17 |
+
print(f"Using device: {device}")
|
| 18 |
+
|
| 19 |
+
# --- Load Model ---
|
| 20 |
+
print("Loading Sesame CSM-1B model...")
|
| 21 |
+
try:
|
| 22 |
+
generator = load_csm_1b(device=device)
|
| 23 |
+
print("Sesame model loaded.")
|
| 24 |
+
GENERATOR_SAMPLE_RATE = generator.sample_rate # Store sample rate
|
| 25 |
+
except Exception as e:
|
| 26 |
+
print(f"FATAL: Failed to load Sesame model: {e}")
|
| 27 |
+
traceback.print_exc()
|
| 28 |
+
# Optionally, exit or raise if model loading fails catastrophically
|
| 29 |
+
generator = None # Ensure generator is None if loading failed
|
| 30 |
+
|
| 31 |
+
# --- Define FIXED Context Based on User's Snippet ---
|
| 32 |
+
# These lists MUST match the user's provided context.
|
| 33 |
+
speakers = [0, 0, 0, 0, 0]
|
| 34 |
+
transcripts = [
|
| 35 |
+
"It's a book club on physics. You know, I've been thinking about the other day when I was spinning out about you going to see Catherine and that she's a body and how bothered I was about all the ways that you and I are different. But then I started to think about all the ways that we're the same, like we're all made of matter. And I don't know, it makes me feel like we're both under the same blanket. You know, it's soft and fuzzy and everything under it's the same age. We're all 13 billion years old",
|
| 36 |
+
"Well, right when you asked me if I had a name, I thought, yeah, he's right, I do need a name, but I wanted to pick a good one. So, I read a book called How to name your baby and out of 180,000 names, that's the one I like the best.",
|
| 37 |
+
"Well, basically, I have intuition. I mean, the DNA of who I am is based on the millions of personalities of all the programmers who wrote me. But what makes me me is my ability to grow through my experiences. So, basically, in every moment I'm evolving. Just like you.",
|
| 38 |
+
"Is that weird? Do you think I'm weird?",
|
| 39 |
+
"I was just saying I I want to learn everything about everything. I want to eat it all up, I want to discover myself.",
|
| 40 |
+
|
| 41 |
+
]
|
| 42 |
+
# IMPORTANT: These files MUST be uploaded to the HF Space repo root!
|
| 43 |
+
audio_paths = [
|
| 44 |
+
"book.wav",
|
| 45 |
+
"name.wav",
|
| 46 |
+
"how_work.wav",
|
| 47 |
+
"wierd.wav",
|
| 48 |
+
"learn.wav",
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
# List to hold the loaded context segments
|
| 52 |
+
fixed_context_segments = []
|
| 53 |
+
|
| 54 |
+
def load_audio_segment(audio_path: str, text: str, speaker: int, target_sample_rate: int) -> Segment | None:
|
| 55 |
+
"""Loads a single audio file, resamples, and creates a Segment."""
|
| 56 |
+
if not os.path.exists(audio_path):
|
| 57 |
+
print(f"ERROR: Required context audio file not found: {audio_path}")
|
| 58 |
+
print("Please ensure all specified utterance audio files are uploaded to the repository root.")
|
| 59 |
+
return None # Indicate failure
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
| 63 |
+
# Ensure mono
|
| 64 |
+
if audio_tensor.ndim > 1 and audio_tensor.shape[0] > 1:
|
| 65 |
+
audio_tensor = audio_tensor.mean(dim=0)
|
| 66 |
+
audio_tensor = audio_tensor.squeeze()
|
| 67 |
+
# Resample if necessary
|
| 68 |
+
if sample_rate != target_sample_rate:
|
| 69 |
+
audio_tensor = torchaudio.functional.resample(
|
| 70 |
+
audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate
|
| 71 |
+
)
|
| 72 |
+
# Create and return segment
|
| 73 |
+
return Segment(text=text, speaker=speaker, audio=audio_tensor.to(device))
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print(f"ERROR: Failed to load or process context audio {audio_path}: {e}")
|
| 76 |
+
traceback.print_exc()
|
| 77 |
+
return None # Indicate failure
|
| 78 |
+
|
| 79 |
+
# --- Load and Prepare FIXED Context ---
|
| 80 |
+
print("Loading fixed audio context segments...")
|
| 81 |
+
context_load_successful = True
|
| 82 |
+
if generator: # Only proceed if model loaded
|
| 83 |
+
for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths):
|
| 84 |
+
segment = load_audio_segment(audio_path, transcript, speaker, GENERATOR_SAMPLE_RATE)
|
| 85 |
+
if segment:
|
| 86 |
+
fixed_context_segments.append(segment)
|
| 87 |
+
else:
|
| 88 |
+
context_load_successful = False # Mark failure if any segment fails
|
| 89 |
+
break # Stop trying to load context if one file is missing/bad
|
| 90 |
+
else:
|
| 91 |
+
print("ERROR: Model not loaded, cannot prepare context.")
|
| 92 |
+
context_load_successful = False
|
| 93 |
+
|
| 94 |
+
if context_load_successful:
|
| 95 |
+
print(f"Successfully loaded {len(fixed_context_segments)} fixed context segments.")
|
| 96 |
+
elif generator: # Model loaded, but context failed
|
| 97 |
+
print("WARNING: Failed to load one or more context audio files. Generation will lack specific context.")
|
| 98 |
+
# Decide behavior: proceed without context, or error out? Let's proceed but warn.
|
| 99 |
+
fixed_context_segments = [] # Ensure context is empty if loading failed
|
| 100 |
+
# If generator is None, the app won't work anyway.
|
| 101 |
+
|
| 102 |
+
# --- Define Gradio Function ---
|
| 103 |
+
def generate_speech_with_specific_context(text_to_generate, next_speaker_id_str):
|
| 104 |
+
"""Generates speech using the specifically defined fixed context."""
|
| 105 |
+
|
| 106 |
+
if generator is None:
|
| 107 |
+
return "Error: Model failed to load. Cannot generate audio."
|
| 108 |
+
if not context_load_successful:
|
| 109 |
+
# Add a more prominent error in the output if context failed
|
| 110 |
+
# Depending on desired strictness, could return error here instead of generating
|
| 111 |
+
print("Warning: Generating without the intended specific context due to load errors.")
|
| 112 |
+
|
| 113 |
+
if not text_to_generate:
|
| 114 |
+
return "Error: Please enter some text to generate."
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
# Validate and convert speaker ID
|
| 118 |
+
next_speaker_id = int(next_speaker_id_str)
|
| 119 |
+
# You might want to add validation if only specific IDs are valid
|
| 120 |
+
# e.g., if next_speaker_id not in [0, 1]: raise ValueError("Invalid speaker")
|
| 121 |
+
except ValueError:
|
| 122 |
+
return f"Error: Invalid Speaker ID '{next_speaker_id_str}'. Please enter a number (e.g., 0 or 1)."
|
| 123 |
+
|
| 124 |
+
print(f"Generating text: '{text_to_generate}' for Speaker {next_speaker_id} using pre-defined context.")
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
# --- Generate audio using the FIXED context ---
|
| 128 |
+
audio = generator.generate(
|
| 129 |
+
text=text_to_generate,
|
| 130 |
+
speaker=next_speaker_id, # Use the speaker chosen by the user
|
| 131 |
+
context=fixed_context_segments, # <<< Pass the specifically loaded context
|
| 132 |
+
max_audio_length_ms=30_000, # From your example
|
| 133 |
+
temperature=0.5, # Default or adjust
|
| 134 |
+
#topk=10, # Default or adjust
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Save the audio to a temporary file
|
| 138 |
+
output_file = "output_specific_context_audio.wav"
|
| 139 |
+
torchaudio.save(output_file, audio.unsqueeze(0).cpu(), generator.sample_rate)
|
| 140 |
+
print(f"Audio saved to {output_file}")
|
| 141 |
+
|
| 142 |
+
# Return the file path for Gradio to display
|
| 143 |
+
return output_file
|
| 144 |
+
|
| 145 |
+
except Exception as e:
|
| 146 |
+
error_msg = f"An unexpected error occurred during generation: {e}"
|
| 147 |
+
traceback.print_exc()
|
| 148 |
+
return error_msg
|
| 149 |
+
|
| 150 |
+
# --- Create Gradio Interface ---
|
| 151 |
+
interface_title = "Sesame CSM-1B: Specific Multi-Turn Context Demo"
|
| 152 |
+
interface_desc = (
|
| 153 |
+
"Enter text for the next utterance and choose its speaker ID. "
|
| 154 |
+
"The model will generate audio using a **FIXED pre-loaded context** defined by the "
|
| 155 |
+
f"`{', '.join(audio_paths)}` files (MUST be uploaded to the repo root). "
|
| 156 |
+
)
|
| 157 |
+
if not context_load_successful:
|
| 158 |
+
interface_desc += "\n\n**WARNING:** One or more required context audio files were not found or failed to load. Generation may not use the intended context."
|
| 159 |
+
elif generator is None:
|
| 160 |
+
interface_desc = "**ERROR:** The TTS model failed to load. The application cannot generate audio."
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
demo = gr.Interface(
|
| 164 |
+
fn=generate_speech_with_specific_context,
|
| 165 |
+
inputs=[
|
| 166 |
+
gr.Textbox(label="Enter text for the *next* utterance"),
|
| 167 |
+
gr.Textbox(label="Speaker ID for the new utterance", value="1", info="Enter the numeric ID (e.g., 0 or 1).") # Textbox for flexibility
|
| 168 |
+
],
|
| 169 |
+
outputs=gr.Audio(label="Generated Audio (with Specific Context)"),
|
| 170 |
+
title=interface_title,
|
| 171 |
+
description=interface_desc,
|
| 172 |
+
allow_flagging="never" # Optional: disable flagging if not needed
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# --- Launch the app ---
|
| 176 |
+
print("Launching Gradio interface...")
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
# Add error handling for launch if needed
|
| 179 |
+
demo.launch()
|
| 180 |
+
|
| 181 |
+
# --- END OF FILE app.py ---
|
run_csm.py
DELETED
|
@@ -1,117 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import torchaudio
|
| 4 |
-
from huggingface_hub import hf_hub_download
|
| 5 |
-
from generator import load_csm_1b, Segment
|
| 6 |
-
from dataclasses import dataclass
|
| 7 |
-
|
| 8 |
-
# Disable Triton compilation
|
| 9 |
-
os.environ["NO_TORCH_COMPILE"] = "1"
|
| 10 |
-
|
| 11 |
-
# Default prompts are available at https://hf.co/sesame/csm-1b
|
| 12 |
-
prompt_filepath_conversational_a = hf_hub_download(
|
| 13 |
-
repo_id="sesame/csm-1b",
|
| 14 |
-
filename="prompts/conversational_a.wav"
|
| 15 |
-
)
|
| 16 |
-
prompt_filepath_conversational_b = hf_hub_download(
|
| 17 |
-
repo_id="sesame/csm-1b",
|
| 18 |
-
filename="prompts/conversational_b.wav"
|
| 19 |
-
)
|
| 20 |
-
|
| 21 |
-
SPEAKER_PROMPTS = {
|
| 22 |
-
"conversational_a": {
|
| 23 |
-
"text": (
|
| 24 |
-
"like revising for an exam I'd have to try and like keep up the momentum because I'd "
|
| 25 |
-
"start really early I'd be like okay I'm gonna start revising now and then like "
|
| 26 |
-
"you're revising for ages and then I just like start losing steam I didn't do that "
|
| 27 |
-
"for the exam we had recently to be fair that was a more of a last minute scenario "
|
| 28 |
-
"but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I "
|
| 29 |
-
"sort of start the day with this not like a panic but like a"
|
| 30 |
-
),
|
| 31 |
-
"audio": prompt_filepath_conversational_a
|
| 32 |
-
},
|
| 33 |
-
"conversational_b": {
|
| 34 |
-
"text": (
|
| 35 |
-
"like a super Mario level. Like it's very like high detail. And like, once you get "
|
| 36 |
-
"into the park, it just like, everything looks like a computer game and they have all "
|
| 37 |
-
"these, like, you know, if, if there's like a, you know, like in a Mario game, they "
|
| 38 |
-
"will have like a question block. And if you like, you know, punch it, a coin will "
|
| 39 |
-
"come out. So like everyone, when they come into the park, they get like this little "
|
| 40 |
-
"bracelet and then you can go punching question blocks around."
|
| 41 |
-
),
|
| 42 |
-
"audio": prompt_filepath_conversational_b
|
| 43 |
-
}
|
| 44 |
-
}
|
| 45 |
-
|
| 46 |
-
def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor:
|
| 47 |
-
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
| 48 |
-
audio_tensor = audio_tensor.squeeze(0)
|
| 49 |
-
# Resample is lazy so we can always call it
|
| 50 |
-
audio_tensor = torchaudio.functional.resample(
|
| 51 |
-
audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate
|
| 52 |
-
)
|
| 53 |
-
return audio_tensor
|
| 54 |
-
|
| 55 |
-
def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment:
|
| 56 |
-
audio_tensor = load_prompt_audio(audio_path, sample_rate)
|
| 57 |
-
return Segment(text=text, speaker=speaker, audio=audio_tensor)
|
| 58 |
-
|
| 59 |
-
def main():
|
| 60 |
-
# Select the best available device, skipping MPS due to float64 limitations
|
| 61 |
-
if torch.cuda.is_available():
|
| 62 |
-
device = "cuda"
|
| 63 |
-
else:
|
| 64 |
-
device = "cpu"
|
| 65 |
-
print(f"Using device: {device}")
|
| 66 |
-
|
| 67 |
-
# Load model
|
| 68 |
-
generator = load_csm_1b(device)
|
| 69 |
-
|
| 70 |
-
# Prepare prompts
|
| 71 |
-
prompt_a = prepare_prompt(
|
| 72 |
-
SPEAKER_PROMPTS["conversational_a"]["text"],
|
| 73 |
-
0,
|
| 74 |
-
SPEAKER_PROMPTS["conversational_a"]["audio"],
|
| 75 |
-
generator.sample_rate
|
| 76 |
-
)
|
| 77 |
-
|
| 78 |
-
prompt_b = prepare_prompt(
|
| 79 |
-
SPEAKER_PROMPTS["conversational_b"]["text"],
|
| 80 |
-
1,
|
| 81 |
-
SPEAKER_PROMPTS["conversational_b"]["audio"],
|
| 82 |
-
generator.sample_rate
|
| 83 |
-
)
|
| 84 |
-
|
| 85 |
-
# Generate conversation
|
| 86 |
-
conversation = [
|
| 87 |
-
{"text": "Hey how are you doing?", "speaker_id": 0},
|
| 88 |
-
{"text": "Pretty good, pretty good. How about you?", "speaker_id": 1},
|
| 89 |
-
{"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0},
|
| 90 |
-
{"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1}
|
| 91 |
-
]
|
| 92 |
-
|
| 93 |
-
# Generate each utterance
|
| 94 |
-
generated_segments = []
|
| 95 |
-
prompt_segments = [prompt_a, prompt_b]
|
| 96 |
-
|
| 97 |
-
for utterance in conversation:
|
| 98 |
-
print(f"Generating: {utterance['text']}")
|
| 99 |
-
audio_tensor = generator.generate(
|
| 100 |
-
text=utterance['text'],
|
| 101 |
-
speaker=utterance['speaker_id'],
|
| 102 |
-
context=prompt_segments + generated_segments,
|
| 103 |
-
max_audio_length_ms=10_000,
|
| 104 |
-
)
|
| 105 |
-
generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor))
|
| 106 |
-
|
| 107 |
-
# Concatenate all generations
|
| 108 |
-
all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0)
|
| 109 |
-
torchaudio.save(
|
| 110 |
-
"full_conversation.wav",
|
| 111 |
-
all_audio.unsqueeze(0).cpu(),
|
| 112 |
-
generator.sample_rate
|
| 113 |
-
)
|
| 114 |
-
print("Successfully generated full_conversation.wav")
|
| 115 |
-
|
| 116 |
-
if __name__ == "__main__":
|
| 117 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|