Spaces:
Sleeping
Sleeping
| """Wave front arrival detection using STA/LTA algorithm""" | |
| import numpy as np | |
| from typing import Tuple, Optional, List | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class STALTA: | |
| """Short-Term Average / Long-Term Average detector""" | |
| def __init__(self, sta_window: int = 120, lta_window: int = 3600): | |
| """ | |
| Args: | |
| sta_window: short-term average window [samples] | |
| lta_window: long-term average window [samples] | |
| """ | |
| self.sta_window = sta_window | |
| self.lta_window = lta_window | |
| self.threshold = 3.5 # default threshold | |
| self.min_gap = 300 # minimum gap between detections [samples] | |
| def compute_ratio(self, data: np.ndarray) -> np.ndarray: | |
| """Compute STA/LTA ratio""" | |
| n = len(data) | |
| ratio = np.zeros(n) | |
| # Compute cumulative sum for efficient moving average | |
| cumsum = np.cumsum(np.abs(data)) | |
| for i in range(self.lta_window, n): | |
| # STA | |
| sta_start = max(0, i - self.sta_window) | |
| sta = (cumsum[i] - cumsum[sta_start]) / (i - sta_start) | |
| # LTA | |
| lta_start = max(0, i - self.lta_window) | |
| lta = (cumsum[i] - cumsum[lta_start]) / (i - lta_start) | |
| if lta > 0: | |
| ratio[i] = sta / lta | |
| return ratio | |
| def detect(self, data: np.ndarray, threshold: Optional[float] = None) -> List[int]: | |
| """Detect arrivals using STA/LTA | |
| Returns: | |
| list of arrival indices | |
| """ | |
| if threshold is not None: | |
| self.threshold = threshold | |
| ratio = self.compute_ratio(data) | |
| # Find peaks above threshold | |
| detections = [] | |
| last_detection = -self.min_gap | |
| for i in range(len(ratio)): | |
| if ratio[i] > self.threshold: | |
| if i - last_detection >= self.min_gap: | |
| detections.append(i) | |
| last_detection = i | |
| logger.info(f"Detected {len(detections)} arrivals") | |
| return detections | |
| def detect_with_confidence(self, data: np.ndarray) -> List[Tuple[int, float]]: | |
| """Detect arrivals with confidence scores""" | |
| ratio = self.compute_ratio(data) | |
| detections = [] | |
| last_detection = -self.min_gap | |
| for i in range(len(ratio)): | |
| if ratio[i] > self.threshold: | |
| if i - last_detection >= self.min_gap: | |
| # Confidence based on how much threshold is exceeded | |
| confidence = min(1.0, (ratio[i] - self.threshold) / self.threshold) | |
| detections.append((i, confidence)) | |
| last_detection = i | |
| return detections | |
| class WaveFrontDetector: | |
| """Multi-station wave front detector""" | |
| def __init__(self, sampling_rate: float = 1.0/60): # 1 minute default | |
| self.sampling_rate = sampling_rate | |
| self.stalta = STALTA( | |
| sta_window=int(120 * sampling_rate), # 2 minutes | |
| lta_window=int(3600 * sampling_rate) # 60 minutes | |
| ) | |
| def detect_arrival(self, station_data: dict) -> dict: | |
| """Detect arrival at single station""" | |
| data = station_data['data'] | |
| station_id = station_data.get('id', 'unknown') | |
| # Compute STA/LTA | |
| arrivals = self.stalta.detect_with_confidence(data) | |
| if arrivals: | |
| arrival_idx, confidence = arrivals[0] | |
| arrival_time = arrival_idx / self.sampling_rate | |
| return { | |
| 'station_id': station_id, | |
| 'arrival_detected': True, | |
| 'arrival_time_seconds': arrival_time, | |
| 'arrival_index': arrival_idx, | |
| 'confidence': confidence, | |
| 'peak_amplitude': float(np.max(np.abs(data[arrival_idx:arrival_idx+100]))) | |
| } | |
| else: | |
| return { | |
| 'station_id': station_id, | |
| 'arrival_detected': False | |
| } | |
| def detect_multi_station(self, stations: List[dict]) -> List[dict]: | |
| """Detect arrivals across multiple stations""" | |
| results = [] | |
| for station in stations: | |
| result = self.detect_arrival(station) | |
| results.append(result) | |
| # Sort by arrival time | |
| results.sort(key=lambda x: x.get('arrival_time_seconds', float('inf'))) | |
| return results | |
| def compute_celerity(self, station1: dict, station2: dict, | |
| distance: float) -> Optional[float]: | |
| """Compute wave celerity between two stations | |
| Args: | |
| station1: first station detection result | |
| station2: second station detection result | |
| distance: distance between stations [m] | |
| Returns: | |
| celerity [m/s] or None if arrivals not detected | |
| """ | |
| if not (station1.get('arrival_detected') and station2.get('arrival_detected')): | |
| return None | |
| time_diff = (station2['arrival_time_seconds'] - | |
| station1['arrival_time_seconds']) | |
| if time_diff <= 0: | |
| return None | |
| return distance / time_diff | |
| def estimate_source(self, stations: List[dict], positions: List[tuple]) -> Optional[tuple]: | |
| """Estimate source location from arrival times (simplified)""" | |
| # Requires at least 3 stations | |
| if len(stations) < 3: | |
| return None | |
| # Simple grid search for source | |
| # (Full implementation would use more sophisticated methods) | |
| return (0, 0) # placeholder | |