File size: 2,499 Bytes
3ab6186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c743c1
3ab6186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from flask import Flask, render_template, request, jsonify
import os
import numpy as np
import soundfile as sf
import librosa
import tensorflow as tf

# NEW: import hf_hub_download
from huggingface_hub import hf_hub_download

import models
import add_clicks

# ===== Hub config =====
REPO_ID = "ginchiostro/music_beat_tracking"
MODEL_NAME = "model_epoch_300"  # folder on the Hub containing the checkpoint files

# --- Download checkpoint files (cached on first run) ---
# We fetch all three to ensure they exist locally; TF will use the prefix (path without .index)
_ = hf_hub_download(repo_id=REPO_ID, filename=f"{MODEL_NAME}/checkpoint")
index_path = hf_hub_download(repo_id=REPO_ID, filename=f"{MODEL_NAME}/{MODEL_NAME}.index")
_ = hf_hub_download(repo_id=REPO_ID, filename=f"{MODEL_NAME}/{MODEL_NAME}.data-00000-of-00001")

# Build models exactly as before
model = models.bidirectional_model()
model_save = models.bidirectional_model_for_save()
model(np.ones((1, 10, 256)))
model_save(np.ones((1, 10, 256)))

# IMPORTANT: TF checkpoints are loaded via the prefix (path without .index)
ckpt_prefix = os.path.splitext(index_path)[0]
model_save.load_weights(ckpt_prefix) #.expect_partial()
model.set_weights(model_save.get_weights())
print("loaded model!")

# ===== Flask app =====
app = Flask(__name__)

UPLOAD_FOLDER = "static/uploaded_songs"
os.makedirs("static", exist_ok=True)
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER

@app.route("/")
def index():
    return render_template("index.html")

@app.route("/upload", methods=["POST"])
def upload_file():
    if "file" not in request.files:
        return "No file part"
    f = request.files["file"]
    if f.filename == "":
        return "No selected file"
    f.save(os.path.join(app.config["UPLOAD_FOLDER"], f.filename))
    return "File uploaded successfully"

@app.route("/get_beat", methods=["POST"])
def get_beat():
    song_name = request.get_json(force=True)
    src = os.path.join(UPLOAD_FOLDER, song_name)
    out = os.path.join("static", f"{os.path.splitext(song_name)[0]}_clicks.wav")
    add_clicks.add_clicks(
        song=src,
        model=model,
        model_passed=True,
        model_loc="",
        output_name=out,
        constant_tempo=True,
        plot=False,
    )
    return "beats added successfully"

if __name__ == "__main__":
    # Run on 0.0.0.0:7860 (common for Spaces)
    port = int(os.environ.get("PORT", 7860))
    app.run(host="0.0.0.0", port=port, debug=False)