Spaces:
Sleeping
Sleeping
| import logging | |
| import mimetypes | |
| # import numpy as np | |
| import os | |
| # import torch | |
| # from beat_this.model.postprocessor import Postprocessor | |
| from flask import Flask, request, jsonify, send_from_directory | |
| from flask_cors import CORS | |
| # from madmom.features.downbeats import DBNDownBeatTrackingProcessor | |
| from statistics import median, mode, StatisticsError | |
| from typing import List, Tuple | |
| # Add MIME types for JavaScript and WebAssembly | |
| mimetypes.add_type('application/javascript', '.js') | |
| mimetypes.add_type('text/javascript', '.js') # Add this as fallback | |
| mimetypes.add_type('application/wasm', '.wasm') | |
| mimetypes.add_type('application/octet-stream', '.wasm') # Add this as fallback | |
| app = Flask(__name__) | |
| CORS(app) # Enable CORS for all routes | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Get the directory where this script is located | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| def serve_index(): | |
| """Serve the main HTML page""" | |
| return send_from_directory(current_dir, 'index.html') | |
| def serve_static(path): | |
| """Serve static files (CSS, JS, ONNX models, etc.)""" | |
| response = send_from_directory(current_dir, path) | |
| # Set correct Content-Type headers for specific file types | |
| if path.endswith('.js'): | |
| response.headers.set('Content-Type', 'application/javascript') | |
| elif path.endswith('.wasm'): | |
| response.headers.set('Content-Type', 'application/wasm') | |
| elif path.endswith('.css'): | |
| response.headers.set('Content-Type', 'text/css') | |
| elif path.endswith('.html'): | |
| response.headers.set('Content-Type', 'text/html') | |
| elif path.endswith('.json'): | |
| response.headers.set('Content-Type', 'application/json') | |
| return response | |
| def health_check(): | |
| """Health check endpoint""" | |
| return jsonify({"status": "healthy", "message": "Beat detection postprocessor is running"}) | |
| # @app.route('/logits_to_bars', methods=['POST']) | |
| # def postprocess_beats(): | |
| # """ | |
| # Postprocess beat and downbeat logits to extract timing information | |
| # | |
| # Expected input: | |
| # { | |
| # "beat_logits": [array of float values], | |
| # "downbeat_logits": [array of float values] | |
| # "min_bpm": min_bpm, | |
| # "max_bpm": max_bpm, | |
| # "beats_per_bar": beats_per_bar, | |
| # } | |
| # | |
| # Returns: | |
| # { | |
| # "bars": {map of bar number to start timings in seconds}, | |
| # } | |
| # """ | |
| # try: | |
| # # Get JSON data from request | |
| # data = request.get_json() | |
| # | |
| # if not data: | |
| # return jsonify({"error": "No JSON data provided"}), 400 | |
| # | |
| # # Extract logits | |
| # beat_logits = np.array(data.get('beat_logits', [])) | |
| # downbeat_logits = np.array(data.get('downbeat_logits', [])) | |
| # beats_per_bar = int(data.get('beats_per_bar', 4)) | |
| # min_bpm = int(data.get('min_bpm', 55.0)) | |
| # max_bpm = int(data.get('max_bpm', 215.0)) | |
| # | |
| # logger.info(f"Received beat_logits: {len(beat_logits)}, downbeat_logits: {len(downbeat_logits)}, beats_per_bar: {beats_per_bar}, min_bpm: {min_bpm}, max_bpm: {max_bpm}") | |
| # | |
| # # Validate input | |
| # if len(beat_logits) == 0 or len(downbeat_logits) == 0: | |
| # return jsonify({"error": "Empty logits arrays provided"}), 400 | |
| # | |
| # if len(beat_logits) != len(downbeat_logits): | |
| # return jsonify({"error": "Beat and downbeat logits must have the same length"}), 400 | |
| # | |
| # # Process the logits to extract beat and downbeat timings | |
| # beats, downbeats = process_logits(beat_logits, downbeat_logits, type='minimal', | |
| # beats_per_bar=beats_per_bar, | |
| # min_bpm=min_bpm, | |
| # max_bpm=max_bpm) | |
| # | |
| # logger.info(f"Processed {len(beats)} beats and {len(downbeats)} downbeats") | |
| # | |
| # downbeats = downbeats.tolist() if isinstance(downbeats, np.ndarray) else downbeats | |
| # beats = beats.tolist() if isinstance(beats, np.ndarray) else beats | |
| # | |
| # estimated_bpm, detected_beats_per_bar, final_indices = analyze_beats(beats, downbeats) | |
| # print(f"estimated bpm: {estimated_bpm}, detected beats_per_bar: {detected_beats_per_bar}") | |
| # print(final_indices) | |
| # | |
| # | |
| # bars = {i+1: beat for i, beat in enumerate(downbeats)} | |
| # | |
| # return jsonify({ | |
| # "bars": bars, | |
| # "estimated_bpm": estimated_bpm, | |
| # "detected_beats_per_bar": detected_beats_per_bar | |
| # }) | |
| # | |
| # except Exception as e: | |
| # import traceback | |
| # logger.error(f"Error in postprocessing: {str(e)}") | |
| # return jsonify({"error": f"Processing failed: {str(e)}"}), 500 | |
| # def process_logits(beat_logits, downbeat_logits, type='minimal', | |
| # beats_per_bar=4, | |
| # min_bpm=55.0, | |
| # max_bpm=215.0, ): | |
| # """ | |
| # Process beat and downbeat logits to extract timing information | |
| # | |
| # Args: | |
| # beat_logits: Array of beat probabilities/logits | |
| # downbeat_logits: Array of downbeat probabilities/logits | |
| # type (str): the type of postprocessing to apply. Either "minimal" or "dbn". Default is "minimal". | |
| # beats_per_bar : int or list | |
| # Number of beats per bar to be modeled. Can be either a single number | |
| # or a list or array with bar lengths (in beats). | |
| # min_bpm : float or list, optional | |
| # Minimum tempo used for beat tracking [bpm]. If a list is given, each | |
| # item corresponds to the number of beats per bar at the same position. | |
| # max_bpm : float or list, optional | |
| # Maximum tempo used for beat tracking [bpm]. If a list is given, each | |
| # item corresponds to the number of beats per bar at the same position. | |
| # | |
| # | |
| # Returns: | |
| # Tuple of (beats, downbeats) where each is an array of timings in seconds | |
| # """ | |
| # frames2beats = Postprocessor(type=type) | |
| # if type == 'dbn' and (beats_per_bar != [3, 4] or min_bpm != 55.0 or max_bpm != 215.0): | |
| # frames2beats.dbb = DBNDownBeatTrackingProcessor( | |
| # beats_per_bar=beats_per_bar, | |
| # min_bpm=min_bpm, | |
| # max_bpm=max_bpm, | |
| # fps=50, | |
| # transition_lambda=100, | |
| # ) | |
| # | |
| # | |
| # # Convert numpy arrays to PyTorch tensors | |
| # beat_logits_tensor = torch.tensor(beat_logits, dtype=torch.float32) | |
| # downbeat_logits_tensor = torch.tensor(downbeat_logits, dtype=torch.float32) | |
| # | |
| # # Process through the postprocessor | |
| # beats, downbeats = frames2beats(beat_logits_tensor, downbeat_logits_tensor) | |
| # | |
| # | |
| # return beats, downbeats | |
| def analyze_beats(beats: List[float], downbeats: List[float]) -> Tuple[float, float, List[int]]: | |
| """ | |
| Analyze beats and downbeats to calculate BPM and clean outliers. | |
| Args: | |
| beats: List of beat positions in seconds | |
| downbeats: List of downbeat positions in seconds (first beat of each bar) | |
| Returns: | |
| Tuple containing: | |
| - estimated_bpm: Calculated BPM after removing outliers | |
| - beats_per_bar: Median bar duration in seconds | |
| - valid_bar_indices: Indices of bars that passed all filters | |
| """ | |
| # Step 1: Calculate beats per bar and bar durations | |
| bar_beats_count = [] | |
| bar_durations = [] | |
| for i in range(len(downbeats) - 1): | |
| # Find beats between current downbeat and next downbeat | |
| start_time = downbeats[i] | |
| end_time = downbeats[i + 1] | |
| # Count beats in this bar | |
| beats_in_bar = len([beat for beat in beats if start_time <= beat < end_time]) | |
| bar_beats_count.append(beats_in_bar) | |
| # Calculate bar duration | |
| bar_duration = end_time - start_time | |
| bar_durations.append(bar_duration) | |
| # Handle the last bar (if we have at least one downbeat) | |
| if len(downbeats) > 0: | |
| last_start = downbeats[-1] | |
| # For the last bar, count beats from last downbeat to end of beats list | |
| last_beats = len([beat for beat in beats if beat >= last_start]) | |
| bar_beats_count.append(last_beats) | |
| # Estimate last bar duration using average beat duration | |
| if len(beats) > 1: | |
| avg_beat_duration = (beats[-1] - beats[0]) / (len(beats) - 1) | |
| last_duration = last_beats * avg_beat_duration | |
| else: | |
| last_duration = 0 | |
| bar_durations.append(last_duration) | |
| # Step 2: Remove bars with outlier beats per bar | |
| if bar_beats_count: | |
| try: | |
| # Find the most common beats per bar value | |
| common_beats_per_bar = mode(bar_beats_count) | |
| # Keep only bars with the common beats per bar | |
| valid_bars_bp = [] | |
| valid_durations_bp = [] | |
| valid_indices_bp = [] | |
| for i, (beat_count, duration) in enumerate(zip(bar_beats_count, bar_durations)): | |
| if beat_count == common_beats_per_bar: | |
| valid_bars_bp.append(beat_count) | |
| valid_durations_bp.append(duration) | |
| valid_indices_bp.append(i) | |
| except StatisticsError: | |
| # If no clear mode, use median | |
| median_beats = median(bar_beats_count) | |
| valid_bars_bp = [bc for bc in bar_beats_count if bc == median_beats] | |
| valid_durations_bp = [bar_durations[i] for i, bc in enumerate(bar_beats_count) if bc == median_beats] | |
| valid_indices_bp = [i for i, bc in enumerate(bar_beats_count) if bc == median_beats] | |
| else: | |
| return 0, 0, [] | |
| # Step 3: Remove bars with outlier durations | |
| if valid_durations_bp: | |
| median_duration = median(valid_durations_bp) | |
| # Calculate reasonable bounds (e.g., ±25% of median) | |
| lower_bound = median_duration * 0.75 | |
| upper_bound = median_duration * 1.25 | |
| final_durations = [] | |
| final_indices = [] | |
| for i, duration in zip(valid_indices_bp, valid_durations_bp): | |
| if lower_bound <= duration <= upper_bound: | |
| final_durations.append(duration) | |
| final_indices.append(i) | |
| # Step 4: Calculate average bar duration and convert to BPM | |
| if final_durations: | |
| avg_bar_duration = sum(final_durations) / len(final_durations) | |
| estimated_bpm = 60.0 / (avg_bar_duration / common_beats_per_bar) # BPM = 60 / (beat duration in seconds) | |
| return estimated_bpm, common_beats_per_bar, final_indices | |
| return 0, 0, [] | |
| if __name__ == '__main__': | |
| app.run(debug=True) |