import logging import mimetypes # import numpy as np import os # import torch # from beat_this.model.postprocessor import Postprocessor from flask import Flask, request, jsonify, send_from_directory from flask_cors import CORS # from madmom.features.downbeats import DBNDownBeatTrackingProcessor from statistics import median, mode, StatisticsError from typing import List, Tuple # Add MIME types for JavaScript and WebAssembly mimetypes.add_type('application/javascript', '.js') mimetypes.add_type('text/javascript', '.js') # Add this as fallback mimetypes.add_type('application/wasm', '.wasm') mimetypes.add_type('application/octet-stream', '.wasm') # Add this as fallback app = Flask(__name__) CORS(app) # Enable CORS for all routes # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Get the directory where this script is located current_dir = os.path.dirname(os.path.abspath(__file__)) @app.route('/') def serve_index(): """Serve the main HTML page""" return send_from_directory(current_dir, 'index.html') @app.route('/') def serve_static(path): """Serve static files (CSS, JS, ONNX models, etc.)""" response = send_from_directory(current_dir, path) # Set correct Content-Type headers for specific file types if path.endswith('.js'): response.headers.set('Content-Type', 'application/javascript') elif path.endswith('.wasm'): response.headers.set('Content-Type', 'application/wasm') elif path.endswith('.css'): response.headers.set('Content-Type', 'text/css') elif path.endswith('.html'): response.headers.set('Content-Type', 'text/html') elif path.endswith('.json'): response.headers.set('Content-Type', 'application/json') return response @app.route('/health', methods=['GET']) def health_check(): """Health check endpoint""" return jsonify({"status": "healthy", "message": "Beat detection postprocessor is running"}) # @app.route('/logits_to_bars', methods=['POST']) # def postprocess_beats(): # """ # Postprocess beat and downbeat logits to extract timing information # # Expected input: # { # "beat_logits": [array of float values], # "downbeat_logits": [array of float values] # "min_bpm": min_bpm, # "max_bpm": max_bpm, # "beats_per_bar": beats_per_bar, # } # # Returns: # { # "bars": {map of bar number to start timings in seconds}, # } # """ # try: # # Get JSON data from request # data = request.get_json() # # if not data: # return jsonify({"error": "No JSON data provided"}), 400 # # # Extract logits # beat_logits = np.array(data.get('beat_logits', [])) # downbeat_logits = np.array(data.get('downbeat_logits', [])) # beats_per_bar = int(data.get('beats_per_bar', 4)) # min_bpm = int(data.get('min_bpm', 55.0)) # max_bpm = int(data.get('max_bpm', 215.0)) # # logger.info(f"Received beat_logits: {len(beat_logits)}, downbeat_logits: {len(downbeat_logits)}, beats_per_bar: {beats_per_bar}, min_bpm: {min_bpm}, max_bpm: {max_bpm}") # # # Validate input # if len(beat_logits) == 0 or len(downbeat_logits) == 0: # return jsonify({"error": "Empty logits arrays provided"}), 400 # # if len(beat_logits) != len(downbeat_logits): # return jsonify({"error": "Beat and downbeat logits must have the same length"}), 400 # # # Process the logits to extract beat and downbeat timings # beats, downbeats = process_logits(beat_logits, downbeat_logits, type='minimal', # beats_per_bar=beats_per_bar, # min_bpm=min_bpm, # max_bpm=max_bpm) # # logger.info(f"Processed {len(beats)} beats and {len(downbeats)} downbeats") # # downbeats = downbeats.tolist() if isinstance(downbeats, np.ndarray) else downbeats # beats = beats.tolist() if isinstance(beats, np.ndarray) else beats # # estimated_bpm, detected_beats_per_bar, final_indices = analyze_beats(beats, downbeats) # print(f"estimated bpm: {estimated_bpm}, detected beats_per_bar: {detected_beats_per_bar}") # print(final_indices) # # # bars = {i+1: beat for i, beat in enumerate(downbeats)} # # return jsonify({ # "bars": bars, # "estimated_bpm": estimated_bpm, # "detected_beats_per_bar": detected_beats_per_bar # }) # # except Exception as e: # import traceback # logger.error(f"Error in postprocessing: {str(e)}") # return jsonify({"error": f"Processing failed: {str(e)}"}), 500 # def process_logits(beat_logits, downbeat_logits, type='minimal', # beats_per_bar=4, # min_bpm=55.0, # max_bpm=215.0, ): # """ # Process beat and downbeat logits to extract timing information # # Args: # beat_logits: Array of beat probabilities/logits # downbeat_logits: Array of downbeat probabilities/logits # type (str): the type of postprocessing to apply. Either "minimal" or "dbn". Default is "minimal". # beats_per_bar : int or list # Number of beats per bar to be modeled. Can be either a single number # or a list or array with bar lengths (in beats). # min_bpm : float or list, optional # Minimum tempo used for beat tracking [bpm]. If a list is given, each # item corresponds to the number of beats per bar at the same position. # max_bpm : float or list, optional # Maximum tempo used for beat tracking [bpm]. If a list is given, each # item corresponds to the number of beats per bar at the same position. # # # Returns: # Tuple of (beats, downbeats) where each is an array of timings in seconds # """ # frames2beats = Postprocessor(type=type) # if type == 'dbn' and (beats_per_bar != [3, 4] or min_bpm != 55.0 or max_bpm != 215.0): # frames2beats.dbb = DBNDownBeatTrackingProcessor( # beats_per_bar=beats_per_bar, # min_bpm=min_bpm, # max_bpm=max_bpm, # fps=50, # transition_lambda=100, # ) # # # # Convert numpy arrays to PyTorch tensors # beat_logits_tensor = torch.tensor(beat_logits, dtype=torch.float32) # downbeat_logits_tensor = torch.tensor(downbeat_logits, dtype=torch.float32) # # # Process through the postprocessor # beats, downbeats = frames2beats(beat_logits_tensor, downbeat_logits_tensor) # # # return beats, downbeats def analyze_beats(beats: List[float], downbeats: List[float]) -> Tuple[float, float, List[int]]: """ Analyze beats and downbeats to calculate BPM and clean outliers. Args: beats: List of beat positions in seconds downbeats: List of downbeat positions in seconds (first beat of each bar) Returns: Tuple containing: - estimated_bpm: Calculated BPM after removing outliers - beats_per_bar: Median bar duration in seconds - valid_bar_indices: Indices of bars that passed all filters """ # Step 1: Calculate beats per bar and bar durations bar_beats_count = [] bar_durations = [] for i in range(len(downbeats) - 1): # Find beats between current downbeat and next downbeat start_time = downbeats[i] end_time = downbeats[i + 1] # Count beats in this bar beats_in_bar = len([beat for beat in beats if start_time <= beat < end_time]) bar_beats_count.append(beats_in_bar) # Calculate bar duration bar_duration = end_time - start_time bar_durations.append(bar_duration) # Handle the last bar (if we have at least one downbeat) if len(downbeats) > 0: last_start = downbeats[-1] # For the last bar, count beats from last downbeat to end of beats list last_beats = len([beat for beat in beats if beat >= last_start]) bar_beats_count.append(last_beats) # Estimate last bar duration using average beat duration if len(beats) > 1: avg_beat_duration = (beats[-1] - beats[0]) / (len(beats) - 1) last_duration = last_beats * avg_beat_duration else: last_duration = 0 bar_durations.append(last_duration) # Step 2: Remove bars with outlier beats per bar if bar_beats_count: try: # Find the most common beats per bar value common_beats_per_bar = mode(bar_beats_count) # Keep only bars with the common beats per bar valid_bars_bp = [] valid_durations_bp = [] valid_indices_bp = [] for i, (beat_count, duration) in enumerate(zip(bar_beats_count, bar_durations)): if beat_count == common_beats_per_bar: valid_bars_bp.append(beat_count) valid_durations_bp.append(duration) valid_indices_bp.append(i) except StatisticsError: # If no clear mode, use median median_beats = median(bar_beats_count) valid_bars_bp = [bc for bc in bar_beats_count if bc == median_beats] valid_durations_bp = [bar_durations[i] for i, bc in enumerate(bar_beats_count) if bc == median_beats] valid_indices_bp = [i for i, bc in enumerate(bar_beats_count) if bc == median_beats] else: return 0, 0, [] # Step 3: Remove bars with outlier durations if valid_durations_bp: median_duration = median(valid_durations_bp) # Calculate reasonable bounds (e.g., ±25% of median) lower_bound = median_duration * 0.75 upper_bound = median_duration * 1.25 final_durations = [] final_indices = [] for i, duration in zip(valid_indices_bp, valid_durations_bp): if lower_bound <= duration <= upper_bound: final_durations.append(duration) final_indices.append(i) # Step 4: Calculate average bar duration and convert to BPM if final_durations: avg_bar_duration = sum(final_durations) / len(final_durations) estimated_bpm = 60.0 / (avg_bar_duration / common_beats_per_bar) # BPM = 60 / (beat duration in seconds) return estimated_bpm, common_beats_per_bar, final_indices return 0, 0, [] if __name__ == '__main__': app.run(debug=True)