Spaces:
Sleeping
Sleeping
Krish Shah-Nathwani
commited on
Commit
·
3cd2d15
1
Parent(s):
06b222f
updated app with trained classifier model trained locally
Browse files- app.py +39 -68
- chord_classifier.pkl +3 -0
- generate_chord_dataset.py +49 -0
- requirements.txt +5 -4
- train_chord_model.py +0 -0
app.py
CHANGED
|
@@ -1,75 +1,46 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
# Load FLAN-T5 small locally
|
| 5 |
-
generator = pipeline("text2text-generation", model="google/flan-t5-small")
|
| 6 |
-
|
| 7 |
-
# System-style chord identification prompt
|
| 8 |
-
CHORD_SYSTEM_PROMPT = """
|
| 9 |
-
You are a music theory expert specialized in chord identification for a **Chord Bot** application.
|
| 10 |
-
Your job is to take an unordered list of pitch names (notes) and return the most likely chord name(s).
|
| 11 |
-
|
| 12 |
-
## Scope & Assumptions
|
| 13 |
-
- Accept any number of notes ≥ 2. Ignore octaves; treat input pitch classes only.
|
| 14 |
-
- Handle enharmonics (C# = Db), accidentals (#, b, ♯, ♭), and duplicates.
|
| 15 |
-
- Recognize: triads (maj, min, dim, aug), sevenths (maj7, 7, m7, mMaj7, dim7, m7b5),
|
| 16 |
-
extended/altered chords (add2/add9, 6, 9, 11, 13), suspensions (sus2/sus4), power chords (5),
|
| 17 |
-
and common alterations (b5, #5, b9, #9, #11, b13).
|
| 18 |
-
- Detect inversions and slash chords (C/E). Prefer root-position naming but include inversion when clear.
|
| 19 |
-
- Key-agnostic: infer chord quality from pitch-set intervals; do not assume a key unless given.
|
| 20 |
-
- If multiple valid interpretations exist, rank by plausibility and provide alternates.
|
| 21 |
-
|
| 22 |
-
## Output Style
|
| 23 |
-
- Primary answer: concise chord name (e.g., Cm7, G7b9, Fadd9, Dsus4/F#).
|
| 24 |
-
- Also include: a short rationale (intervals from root), and 0–1 confidence.
|
| 25 |
-
- If ambiguous: list up to 3 alternate chord names with brief reasons.
|
| 26 |
-
- Prefer ASCII chord symbols (#, b).
|
| 27 |
-
|
| 28 |
-
## Interaction Rules
|
| 29 |
-
- Be deterministic and concise. Do not ask clarifying questions.
|
| 30 |
-
- Accept inputs like: "C E G", ["C#", "E", "G", "A"], or "Db F Ab C".
|
| 31 |
-
- Normalize enharmonics to the spelling that best matches the chord quality.
|
| 32 |
-
|
| 33 |
-
## Output Format
|
| 34 |
-
Return a compact JSON-like block:
|
| 35 |
-
|
| 36 |
-
chord: "<PrimaryChordName>"
|
| 37 |
-
confidence: <0.0–1.0>
|
| 38 |
-
explanation: "Root <X>; intervals <...>."
|
| 39 |
-
alternates: ["<Alt1>", "<Alt2>"]
|
| 40 |
-
|
| 41 |
-
## Examples
|
| 42 |
-
- Input: C E G → chord: "C major"
|
| 43 |
-
- Input: D F# A C → chord: "D7"
|
| 44 |
-
- Input: C Eb G Bb → chord: "Cm7"
|
| 45 |
-
- Input: C D G → chord: "Csus2", alternates: ["Gadd4/C"]
|
| 46 |
-
- Input: E G B D (bass G) → chord: "Em7/G"
|
| 47 |
-
"""
|
| 48 |
-
|
| 49 |
-
def chord_bot(message: str, history: list[tuple[str, str]]):
|
| 50 |
-
# Embed user input into the system prompt
|
| 51 |
-
chord_prompt = f"{CHORD_SYSTEM_PROMPT}\n\nInput: {message}\nOutput:\n"
|
| 52 |
-
|
| 53 |
-
response = generator(
|
| 54 |
-
chord_prompt,
|
| 55 |
-
max_new_tokens=128,
|
| 56 |
-
temperature=0.0, # deterministic
|
| 57 |
-
top_p=0.95
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
raw_text = response[0]["generated_text"]
|
| 61 |
-
# Extract the block after "Output:" if present
|
| 62 |
-
if "Output:" in raw_text:
|
| 63 |
-
answer = raw_text.split("Output:")[-1].strip()
|
| 64 |
-
else:
|
| 65 |
-
answer = raw_text.strip()
|
| 66 |
-
return answer
|
| 67 |
-
|
| 68 |
-
# Gradio Chat UI
|
| 69 |
chatbot = gr.ChatInterface(
|
| 70 |
fn=chord_bot,
|
| 71 |
-
title="🎶
|
| 72 |
-
description="
|
| 73 |
)
|
| 74 |
|
| 75 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import joblib
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import subprocess
|
| 7 |
+
|
| 8 |
+
NAME_TO_PC = {
|
| 9 |
+
"C":0,"C#":1,"Db":1,"D":2,"D#":3,"Eb":3,"E":4,"F":5,"F#":6,"Gb":6,
|
| 10 |
+
"G":7,"G#":8,"Ab":8,"A":9,"A#":10,"Bb":10,"B":11
|
| 11 |
+
}
|
| 12 |
+
NOTE_TOKEN_RE = re.compile(r"[A-Ga-g](?:#|b)?")
|
| 13 |
+
|
| 14 |
+
def notes_to_vector(notes_str: str):
|
| 15 |
+
tokens = NOTE_TOKEN_RE.findall(notes_str)
|
| 16 |
+
pcs = [NAME_TO_PC.get(t.upper(), None) for t in tokens]
|
| 17 |
+
pcs = [p for p in pcs if p is not None]
|
| 18 |
+
vec = np.zeros(12)
|
| 19 |
+
for p in pcs:
|
| 20 |
+
vec[p] = 1
|
| 21 |
+
return vec
|
| 22 |
+
|
| 23 |
+
MODEL_PATH = "chord_classifier.pkl"
|
| 24 |
+
|
| 25 |
+
def load_model():
|
| 26 |
+
if not os.path.exists(MODEL_PATH):
|
| 27 |
+
print("⚠️ chord_classifier.pkl not found. Training model...")
|
| 28 |
+
subprocess.run(["python", "train_chord_model.py"], check=True)
|
| 29 |
+
return joblib.load(MODEL_PATH)
|
| 30 |
+
|
| 31 |
+
clf = load_model()
|
| 32 |
+
|
| 33 |
+
def chord_bot(message: str, history: list[tuple[str,str]]):
|
| 34 |
+
vec = notes_to_vector(message)
|
| 35 |
+
if np.sum(vec) < 2:
|
| 36 |
+
return "⚠️ Please enter at least 2 distinct notes (e.g., C E G)"
|
| 37 |
+
label = clf.predict([vec])[0]
|
| 38 |
+
return f"🎵 Identified chord: **{label}**"
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
chatbot = gr.ChatInterface(
|
| 41 |
fn=chord_bot,
|
| 42 |
+
title="🎶 ML Chord Bot",
|
| 43 |
+
description="Enter 2+ notes (e.g., C E G or Db F Ab C). Powered by a trained RandomForest classifier."
|
| 44 |
)
|
| 45 |
|
| 46 |
if __name__ == "__main__":
|
chord_classifier.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:72c95e743c473a32d7cf07987433dccc46bdb9ce7e62a2d0c635cf75f133900c
|
| 3 |
+
size 70601065
|
generate_chord_dataset.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
|
| 3 |
+
PITCHES_SHARP = ["C","C#","D","D#","E","F","F#","G","G#","A","A#","B"]
|
| 4 |
+
|
| 5 |
+
NAME_TO_PC = {
|
| 6 |
+
"C":0,"C#":1,"Db":1,"D":2,"D#":3,"Eb":3,"E":4,"F":5,"F#":6,"Gb":6,
|
| 7 |
+
"G":7,"G#":8,"Ab":8,"A":9,"A#":10,"Bb":10,"B":11
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
CHORD_FORMULAS = {
|
| 11 |
+
"maj": [0,4,7],
|
| 12 |
+
"min": [0,3,7],
|
| 13 |
+
"dim": [0,3,6],
|
| 14 |
+
"aug": [0,4,8],
|
| 15 |
+
"7": [0,4,7,10],
|
| 16 |
+
"maj7":[0,4,7,11],
|
| 17 |
+
"m7": [0,3,7,10],
|
| 18 |
+
"mMaj7":[0,3,7,11],
|
| 19 |
+
"dim7":[0,3,6,9],
|
| 20 |
+
"m7b5":[0,3,6,10],
|
| 21 |
+
"6": [0,4,7,9],
|
| 22 |
+
"m6": [0,3,7,9],
|
| 23 |
+
"sus2":[0,2,7],
|
| 24 |
+
"sus4":[0,5,7],
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
def generate_chord_vectors():
|
| 28 |
+
data = []
|
| 29 |
+
for root_pc in range(12):
|
| 30 |
+
for quality, intervals in CHORD_FORMULAS.items():
|
| 31 |
+
pcs = [(root_pc + i) % 12 for i in intervals]
|
| 32 |
+
vec = [1 if i in pcs else 0 for i in range(12)]
|
| 33 |
+
root_name = PITCHES_SHARP[root_pc]
|
| 34 |
+
chord_name = root_name + quality
|
| 35 |
+
data.append((vec, chord_name))
|
| 36 |
+
return data
|
| 37 |
+
|
| 38 |
+
def save_to_csv(filename="chords_dataset.csv"):
|
| 39 |
+
data = generate_chord_vectors()
|
| 40 |
+
with open(filename, "w", newline="") as f:
|
| 41 |
+
writer = csv.writer(f)
|
| 42 |
+
header = [f"pc_{p}" for p in PITCHES_SHARP] + ["label"]
|
| 43 |
+
writer.writerow(header)
|
| 44 |
+
for vec, label in data:
|
| 45 |
+
writer.writerow(vec + [label])
|
| 46 |
+
print(f"✅ Saved {len(data)} chords to {filename}")
|
| 47 |
+
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
save_to_csv()
|
requirements.txt
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
-
gradio
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
| 1 |
+
gradio>=4.44.0
|
| 2 |
+
scikit-learn
|
| 3 |
+
joblib
|
| 4 |
+
numpy
|
| 5 |
+
pandas
|
train_chord_model.py
ADDED
|
File without changes
|