VITS-TEST / test_inference.py
kushan1988's picture
Upload 231 files
0b65cde
from flask import Flask, jsonify
from flask_socketio import SocketIO, send
import numpy as np
import json
import soundfile as sf
import os
import logging
from inference import infer_tool
from inference.infer_tool import Svc
from scipy.io import wavfile
from spkmix import spk_mix_map
import soundfile
from inference import infer_tool
from inference.infer_tool import Svc
logging.getLogger('numba').setLevel(logging.WARNING)
chunks_dict = infer_tool.read_temp("inference/chunks_temp.json")
app = Flask(__name__)
socketio = SocketIO(app, cors_allowed_origins="*")
# Load the machine learning model when the Flask app starts
model_path = 'logs/44k/G_354400.pth'
config_path = 'configs/config.json'
cluster_model_path = "logs/44k/kmeans_10000.pt"
diffusion_model_path = 'logs/44k/diffusion/model_0.pt'
diffusion_config_path = 'logs/44k/diffusion/config.yaml'
device = 'cpu'
svc_model = None
# Utility function to generate a WAV file from response data
def generate_wav_file_from_response_data(response_data, output_file):
# Assume the sample rate is 44100 Hz
sample_rate = 44100
# Parse the JSON string back into a numpy array
audio_data = np.array(response_data)
# Generate a .wav file from the numpy array
sf.write(output_file, audio_data, sample_rate)
def inference_with_model(file_path):
# Perform the prediction using the model
global svc_model
if svc_model is None:
svc_model = Svc(
model_path,
config_path,
device,
cluster_model_path,
False, # enhance
diffusion_model_path,
diffusion_config_path,
False, # shallow_diffusion
False, # only_diffusion
False, # use_spk_mix
False # feature_retrieval
)
# Set the parameters for inference
spk_list = ['America']
use_spk_mix = False
if len(spk_mix_map) > 1:
use_spk_mix = True
kwarg = {
"raw_audio_path": file_path,
"spk": spk_list,
"tran": 0,
"slice_db": -40,
"cluster_infer_ratio": 0,
"auto_predict_f0": False,
"noice_scale": 0.4,
"pad_seconds": 0.5,
"clip_seconds": 0,
"lg_num": 0,
"lgr_num": 0.75,
"f0_predictor": 'pm',
"enhancer_adaptive_key": 0,
"cr_threshold": 0.05,
"k_step": 100,
"use_spk_mix": use_spk_mix,
"second_encoding": False,
"loudness_envelope_adjustment": 1
}
audio = svc_model.slice_inference(**kwarg)
result_path = os.path.join(os.curdir, "results", "result.wav")
sf.write(result_path, audio, svc_model.target_sample)
print('inference complete')
return result_path
@socketio.on('audio')
def handle_audio(json_audio):
print('audio recived')
audio_chunk = np.array(json.loads(json_audio)) # Convert the JSON audio chunk to anumpy array
file_path = save_to_wav(audio_chunk) # Save the audio chunk to a .wav file
res_path = inference_with_model(file_path) # Perform inference on the .wav file
# Read the audio data from the generated file
with open(res_path, "rb") as file:
audio_data = file.read()
# Send the audio data back to the client
send(audio_data, binary=True)
def save_to_wav(audio_data):
# Define the output file path
output_file = "input.wav" # Update this to your desired path
# Save the audio data to a .wav file
wavfile.write(output_file, 44100, audio_data)
return output_file
if __name__ == '__main__':
socketio.run(app, host="0.0.0.0", port=3000) # Make sure to use 0.0.0.0 to allow external connections