ArianatorQualquer commited on
Commit
0ec98a2
·
verified ·
1 Parent(s): e1952b6

Create gui-gradio-full.py

Browse files
Files changed (1) hide show
  1. gui-gradio-full.py +214 -0
gui-gradio-full.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+ import os
4
+ import threading
5
+ import queue
6
+ import json
7
+
8
+ # Função para rodar subprocessos com saída ao vivo
9
+ def run_subprocess(cmd, output_queue):
10
+ try:
11
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
12
+ for line in process.stdout:
13
+ output_queue.put(line)
14
+ process.wait()
15
+ if process.returncode == 0:
16
+ output_queue.put("Process completed successfully!")
17
+ else:
18
+ output_queue.put(f"Process failed with return code {process.returncode}")
19
+ except Exception as e:
20
+ output_queue.put(f"An error occurred: {str(e)}")
21
+
22
+ # Gerenciador de configurações e presets
23
+ class ConfigManager:
24
+ def __init__(self, filepath="settings.json"):
25
+ self.filepath = filepath
26
+ self.settings = self.load_settings()
27
+
28
+ def save_settings(self):
29
+ with open(self.filepath, "w") as f:
30
+ json.dump(self.settings, f, indent=2, ensure_ascii=False)
31
+
32
+ def load_settings(self):
33
+ if os.path.exists(self.filepath):
34
+ with open(self.filepath, "r") as f:
35
+ return json.load(f)
36
+ return {"saved_combinations": {}}
37
+
38
+ def get_saved_combinations(self):
39
+ return self.settings.get("saved_combinations", {})
40
+
41
+ def add_combination(self, name, combination):
42
+ self.settings["saved_combinations"][name] = combination
43
+ self.save_settings()
44
+
45
+ config_manager = ConfigManager()
46
+
47
+ # Funções principais (Treinamento, Inferência, etc.)
48
+ def run_training(
49
+ model_type, dataset_type, config_path, start_checkpoint, metrics, results_path,
50
+ data_paths, valid_paths, num_workers, device_ids
51
+ ):
52
+ if not (model_type and dataset_type and config_path and results_path and data_paths and valid_paths and metrics):
53
+ return "Error: Missing required inputs for training."
54
+
55
+ cmd = [
56
+ "python", "train.py",
57
+ "--model_type", model_type,
58
+ "--dataset_type", str(dataset_type),
59
+ "--config_path", config_path,
60
+ "--results_path", results_path,
61
+ "--data_path", *data_paths.split(';'),
62
+ "--valid_path", *valid_paths.split(';'),
63
+ "--metrics", *metrics.split(';'),
64
+ "--num_workers", str(num_workers),
65
+ "--device_ids", device_ids
66
+ ]
67
+
68
+ if start_checkpoint:
69
+ cmd += ["--start_check_point", start_checkpoint]
70
+
71
+ output_queue = queue.Queue()
72
+ threading.Thread(target=run_subprocess, args=(cmd, output_queue), daemon=True).start()
73
+
74
+ output = []
75
+ while not output_queue.empty():
76
+ output.append(output_queue.get())
77
+ return "\n".join(output)
78
+
79
+ def run_inference(
80
+ model_type, config_path, start_checkpoint, input_folder, store_dir,
81
+ extract_instrumental, additional_metrics
82
+ ):
83
+ if not (model_type and config_path and input_folder and store_dir):
84
+ return "Error: Missing required inputs for inference."
85
+
86
+ cmd = [
87
+ "python", "inference.py",
88
+ "--model_type", model_type,
89
+ "--config_path", config_path,
90
+ "--input_folder", input_folder,
91
+ "--store_dir", store_dir
92
+ ]
93
+
94
+ if start_checkpoint:
95
+ cmd += ["--start_check_point", start_checkpoint]
96
+ if extract_instrumental:
97
+ cmd += ["--extract_instrumental"]
98
+ if additional_metrics:
99
+ cmd += ["--metrics", *additional_metrics.split(';')]
100
+
101
+ output_queue = queue.Queue()
102
+ threading.Thread(target=run_subprocess, args=(cmd, output_queue), daemon=True).start()
103
+
104
+ output = []
105
+ while not output_queue.empty():
106
+ output.append(output_queue.get())
107
+ return "\n".join(output)
108
+
109
+ # Função para salvar presets
110
+ def add_preset(name, model_type, config_path, checkpoint, metrics):
111
+ if not name:
112
+ return "Error: Name is required to save a preset."
113
+
114
+ config_manager.add_combination(name, {
115
+ "model_type": model_type,
116
+ "config_path": config_path,
117
+ "checkpoint": checkpoint,
118
+ "metrics": metrics
119
+ })
120
+ return f"Preset '{name}' saved successfully."
121
+
122
+ # Carregar presets
123
+ saved_presets = config_manager.get_saved_combinations()
124
+ preset_names = list(saved_presets.keys())
125
+
126
+ def load_preset(name):
127
+ if name in saved_presets:
128
+ preset = saved_presets[name]
129
+ return preset["model_type"], preset["config_path"], preset["checkpoint"], preset["metrics"]
130
+ return "", "", "", ""
131
+
132
+ # Interface Gradio
133
+ with gr.Blocks() as gui:
134
+ gr.Markdown("# 🎵 Music Source Separation GUI - Complete Version")
135
+
136
+ # Aba: Treinamento
137
+ with gr.Tab("Training"):
138
+ with gr.Row():
139
+ model_type = gr.Dropdown(
140
+ label="Model Type",
141
+ choices=["apollo", "bandit", "bandit_v2", "bs_roformer", "htdemucs", "mdx23c", "mel_band_roformer", "scnet", "scnet_unofficial", "segm_models", "swin_upernet", "torchseg"]
142
+ )
143
+ dataset_type = gr.Dropdown(label="Dataset Type", choices=["1", "2", "3", "4"])
144
+
145
+ with gr.Row():
146
+ config_path = gr.Textbox(label="Config File Path")
147
+ start_checkpoint = gr.Textbox(label="Start Checkpoint (Optional)")
148
+
149
+ metrics = gr.Textbox(label="Metrics (comma-separated, e.g., SDR, SIR)")
150
+ results_path = gr.Textbox(label="Results Path")
151
+ data_paths = gr.Textbox(label="Data Paths (separated by ';')")
152
+ valid_paths = gr.Textbox(label="Validation Paths (separated by ';')")
153
+
154
+ num_workers = gr.Number(label="Number of Workers", value=4)
155
+ device_ids = gr.Textbox(label="Device IDs (comma-separated)", value="0")
156
+ train_output = gr.Textbox(label="Training Output", interactive=False)
157
+
158
+ gr.Button("Run Training").click(
159
+ run_training,
160
+ inputs=[
161
+ model_type, dataset_type, config_path, start_checkpoint, metrics,
162
+ results_path, data_paths, valid_paths, num_workers, device_ids
163
+ ],
164
+ outputs=train_output
165
+ )
166
+
167
+ # Aba: Inferência
168
+ with gr.Tab("Inference"):
169
+ with gr.Row():
170
+ infer_model_type = gr.Dropdown(
171
+ label="Model Type",
172
+ choices=["apollo", "bandit", "bandit_v2", "bs_roformer", "htdemucs", "mdx23c", "mel_band_roformer", "scnet", "scnet_unofficial", "segm_models", "swin_upernet", "torchseg"]
173
+ )
174
+ infer_config_path = gr.Textbox(label="Config File Path")
175
+
176
+ infer_checkpoint = gr.Textbox(label="Start Checkpoint (Optional)")
177
+ input_folder = gr.Textbox(label="Input Folder")
178
+ store_dir = gr.Textbox(label="Output Folder")
179
+ extract_instrumental = gr.Checkbox(label="Extract Instrumental", value=False)
180
+ additional_metrics = gr.Textbox(label="Additional Metrics (comma-separated)")
181
+
182
+ infer_output = gr.Textbox(label="Inference Output", interactive=False)
183
+
184
+ gr.Button("Run Inference").click(
185
+ run_inference,
186
+ inputs=[
187
+ infer_model_type, infer_config_path, infer_checkpoint, input_folder,
188
+ store_dir, extract_instrumental, additional_metrics
189
+ ],
190
+ outputs=infer_output
191
+ )
192
+
193
+ # Aba: Presets
194
+ with gr.Tab("Presets"):
195
+ preset_name = gr.Textbox(label="Preset Name")
196
+ preset_model_type = gr.Textbox(label="Model Type")
197
+ preset_config_path = gr.Textbox(label="Config Path")
198
+ preset_checkpoint = gr.Textbox(label="Checkpoint")
199
+ preset_metrics = gr.Textbox(label="Metrics")
200
+ preset_feedback = gr.Textbox(label="Feedback", interactive=False)
201
+
202
+ gr.Button("Save Preset").click(
203
+ add_preset,
204
+ inputs=[preset_name, preset_model_type, preset_config_path, preset_checkpoint, preset_metrics],
205
+ outputs=preset_feedback
206
+ )
207
+
208
+ preset_dropdown = gr.Dropdown(choices=preset_names, label="Load Preset")
209
+ gr.Button("Load Preset").click(
210
+ load_preset, inputs=preset_dropdown,
211
+ outputs=[preset_model_type, preset_config_path, preset_checkpoint, preset_metrics]
212
+ )
213
+
214
+ gui.launch()