csm-1b-lora-fft / app.py
matroks's picture
Create app.py
abc41b6 verified
raw
history blame
2.49 kB
import numpy as np
import torch
import soundfile as sf
from transformers import AutoProcessor
from peft import PeftModel
from transformers import CsmForConditionalGeneration
device = "cuda" if torch.cuda.is_available() else "cpu"
sampling_rate = 24_000
base_id = "unsloth/csm-1b"
adapter_id = "TurkishCodeMan/csm-1b-lora-fft"
processor = AutoProcessor.from_pretrained(base_id)
base = CsmForConditionalGeneration.from_pretrained(base_id, torch_dtype="auto").to(device)
model = PeftModel.from_pretrained(base, adapter_id).to(device)
model.eval()
def _resample_linear(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
if orig_sr == target_sr:
return audio
if audio.ndim == 2:
audio = audio.mean(axis=1)
n = audio.shape[0]
new_n = int(round(n * (target_sr / orig_sr)))
if new_n <= 1:
return audio[:1].astype(np.float32)
x_old = np.linspace(0.0, 1.0, num=n, endpoint=True)
x_new = np.linspace(0.0, 1.0, num=new_n, endpoint=True)
return np.interp(x_new, x_old, audio).astype(np.float32)
# Reference audio (wav path)
ref_path = "reference.wav"
ref_audio, ref_sr = sf.read(ref_path, dtype="float32")
if ref_audio.ndim == 2:
ref_audio = ref_audio.mean(axis=1).astype(np.float32)
if ref_sr != sampling_rate:
ref_audio = _resample_linear(ref_audio, ref_sr, sampling_rate)
ref_text = "Reference transcript (optional)."
target_text = "We extend the standard NIAH task, to investigate model behavior in previously underexplored settings."
speaker_role = "0"
conversation = [
{
"role": speaker_role,
"content": [
{"type": "text", "text": "Please speak english\n\n" + ref_text},
{"type": "audio", "audio": ref_audio},
],
},
{
"role": speaker_role,
"content": [
{"type": "text", "text": target_text},
],
},
]
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(device)
with torch.no_grad():
out = model.generate(
**inputs,
output_audio=True,
max_new_tokens=200,
depth_decoder_temperature=0.6,
depth_decoder_top_k=0,
depth_decoder_top_p=0.7,
temperature=0.3,
top_k=50,
top_p=1.0,
)
generated_audio = out[0].detach().cpu().to(torch.float32).numpy()
sf.write("generated_audio.wav", generated_audio, samplerate=sampling_rate)
print("Wrote generated_audio.wav")