Spaces:
Sleeping
Sleeping
| from flask import Flask, render_template, request, jsonify | |
| import os | |
| import numpy as np | |
| import soundfile as sf | |
| import librosa | |
| import tensorflow as tf | |
| # NEW: import hf_hub_download | |
| from huggingface_hub import hf_hub_download | |
| import models | |
| import add_clicks | |
| # ===== Hub config ===== | |
| REPO_ID = "ginchiostro/music_beat_tracking" | |
| MODEL_NAME = "model_epoch_300" # folder on the Hub containing the checkpoint files | |
| # --- Download checkpoint files (cached on first run) --- | |
| # We fetch all three to ensure they exist locally; TF will use the prefix (path without .index) | |
| _ = hf_hub_download(repo_id=REPO_ID, filename=f"{MODEL_NAME}/checkpoint") | |
| index_path = hf_hub_download(repo_id=REPO_ID, filename=f"{MODEL_NAME}/{MODEL_NAME}.index") | |
| _ = hf_hub_download(repo_id=REPO_ID, filename=f"{MODEL_NAME}/{MODEL_NAME}.data-00000-of-00001") | |
| # Build models exactly as before | |
| model = models.bidirectional_model() | |
| model_save = models.bidirectional_model_for_save() | |
| model(np.ones((1, 10, 256))) | |
| model_save(np.ones((1, 10, 256))) | |
| # IMPORTANT: TF checkpoints are loaded via the prefix (path without .index) | |
| ckpt_prefix = os.path.splitext(index_path)[0] | |
| model_save.load_weights(ckpt_prefix) #.expect_partial() | |
| model.set_weights(model_save.get_weights()) | |
| print("loaded model!") | |
| # ===== Flask app ===== | |
| app = Flask(__name__) | |
| UPLOAD_FOLDER = "static/uploaded_songs" | |
| os.makedirs("static", exist_ok=True) | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER | |
| def index(): | |
| return render_template("index.html") | |
| def upload_file(): | |
| if "file" not in request.files: | |
| return "No file part" | |
| f = request.files["file"] | |
| if f.filename == "": | |
| return "No selected file" | |
| f.save(os.path.join(app.config["UPLOAD_FOLDER"], f.filename)) | |
| return "File uploaded successfully" | |
| def get_beat(): | |
| song_name = request.get_json(force=True) | |
| src = os.path.join(UPLOAD_FOLDER, song_name) | |
| out = os.path.join("static", f"{os.path.splitext(song_name)[0]}_clicks.wav") | |
| add_clicks.add_clicks( | |
| song=src, | |
| model=model, | |
| model_passed=True, | |
| model_loc="", | |
| output_name=out, | |
| constant_tempo=True, | |
| plot=False, | |
| ) | |
| return "beats added successfully" | |
| if __name__ == "__main__": | |
| # Run on 0.0.0.0:7860 (common for Spaces) | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host="0.0.0.0", port=port, debug=False) |