Spaces:
Sleeping
Sleeping
| """ | |
| ConTime 모델 구현을 위한 커스텀 레이어 | |
| """ | |
| import tensorflow as tf | |
| class ODEFunc(tf.keras.layers.Layer): | |
| def __init__(self, hidden_dim, activation='tanh', **kwargs): | |
| super(ODEFunc, self).__init__(**kwargs) | |
| self.hidden_dim = hidden_dim | |
| self.activation = activation | |
| def build(self, input_shape): | |
| # input_shape 처리 개선 | |
| if isinstance(input_shape, (list, tuple)): | |
| if len(input_shape) > 0 and isinstance(input_shape[0], (list, tuple)): | |
| feature_dim = input_shape[0][-1] | |
| elif len(input_shape) > 0: | |
| feature_dim = input_shape[-1] if len(input_shape) > 1 else input_shape[0] | |
| else: | |
| feature_dim = input_shape[-1] | |
| else: | |
| feature_dim = input_shape | |
| # LSTM 게이트를 위한 가중치 | |
| self.W_i = self.add_weight(shape=(feature_dim + self.hidden_dim, self.hidden_dim), | |
| initializer='glorot_uniform', name='W_i') | |
| self.W_f = self.add_weight(shape=(feature_dim + self.hidden_dim, self.hidden_dim), | |
| initializer='glorot_uniform', name='W_f') | |
| self.W_o = self.add_weight(shape=(feature_dim + self.hidden_dim, self.hidden_dim), | |
| initializer='glorot_uniform', name='W_o') | |
| self.W_g = self.add_weight(shape=(feature_dim + self.hidden_dim, self.hidden_dim), | |
| initializer='glorot_uniform', name='W_g') | |
| self.b_i = self.add_weight(shape=(self.hidden_dim,), initializer='zeros', name='b_i') | |
| self.b_f = self.add_weight(shape=(self.hidden_dim,), initializer='zeros', name='b_f') | |
| self.b_o = self.add_weight(shape=(self.hidden_dim,), initializer='zeros', name='b_o') | |
| self.b_g = self.add_weight(shape=(self.hidden_dim,), initializer='zeros', name='b_g') | |
| super(ODEFunc, self).build(input_shape) | |
| def get_config(self): | |
| config = super(ODEFunc, self).get_config() | |
| config.update({ | |
| 'hidden_dim': self.hidden_dim, | |
| 'activation': self.activation | |
| }) | |
| return config | |
| def call(self, x_t, h, c): | |
| """ | |
| ODEFunc call 메서드 | |
| """ | |
| # 입력과 은닉 상태 연결 | |
| combined = tf.concat([x_t, h], axis=-1) | |
| # 게이트 계산 | |
| i = tf.nn.sigmoid(tf.matmul(combined, self.W_i) + self.b_i) | |
| f = tf.nn.sigmoid(tf.matmul(combined, self.W_f) + self.b_f) | |
| o = tf.nn.sigmoid(tf.matmul(combined, self.W_o) + self.b_o) | |
| g = tf.nn.tanh(tf.matmul(combined, self.W_g) + self.b_g) | |
| # 미분 계산 | |
| dc_dt = i * g + f * c - c | |
| dh_dt = o * tf.nn.tanh(c) - h | |
| return dh_dt, dc_dt | |
| class ContinuousLSTMLayer(tf.keras.layers.Layer): | |
| def __init__(self, hidden_dim, return_sequences=True, dt=0.1, | |
| ode_steps=5, max_dt=3.0, reverse=False, **kwargs): | |
| super(ContinuousLSTMLayer, self).__init__(**kwargs) | |
| self.hidden_dim = hidden_dim | |
| self.return_sequences = return_sequences | |
| self.dt = dt | |
| self.ode_steps = ode_steps | |
| self.max_dt = max_dt | |
| self.reverse = reverse | |
| def build(self, input_shape): | |
| # input_shape는 [(batch, seq, features), (batch, seq)] | |
| x_input_shape = input_shape[0] | |
| feature_dim = x_input_shape[-1] # 특성 차원 추출 | |
| self.ode_func = ODEFunc(self.hidden_dim) | |
| # 특성 차원만 전달 | |
| self.ode_func.build(feature_dim) | |
| super(ContinuousLSTMLayer, self).build(input_shape) | |
| def compute_output_shape(self, input_shape): | |
| if self.return_sequences: | |
| return (input_shape[0][0], input_shape[0][1], self.hidden_dim) | |
| else: | |
| return (input_shape[0][0], self.hidden_dim) | |
| def get_config(self): | |
| config = super(ContinuousLSTMLayer, self).get_config() | |
| config.update({ | |
| 'hidden_dim': self.hidden_dim, | |
| 'return_sequences': self.return_sequences, | |
| 'dt': self.dt, | |
| 'ode_steps': self.ode_steps, | |
| 'max_dt': self.max_dt, | |
| 'reverse': self.reverse | |
| }) | |
| return config | |
| def rk4_step(self, x_t, h, c, dt): | |
| """RK4 스텝""" | |
| # k1 | |
| k1_h, k1_c = self.ode_func(x_t, h, c) | |
| # k2 | |
| h_k2 = h + 0.5 * dt * k1_h | |
| c_k2 = c + 0.5 * dt * k1_c | |
| k2_h, k2_c = self.ode_func(x_t, h_k2, c_k2) | |
| # k3 | |
| h_k3 = h + 0.5 * dt * k2_h | |
| c_k3 = c + 0.5 * dt * k2_c | |
| k3_h, k3_c = self.ode_func(x_t, h_k3, c_k3) | |
| # k4 | |
| h_k4 = h + dt * k3_h | |
| c_k4 = c + dt * k3_c | |
| k4_h, k4_c = self.ode_func(x_t, h_k4, c_k4) | |
| # 최종 업데이트 | |
| h_new = h + (dt / 6.0) * (k1_h + 2*k2_h + 2*k3_h + k4_h) | |
| c_new = c + (dt / 6.0) * (k1_c + 2*k2_c + 2*k3_c + k4_c) | |
| return h_new, c_new | |
| def solve_ode(self, x_t, h, c, dt_value): | |
| """ODE 적분""" | |
| # 스칼라 dt 처리 | |
| dt_scalar = dt_value | |
| if tf.rank(dt_value) > 0: | |
| dt_scalar = dt_value if tf.rank(dt_value) == 0 else dt_value[0] | |
| # 적응적 시간 간격 | |
| adaptive_dt = tf.minimum(dt_scalar, self.max_dt) | |
| sub_dt = adaptive_dt / tf.cast(self.ode_steps, tf.float32) | |
| # RK4 적분 | |
| current_h, current_c = h, c | |
| for _ in range(self.ode_steps): | |
| current_h, current_c = self.rk4_step(x_t, current_h, current_c, sub_dt) | |
| return current_h, current_c | |
| def call(self, inputs): | |
| x, time_diffs = inputs | |
| batch_size = tf.shape(x)[0] | |
| seq_len = tf.shape(x)[1] | |
| # 시퀀스 방향 결정 | |
| if self.reverse: | |
| x = tf.reverse(x, axis=[1]) | |
| time_diffs = tf.reverse(time_diffs, axis=[1]) | |
| # 초기 상태 - 동적 배치 크기 처리 | |
| h = tf.zeros((batch_size, self.hidden_dim), dtype=tf.float32) | |
| c = tf.zeros((batch_size, self.hidden_dim), dtype=tf.float32) | |
| # TensorArray 초기화 | |
| outputs = tf.TensorArray( | |
| dtype=tf.float32, | |
| size=seq_len, | |
| dynamic_size=False, | |
| element_shape=tf.TensorShape([None, self.hidden_dim]) | |
| ) | |
| # tf.while_loop을 사용한 시퀀스 처리 | |
| def cond(t, h, c, outputs): | |
| return t < seq_len | |
| def body(t, h, c, outputs): | |
| x_t = x[:, t, :] | |
| dt_t = time_diffs[:, t] | |
| # ODE 적분 | |
| h_new, c_new = self.solve_ode(x_t, h, c, dt_t) | |
| # 형상 명시적 설정 | |
| h_new = tf.ensure_shape(h_new, [None, self.hidden_dim]) | |
| c_new = tf.ensure_shape(c_new, [None, self.hidden_dim]) | |
| # 출력 저장 | |
| outputs = outputs.write(t, h_new) | |
| return t + 1, h_new, c_new, outputs | |
| # while_loop 실행 | |
| _, final_h, final_c, outputs = tf.while_loop( | |
| cond, body, | |
| [0, h, c, outputs], | |
| shape_invariants=[ | |
| tf.TensorShape([]), # t: 스칼라 | |
| tf.TensorShape([None, self.hidden_dim]), # h: [batch, hidden] | |
| tf.TensorShape([None, self.hidden_dim]), # c: [batch, hidden] | |
| tf.TensorShape(None) # outputs: TensorArray | |
| ], | |
| parallel_iterations=1, # 순차 처리 보장 | |
| back_prop=True | |
| ) | |
| # 결과 스택 | |
| all_outputs = outputs.stack() # [seq_len, batch, hidden] | |
| all_outputs = tf.transpose(all_outputs, [1, 0, 2]) # [batch, seq_len, hidden] | |
| # 시퀀스 방향 복원 | |
| if self.reverse: | |
| all_outputs = tf.reverse(all_outputs, axis=[1]) | |
| if self.return_sequences: | |
| return all_outputs | |
| else: | |
| return all_outputs[:, -1, :] | |
| class BidirectionalContinuousLSTMLayer(tf.keras.layers.Layer): | |
| def __init__(self, hidden_dim=64, return_sequences=True, dt=0.1, ode_steps=5, merge_mode='concat', **kwargs): | |
| super(BidirectionalContinuousLSTMLayer, self).__init__(**kwargs) | |
| self.hidden_dim = hidden_dim | |
| self.return_sequences = return_sequences | |
| self.dt = dt | |
| self.ode_steps = ode_steps | |
| self.merge_mode = merge_mode | |
| def build(self, input_shape): | |
| self.forward_layer = ContinuousLSTMLayer( | |
| hidden_dim=self.hidden_dim, | |
| return_sequences=True, | |
| dt=self.dt, | |
| ode_steps=self.ode_steps, | |
| reverse=False | |
| ) | |
| self.backward_layer = ContinuousLSTMLayer( | |
| hidden_dim=self.hidden_dim, | |
| return_sequences=True, | |
| dt=self.dt, | |
| ode_steps=self.ode_steps, | |
| reverse=True | |
| ) | |
| self.forward_layer.build(input_shape) | |
| self.backward_layer.build(input_shape) | |
| super(BidirectionalContinuousLSTMLayer, self).build(input_shape) | |
| def compute_output_shape(self, input_shape): | |
| if self.merge_mode == 'concat': | |
| output_dim = self.hidden_dim * 2 | |
| else: # 'ave', 'sum', 'mul' | |
| output_dim = self.hidden_dim | |
| if self.return_sequences: | |
| return (input_shape[0][0], input_shape[0][1], output_dim) | |
| else: | |
| return (input_shape[0][0], output_dim) | |
| def get_config(self): | |
| config = super(BidirectionalContinuousLSTMLayer, self).get_config() | |
| config.update({ | |
| 'hidden_dim': self.hidden_dim, | |
| 'return_sequences': self.return_sequences, | |
| 'dt': self.dt, | |
| 'ode_steps': self.ode_steps, | |
| 'merge_mode': self.merge_mode | |
| }) | |
| return config | |
| def call(self, inputs): | |
| # Forward와 Backward 처리 | |
| forward_output = self.forward_layer(inputs) | |
| backward_output = self.backward_layer(inputs) | |
| # 병합 방식에 따른 출력 | |
| if self.merge_mode == 'concat': | |
| merged = tf.concat([forward_output, backward_output], axis=-1) | |
| elif self.merge_mode == 'ave': | |
| merged = (forward_output + backward_output) / 2.0 | |
| elif self.merge_mode == 'sum': | |
| merged = forward_output + backward_output | |
| elif self.merge_mode == 'mul': | |
| merged = forward_output * backward_output | |
| else: | |
| merged = tf.concat([forward_output, backward_output], axis=-1) | |
| if not self.return_sequences: | |
| merged = merged[:, -1, :] | |
| return merged |