Spaces:
Runtime error
Runtime error
| import time | |
| from pathlib import Path | |
| import shutil | |
| import json | |
| import argbind | |
| import audiotools as at | |
| from gradio_client import Client, handle_file | |
| from pythonosc.osc_server import ThreadingOSCUDPServer | |
| from pythonosc.udp_client import SimpleUDPClient | |
| from pythonosc.dispatcher import Dispatcher | |
| import torch | |
| class Timer: | |
| def __init__(self): | |
| self.times = {} | |
| def tick(self, name: str): | |
| self.times[name] = time.time() | |
| def tock(self, name: str): | |
| toc = time.time() - self.times[name] | |
| print(f"{name} took {toc} seconds") | |
| return toc | |
| def __str__(self): | |
| return str(self.times) | |
| timer = Timer() | |
| DOWNLOADS_DIR = ".gradio" | |
| def clear_file(file): | |
| file = Path(file) | |
| if file.exists(): | |
| file.unlink() | |
| class OSCManager: | |
| def __init__( | |
| self, | |
| ip: str, | |
| s_port: str, | |
| r_port: str, | |
| process_fn: callable, | |
| # param_change_callback: callable = None | |
| ): | |
| self.ip = ip | |
| self.s_port = s_port | |
| self.r_port = r_port | |
| # register the process_fn | |
| self.process_fn = process_fn | |
| print(f"will send to {ip}:{s_port}") | |
| self.client = SimpleUDPClient(ip, s_port) | |
| def start_server(self,): | |
| dispatcher = Dispatcher() | |
| dispatcher.map("/process", self.process_fn) | |
| def send_heartbeat(_, *args): | |
| # print("Received heartbeat") | |
| self.client.send_message("/heartbeat", "pong") | |
| dispatcher.map("/heartbeat", lambda a, *r: send_heartbeat(a, *r)) | |
| dispatcher.map("/cleanup", lambda a, *r: clear_file(r[0])) | |
| dispatcher.set_default_handler(lambda a, *r: print(a, r)) | |
| server = ThreadingOSCUDPServer((self.ip, self.r_port), dispatcher) | |
| print(f"Serving on {server.server_address}") | |
| server.serve_forever() | |
| def error(self, msg: str): | |
| self.client.send_message("/error", msg) | |
| def log(self, msg: str): | |
| self.client.send_message("/log", msg) | |
| class GradioOSCClient: | |
| def __init__(self, | |
| ip: str, | |
| s_port: int, r_port: int, | |
| vampnet_url: str = None, # url for vampnet | |
| ): | |
| self.osc_manager = OSCManager( | |
| ip=ip, s_port=s_port, r_port=r_port, | |
| process_fn=self.process, | |
| ) | |
| self.clients = {} | |
| if vampnet_url is not None: | |
| self.clients["vampnet"] = Client(src=vampnet_url, download_files=DOWNLOADS_DIR) | |
| assert len(self.clients) > 0, "At least one client must be specified!" | |
| self.batch_size = 2# TODO: automatically get batch size from client. | |
| self.osc_manager.log("hello from gradio client!") | |
| self.inf_idx = 0 | |
| def param_changed(self, param_name, new_value): | |
| print(f"Parameter {param_name} changed to {new_value}") | |
| def vampnet_process(self, address: str, *args): | |
| client = self.clients["vampnet"] | |
| # query id --- audiofile ---- model_choice --- periodic --- drop --- seed | |
| query_id = args[0] | |
| client_type = args[1] | |
| audio_path = Path(args[2]) | |
| model_choice = args[3] | |
| periodic_p = args[4] | |
| dropout = args[5] | |
| seed = args[6] | |
| looplength_ms = args[7] | |
| typical_filter = args[8] | |
| typical_mass = args[9] | |
| typical_min_tokens = args[10] | |
| upper_codebook_mask = args[11] | |
| onset_mask_width = args[12] | |
| sampling_steps = args[13] | |
| temperature = args[14] | |
| top_p = args[15] | |
| beat_mask_ms = args[16] | |
| num_feedback_steps = args[17] | |
| if not audio_path.exists(): | |
| print(f"File {audio_path} does not exist") | |
| self.osc_manager.error(f"File {audio_path} does not exist") | |
| return | |
| sig = at.AudioSignal(audio_path) | |
| sig.to_mono() | |
| sig.sample_rate = 48000 # HOT PATCH (FIXME IN MAX: sample rate is being forced to 48k) | |
| # grab the looplength only | |
| # TODO: although I added this, | |
| # the max patch is still configured to crop anything past the looplength off | |
| # so we'll have to change that in order to make an effect. | |
| end_sample = int((looplength_ms * sig.sample_rate) / 1000) | |
| # grab the remainder of the waveform | |
| num_cut_samples = sig.samples.shape[-1] - end_sample | |
| cut_wav = sig.samples[..., -num_cut_samples:] | |
| sig.samples = sig.samples[..., :end_sample] | |
| # write the file back | |
| sig.write(audio_path) | |
| timer.tick("predict") | |
| print(f"Processing {address} with args {args}") | |
| # breakpoint() | |
| job = client.submit( | |
| input_audio=handle_file(audio_path), | |
| sampletemp=temperature, | |
| top_p=top_p, | |
| periodic_p=periodic_p, | |
| dropout=dropout, | |
| stretch_factor=1, | |
| onset_mask_width=onset_mask_width, | |
| typical_filtering=bool(typical_filter), | |
| typical_mass=typical_mass, | |
| typical_min_tokens=typical_min_tokens, | |
| seed=seed, | |
| model_choice=model_choice, | |
| n_mask_codebooks=upper_codebook_mask, | |
| pitch_shift_amt=0, | |
| sample_cutoff=1.0, | |
| sampling_steps=sampling_steps, | |
| beat_mask_ms=int(beat_mask_ms), | |
| num_feedback_steps=num_feedback_steps, | |
| api_name="/vamp_1" | |
| ) | |
| while not job.done(): | |
| time.sleep(0.1) | |
| self.osc_manager.client.send_message("/progress", [query_id, str(job.status().code)]) | |
| result = job.result() | |
| # audio_file = result | |
| # audio_files = [audio_file] * self.batch_size | |
| audio_files = list(result[:self.batch_size]) | |
| # if each file is missing a .wav at the end, add it | |
| first_audio = audio_files[0] | |
| if not first_audio.endswith(".wav"): | |
| for audio_file in set(audio_files): | |
| if not audio_file.endswith(".wav"): | |
| shutil.move(audio_file, f"{audio_file}.wav") | |
| audio_file = f"{audio_file}.wav" | |
| audio_files = [f"{audio}.wav" for audio in audio_files if not audio.endswith(".wav")] | |
| for audio_file in audio_files: | |
| # load the file, add the cut samples back | |
| sig = at.AudioSignal(audio_file) | |
| sig.resample(48000) | |
| sig.samples = torch.cat([sig.samples, cut_wav], dim=-1) | |
| sig.write(audio_file) | |
| seed = result[-1] | |
| timer.tock("predict") | |
| # send a message that the process is done | |
| self.osc_manager.log(f"query {query_id} has been processed") | |
| self.osc_manager.client.send_message("/process-result", [query_id] + audio_files) | |
| def process(self, address: str, *args): | |
| query_id = args[0] | |
| client_type = args[1] | |
| audio_path = Path(args[2]) | |
| if client_type == "vampnet": | |
| self.vampnet_process(address, *args) | |
| return | |
| elif client_type == "sketch2sound": | |
| self.process_s2s(address, *args) | |
| return | |
| else: | |
| raise ValueError(f"Unknown client type {client_type}") | |
| def gradio_main( | |
| vampnet_url: str = None | |
| ): | |
| system = GradioOSCClient( | |
| vampnet_url=vampnet_url, | |
| ip="127.0.0.1", s_port=8003, r_port=8001, | |
| ) | |
| system.osc_manager.start_server() | |
| if __name__ == "__main__": | |
| try: | |
| gradio_main = argbind.bind(gradio_main, without_prefix=True) | |
| args = argbind.parse_args() | |
| with argbind.scope(args): | |
| gradio_main() | |
| except Exception as e: | |
| import shutil | |
| shutil.rmtree(DOWNLOADS_DIR, ignore_errors=True) | |
| raise e |