Spaces:
Runtime error
Runtime error
| import os | |
| import yaml | |
| from argparse import Action | |
| from ast import literal_eval | |
| from torch.cuda import is_available | |
| from torch import get_num_threads, set_num_threads | |
| CACHE_DIR = os.getenv( | |
| "AUTOT_CACHE", | |
| os.path.expanduser("~/.cache/torch/models"), | |
| ) | |
| os.environ["PYANNOTE_CACHE"] = os.getenv( | |
| "PYANNOTE_CACHE", | |
| os.path.join(CACHE_DIR, "pyannote"), | |
| ) | |
| WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") | |
| PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") | |
| # PYANNOTE_DEFAULT_CONFIG = ('pyannote/speaker-diarization-3.1','Jaikinator/ScrAIbe') | |
| PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ | |
| if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ | |
| else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1') | |
| SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu") | |
| SCRAIBE_NUM_THREADS = os.getenv("SCRAIBE_NUM_THREADS", min(8, get_num_threads())) | |
| def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: | |
| """Configure diarization pipeline from a YAML file. | |
| This function updates the YAML file to use the given segmentation model | |
| offline, and avoids manual file manipulation. | |
| Args: | |
| file_path (str): Path to the YAML file. | |
| path_to_segmentation (str, optional): Optional path to the segmentation model. | |
| Raises: | |
| FileNotFoundError: If the segmentation model file is not found. | |
| """ | |
| with open(file_path, "r") as stream: | |
| yml = yaml.safe_load(stream) | |
| segmentation_path = path_to_segmentation or os.path.join( | |
| PYANNOTE_DEFAULT_PATH, "pytorch_model.bin") | |
| yml["pipeline"]["params"]["segmentation"] = segmentation_path | |
| if not os.path.exists(segmentation_path): | |
| raise FileNotFoundError( | |
| f"Segmentation model not found at {segmentation_path}") | |
| with open(file_path, "w") as stream: | |
| yaml.dump(yml, stream) | |
| def set_threads(parse_threads=None, | |
| yaml_threads=None): | |
| global SCRAIBE_NUM_THREADS | |
| if parse_threads is not None: | |
| if not isinstance(parse_threads, int): | |
| # probably covered with int type of parser arg | |
| raise ValueError(f"Type of --num-threads must be int, but the type is {type(parse_threads)}") | |
| elif parse_threads < 1: | |
| raise ValueError(f"Number of threads must be a positive integer, {parse_threads} was given") | |
| else: | |
| set_num_threads(parse_threads) | |
| SCRAIBE_NUM_THREADS = parse_threads | |
| elif yaml_threads is not None: | |
| if not isinstance(yaml_threads, int): | |
| raise ValueError(f"Type of num_threads must be int, but the type is {type(yaml_threads)}") | |
| elif yaml_threads < 1: | |
| raise ValueError(f"Number of threads must be a positive integer, {yaml_threads} was given") | |
| else: | |
| set_num_threads(yaml_threads) | |
| SCRAIBE_NUM_THREADS = yaml_threads | |
| class ParseKwargs(Action): | |
| """ | |
| Custom argparse action to parse keyword arguments. | |
| """ | |
| def __call__(self, parser, namespace, values, option_string=None): | |
| setattr(namespace, self.dest, dict()) | |
| for value in values: | |
| key, value = value.split('=') | |
| try: | |
| value = literal_eval(value) | |
| except: | |
| pass | |
| getattr(namespace, self.dest)[key] = value | |