loop_maestro / app.py
jorisvaneyghen's picture
remove dependencies
7fbb7de
import logging
import mimetypes
# import numpy as np
import os
# import torch
# from beat_this.model.postprocessor import Postprocessor
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
# from madmom.features.downbeats import DBNDownBeatTrackingProcessor
from statistics import median, mode, StatisticsError
from typing import List, Tuple
# Add MIME types for JavaScript and WebAssembly
mimetypes.add_type('application/javascript', '.js')
mimetypes.add_type('text/javascript', '.js') # Add this as fallback
mimetypes.add_type('application/wasm', '.wasm')
mimetypes.add_type('application/octet-stream', '.wasm') # Add this as fallback
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Get the directory where this script is located
current_dir = os.path.dirname(os.path.abspath(__file__))
@app.route('/')
def serve_index():
"""Serve the main HTML page"""
return send_from_directory(current_dir, 'index.html')
@app.route('/<path:path>')
def serve_static(path):
"""Serve static files (CSS, JS, ONNX models, etc.)"""
response = send_from_directory(current_dir, path)
# Set correct Content-Type headers for specific file types
if path.endswith('.js'):
response.headers.set('Content-Type', 'application/javascript')
elif path.endswith('.wasm'):
response.headers.set('Content-Type', 'application/wasm')
elif path.endswith('.css'):
response.headers.set('Content-Type', 'text/css')
elif path.endswith('.html'):
response.headers.set('Content-Type', 'text/html')
elif path.endswith('.json'):
response.headers.set('Content-Type', 'application/json')
return response
@app.route('/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
return jsonify({"status": "healthy", "message": "Beat detection postprocessor is running"})
# @app.route('/logits_to_bars', methods=['POST'])
# def postprocess_beats():
# """
# Postprocess beat and downbeat logits to extract timing information
#
# Expected input:
# {
# "beat_logits": [array of float values],
# "downbeat_logits": [array of float values]
# "min_bpm": min_bpm,
# "max_bpm": max_bpm,
# "beats_per_bar": beats_per_bar,
# }
#
# Returns:
# {
# "bars": {map of bar number to start timings in seconds},
# }
# """
# try:
# # Get JSON data from request
# data = request.get_json()
#
# if not data:
# return jsonify({"error": "No JSON data provided"}), 400
#
# # Extract logits
# beat_logits = np.array(data.get('beat_logits', []))
# downbeat_logits = np.array(data.get('downbeat_logits', []))
# beats_per_bar = int(data.get('beats_per_bar', 4))
# min_bpm = int(data.get('min_bpm', 55.0))
# max_bpm = int(data.get('max_bpm', 215.0))
#
# logger.info(f"Received beat_logits: {len(beat_logits)}, downbeat_logits: {len(downbeat_logits)}, beats_per_bar: {beats_per_bar}, min_bpm: {min_bpm}, max_bpm: {max_bpm}")
#
# # Validate input
# if len(beat_logits) == 0 or len(downbeat_logits) == 0:
# return jsonify({"error": "Empty logits arrays provided"}), 400
#
# if len(beat_logits) != len(downbeat_logits):
# return jsonify({"error": "Beat and downbeat logits must have the same length"}), 400
#
# # Process the logits to extract beat and downbeat timings
# beats, downbeats = process_logits(beat_logits, downbeat_logits, type='minimal',
# beats_per_bar=beats_per_bar,
# min_bpm=min_bpm,
# max_bpm=max_bpm)
#
# logger.info(f"Processed {len(beats)} beats and {len(downbeats)} downbeats")
#
# downbeats = downbeats.tolist() if isinstance(downbeats, np.ndarray) else downbeats
# beats = beats.tolist() if isinstance(beats, np.ndarray) else beats
#
# estimated_bpm, detected_beats_per_bar, final_indices = analyze_beats(beats, downbeats)
# print(f"estimated bpm: {estimated_bpm}, detected beats_per_bar: {detected_beats_per_bar}")
# print(final_indices)
#
#
# bars = {i+1: beat for i, beat in enumerate(downbeats)}
#
# return jsonify({
# "bars": bars,
# "estimated_bpm": estimated_bpm,
# "detected_beats_per_bar": detected_beats_per_bar
# })
#
# except Exception as e:
# import traceback
# logger.error(f"Error in postprocessing: {str(e)}")
# return jsonify({"error": f"Processing failed: {str(e)}"}), 500
# def process_logits(beat_logits, downbeat_logits, type='minimal',
# beats_per_bar=4,
# min_bpm=55.0,
# max_bpm=215.0, ):
# """
# Process beat and downbeat logits to extract timing information
#
# Args:
# beat_logits: Array of beat probabilities/logits
# downbeat_logits: Array of downbeat probabilities/logits
# type (str): the type of postprocessing to apply. Either "minimal" or "dbn". Default is "minimal".
# beats_per_bar : int or list
# Number of beats per bar to be modeled. Can be either a single number
# or a list or array with bar lengths (in beats).
# min_bpm : float or list, optional
# Minimum tempo used for beat tracking [bpm]. If a list is given, each
# item corresponds to the number of beats per bar at the same position.
# max_bpm : float or list, optional
# Maximum tempo used for beat tracking [bpm]. If a list is given, each
# item corresponds to the number of beats per bar at the same position.
#
#
# Returns:
# Tuple of (beats, downbeats) where each is an array of timings in seconds
# """
# frames2beats = Postprocessor(type=type)
# if type == 'dbn' and (beats_per_bar != [3, 4] or min_bpm != 55.0 or max_bpm != 215.0):
# frames2beats.dbb = DBNDownBeatTrackingProcessor(
# beats_per_bar=beats_per_bar,
# min_bpm=min_bpm,
# max_bpm=max_bpm,
# fps=50,
# transition_lambda=100,
# )
#
#
# # Convert numpy arrays to PyTorch tensors
# beat_logits_tensor = torch.tensor(beat_logits, dtype=torch.float32)
# downbeat_logits_tensor = torch.tensor(downbeat_logits, dtype=torch.float32)
#
# # Process through the postprocessor
# beats, downbeats = frames2beats(beat_logits_tensor, downbeat_logits_tensor)
#
#
# return beats, downbeats
def analyze_beats(beats: List[float], downbeats: List[float]) -> Tuple[float, float, List[int]]:
"""
Analyze beats and downbeats to calculate BPM and clean outliers.
Args:
beats: List of beat positions in seconds
downbeats: List of downbeat positions in seconds (first beat of each bar)
Returns:
Tuple containing:
- estimated_bpm: Calculated BPM after removing outliers
- beats_per_bar: Median bar duration in seconds
- valid_bar_indices: Indices of bars that passed all filters
"""
# Step 1: Calculate beats per bar and bar durations
bar_beats_count = []
bar_durations = []
for i in range(len(downbeats) - 1):
# Find beats between current downbeat and next downbeat
start_time = downbeats[i]
end_time = downbeats[i + 1]
# Count beats in this bar
beats_in_bar = len([beat for beat in beats if start_time <= beat < end_time])
bar_beats_count.append(beats_in_bar)
# Calculate bar duration
bar_duration = end_time - start_time
bar_durations.append(bar_duration)
# Handle the last bar (if we have at least one downbeat)
if len(downbeats) > 0:
last_start = downbeats[-1]
# For the last bar, count beats from last downbeat to end of beats list
last_beats = len([beat for beat in beats if beat >= last_start])
bar_beats_count.append(last_beats)
# Estimate last bar duration using average beat duration
if len(beats) > 1:
avg_beat_duration = (beats[-1] - beats[0]) / (len(beats) - 1)
last_duration = last_beats * avg_beat_duration
else:
last_duration = 0
bar_durations.append(last_duration)
# Step 2: Remove bars with outlier beats per bar
if bar_beats_count:
try:
# Find the most common beats per bar value
common_beats_per_bar = mode(bar_beats_count)
# Keep only bars with the common beats per bar
valid_bars_bp = []
valid_durations_bp = []
valid_indices_bp = []
for i, (beat_count, duration) in enumerate(zip(bar_beats_count, bar_durations)):
if beat_count == common_beats_per_bar:
valid_bars_bp.append(beat_count)
valid_durations_bp.append(duration)
valid_indices_bp.append(i)
except StatisticsError:
# If no clear mode, use median
median_beats = median(bar_beats_count)
valid_bars_bp = [bc for bc in bar_beats_count if bc == median_beats]
valid_durations_bp = [bar_durations[i] for i, bc in enumerate(bar_beats_count) if bc == median_beats]
valid_indices_bp = [i for i, bc in enumerate(bar_beats_count) if bc == median_beats]
else:
return 0, 0, []
# Step 3: Remove bars with outlier durations
if valid_durations_bp:
median_duration = median(valid_durations_bp)
# Calculate reasonable bounds (e.g., ±25% of median)
lower_bound = median_duration * 0.75
upper_bound = median_duration * 1.25
final_durations = []
final_indices = []
for i, duration in zip(valid_indices_bp, valid_durations_bp):
if lower_bound <= duration <= upper_bound:
final_durations.append(duration)
final_indices.append(i)
# Step 4: Calculate average bar duration and convert to BPM
if final_durations:
avg_bar_duration = sum(final_durations) / len(final_durations)
estimated_bpm = 60.0 / (avg_bar_duration / common_beats_per_bar) # BPM = 60 / (beat duration in seconds)
return estimated_bpm, common_beats_per_bar, final_indices
return 0, 0, []
if __name__ == '__main__':
app.run(debug=True)