Spaces:
Bradarr
/
Runtime error

csm-1b / app.py
Bradarr's picture
Update app.py
6989477 verified
import os
import gradio as gr
import numpy as np
import spaces
import torch
import torchaudio
from generator import Segment, load_csm_1b # We'll use load_csm_1b *later*
from huggingface_hub import hf_hub_download, login, HfApi
from watermarking import watermark
import whisper # We'll use whisper.load_model *later*
from transformers import AutoTokenizer, AutoModelForCausalLM # We'll use these *later*
import logging
from transformers import GenerationConfig
# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
# --- Authentication and Configuration ---
try:
api_key = os.getenv("HF_TOKEN")
if not api_key:
raise ValueError("HF_TOKEN not found in environment variables.")
login(token=api_key)
CSM_1B_HF_WATERMARK = list(map(int, os.getenv("WATERMARK_KEY").split(" ")))
if not CSM_1B_HF_WATERMARK:
raise ValueError("WATERMARK_KEY not found or invalid in environment variables.")
gpu_timeout = int(os.getenv("GPU_TIMEOUT", 120))
except (ValueError, TypeError) as e:
logging.error(f"Configuration error: {e}")
raise
SPACE_INTRO_TEXT = """
# Sesame CSM 1B - Conversational Demo
This demo allows you to have a conversation with Sesame CSM 1B, leveraging Whisper for speech-to-text and Gemma for generating responses. This is an experimental integration and may require significant resources.
*Disclaimer: This demo relies on several large models. Expect longer processing times, and potential resource limitations.*
"""
# --- Constants ---
SPEAKER_ID = 0
MAX_CONTEXT_SEGMENTS = 3
MAX_GEMMA_LENGTH = 128
# --- Global Conversation History ---
conversation_history = []
# --- Model Downloading (PRE-DOWNLOAD, NO LOADING) ---
# 1. Download Sesame CSM 1B
csm_1b_model_path = "csm_1b_ckpt.pt" # Local path for the downloaded model
try:
if not os.path.exists(csm_1b_model_path):
hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt", local_dir=".", local_dir_use_symlinks=False)
os.rename("ckpt.pt", csm_1b_model_path)
logging.info("Sesame CSM 1B model downloaded.")
else:
logging.info("Sesame CSM 1B model already downloaded.")
except Exception as e:
logging.error(f"Error downloading Sesame CSM 1B: {e}")
raise
# 2. Download Whisper (using hf_hub_download for consistency)
whisper_model_name = "small.en"
whisper_local_dir = "whisper_model" # Local directory for Whisper
try:
if not os.path.exists(whisper_local_dir):
os.makedirs(whisper_local_dir, exist_ok=True) #Create if not exist
#Whisper uses a specific download method. This command should pre download everything needed
whisper.load_model(whisper_model_name, download_root=whisper_local_dir)
else:
logging.info("Whisper model already downloaded.")
except Exception as e:
logging.error(f"Whisper model download failed with exception: {e}")
# 3. Download Gemma 3 1B (using hf_hub_download, individual files)
gemma_repo_id = "google/gemma-3-1b-it"
gemma_local_path = os.path.abspath("gemma_model") # Absolute path
try:
if not os.path.exists(gemma_local_path):
os.makedirs(gemma_local_path, exist_ok=True) # Create the directory
api = HfApi()
# List all files in the repository
repo_files = api.list_repo_files(gemma_repo_id)
# Download each file individually
for file in repo_files:
hf_hub_download(
repo_id=gemma_repo_id,
filename=file,
local_dir=gemma_local_path,
local_dir_use_symlinks=False, # Ensure files are copied, not linked
)
logging.info("Gemma 3 1B model and tokenizer files downloaded.")
else:
logging.info("Gemma 3 1B model and tokenizer files already downloaded.")
except Exception as e:
logging.error(f"Error downloading Gemma 3 1B: {e}")
raise
# --- Helper Functions ---
def transcribe_audio(audio_path: str, whisper_model) -> str:
try:
audio = whisper.load_audio(audio_path)
audio = whisper.pad_or_trim(audio)
result = whisper_model.transcribe(audio)
return result["text"]
except Exception as e:
logging.error(f"Whisper transcription error: {e}")
return "Error: Could not transcribe audio."
def generate_response(text: str, model_gemma, tokenizer_gemma, device) -> str:
try:
messages = [{"role": "user", "content": text}]
input = tokenizer_gemma.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(device)
generation_config = GenerationConfig(
max_new_tokens=MAX_GEMMA_LENGTH,
early_stopping=True,
)
generated_output = model_gemma.generate(input, generation_config=generation_config)
decoded_output = tokenizer_gemma.decode(generated_output[0], skip_special_tokens=False)
start_token = "<start_of_turn>model"
end_token = "<end_of_turn>"
start_index = decoded_output.find(start_token)
if start_index != -1:
start_index += len(start_token)
end_index = decoded_output.find(end_token, start_index)
assistant_response = decoded_output[start_index:].strip()
return assistant_response
return decoded_output
except Exception as e:
logging.error(f"Gemma response generation error: {e}")
return "I'm sorry, I encountered an error generating a response."
def load_audio(audio_path: str, generator) -> torch.Tensor:
try:
audio_tensor, sample_rate = torchaudio.load(audio_path)
audio_tensor = audio_tensor.mean(dim=0)
if sample_rate != generator.sample_rate:
audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=sample_rate, new_freq=generator.sample_rate)
return audio_tensor
except Exception as e:
logging.error(f"Audio loading error: {e}")
raise gr.Error("Could not load or process the audio file.") from e
def clear_history():
global conversation_history
conversation_history = []
logging.info("Conversation history cleared.")
return "Conversation history cleared."
# --- Main Inference Function ---
@spaces.GPU(duration=gpu_timeout) # GPU decorator
def infer(user_audio) -> tuple[int, np.ndarray]:
if torch.cuda.is_available():
device = "cuda"
logging.info(f"CUDA is available! Using device: {torch.cuda.get_device_name(0)}")
else:
device = "cpu"
logging.info("CUDA is NOT available. Using CPU.")
try:
# --- Model Loading (ONLY inside infer, after GPU is available) ---
generator = load_csm_1b(csm_1b_model_path, device)
logging.info("Sesame CSM 1B loaded successfully.")
whisper_model = whisper.load_model(whisper_model_name, device=device, download_root=whisper_local_dir)
logging.info(f"Whisper model '{whisper_model_name}' loaded successfully.")
tokenizer_gemma = AutoTokenizer.from_pretrained(gemma_local_path)
model_gemma = AutoModelForCausalLM.from_pretrained(gemma_local_path).to(device)
logging.info("Gemma 3 1B pt model loaded successfully.")
if not user_audio:
raise ValueError("No audio input received.")
return _infer(user_audio, generator, whisper_model, tokenizer_gemma, model_gemma, device)
except Exception as e:
logging.exception(f"Inference error: {e}")
raise gr.Error(f"An error occurred during processing: {e}")
def _infer(user_audio, generator, whisper_model, tokenizer_gemma, model_gemma, device) -> tuple[int, np.ndarray]:
global conversation_history
try:
user_text = transcribe_audio(user_audio, whisper_model)
logging.info(f"User: {user_text}")
ai_text = generate_response(user_text, model_gemma, tokenizer_gemma, device)
logging.info(f"AI: {ai_text}")
try:
ai_audio = generator.generate(
text=ai_text,
speaker=SPEAKER_ID,
context=conversation_history,
max_audio_length_ms=10_000,
)
logging.info("Audio generated successfully.")
except Exception as e:
logging.error(f"Sesame response generation error: {e}")
raise gr.Error(f"Sesame response generation error: {e}")
user_segment = Segment(speaker = 1, text = user_text, audio = load_audio(user_audio, generator))
ai_segment = Segment(speaker = SPEAKER_ID, text = ai_text, audio = ai_audio)
conversation_history.append(user_segment)
conversation_history.append(ai_segment)
if len(conversation_history) > MAX_CONTEXT_SEGMENTS:
conversation_history.pop(0)
audio_tensor, wm_sample_rate = watermark(
generator._watermarker, ai_audio, generator.sample_rate, CSM_1B_HF_WATERMARK
)
audio_tensor = torchaudio.functional.resample(
audio_tensor, orig_freq=wm_sample_rate, new_freq=generator.sample_rate
)
ai_audio_array = (audio_tensor * 32768).to(torch.int16).cpu().numpy()
return generator.sample_rate, ai_audio_array
except Exception as e:
logging.exception(f"Error in _infer: {e}")
raise gr.Error(f"An error occurred during processing: {e}")
# --- Gradio Interface ---
with gr.Blocks() as app:
gr.Markdown(SPACE_INTRO_TEXT)
audio_input = gr.Audio(label="Your Input", type="filepath")
audio_output = gr.Audio(label="AI Response")
clear_button = gr.Button("Clear Conversation History")
status_display = gr.Textbox(label="Status", visible=False)
btn = gr.Button("Generate Response")
btn.click(infer, inputs=[audio_input], outputs=[audio_output])
clear_button.click(clear_history, outputs=[status_display])
app.launch(ssr_mode=False)