from flask import Flask, render_template, request, jsonify import os import numpy as np import soundfile as sf import librosa import tensorflow as tf # NEW: import hf_hub_download from huggingface_hub import hf_hub_download import models import add_clicks # ===== Hub config ===== REPO_ID = "ginchiostro/music_beat_tracking" MODEL_NAME = "model_epoch_300" # folder on the Hub containing the checkpoint files # --- Download checkpoint files (cached on first run) --- # We fetch all three to ensure they exist locally; TF will use the prefix (path without .index) _ = hf_hub_download(repo_id=REPO_ID, filename=f"{MODEL_NAME}/checkpoint") index_path = hf_hub_download(repo_id=REPO_ID, filename=f"{MODEL_NAME}/{MODEL_NAME}.index") _ = hf_hub_download(repo_id=REPO_ID, filename=f"{MODEL_NAME}/{MODEL_NAME}.data-00000-of-00001") # Build models exactly as before model = models.bidirectional_model() model_save = models.bidirectional_model_for_save() model(np.ones((1, 10, 256))) model_save(np.ones((1, 10, 256))) # IMPORTANT: TF checkpoints are loaded via the prefix (path without .index) ckpt_prefix = os.path.splitext(index_path)[0] model_save.load_weights(ckpt_prefix) #.expect_partial() model.set_weights(model_save.get_weights()) print("loaded model!") # ===== Flask app ===== app = Flask(__name__) UPLOAD_FOLDER = "static/uploaded_songs" os.makedirs("static", exist_ok=True) os.makedirs(UPLOAD_FOLDER, exist_ok=True) app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER @app.route("/") def index(): return render_template("index.html") @app.route("/upload", methods=["POST"]) def upload_file(): if "file" not in request.files: return "No file part" f = request.files["file"] if f.filename == "": return "No selected file" f.save(os.path.join(app.config["UPLOAD_FOLDER"], f.filename)) return "File uploaded successfully" @app.route("/get_beat", methods=["POST"]) def get_beat(): song_name = request.get_json(force=True) src = os.path.join(UPLOAD_FOLDER, song_name) out = os.path.join("static", f"{os.path.splitext(song_name)[0]}_clicks.wav") add_clicks.add_clicks( song=src, model=model, model_passed=True, model_loc="", output_name=out, constant_tempo=True, plot=False, ) return "beats added successfully" if __name__ == "__main__": # Run on 0.0.0.0:7860 (common for Spaces) port = int(os.environ.get("PORT", 7860)) app.run(host="0.0.0.0", port=port, debug=False)