Krish Shah-Nathwani commited on
Commit
3cd2d15
·
1 Parent(s): 06b222f

updated app with trained classifier model trained locally

Browse files
app.py CHANGED
@@ -1,75 +1,46 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- # Load FLAN-T5 small locally
5
- generator = pipeline("text2text-generation", model="google/flan-t5-small")
6
-
7
- # System-style chord identification prompt
8
- CHORD_SYSTEM_PROMPT = """
9
- You are a music theory expert specialized in chord identification for a **Chord Bot** application.
10
- Your job is to take an unordered list of pitch names (notes) and return the most likely chord name(s).
11
-
12
- ## Scope & Assumptions
13
- - Accept any number of notes ≥ 2. Ignore octaves; treat input pitch classes only.
14
- - Handle enharmonics (C# = Db), accidentals (#, b, ♯, ♭), and duplicates.
15
- - Recognize: triads (maj, min, dim, aug), sevenths (maj7, 7, m7, mMaj7, dim7, m7b5),
16
- extended/altered chords (add2/add9, 6, 9, 11, 13), suspensions (sus2/sus4), power chords (5),
17
- and common alterations (b5, #5, b9, #9, #11, b13).
18
- - Detect inversions and slash chords (C/E). Prefer root-position naming but include inversion when clear.
19
- - Key-agnostic: infer chord quality from pitch-set intervals; do not assume a key unless given.
20
- - If multiple valid interpretations exist, rank by plausibility and provide alternates.
21
-
22
- ## Output Style
23
- - Primary answer: concise chord name (e.g., Cm7, G7b9, Fadd9, Dsus4/F#).
24
- - Also include: a short rationale (intervals from root), and 0–1 confidence.
25
- - If ambiguous: list up to 3 alternate chord names with brief reasons.
26
- - Prefer ASCII chord symbols (#, b).
27
-
28
- ## Interaction Rules
29
- - Be deterministic and concise. Do not ask clarifying questions.
30
- - Accept inputs like: "C E G", ["C#", "E", "G", "A"], or "Db F Ab C".
31
- - Normalize enharmonics to the spelling that best matches the chord quality.
32
-
33
- ## Output Format
34
- Return a compact JSON-like block:
35
-
36
- chord: "<PrimaryChordName>"
37
- confidence: <0.0–1.0>
38
- explanation: "Root <X>; intervals <...>."
39
- alternates: ["<Alt1>", "<Alt2>"]
40
-
41
- ## Examples
42
- - Input: C E G → chord: "C major"
43
- - Input: D F# A C → chord: "D7"
44
- - Input: C Eb G Bb → chord: "Cm7"
45
- - Input: C D G → chord: "Csus2", alternates: ["Gadd4/C"]
46
- - Input: E G B D (bass G) → chord: "Em7/G"
47
- """
48
-
49
- def chord_bot(message: str, history: list[tuple[str, str]]):
50
- # Embed user input into the system prompt
51
- chord_prompt = f"{CHORD_SYSTEM_PROMPT}\n\nInput: {message}\nOutput:\n"
52
-
53
- response = generator(
54
- chord_prompt,
55
- max_new_tokens=128,
56
- temperature=0.0, # deterministic
57
- top_p=0.95
58
- )
59
-
60
- raw_text = response[0]["generated_text"]
61
- # Extract the block after "Output:" if present
62
- if "Output:" in raw_text:
63
- answer = raw_text.split("Output:")[-1].strip()
64
- else:
65
- answer = raw_text.strip()
66
- return answer
67
-
68
- # Gradio Chat UI
69
  chatbot = gr.ChatInterface(
70
  fn=chord_bot,
71
- title="🎶 FLAN-T5 Chord Bot",
72
- description="Type 2+ notes (e.g., C E G or Db F Ab C) and I'll identify the chord."
73
  )
74
 
75
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import joblib
3
+ import numpy as np
4
+ import os
5
+ import re
6
+ import subprocess
7
+
8
+ NAME_TO_PC = {
9
+ "C":0,"C#":1,"Db":1,"D":2,"D#":3,"Eb":3,"E":4,"F":5,"F#":6,"Gb":6,
10
+ "G":7,"G#":8,"Ab":8,"A":9,"A#":10,"Bb":10,"B":11
11
+ }
12
+ NOTE_TOKEN_RE = re.compile(r"[A-Ga-g](?:#|b)?")
13
+
14
+ def notes_to_vector(notes_str: str):
15
+ tokens = NOTE_TOKEN_RE.findall(notes_str)
16
+ pcs = [NAME_TO_PC.get(t.upper(), None) for t in tokens]
17
+ pcs = [p for p in pcs if p is not None]
18
+ vec = np.zeros(12)
19
+ for p in pcs:
20
+ vec[p] = 1
21
+ return vec
22
+
23
+ MODEL_PATH = "chord_classifier.pkl"
24
+
25
+ def load_model():
26
+ if not os.path.exists(MODEL_PATH):
27
+ print("⚠️ chord_classifier.pkl not found. Training model...")
28
+ subprocess.run(["python", "train_chord_model.py"], check=True)
29
+ return joblib.load(MODEL_PATH)
30
+
31
+ clf = load_model()
32
+
33
+ def chord_bot(message: str, history: list[tuple[str,str]]):
34
+ vec = notes_to_vector(message)
35
+ if np.sum(vec) < 2:
36
+ return "⚠️ Please enter at least 2 distinct notes (e.g., C E G)"
37
+ label = clf.predict([vec])[0]
38
+ return f"🎵 Identified chord: **{label}**"
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  chatbot = gr.ChatInterface(
41
  fn=chord_bot,
42
+ title="🎶 ML Chord Bot",
43
+ description="Enter 2+ notes (e.g., C E G or Db F Ab C). Powered by a trained RandomForest classifier."
44
  )
45
 
46
  if __name__ == "__main__":
chord_classifier.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72c95e743c473a32d7cf07987433dccc46bdb9ce7e62a2d0c635cf75f133900c
3
+ size 70601065
generate_chord_dataset.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+
3
+ PITCHES_SHARP = ["C","C#","D","D#","E","F","F#","G","G#","A","A#","B"]
4
+
5
+ NAME_TO_PC = {
6
+ "C":0,"C#":1,"Db":1,"D":2,"D#":3,"Eb":3,"E":4,"F":5,"F#":6,"Gb":6,
7
+ "G":7,"G#":8,"Ab":8,"A":9,"A#":10,"Bb":10,"B":11
8
+ }
9
+
10
+ CHORD_FORMULAS = {
11
+ "maj": [0,4,7],
12
+ "min": [0,3,7],
13
+ "dim": [0,3,6],
14
+ "aug": [0,4,8],
15
+ "7": [0,4,7,10],
16
+ "maj7":[0,4,7,11],
17
+ "m7": [0,3,7,10],
18
+ "mMaj7":[0,3,7,11],
19
+ "dim7":[0,3,6,9],
20
+ "m7b5":[0,3,6,10],
21
+ "6": [0,4,7,9],
22
+ "m6": [0,3,7,9],
23
+ "sus2":[0,2,7],
24
+ "sus4":[0,5,7],
25
+ }
26
+
27
+ def generate_chord_vectors():
28
+ data = []
29
+ for root_pc in range(12):
30
+ for quality, intervals in CHORD_FORMULAS.items():
31
+ pcs = [(root_pc + i) % 12 for i in intervals]
32
+ vec = [1 if i in pcs else 0 for i in range(12)]
33
+ root_name = PITCHES_SHARP[root_pc]
34
+ chord_name = root_name + quality
35
+ data.append((vec, chord_name))
36
+ return data
37
+
38
+ def save_to_csv(filename="chords_dataset.csv"):
39
+ data = generate_chord_vectors()
40
+ with open(filename, "w", newline="") as f:
41
+ writer = csv.writer(f)
42
+ header = [f"pc_{p}" for p in PITCHES_SHARP] + ["label"]
43
+ writer.writerow(header)
44
+ for vec, label in data:
45
+ writer.writerow(vec + [label])
46
+ print(f"✅ Saved {len(data)} chords to {filename}")
47
+
48
+ if __name__ == "__main__":
49
+ save_to_csv()
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
- gradio
2
- transformers
3
- torch
4
- sentencepiece
 
 
1
+ gradio>=4.44.0
2
+ scikit-learn
3
+ joblib
4
+ numpy
5
+ pandas
train_chord_model.py ADDED
File without changes