koko / app.py
don0726's picture
Update app.py
58cd85b verified
import os
import uuid
import shutil
import gradio as gr
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import FileResponse
from core.cloner import KokoClone
import uvicorn
print("Loading model...")
cloner = KokoClone()
# -----------------------------
# FastAPI
# -----------------------------
app = FastAPI()
# -----------------------------
# API ROUTE
# -----------------------------
@app.post("/clone")
async def clone_voice(
source_audio: UploadFile = File(...),
reference_audio: UploadFile = File(...)
):
source_path = f"source_{uuid.uuid4().hex}.wav"
ref_path = f"ref_{uuid.uuid4().hex}.wav"
output_path = f"output_{uuid.uuid4().hex}.wav"
try:
# save uploaded files
with open(source_path, "wb") as f:
shutil.copyfileobj(source_audio.file, f)
with open(ref_path, "wb") as f:
shutil.copyfileobj(reference_audio.file, f)
# convert
cloner.convert(
source_audio=source_path,
reference_audio=ref_path,
output_path=output_path
)
return FileResponse(
output_path,
media_type="audio/wav",
filename="output.wav"
)
except Exception as e:
return {"error": str(e)}
finally:
# cleanup input files
for p in [source_path, ref_path]:
if os.path.exists(p):
os.remove(p)
# -----------------------------
# GRADIO UI
# -----------------------------
def convert_voice(source_audio_path, ref_audio_path):
output_file = f"converted_{uuid.uuid4().hex}.wav"
cloner.convert(
source_audio=source_audio_path,
reference_audio=ref_audio_path,
output_path=output_file
)
return output_file
with gr.Blocks() as demo:
gr.Markdown("# Voice Clone")
with gr.Row():
with gr.Column():
source_audio = gr.Audio(
label="Source Audio",
type="filepath"
)
reference_audio = gr.Audio(
label="Reference Audio",
type="filepath"
)
btn = gr.Button("Clone")
with gr.Column():
output_audio = gr.Audio(label="Output")
btn.click(
fn=convert_voice,
inputs=[source_audio, reference_audio],
outputs=output_audio
)
# mount gradio at /ui
app = gr.mount_gradio_app(app, demo, path="/ui")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)