ohollo commited on
Commit
c5184df
·
1 Parent(s): 12a0c2c

Pad short chord sequences

Browse files
Files changed (4) hide show
  1. app.py +37 -11
  2. src/analysis.py +9 -0
  3. src/convert.py +17 -2
  4. src/methodology.py +1 -1
app.py CHANGED
@@ -18,8 +18,10 @@ logger = logging.getLogger(__name__)
18
  INDEX_LOCATION = './assets/chords_20251021.index'
19
  LABELS_LOCATION = './assets/all_labels.csv'
20
  LOOKUP_DS_NAME = 'ohollo/lmd_chords'
21
- CLOSE_THRESHOLD = 0.99
22
  SCALER_DICT_LOCATION = './assets/quantile_transformers.joblib'
 
 
23
 
24
  # Load models and data
25
  print("Loading models and data...")
@@ -33,18 +35,31 @@ lookup = ds['train'].to_pandas().set_index('track_id')[['title', 'artist']]
33
  ea = EmbeddingsAnalysis(index, all_labels, lookup, scalers, close_threshold=CLOSE_THRESHOLD)
34
  print("Models loaded successfully!")
35
 
 
 
 
 
36
  def _parse_chord_input(chord_text):
37
  if not chord_text.strip():
38
  return []
39
-
40
  # Try comma separation first, then space separation
41
  if ',' in chord_text:
42
  chords = [chord.strip() for chord in chord_text.split(',') if chord.strip()]
43
  else:
44
  chords = chord_text.split()
45
-
46
  return chords
47
 
 
 
 
 
 
 
 
 
 
48
  def _neighbours_to_dict(neighbours_list):
49
  result = []
50
  for neighbours_group in neighbours_list:
@@ -59,9 +74,9 @@ def _neighbours_to_dict(neighbours_list):
59
  result.append(group_result)
60
  return result
61
 
62
- def _perform_analysis(embeddings, sequence_lengths):
63
  scores = ea.get_scores(embeddings, sequence_lengths)
64
- neighbours = ea.get_neighbours(embeddings, limit=10)
65
  score = scores[0]
66
  # Convert neighbours to list of dicts for the first sequence
67
  neighbours_dict = []
@@ -81,11 +96,11 @@ def analyze_chord_sequence_text(chord_text: str) -> tuple[float, list[dict]]:
81
  Analyze a chord sequence from text input. Analysis is in the form of
82
  an originality score and a list of similar songs from a non-exhaustive
83
  sample set of songs in the system data store.
84
-
85
  Args:
86
- chord_text: Chord sequence input as text (comma-separated or space-separated). IMPORTANT: Unless length is explicitly specified, for accurate results, provide the complete sequence of chords that would feature in a typical song, e.g. "C, Am, F, G, C, Am, F, G, C, Am, F, G, C, Am, F, G, ..." NOT just "C, Am, F, G". This could be 20-30 chords for a three minute song.
87
  Returns:
88
- tuple[float, list[dict]]: Originality score and list of dictionaries, each representing a similar song. You may infer that some items in the list are essentially the same song - if so don't repeat them to the user. Also some songs are more famous than others in the results - if you come across a famous one, then highlight it.
89
  """
90
  logging.info(f"Analyzing chord sequence: {chord_text}")
91
  try:
@@ -93,7 +108,11 @@ def analyze_chord_sequence_text(chord_text: str) -> tuple[float, list[dict]]:
93
  if not chords:
94
  return None, None
95
  embeddings = get_embeddings_from_chord_sequences([chords])
96
- score, neighbours = _perform_analysis(embeddings, [len(chords)])
 
 
 
 
97
  return score, neighbours
98
  except Exception as e:
99
  logger.error(f"Error analyzing chord sequence: {e}")
@@ -178,7 +197,10 @@ with gr.Blocks(title="Harmonic Analysis Tool", theme=gr.themes.Soft()) as app:
178
  ["Am, F, C, G"],
179
  ["D, A, Bm, G"],
180
  ["Em, C, G, D"],
181
- ["F, C, Dm, Bb"]
 
 
 
182
  ],
183
  inputs=[chord_input]
184
  )
@@ -200,7 +222,11 @@ with gr.Blocks(title="Harmonic Analysis Tool", theme=gr.themes.Soft()) as app:
200
  file_display = gr.Markdown(label="File Info")
201
  audio_scores_output = gr.Markdown(label="Analysis Results")
202
  audio_neighbours_output = gr.Markdown(label="Similar Songs")
203
-
 
 
 
 
204
  # Event handlers
205
  analyze_btn.click(
206
  fn=_format_chord_analysis_for_ui,
 
18
  INDEX_LOCATION = './assets/chords_20251021.index'
19
  LABELS_LOCATION = './assets/all_labels.csv'
20
  LOOKUP_DS_NAME = 'ohollo/lmd_chords'
21
+ CLOSE_THRESHOLD = 0.9
22
  SCALER_DICT_LOCATION = './assets/quantile_transformers.joblib'
23
+ MIN_SEQUENCE_LENGTH_FOR_NEIGHBOURS = 24
24
+ HOW_IT_WORKS_MD_LOCATION = './how_it_works.md'
25
 
26
  # Load models and data
27
  print("Loading models and data...")
 
35
  ea = EmbeddingsAnalysis(index, all_labels, lookup, scalers, close_threshold=CLOSE_THRESHOLD)
36
  print("Models loaded successfully!")
37
 
38
+ # Load how it works content
39
+ with open(HOW_IT_WORKS_MD_LOCATION, 'r') as f:
40
+ how_it_works_content = f.read()
41
+
42
  def _parse_chord_input(chord_text):
43
  if not chord_text.strip():
44
  return []
45
+
46
  # Try comma separation first, then space separation
47
  if ',' in chord_text:
48
  chords = [chord.strip() for chord in chord_text.split(',') if chord.strip()]
49
  else:
50
  chords = chord_text.split()
51
+
52
  return chords
53
 
54
+
55
+ def _pad_sequence_by_repetition(sequence, min_length):
56
+ if len(sequence) >= min_length:
57
+ return sequence
58
+ result = sequence.copy()
59
+ while len(result) < min_length:
60
+ result.extend(sequence)
61
+ return result
62
+
63
  def _neighbours_to_dict(neighbours_list):
64
  result = []
65
  for neighbours_group in neighbours_list:
 
74
  result.append(group_result)
75
  return result
76
 
77
+ def _perform_analysis(embeddings, sequence_lengths, neighbour_embeddings=None):
78
  scores = ea.get_scores(embeddings, sequence_lengths)
79
+ neighbours = ea.get_neighbours(neighbour_embeddings if neighbour_embeddings is not None else embeddings, limit=10)
80
  score = scores[0]
81
  # Convert neighbours to list of dicts for the first sequence
82
  neighbours_dict = []
 
96
  Analyze a chord sequence from text input. Analysis is in the form of
97
  an originality score and a list of similar songs from a non-exhaustive
98
  sample set of songs in the system data store.
99
+
100
  Args:
101
+ chord_text: Chord sequence input as text (comma-separated or space-separated). IMPORTANT: Unless length is explicitly specified, for accurate results, provide the complete sequence of chords that would feature in a typical song, e.g. "C, Am, F, G, C, Am, F, G, C, Am, F, G, C, Am, F, G, ..." NOT just "C, Am, F, G". This could be 20-30 chords for a three minute song.
102
  Returns:
103
+ tuple[float, list[dict]]: Originality score and list of dictionaries, each representing a similar song. You may infer that some items in the list are essentially the same song - if so don't repeat them to the user. Also some songs are more famous than others in the results - if you come across a famous one, then highlight it.
104
  """
105
  logging.info(f"Analyzing chord sequence: {chord_text}")
106
  try:
 
108
  if not chords:
109
  return None, None
110
  embeddings = get_embeddings_from_chord_sequences([chords])
111
+ neighbour_embeddings = None
112
+ if len(chords) < MIN_SEQUENCE_LENGTH_FOR_NEIGHBOURS:
113
+ padded_chords = _pad_sequence_by_repetition(chords, MIN_SEQUENCE_LENGTH_FOR_NEIGHBOURS)
114
+ neighbour_embeddings = get_embeddings_from_chord_sequences([padded_chords])
115
+ score, neighbours = _perform_analysis(embeddings, [len(chords)], neighbour_embeddings)
116
  return score, neighbours
117
  except Exception as e:
118
  logger.error(f"Error analyzing chord sequence: {e}")
 
197
  ["Am, F, C, G"],
198
  ["D, A, Bm, G"],
199
  ["Em, C, G, D"],
200
+ ["F, C, Dm, Bb"],
201
+ ["A7, D7, A7, E7, D7, A7"],
202
+ ["Am, F, C, G, Am, F, C, G, C, G, Am, F, C, G, Am, F, "
203
+ "Am, F, C, G, Am, F, C, G, C, G, Am, F, C, G, Am, F"]
204
  ],
205
  inputs=[chord_input]
206
  )
 
222
  file_display = gr.Markdown(label="File Info")
223
  audio_scores_output = gr.Markdown(label="Analysis Results")
224
  audio_neighbours_output = gr.Markdown(label="Similar Songs")
225
+
226
+ # Tab 3: How It Works
227
+ with gr.TabItem("How It Works"):
228
+ gr.Markdown(how_it_works_content)
229
+
230
  # Event handlers
231
  analyze_btn.click(
232
  fn=_format_chord_analysis_for_ui,
src/analysis.py CHANGED
@@ -8,6 +8,15 @@ from src.scorer import EmbeddingsOriginalityScorer
8
 
9
 
10
  class EmbeddingsAnalysis:
 
 
 
 
 
 
 
 
 
11
  def __init__(self, index, all_labels, lookup, scalers, close_threshold=0.95):
12
  all_labels_np = all_labels['track_id'].to_numpy()
13
  all_lengths_np = all_labels['length'].to_numpy()
 
8
 
9
 
10
  class EmbeddingsAnalysis:
11
+ """
12
+ Facade for analyzing embeddings, combining neighbor search and originality scoring.
13
+
14
+ :param index: FAISS index for similarity search.
15
+ :param all_labels: DataFrame containing 'track_id' and 'length' columns for indexed entries.
16
+ :param lookup: Pandas DataFrame containing metadata for each indexed entry.
17
+ :param scalers: Dictionary mapping length ranges to quantile transformers for score normalization.
18
+ :param close_threshold: Similarity threshold for neighbor search.
19
+ """
20
  def __init__(self, index, all_labels, lookup, scalers, close_threshold=0.95):
21
  all_labels_np = all_labels['track_id'].to_numpy()
22
  all_lengths_np = all_labels['length'].to_numpy()
src/convert.py CHANGED
@@ -2,18 +2,33 @@ import numpy as np
2
  from gradio_client import Client
3
  import os
4
  import json
 
 
5
 
6
  from chord_extractor.extractors import Chordino
7
  from chord_extractor import clear_conversion_cache, LabelledChordSequence
8
 
9
- _CONSTANT_GAP_SECS = 2
10
  _SEQ_EMBED_SPACE = 'ohollo/chord-seq-embed'
11
  _POST_PROCESS_CHORD_LEN_RATIO = 0.7
 
 
12
 
13
 
14
- _client = Client(_SEQ_EMBED_SPACE)
 
 
 
 
 
 
 
 
 
 
15
 
16
  def _call_embedding_service(chords_w_timestamps):
 
17
  result = _client.predict(json.dumps(chords_w_timestamps), api_name="/predict")
18
  return json.loads(result)
19
 
 
2
  from gradio_client import Client
3
  import os
4
  import json
5
+ import time
6
+ import httpx
7
 
8
  from chord_extractor.extractors import Chordino
9
  from chord_extractor import clear_conversion_cache, LabelledChordSequence
10
 
11
+ _CONSTANT_GAP_SECS = 2
12
  _SEQ_EMBED_SPACE = 'ohollo/chord-seq-embed'
13
  _POST_PROCESS_CHORD_LEN_RATIO = 0.7
14
+ _MAX_RETRIES = 3
15
+ _RETRY_DELAY_SECS = 2
16
 
17
 
18
+ def _create_client():
19
+ for attempt in range(_MAX_RETRIES):
20
+ try:
21
+ return Client(_SEQ_EMBED_SPACE)
22
+ except httpx.ReadTimeout:
23
+ if attempt < _MAX_RETRIES - 1:
24
+ time.sleep(_RETRY_DELAY_SECS)
25
+ else:
26
+ raise
27
+
28
+ _client = _create_client()
29
 
30
  def _call_embedding_service(chords_w_timestamps):
31
+ print(chords_w_timestamps)
32
  result = _client.predict(json.dumps(chords_w_timestamps), api_name="/predict")
33
  return json.loads(result)
34
 
src/methodology.py CHANGED
@@ -12,7 +12,7 @@ class _TransformerProtocol:
12
 
13
  class CountBasedMethodology(ABC):
14
  @abstractmethod
15
- def execute(self, neighbours_df: pd.DataFrame) -> pd.Series:
16
  ...
17
 
18
  @abstractmethod
 
12
 
13
  class CountBasedMethodology(ABC):
14
  @abstractmethod
15
+ def execute(self, neighbours_df: pd.DataFrame, lengths: pd.Series) -> pd.Series:
16
  ...
17
 
18
  @abstractmethod