Spaces:
Sleeping
Sleeping
File size: 6,357 Bytes
08eeac9 56e1924 7515bd3 08eeac9 28f49cd 7515bd3 28f49cd 7515bd3 28f49cd 7515bd3 08eeac9 7515bd3 08eeac9 7515bd3 28f49cd 7515bd3 28f49cd 7515bd3 56e1924 7515bd3 56e1924 7515bd3 56e1924 7515bd3 56e1924 7515bd3 28f49cd 7515bd3 56e1924 7515bd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import gradio as gr
import json
import torch
import numpy as np
import soundfile as sf
from src.config import BaseConfig
from src.reverb import BaseFDN
from flamo.optimize.trainer import Trainer
from flamo.optimize.dataset import DatasetColorless, load_dataset
from flamo.processor import dsp, system
from flamo.optimize.loss import mse_loss, sparsity_loss
def process_fdn(N, delay_lengths, learning_rate, sparsity_weight, max_epochs):
"""
Process feedback delay network parameters.
Args:
N: Number of delay lines (integer)
delay_lengths: Array of N integer values for delay lengths
learning_rate: Learning rate for optimization
sparsity_weight: Weight for sparsity loss
max_epochs: Maximum number of training epochs
Returns:
A message confirming the inputs
"""
print(f"Number of delay lines (N): {N}")
print(f"Delay lengths: {delay_lengths}")
print(f"Type of delay_lengths: {type(delay_lengths)}")
print(f"Learning rate: {learning_rate}")
print(f"Sparsity weight: {sparsity_weight}")
print(f"Max epochs: {max_epochs}")
try:
# Extract delay length values from the dataframe
if delay_lengths and len(delay_lengths) > 0:
# delay_lengths is a list of rows, extract the first column value from each row
delays = [int(row[0]) for row in delay_lengths if row and len(row) > 0]
# Validate that we have N delay values
if len(delays) != N:
return f"Error: Expected {N} delay lengths, but got {len(delays)}"
result = f"Successfully configured FDN with:\n"
result += f"- Number of delay lines: {N}\n"
result += f"- Delay lengths: {delays}"
# Create the config with FDN parameters
config = BaseConfig.create_with_fdn_params(
N=N,
delay_lengths=delays
)
# Initialize BaseFDN with proper parameters
model = BaseFDN(
config=config.fdn_config,
nfft=config.nfft,
alias_decay_db=config.fdn_config.alias_decay_db,
device=config.device,
requires_grad=True,
delay_lengths=delays,
output_layer="freq_mag",
)
dataset = DatasetColorless(
input_shape=(1, config.nfft, 1),
target_shape=(1, config.nfft // 2 + 1, 1),
expand=config.fdn_optim_config.dataset_length,
device=config.device,
)
train_loader, valid_loader = load_dataset(dataset, batch_size=config.fdn_optim_config.batch_size)
# Initialize training process
trainer = Trainer(
model.shell,
max_epochs=max_epochs,
lr=learning_rate,
device=config.device,
log=False
)
trainer.register_criterion(mse_loss(nfft=config.nfft, device=config.device), 1)
trainer.register_criterion(sparsity_loss(), sparsity_weight, requires_model=True)
## ---------------- TRAIN ---------------- ##
# Train the model
print("Starting training...")
trainer.train(train_loader, valid_loader)
est_param = model.get_params()
# Convert parameters to JSON format
# Assuming est_param is a dict or can be converted to one
param_dict = {}
for key, value in est_param.items():
# Convert tensors to lists for JSON serialization
if hasattr(value, 'cpu'):
param_dict[key] = value.cpu().detach().numpy().tolist()
else:
param_dict[key] = value
# Save to JSON file
output_path = "estimated_parameters.json"
with open(output_path, 'w') as f:
json.dump(param_dict, f, indent=2)
ir = model.shell.get_time_response()
# Convert ir to audio format for Gradio
ir_audio = ir.cpu().detach().numpy()
# Ensure proper shape (1D array)
if ir_audio.ndim > 1:
ir_audio = ir_audio.squeeze()
# Normalize to [-1, 1] range to prevent overflow
max_val = np.abs(ir_audio).max()
if max_val > 0:
ir_audio = ir_audio / max_val
# Get the sample rate from config
sample_rate = getattr(config, 'fs', 48000)
# Save audio to file using soundfile (avoids Gradio's conversion issues)
audio_path = "impulse_response.wav"
sf.write(audio_path, ir_audio, sample_rate)
return result, output_path, audio_path
else:
return "Error: No delay lengths provided", None, None
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
return f"Error processing inputs: {str(e)}", None, None
demo = gr.Interface(
fn=process_fdn,
inputs=[
gr.Number(label="N (Number of delay lines)", value=4, precision=0),
gr.Dataframe(
headers=["Delay Length"],
type="array",
col_count=(1, "fixed"),
row_count=(4, "dynamic"),
label="Delay Lengths (N integer values)"
),
gr.Number(label="Learning Rate", value=0.01, minimum=0.0001, maximum=1.0, step=0.0001),
gr.Number(label="Sparsity Loss Weight", value=1.0, minimum=0.0, maximum=10.0, step=0.1),
gr.Number(label="Max Epochs", value=20, precision=0, minimum=1, maximum=100)
],
outputs=[
gr.Textbox(label="Output"),
gr.File(label="Estimated Parameters (JSON)"),
gr.Audio(label="Impulse Response", type="numpy")
],
title="Feedback Delay Network Optimization",
description="Configure your homogeneous feedback delay network by specifying N (number of delay lines) and their corresponding delay lengths. Submit the values to run optimization and obtain estimated parameters and playback the resulting impulse response."
)
demo.launch(debug=True) |