| | import runpod |
| | import requests |
| | from voice_generation import generate_wav |
| | import boto3 |
| | import os |
| | import uuid |
| | from pydub import AudioSegment |
| | import time |
| | from typing import Optional |
| |
|
| |
|
| | AWS_ACCESS_KEY_ID = os.environ.get('AWS_ACCESS_KEY_ID') |
| | AWS_SECRET_ACCESS_KEY = os.environ.get('AWS_SECRET_ACCESS_KEY') |
| |
|
| |
|
| | models = { |
| | 'kanye': 'weights/kanye.pth', |
| | 'rose-bp': 'weights/rose-bp.pth', |
| | 'jungkook': 'weights/jungkook.pth', |
| | 'iu': 'weights/iu.pth', |
| | 'drake': 'weights/drake.pth', |
| | 'ariana-grande': 'weights/ariana-grande.pth' |
| | } |
| |
|
| |
|
| | print('run handler. Removed 2nd gen') |
| |
|
| |
|
| | def combine_audio(voice_path: str, instrumental_path: str, optional_path: Optional[str] = None): |
| | audio1 = AudioSegment.from_file(instrumental_path) |
| | audio2 = AudioSegment.from_file(voice_path) |
| | |
| | length = max(len(audio1), len(audio2)) |
| | audio1 = audio1 + AudioSegment.silent(duration=length - len(audio1)) |
| | audio2 = audio2 + AudioSegment.silent(duration=length - len(audio2)) |
| | |
| | if optional_path: |
| | audio3 = AudioSegment.from_file(optional_path) |
| | audio3 = audio3 + AudioSegment.silent(duration=length - len(audio3)) |
| | combined = audio1.overlay(audio2).overlay(audio3) |
| | else: |
| | combined = audio1.overlay(audio2) |
| | |
| | combined.export("combined.mp3", format="mp3") |
| |
|
| |
|
| | def upload_file_to_s3(local_file_path, s3_file_path): |
| | bucket_name = 'voice-gen-audios' |
| | s3 = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY) |
| | try: |
| | s3.upload_file(local_file_path, bucket_name, s3_file_path) |
| | return {"url": f"https://{bucket_name}.s3.eu-north-1.amazonaws.com/{s3_file_path}"} |
| | except boto3.exceptions.S3UploadFailedError as e: |
| | return {"error": f"failed to upload file {local_file_path} to s3 as {s3_file_path}"} |
| |
|
| |
|
| | def clean_up_files(remove_voice_model=False, remove_back_vocal=False): |
| | files = [ |
| | "instrumental.mp3", |
| | "vocal.mp3", |
| | "output_voice.wav", |
| | "combined.mp3", |
| | ] |
| | if remove_voice_model: |
| | files.append("voice_model.pth") |
| | if remove_back_vocal: |
| | files.append("back_vocal.mp3") |
| | for file in files: |
| | try: |
| | os.remove(file) |
| | except FileNotFoundError: |
| | return {"error": f"failed to remove file {file}"} |
| | return {"success": "files removed successfully"} |
| |
|
| |
|
| | def get_voice_model(event): |
| | voice_model_id = event["input"].get("voice_model_id", "") |
| | voice_model_url = event["input"].get("voice_model_url", "") |
| | |
| | if not voice_model_url and not voice_model_id: |
| | return {"error": "voice_model_url or voice_model_id is required"} |
| |
|
| | if voice_model_id and voice_model_id not in models: |
| | return {"error": "model not found in pre-loaded models"} |
| | |
| | if voice_model_id: |
| | return {"model_path": models[voice_model_id]} |
| | |
| | print("downloading voice_model") |
| | voice_model_response = requests.get(voice_model_url) |
| | if voice_model_response.status_code != 200: |
| | return {"error": f"failed to download voice_model, error: {voice_model_response.text}"} |
| | |
| | with open("voice_model.pth", "wb") as f: |
| | f.write(voice_model_response.content) |
| |
|
| | return {"model_path": "voice_model.pth"} |
| |
|
| |
|
| | def get_method(event): |
| | method = event["input"].get("method", "pm") |
| | if method not in ["pm", "harvest"]: |
| | method = "pm" |
| | return method |
| |
|
| |
|
| | def get_index_rate(event): |
| | index_rate = event["input"].get("index_rate", 0.6) |
| | if index_rate < 0 or index_rate > 1: |
| | index_rate = 0.6 |
| | return index_rate |
| |
|
| |
|
| | def handler(event): |
| | print(event) |
| | file_id = str(uuid.uuid4()) |
| | user_id = event["input"].get("user_id", "not provided") |
| | |
| | if not AWS_ACCESS_KEY_ID or not AWS_SECRET_ACCESS_KEY: |
| | return {"error": "AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are missing from environment variables"} |
| | |
| | voice_model = get_voice_model(event) |
| | if "error" in voice_model: |
| | return voice_model.get("error") |
| |
|
| | vocal_url = event["input"].get("vocal_url", "") |
| | if vocal_url == "": |
| | return {"error": "vocal_url is required"} |
| | |
| | instrumental_url = event["input"].get("instrumental_url", "") |
| | if instrumental_url == "": |
| | return {"error": "instrumental_url is required"} |
| | |
| | vocal_file = requests.get(vocal_url) |
| | if vocal_file.status_code != 200: |
| | return {"error": "failed to download vocal_file"} |
| | with open("vocal.mp3", "wb") as f: |
| | f.write(vocal_file.content) |
| | |
| | instrumental_file = requests.get(instrumental_url) |
| | if instrumental_file.status_code != 200: |
| | return {"error": "failed to download instrumental_file"} |
| | with open("instrumental.mp3", "wb") as f: |
| | f.write(instrumental_file.content) |
| |
|
| | back_vocal_url = event["input"].get("back_vocal_url", "") |
| | if back_vocal_url != "": |
| | back_vocal_file = requests.get(back_vocal_url) |
| | if back_vocal_file.status_code != 200: |
| | return {"error": "failed to download back_vocal_file"} |
| | with open("back_vocal.mp3", "wb") as f: |
| | f.write(back_vocal_file.content) |
| |
|
| | generation_start = time.time() |
| |
|
| | generation = generate_wav( |
| | audio_file='vocal.mp3', |
| | method=get_method(event), |
| | index_rate=get_index_rate(event), |
| | output_file='output_voice.wav', |
| | model_path=voice_model.get("model_path") |
| | ) |
| | generation_end = time.time() |
| | time_taken_generation = generation_end - generation_start |
| | print(f"generation took {time_taken_generation} seconds") |
| | if "error" in generation: |
| | return generation.get("error") |
| |
|
| | if back_vocal_url: |
| | combine_audio("output_voice.wav", "instrumental.mp3", "back_vocal.mp3") |
| | else: |
| | combine_audio("output_voice.wav", "instrumental.mp3") |
| |
|
| | if not os.path.exists("combined.mp3"): |
| | return {"error": "failed to combine audio"} |
| |
|
| | combined = upload_file_to_s3("combined.mp3", f"{file_id}.mp3") |
| | output_voice = upload_file_to_s3("output_voice.wav", f"{file_id}-generated-voice.wav") |
| |
|
| | if combined_error := combined.get("error"): |
| | return combined_error |
| | |
| | if output_voice_error := output_voice.get("error"): |
| | return output_voice_error |
| | |
| | combined_url = combined.get("url") |
| | output_voice_url = output_voice.get("url") |
| |
|
| | need_to_remove_voice_model = False |
| | need_to_remove_back_vocal = False |
| | if voice_model.get("model_path") == "voice_model.pth": |
| | need_to_remove_voice_model = True |
| | if back_vocal_url: |
| | need_to_remove_back_vocal = True |
| | cleanup_result = clean_up_files(need_to_remove_voice_model, need_to_remove_back_vocal) |
| | if cleanup_error := cleanup_result.get("error"): |
| | return cleanup_error |
| |
|
| | return { |
| | "combined_url": combined_url, |
| | "output_voice_url": output_voice_url, |
| | "user_id": user_id, |
| | "time_taken_generation": time_taken_generation, |
| | } |
| |
|
| |
|
| | runpod.serverless.start({"handler": handler}) |
| |
|