File size: 1,992 Bytes
2962055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin
from scipy.signal import butter, filtfilt

class LowPassFilter(BaseEstimator, TransformerMixin):
    def __init__(self, cutoff_frequency, sampling_rate, order):
        """
        Initialize the LowPassFilter class.
        
        Parameters:
        - cutoff_frequency: The cutoff frequency for the low-pass filter (default: 5 Hz).
        - sampling_rate: The sampling rate of the accelerometer data (default: 25 Hz).
        - order: The order of the filter (default: 4).
        """
        self.cutoff_frequency = cutoff_frequency
        self.sampling_rate = sampling_rate
        self.order = order

    def _butter_lowpass_filter(self, data):
        """
        Apply a Butterworth low-pass filter to the data.
        
        Parameters:
        - data: A NumPy array containing the accelerometer data to be filtered.
        
        Returns:
        - A filtered NumPy array.
        """
        nyquist = 0.5 * self.sampling_rate
        normalized_cutoff = self.cutoff_frequency / nyquist
        b, a = butter(self.order, normalized_cutoff, btype='low', analog=False)
        filtered_data = filtfilt(b, a, data, axis=0)
        return filtered_data

    def fit(self, X, y=None):
        return self

    def transform(self, X):
        """
        Apply the low-pass filter to the accelerometer data.
        
        Parameters:
        - X: A DataFrame with 'x', 'y', and 'z' columns representing the accelerometer data.
        
        Returns:
        - The DataFrame with filtered 'x', 'y', and 'z' columns.
        """
        if 'x' in X.columns and 'y' in X.columns and 'z' in X.columns:
            X[['x', 'y', 'z']] = self._butter_lowpass_filter(X[['x', 'y', 'z']].values)
            print("Low-pass filter applied successfully.")
        else:
            raise ValueError("The input DataFrame must contain 'x', 'y', and 'z' columns.")
        
        return X