kshahnathwani commited on
Commit
b3f44b9
·
verified ·
1 Parent(s): 445c40e

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +5 -4
inference.py CHANGED
@@ -10,7 +10,7 @@ NAME_TO_PC = {
10
  }
11
  NOTE_TOKEN_RE = re.compile(r"[A-Ga-g](?:#|b)?")
12
 
13
- # Load model at startup
14
  MODEL_PATH = os.path.join(os.path.dirname(__file__), "chord_classifier.pkl")
15
  clf = joblib.load(MODEL_PATH)
16
 
@@ -23,12 +23,13 @@ def notes_to_vector(notes_str: str):
23
  vec[p] = 1
24
  return vec
25
 
26
- def predict(notes: str):
 
27
  """
28
- Main inference function.
29
  Hugging Face will call this with {"inputs": "C E G"}.
 
30
  """
31
- vec = notes_to_vector(notes)
32
  if np.sum(vec) < 2:
33
  return {"label": "Invalid input (need 2+ notes)"}
34
  label = clf.predict([vec])[0]
 
10
  }
11
  NOTE_TOKEN_RE = re.compile(r"[A-Ga-g](?:#|b)?")
12
 
13
+ # Load model
14
  MODEL_PATH = os.path.join(os.path.dirname(__file__), "chord_classifier.pkl")
15
  clf = joblib.load(MODEL_PATH)
16
 
 
23
  vec[p] = 1
24
  return vec
25
 
26
+ # 👇 Hugging Face Inference API entrypoint
27
+ def predict(inputs: str):
28
  """
 
29
  Hugging Face will call this with {"inputs": "C E G"}.
30
+ Must return JSON-serializable output.
31
  """
32
+ vec = notes_to_vector(inputs)
33
  if np.sum(vec) < 2:
34
  return {"label": "Invalid input (need 2+ notes)"}
35
  label = clf.predict([vec])[0]