Ovi / app.py
don0726's picture
Update app.py
bf19d3d verified
import os
import time
import torch
import torchaudio
import gradio as gr
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import FileResponse
from huggingface_hub import snapshot_download
from omnivoice import OmniVoice
# πŸ”₯ Cache fix
os.environ["HF_HOME"] = "/app/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
os.environ["HF_HUB_CACHE"] = "/app/hf_cache"
app = FastAPI()
# =========================
# πŸ”Ή Load model ONCE
# =========================
print("⏳ Downloading model...")
model_path = snapshot_download(
repo_id="k2-fsa/OmniVoice",
local_dir="omnivoice_model",
local_dir_use_symlinks=False
)
print("⏳ Loading model...")
device = "cpu"
model = OmniVoice.from_pretrained(
model_path,
device_map=device,
torch_dtype=torch.float32
)
print("βœ… Model ready!")
# =========================
# πŸ”Ή Core function
# =========================
def generate_voice(audio_path, text, lang):
waveform, sr = torchaudio.load(audio_path)
# mono
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# resample
if sr != 24000:
resampler = torchaudio.transforms.Resample(sr, 24000)
waveform = resampler(waveform)
processed_path = f"processed_{int(time.time())}.wav"
torchaudio.save(processed_path, waveform, 24000)
final_text = f"[{lang}] {text}"
audio_out = model.generate(
text=final_text,
ref_audio=processed_path,
ref_text=None,
language=lang
)
if isinstance(audio_out, list):
audio_out = audio_out[0]
if not isinstance(audio_out, torch.Tensor):
audio_out = torch.tensor(audio_out)
if audio_out.dim() == 1:
audio_out = audio_out.unsqueeze(0)
output_path = f"output_{int(time.time())}.wav"
torchaudio.save(output_path, audio_out.cpu(), 24000)
return output_path
# =========================
# πŸ”Ή API Endpoint
# =========================
@app.post("/clone")
async def clone_voice(
audio: UploadFile = File(...),
text: str = Form(...),
lang: str = Form(...)
):
input_path = f"input_{int(time.time())}.wav"
with open(input_path, "wb") as f:
f.write(await audio.read())
output_path = generate_voice(input_path, text, lang)
return FileResponse(output_path, media_type="audio/wav", filename="output.wav")
# =========================
# πŸ”Ή Gradio UI
# =========================
def gradio_fn(audio, text, lang):
if audio is None:
return None
output_path = generate_voice(audio, text, lang)
return output_path
demo = gr.Interface(
fn=gradio_fn,
inputs=[
gr.Audio(type="filepath", label="Upload Reference Audio"),
gr.Textbox(label="Text"),
gr.Textbox(label="Language Code (en, hi, etc)")
],
outputs=gr.Audio(label="Cloned Voice"),
title="OmniVoice Voice Cloning (CPU)",
description="Upload voice + enter text + language β†’ get cloned speech"
)
# Mount Gradio to FastAPI
app = gr.mount_gradio_app(app, demo, path="/")