beat_tracking / app.py
ginchiostro's picture
Update app.py
9c743c1 verified
from flask import Flask, render_template, request, jsonify
import os
import numpy as np
import soundfile as sf
import librosa
import tensorflow as tf
# NEW: import hf_hub_download
from huggingface_hub import hf_hub_download
import models
import add_clicks
# ===== Hub config =====
REPO_ID = "ginchiostro/music_beat_tracking"
MODEL_NAME = "model_epoch_300" # folder on the Hub containing the checkpoint files
# --- Download checkpoint files (cached on first run) ---
# We fetch all three to ensure they exist locally; TF will use the prefix (path without .index)
_ = hf_hub_download(repo_id=REPO_ID, filename=f"{MODEL_NAME}/checkpoint")
index_path = hf_hub_download(repo_id=REPO_ID, filename=f"{MODEL_NAME}/{MODEL_NAME}.index")
_ = hf_hub_download(repo_id=REPO_ID, filename=f"{MODEL_NAME}/{MODEL_NAME}.data-00000-of-00001")
# Build models exactly as before
model = models.bidirectional_model()
model_save = models.bidirectional_model_for_save()
model(np.ones((1, 10, 256)))
model_save(np.ones((1, 10, 256)))
# IMPORTANT: TF checkpoints are loaded via the prefix (path without .index)
ckpt_prefix = os.path.splitext(index_path)[0]
model_save.load_weights(ckpt_prefix) #.expect_partial()
model.set_weights(model_save.get_weights())
print("loaded model!")
# ===== Flask app =====
app = Flask(__name__)
UPLOAD_FOLDER = "static/uploaded_songs"
os.makedirs("static", exist_ok=True)
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER
@app.route("/")
def index():
return render_template("index.html")
@app.route("/upload", methods=["POST"])
def upload_file():
if "file" not in request.files:
return "No file part"
f = request.files["file"]
if f.filename == "":
return "No selected file"
f.save(os.path.join(app.config["UPLOAD_FOLDER"], f.filename))
return "File uploaded successfully"
@app.route("/get_beat", methods=["POST"])
def get_beat():
song_name = request.get_json(force=True)
src = os.path.join(UPLOAD_FOLDER, song_name)
out = os.path.join("static", f"{os.path.splitext(song_name)[0]}_clicks.wav")
add_clicks.add_clicks(
song=src,
model=model,
model_passed=True,
model_loc="",
output_name=out,
constant_tempo=True,
plot=False,
)
return "beats added successfully"
if __name__ == "__main__":
# Run on 0.0.0.0:7860 (common for Spaces)
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port, debug=False)