gyroing's picture
Update app.py
d28dce8 verified
import os
import torch
import torchaudio
import gradio as gr
from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS
from safetensors.torch import load_file as load_safetensors
from huggingface_hub import hf_hub_download
# --- تنظیمات اولیه ---
# همیشه روی CPU اجرا شود
device = "cpu"
my_token = os.getenv("HF_TOKEN")
print("--- Starting Application on CPU ---")
print("Loading model structure...")
multilingual_model = ChatterboxMultilingualTTS.from_pretrained(device=device)
print("Downloading weights...")
try:
model_path = hf_hub_download(
repo_id="Thomcles/Chatterbox-TTS-Persian-Farsi",
filename="t3_fa.safetensors",
token=my_token
)
print("Weights downloaded.")
except Exception as e:
print(f"Download Error: {e}")
raise e
print("Loading weights into model...")
# لود کردن مستقیم روی CPU
t3_state = load_safetensors(model_path, device="cpu")
multilingual_model.t3.load_state_dict(t3_state)
multilingual_model.t3.to(device).eval()
print("Model ready!")
# --- تابع اصلی تولید صدا ---
def generate_audio(text):
if not text:
return None
# محدودیت ۳۰۰ کاراکتر
if len(text) > 300:
text = text[:300]
try:
# تولید صدا (بدون نیاز به جابجایی به GPU)
with torch.no_grad(): # این دستور مصرف رم را کم می‌کند
wav_tensor = multilingual_model.generate(text, language_id=None)
# تبدیل خروجی برای گرادیو
audio_numpy = wav_tensor.squeeze().cpu().numpy()
return (multilingual_model.sr, audio_numpy)
except Exception as e:
print(f"Error: {e}")
raise gr.Error(f"خطا در تولید صدا: {e}")
# --- رابط کاربری ---
with gr.Blocks(title="مبدل متن به گفتار فارسی Chatterbox", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# مبدل متن به گفتار فارسی
(CPU)
"""
)
with gr.Row():
with gr.Column():
text_input = gr.TextArea(
label="متن فارسی را وارد کنید",
value="سلام! این یک تست روی پردازنده معمولی است.",
lines=3,
rtl=True,
max_length=300
)
submit_btn = gr.Button("🎧 تولید صدا", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="پخش صدا", type="numpy")
submit_btn.click(
fn=generate_audio,
inputs=text_input,
outputs=audio_output
)
demo.launch()