Spaces:
Paused
Paused
| import os | |
| import tempfile | |
| import gradio as gr | |
| import torch | |
| from zerorvc import RVCTrainer, pretrained_checkpoints, SynthesizerTrnMs768NSFsid | |
| from zerorvc.trainer import TrainingCheckpoint | |
| from datasets import load_from_disk | |
| from huggingface_hub import snapshot_download | |
| from .zero import zero | |
| from .model import accelerator, device | |
| from .constants import BATCH_SIZE, ROOT_EXP_DIR, TRAINING_EPOCHS | |
| def train_model(exp_dir: str, progress=gr.Progress()): | |
| dataset = os.path.join(exp_dir, "dataset") | |
| if not os.path.exists(dataset): | |
| raise gr.Error("Dataset not found. Please prepare the dataset first.") | |
| ds = load_from_disk(dataset) | |
| checkpoint_dir = os.path.join(exp_dir, "checkpoints") | |
| trainer = RVCTrainer(checkpoint_dir) | |
| resume_from = trainer.latest_checkpoint() | |
| if resume_from is None: | |
| resume_from = pretrained_checkpoints() | |
| gr.Info(f"Starting training from pretrained checkpoints.") | |
| else: | |
| gr.Info(f"Resuming training from {resume_from}") | |
| tqdm = progress.tqdm( | |
| trainer.train( | |
| dataset=ds["train"], | |
| resume_from=resume_from, | |
| batch_size=BATCH_SIZE, | |
| epochs=TRAINING_EPOCHS, | |
| accelerator=accelerator, | |
| ), | |
| total=TRAINING_EPOCHS, | |
| unit="epochs", | |
| desc="Training", | |
| ) | |
| for ckpt in tqdm: | |
| info = f"Epoch: {ckpt.epoch} loss: (gen: {ckpt.loss_gen:.4f}, fm: {ckpt.loss_fm:.4f}, mel: {ckpt.loss_mel:.4f}, kl: {ckpt.loss_kl:.4f}, disc: {ckpt.loss_disc:.4f})" | |
| print(info) | |
| latest: TrainingCheckpoint = ckpt | |
| latest.save(trainer.checkpoint_dir) | |
| latest.G.save_pretrained(trainer.checkpoint_dir) | |
| result = f"{TRAINING_EPOCHS} epochs trained. Latest loss: (gen: {latest.loss_gen:.4f}, fm: {latest.loss_fm:.4f}, mel: {latest.loss_mel:.4f}, kl: {latest.loss_kl:.4f}, disc: {latest.loss_disc:.4f})" | |
| del trainer | |
| if device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| return result | |
| def upload_model(exp_dir: str, repo: str, hf_token: str): | |
| checkpoint_dir = os.path.join(exp_dir, "checkpoints") | |
| if not os.path.exists(checkpoint_dir): | |
| raise gr.Error("Model not found") | |
| gr.Info("Uploading model") | |
| model = SynthesizerTrnMs768NSFsid.from_pretrained(checkpoint_dir) | |
| model.push_to_hub(repo, token=hf_token, private=True) | |
| gr.Info("Model uploaded successfully") | |
| def upload_checkpoints(exp_dir: str, repo: str, hf_token: str): | |
| checkpoint_dir = os.path.join(exp_dir, "checkpoints") | |
| if not os.path.exists(checkpoint_dir): | |
| raise gr.Error("Checkpoints not found") | |
| gr.Info("Uploading checkpoints") | |
| trainer = RVCTrainer(checkpoint_dir) | |
| trainer.push_to_hub(repo, token=hf_token, private=True) | |
| gr.Info("Checkpoints uploaded successfully") | |
| def fetch_model(exp_dir: str, repo: str, hf_token: str): | |
| if not exp_dir: | |
| exp_dir = tempfile.mkdtemp(dir=ROOT_EXP_DIR) | |
| checkpoint_dir = os.path.join(exp_dir, "checkpoints") | |
| gr.Info("Fetching model") | |
| files = ["README.md", "config.json", "model.safetensors"] | |
| snapshot_download( | |
| repo, token=hf_token, local_dir=checkpoint_dir, allow_patterns=files | |
| ) | |
| gr.Info("Model fetched successfully") | |
| return exp_dir | |
| def fetch_checkpoints(exp_dir: str, repo: str, hf_token: str): | |
| if not exp_dir: | |
| exp_dir = tempfile.mkdtemp(dir=ROOT_EXP_DIR) | |
| checkpoint_dir = os.path.join(exp_dir, "checkpoints") | |
| gr.Info("Fetching checkpoints") | |
| snapshot_download(repo, token=hf_token, local_dir=checkpoint_dir) | |
| gr.Info("Checkpoints fetched successfully") | |
| return exp_dir | |
| class TrainTab: | |
| def __init__(self): | |
| pass | |
| def ui(self): | |
| gr.Markdown("# Training") | |
| gr.Markdown( | |
| "You can start training the model by clicking the button below. " | |
| f"Each time you click the button, the model will train for {TRAINING_EPOCHS} epochs, which takes about 3 minutes on ZeroGPU (A100). " | |
| ) | |
| with gr.Row(): | |
| self.train_btn = gr.Button(value="Train", variant="primary") | |
| self.result = gr.Textbox(label="Training Result", lines=3) | |
| gr.Markdown("## Sync Model and Checkpoints with Hugging Face") | |
| gr.Markdown( | |
| "You can upload the trained model and checkpoints to Hugging Face for sharing or further training." | |
| ) | |
| self.repo = gr.Textbox(label="Repository ID", placeholder="username/repo") | |
| with gr.Row(): | |
| self.upload_model_btn = gr.Button(value="Upload Model", variant="primary") | |
| self.upload_checkpoints_btn = gr.Button( | |
| value="Upload Checkpoints", variant="primary" | |
| ) | |
| with gr.Row(): | |
| self.fetch_mode_btn = gr.Button(value="Fetch Model", variant="primary") | |
| self.fetch_checkpoints_btn = gr.Button( | |
| value="Fetch Checkpoints", variant="primary" | |
| ) | |
| def build(self, exp_dir: gr.Textbox, hf_token: gr.Textbox): | |
| self.train_btn.click( | |
| fn=train_model, | |
| inputs=[exp_dir], | |
| outputs=[self.result], | |
| ) | |
| self.upload_model_btn.click( | |
| fn=upload_model, | |
| inputs=[exp_dir, self.repo, hf_token], | |
| ) | |
| self.upload_checkpoints_btn.click( | |
| fn=upload_checkpoints, | |
| inputs=[exp_dir, self.repo, hf_token], | |
| ) | |
| self.fetch_mode_btn.click( | |
| fn=fetch_model, | |
| inputs=[exp_dir, self.repo, hf_token], | |
| outputs=[exp_dir], | |
| ) | |
| self.fetch_checkpoints_btn.click( | |
| fn=fetch_checkpoints, | |
| inputs=[exp_dir, self.repo, hf_token], | |
| outputs=[exp_dir], | |
| ) | |