Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import shutil | |
| from dataclasses import field | |
| from pathlib import Path | |
| from typing import Dict, List | |
| import torch | |
| from datasets import concatenate_datasets, load_from_disk | |
| from wandb import Audio | |
| from datasets import load_from_disk, concatenate_datasets | |
| def list_field(default=None, metadata=None): | |
| return field(default_factory=lambda: default, metadata=metadata) | |
| _RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$") | |
| CHECKPOINT_CODEC_PREFIX = "checkpoint" | |
| _RE_CODEC_CHECKPOINT = re.compile(r"^checkpoint-(\d+)$") | |
| def get_last_checkpoint(folder): | |
| content = os.listdir(folder) | |
| checkpoints = [ | |
| path | |
| for path in content | |
| if _RE_CHECKPOINT.search(path) is not None and os.path.isdir(os.path.join(folder, path)) | |
| ] | |
| if len(checkpoints) == 0: | |
| return | |
| return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0]))) | |
| def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]: | |
| """Helper function to sort saved checkpoints from oldest to newest.""" | |
| ordering_and_checkpoint_path = [] | |
| glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)] | |
| for path in glob_checkpoints: | |
| regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) | |
| if regex_match is not None and regex_match.groups() is not None: | |
| ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) | |
| checkpoints_sorted = sorted(ordering_and_checkpoint_path) | |
| checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] | |
| return checkpoints_sorted | |
| def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint", logger=None) -> None: | |
| """Helper function to delete old checkpoints.""" | |
| if save_total_limit is None or save_total_limit <= 0: | |
| return | |
| # Check if we should delete older checkpoint(s) | |
| checkpoints_sorted = sorted_checkpoints(output_dir=output_dir, checkpoint_prefix=checkpoint_prefix) | |
| if len(checkpoints_sorted) <= save_total_limit: | |
| return | |
| number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit) | |
| checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] | |
| for checkpoint in checkpoints_to_be_deleted: | |
| logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") | |
| shutil.rmtree(checkpoint, ignore_errors=True) | |
| def save_codec_checkpoint(output_dir, dataset, step): | |
| checkpoint_path = f"{CHECKPOINT_CODEC_PREFIX}-{step}" | |
| output_path = os.path.join(output_dir, checkpoint_path) | |
| dataset.save_to_disk(output_path) | |
| def load_codec_checkpoint(checkpoint_path): | |
| dataset = load_from_disk(checkpoint_path) | |
| return dataset | |
| def sorted_codec_checkpoints(output_dir=None) -> List[str]: | |
| """Helper function to sort saved checkpoints from oldest to newest.""" | |
| ordering_and_checkpoint_path = [] | |
| glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{CHECKPOINT_CODEC_PREFIX}-*")] | |
| for path in glob_checkpoints: | |
| regex_match = re.match(f".*{CHECKPOINT_CODEC_PREFIX}-([0-9]+)", path) | |
| if regex_match is not None and regex_match.groups() is not None: | |
| ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) | |
| checkpoints_sorted = sorted(ordering_and_checkpoint_path) | |
| checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] | |
| return checkpoints_sorted | |
| def load_all_codec_checkpoints(output_dir=None) -> List[str]: | |
| """Helper function to load and concat all checkpoints.""" | |
| checkpoints_sorted = sorted_codec_checkpoints(output_dir=output_dir) | |
| datasets = [load_from_disk(checkpoint) for checkpoint in checkpoints_sorted] | |
| datasets = concatenate_datasets(datasets, axis=0) | |
| return datasets | |
| def get_last_codec_checkpoint_step(folder) -> int: | |
| if not os.path.exists(folder) or not os.path.isdir(folder): | |
| os.makedirs(folder, exist_ok=True) | |
| return 0 | |
| content = os.listdir(folder) | |
| checkpoints = [path for path in content if _RE_CODEC_CHECKPOINT.search(path) is not None] | |
| if len(checkpoints) == 0: | |
| return 0 | |
| last_checkpoint = os.path.join( | |
| folder, max(checkpoints, key=lambda x: int(_RE_CODEC_CHECKPOINT.search(x).groups()[0])) | |
| ) | |
| # Find num steps saved state string pattern | |
| pattern = r"checkpoint-(\d+)" | |
| match = re.search(pattern, last_checkpoint) | |
| cur_step = int(match.group(1)) | |
| return cur_step | |
| def log_metric( | |
| accelerator, | |
| metrics: Dict, | |
| train_time: float, | |
| step: int, | |
| epoch: int, | |
| learning_rate: float = None, | |
| prefix: str = "train", | |
| ): | |
| """Helper function to log all training/evaluation metrics with the correct prefixes and styling.""" | |
| log_metrics = {} | |
| for k, v in metrics.items(): | |
| if "codebook" in k: | |
| log_metrics[f"codebook_{prefix}/{k}"] = v | |
| else: | |
| log_metrics[f"{prefix}/{k}"] = v | |
| log_metrics[f"{prefix}/time"] = train_time | |
| log_metrics[f"{prefix}/epoch"] = epoch | |
| if learning_rate is not None: | |
| log_metrics[f"{prefix}/learning_rate"] = learning_rate | |
| accelerator.log(log_metrics, step=step) | |
| def log_pred( | |
| accelerator, | |
| pred_descriptions: List[str], | |
| pred_prompts: List[str], | |
| transcriptions: List[str], | |
| audios: List[torch.Tensor], | |
| si_sdr_measures: List[float], | |
| sampling_rate: int, | |
| step: int, | |
| prefix: str = "eval", | |
| num_lines: int = 200000, | |
| ): | |
| """Helper function to log target/predicted transcriptions to weights and biases (wandb).""" | |
| if accelerator.is_main_process: | |
| wandb_tracker = accelerator.get_tracker("wandb") | |
| # pretty name for current step: step 50000 -> step 50k | |
| cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step | |
| prefix_pretty = prefix.replace("/", "-") | |
| if si_sdr_measures is None: | |
| # convert str data to a wandb compatible format | |
| str_data = [ | |
| [pred_descriptions[i], pred_prompts[i], transcriptions[i]] for i in range(len(pred_descriptions)) | |
| ] | |
| # log as a table with the appropriate headers | |
| wandb_tracker.log_table( | |
| table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}", | |
| columns=["Target descriptions", "Target prompts", "Predicted transcriptions"], | |
| data=str_data[:num_lines], | |
| step=step, | |
| commit=False, | |
| ) | |
| else: | |
| # convert str data to a wandb compatible format | |
| str_data = [ | |
| [pred_descriptions[i], pred_prompts[i], transcriptions[i], si_sdr_measures[i]] | |
| for i in range(len(pred_descriptions)) | |
| ] | |
| # log as a table with the appropriate headers | |
| wandb_tracker.log_table( | |
| table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}", | |
| columns=["Target descriptions", "Target prompts", "Predicted transcriptions", "Noise estimation"], | |
| data=str_data[:num_lines], | |
| step=step, | |
| commit=False, | |
| ) | |
| # wandb can only loads 100 audios per step | |
| wandb_tracker.log( | |
| { | |
| "Speech samples": [ | |
| Audio( | |
| audio, | |
| caption=f"{pred_prompts[i]} --- DESCRIPTION: {pred_descriptions[i]}", | |
| sample_rate=sampling_rate, | |
| ) | |
| for (i, audio) in enumerate(audios[: min(len(audios), 100)]) | |
| ] | |
| }, | |
| step=step, | |
| ) | |