File size: 10,572 Bytes
b8279df
 
7fbb7de
b8279df
7fbb7de
 
9025387
 
7fbb7de
ee88cac
b8279df
9025387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fbb7de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9025387
 
b8279df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9025387
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import logging
import mimetypes
# import numpy as np
import os
# import torch
# from beat_this.model.postprocessor import Postprocessor
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
# from madmom.features.downbeats import DBNDownBeatTrackingProcessor
from statistics import median, mode, StatisticsError
from typing import List, Tuple

# Add MIME types for JavaScript and WebAssembly
mimetypes.add_type('application/javascript', '.js')
mimetypes.add_type('text/javascript', '.js')  # Add this as fallback
mimetypes.add_type('application/wasm', '.wasm')
mimetypes.add_type('application/octet-stream', '.wasm')  # Add this as fallback


app = Flask(__name__)
CORS(app)  # Enable CORS for all routes

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Get the directory where this script is located
current_dir = os.path.dirname(os.path.abspath(__file__))
@app.route('/')
def serve_index():
    """Serve the main HTML page"""
    return send_from_directory(current_dir, 'index.html')


@app.route('/<path:path>')
def serve_static(path):
    """Serve static files (CSS, JS, ONNX models, etc.)"""
    response = send_from_directory(current_dir, path)

    # Set correct Content-Type headers for specific file types
    if path.endswith('.js'):
        response.headers.set('Content-Type', 'application/javascript')
    elif path.endswith('.wasm'):
        response.headers.set('Content-Type', 'application/wasm')
    elif path.endswith('.css'):
        response.headers.set('Content-Type', 'text/css')
    elif path.endswith('.html'):
        response.headers.set('Content-Type', 'text/html')
    elif path.endswith('.json'):
        response.headers.set('Content-Type', 'application/json')

    return response


@app.route('/health', methods=['GET'])
def health_check():
    """Health check endpoint"""
    return jsonify({"status": "healthy", "message": "Beat detection postprocessor is running"})


# @app.route('/logits_to_bars', methods=['POST'])
# def postprocess_beats():
#     """
#     Postprocess beat and downbeat logits to extract timing information
#
#     Expected input:
#     {
#         "beat_logits": [array of float values],
#         "downbeat_logits": [array of float values]
#         "min_bpm": min_bpm,
#         "max_bpm": max_bpm,
#         "beats_per_bar": beats_per_bar,
#     }
#
#     Returns:
#     {
#         "bars": {map of bar number to start timings in seconds},
#     }
#     """
#     try:
#         # Get JSON data from request
#         data = request.get_json()
#
#         if not data:
#             return jsonify({"error": "No JSON data provided"}), 400
#
#         # Extract logits
#         beat_logits = np.array(data.get('beat_logits', []))
#         downbeat_logits = np.array(data.get('downbeat_logits', []))
#         beats_per_bar = int(data.get('beats_per_bar', 4))
#         min_bpm = int(data.get('min_bpm', 55.0))
#         max_bpm = int(data.get('max_bpm', 215.0))
#
#         logger.info(f"Received beat_logits: {len(beat_logits)}, downbeat_logits: {len(downbeat_logits)}, beats_per_bar: {beats_per_bar}, min_bpm: {min_bpm}, max_bpm: {max_bpm}")
#
#         # Validate input
#         if len(beat_logits) == 0 or len(downbeat_logits) == 0:
#             return jsonify({"error": "Empty logits arrays provided"}), 400
#
#         if len(beat_logits) != len(downbeat_logits):
#             return jsonify({"error": "Beat and downbeat logits must have the same length"}), 400
#
#         # Process the logits to extract beat and downbeat timings
#         beats, downbeats = process_logits(beat_logits, downbeat_logits, type='minimal',
#                    beats_per_bar=beats_per_bar,
#                    min_bpm=min_bpm,
#                    max_bpm=max_bpm)
#
#         logger.info(f"Processed {len(beats)} beats and {len(downbeats)} downbeats")
#
#         downbeats = downbeats.tolist() if isinstance(downbeats, np.ndarray) else downbeats
#         beats = beats.tolist() if isinstance(beats, np.ndarray) else beats
#
#         estimated_bpm, detected_beats_per_bar, final_indices = analyze_beats(beats, downbeats)
#         print(f"estimated bpm: {estimated_bpm}, detected beats_per_bar: {detected_beats_per_bar}")
#         print(final_indices)
#
#
#         bars = {i+1: beat for i, beat in enumerate(downbeats)}
#
#         return jsonify({
#             "bars": bars,
#             "estimated_bpm": estimated_bpm,
#             "detected_beats_per_bar": detected_beats_per_bar
#         })
#
#     except Exception as e:
#         import traceback
#         logger.error(f"Error in postprocessing: {str(e)}")
#         return jsonify({"error": f"Processing failed: {str(e)}"}), 500


# def process_logits(beat_logits, downbeat_logits, type='minimal',
#                    beats_per_bar=4,
#                    min_bpm=55.0,
#                    max_bpm=215.0, ):
#     """
#     Process beat and downbeat logits to extract timing information
#
#     Args:
#         beat_logits: Array of beat probabilities/logits
#         downbeat_logits: Array of downbeat probabilities/logits
#         type (str): the type of postprocessing to apply. Either "minimal" or "dbn". Default is "minimal".
#         beats_per_bar : int or list
#             Number of beats per bar to be modeled. Can be either a single number
#             or a list or array with bar lengths (in beats).
#         min_bpm : float or list, optional
#             Minimum tempo used for beat tracking [bpm]. If a list is given, each
#             item corresponds to the number of beats per bar at the same position.
#         max_bpm : float or list, optional
#             Maximum tempo used for beat tracking [bpm]. If a list is given, each
#             item corresponds to the number of beats per bar at the same position.
#
#
#     Returns:
#         Tuple of (beats, downbeats) where each is an array of timings in seconds
#     """
#     frames2beats = Postprocessor(type=type)
#     if type == 'dbn' and (beats_per_bar != [3, 4] or min_bpm != 55.0 or max_bpm != 215.0):
#         frames2beats.dbb = DBNDownBeatTrackingProcessor(
#             beats_per_bar=beats_per_bar,
#             min_bpm=min_bpm,
#             max_bpm=max_bpm,
#             fps=50,
#             transition_lambda=100,
#         )
#
#
#     # Convert numpy arrays to PyTorch tensors
#     beat_logits_tensor = torch.tensor(beat_logits, dtype=torch.float32)
#     downbeat_logits_tensor = torch.tensor(downbeat_logits, dtype=torch.float32)
#
#     # Process through the postprocessor
#     beats, downbeats = frames2beats(beat_logits_tensor, downbeat_logits_tensor)
#
#
#     return beats, downbeats



def analyze_beats(beats: List[float], downbeats: List[float]) -> Tuple[float, float, List[int]]:
    """
    Analyze beats and downbeats to calculate BPM and clean outliers.

    Args:
        beats: List of beat positions in seconds
        downbeats: List of downbeat positions in seconds (first beat of each bar)

    Returns:
        Tuple containing:
        - estimated_bpm: Calculated BPM after removing outliers
        - beats_per_bar: Median bar duration in seconds
        - valid_bar_indices: Indices of bars that passed all filters
    """

    # Step 1: Calculate beats per bar and bar durations
    bar_beats_count = []
    bar_durations = []

    for i in range(len(downbeats) - 1):
        # Find beats between current downbeat and next downbeat
        start_time = downbeats[i]
        end_time = downbeats[i + 1]

        # Count beats in this bar
        beats_in_bar = len([beat for beat in beats if start_time <= beat < end_time])
        bar_beats_count.append(beats_in_bar)

        # Calculate bar duration
        bar_duration = end_time - start_time
        bar_durations.append(bar_duration)

    # Handle the last bar (if we have at least one downbeat)
    if len(downbeats) > 0:
        last_start = downbeats[-1]
        # For the last bar, count beats from last downbeat to end of beats list
        last_beats = len([beat for beat in beats if beat >= last_start])
        bar_beats_count.append(last_beats)

        # Estimate last bar duration using average beat duration
        if len(beats) > 1:
            avg_beat_duration = (beats[-1] - beats[0]) / (len(beats) - 1)
            last_duration = last_beats * avg_beat_duration
        else:
            last_duration = 0
        bar_durations.append(last_duration)

    # Step 2: Remove bars with outlier beats per bar
    if bar_beats_count:
        try:
            # Find the most common beats per bar value
            common_beats_per_bar = mode(bar_beats_count)

            # Keep only bars with the common beats per bar
            valid_bars_bp = []
            valid_durations_bp = []
            valid_indices_bp = []

            for i, (beat_count, duration) in enumerate(zip(bar_beats_count, bar_durations)):
                if beat_count == common_beats_per_bar:
                    valid_bars_bp.append(beat_count)
                    valid_durations_bp.append(duration)
                    valid_indices_bp.append(i)
        except StatisticsError:
            # If no clear mode, use median
            median_beats = median(bar_beats_count)
            valid_bars_bp = [bc for bc in bar_beats_count if bc == median_beats]
            valid_durations_bp = [bar_durations[i] for i, bc in enumerate(bar_beats_count) if bc == median_beats]
            valid_indices_bp = [i for i, bc in enumerate(bar_beats_count) if bc == median_beats]
    else:
        return 0, 0, []

    # Step 3: Remove bars with outlier durations
    if valid_durations_bp:
        median_duration = median(valid_durations_bp)

        # Calculate reasonable bounds (e.g., ±25% of median)
        lower_bound = median_duration * 0.75
        upper_bound = median_duration * 1.25

        final_durations = []
        final_indices = []

        for i, duration in zip(valid_indices_bp, valid_durations_bp):
            if lower_bound <= duration <= upper_bound:
                final_durations.append(duration)
                final_indices.append(i)

        # Step 4: Calculate average bar duration and convert to BPM
        if final_durations:
            avg_bar_duration = sum(final_durations) / len(final_durations)
            estimated_bpm = 60.0 / (avg_bar_duration / common_beats_per_bar)  # BPM = 60 / (beat duration in seconds)

            return estimated_bpm, common_beats_per_bar, final_indices

    return 0, 0, []



if __name__ == '__main__':
    app.run(debug=True)