Spaces:
Running
Running
| import logging | |
| import numpy as np | |
| from gradio_client import Client | |
| import os | |
| import json | |
| import time | |
| import httpx | |
| logger = logging.getLogger(__name__) | |
| from chord_extractor.extractors import Chordino | |
| from chord_extractor import clear_conversion_cache, LabelledChordSequence | |
| _CONSTANT_GAP_SECS = 2 | |
| _SEQ_EMBED_SPACE = 'ohollo/chord-seq-embed' | |
| _POST_PROCESS_CHORD_LEN_RATIO = 0.7 | |
| _MAX_RETRIES = 3 | |
| _RETRY_DELAY_SECS = 2 | |
| def _create_client(): | |
| for attempt in range(_MAX_RETRIES): | |
| try: | |
| return Client(_SEQ_EMBED_SPACE) | |
| except httpx.ReadTimeout: | |
| if attempt < _MAX_RETRIES - 1: | |
| time.sleep(_RETRY_DELAY_SECS) | |
| else: | |
| raise | |
| _client = _create_client() | |
| def _call_embedding_service(chords_w_timestamps): | |
| logger.info(chords_w_timestamps) | |
| result = _client.predict(json.dumps(chords_w_timestamps), api_name="/predict") | |
| return json.loads(result) | |
| def get_embeddings_from_chord_sequences(chord_sequences: list[list[str]], constant_gap_secs: float = _CONSTANT_GAP_SECS) -> np.ndarray: | |
| """ | |
| Converts chord sequences into its corresponding embeddings. | |
| :param chord_sequence: List of chords representing the chord sequence. | |
| :return: 2-d numpy array of embeddings per chord sequence. | |
| """ | |
| chords_w_timestamps = [ | |
| {'label': chord_sequence, 'timestamp': [i* constant_gap_secs for i, _ in enumerate(chord_sequence)]} | |
| for chord_sequence in chord_sequences | |
| ] | |
| return np.array(_call_embedding_service(chords_w_timestamps)['embeddings']) | |
| def get_embedding_from_filepaths(file_paths: list[str]) -> tuple[np.ndarray, list[int]]: | |
| """ | |
| Reads chord sequences from a given filepath and converts them into embeddings. | |
| :param file_paths: List of paths to the audio files. Can be anything supported by chord-extractor - .mid, .wav, .mp3, .ogg | |
| :return: 2-d numpy array of embeddings per chord sequence. | |
| """ | |
| chords_w_timestamps = [] | |
| chord_lengths = [] | |
| chordino = Chordino() | |
| for file_path in file_paths: | |
| if not os.path.isfile(file_path): | |
| raise FileNotFoundError(f"File not found: {file_path}") | |
| conversion_file_path = chordino.preprocess(file_path) | |
| chords = chordino.extract(conversion_file_path if conversion_file_path else file_path) | |
| chords_w_timestamps.append({ | |
| 'label': [chord.chord for chord in chords], | |
| 'timestamp': [chord.timestamp for chord in chords] | |
| }) | |
| chord_lengths.append(int(len(chords) * _POST_PROCESS_CHORD_LEN_RATIO)) | |
| return np.array(_call_embedding_service(chords_w_timestamps)['embeddings']), chord_lengths |