""" ConTime 모델 구현을 위한 커스텀 레이어 """ import tensorflow as tf class ODEFunc(tf.keras.layers.Layer): def __init__(self, hidden_dim, activation='tanh', **kwargs): super(ODEFunc, self).__init__(**kwargs) self.hidden_dim = hidden_dim self.activation = activation def build(self, input_shape): # input_shape 처리 개선 if isinstance(input_shape, (list, tuple)): if len(input_shape) > 0 and isinstance(input_shape[0], (list, tuple)): feature_dim = input_shape[0][-1] elif len(input_shape) > 0: feature_dim = input_shape[-1] if len(input_shape) > 1 else input_shape[0] else: feature_dim = input_shape[-1] else: feature_dim = input_shape # LSTM 게이트를 위한 가중치 self.W_i = self.add_weight(shape=(feature_dim + self.hidden_dim, self.hidden_dim), initializer='glorot_uniform', name='W_i') self.W_f = self.add_weight(shape=(feature_dim + self.hidden_dim, self.hidden_dim), initializer='glorot_uniform', name='W_f') self.W_o = self.add_weight(shape=(feature_dim + self.hidden_dim, self.hidden_dim), initializer='glorot_uniform', name='W_o') self.W_g = self.add_weight(shape=(feature_dim + self.hidden_dim, self.hidden_dim), initializer='glorot_uniform', name='W_g') self.b_i = self.add_weight(shape=(self.hidden_dim,), initializer='zeros', name='b_i') self.b_f = self.add_weight(shape=(self.hidden_dim,), initializer='zeros', name='b_f') self.b_o = self.add_weight(shape=(self.hidden_dim,), initializer='zeros', name='b_o') self.b_g = self.add_weight(shape=(self.hidden_dim,), initializer='zeros', name='b_g') super(ODEFunc, self).build(input_shape) def get_config(self): config = super(ODEFunc, self).get_config() config.update({ 'hidden_dim': self.hidden_dim, 'activation': self.activation }) return config def call(self, x_t, h, c): """ ODEFunc call 메서드 """ # 입력과 은닉 상태 연결 combined = tf.concat([x_t, h], axis=-1) # 게이트 계산 i = tf.nn.sigmoid(tf.matmul(combined, self.W_i) + self.b_i) f = tf.nn.sigmoid(tf.matmul(combined, self.W_f) + self.b_f) o = tf.nn.sigmoid(tf.matmul(combined, self.W_o) + self.b_o) g = tf.nn.tanh(tf.matmul(combined, self.W_g) + self.b_g) # 미분 계산 dc_dt = i * g + f * c - c dh_dt = o * tf.nn.tanh(c) - h return dh_dt, dc_dt class ContinuousLSTMLayer(tf.keras.layers.Layer): def __init__(self, hidden_dim, return_sequences=True, dt=0.1, ode_steps=5, max_dt=3.0, reverse=False, **kwargs): super(ContinuousLSTMLayer, self).__init__(**kwargs) self.hidden_dim = hidden_dim self.return_sequences = return_sequences self.dt = dt self.ode_steps = ode_steps self.max_dt = max_dt self.reverse = reverse def build(self, input_shape): # input_shape는 [(batch, seq, features), (batch, seq)] x_input_shape = input_shape[0] feature_dim = x_input_shape[-1] # 특성 차원 추출 self.ode_func = ODEFunc(self.hidden_dim) # 특성 차원만 전달 self.ode_func.build(feature_dim) super(ContinuousLSTMLayer, self).build(input_shape) def compute_output_shape(self, input_shape): if self.return_sequences: return (input_shape[0][0], input_shape[0][1], self.hidden_dim) else: return (input_shape[0][0], self.hidden_dim) def get_config(self): config = super(ContinuousLSTMLayer, self).get_config() config.update({ 'hidden_dim': self.hidden_dim, 'return_sequences': self.return_sequences, 'dt': self.dt, 'ode_steps': self.ode_steps, 'max_dt': self.max_dt, 'reverse': self.reverse }) return config def rk4_step(self, x_t, h, c, dt): """RK4 스텝""" # k1 k1_h, k1_c = self.ode_func(x_t, h, c) # k2 h_k2 = h + 0.5 * dt * k1_h c_k2 = c + 0.5 * dt * k1_c k2_h, k2_c = self.ode_func(x_t, h_k2, c_k2) # k3 h_k3 = h + 0.5 * dt * k2_h c_k3 = c + 0.5 * dt * k2_c k3_h, k3_c = self.ode_func(x_t, h_k3, c_k3) # k4 h_k4 = h + dt * k3_h c_k4 = c + dt * k3_c k4_h, k4_c = self.ode_func(x_t, h_k4, c_k4) # 최종 업데이트 h_new = h + (dt / 6.0) * (k1_h + 2*k2_h + 2*k3_h + k4_h) c_new = c + (dt / 6.0) * (k1_c + 2*k2_c + 2*k3_c + k4_c) return h_new, c_new def solve_ode(self, x_t, h, c, dt_value): """ODE 적분""" # 스칼라 dt 처리 dt_scalar = dt_value if tf.rank(dt_value) > 0: dt_scalar = dt_value if tf.rank(dt_value) == 0 else dt_value[0] # 적응적 시간 간격 adaptive_dt = tf.minimum(dt_scalar, self.max_dt) sub_dt = adaptive_dt / tf.cast(self.ode_steps, tf.float32) # RK4 적분 current_h, current_c = h, c for _ in range(self.ode_steps): current_h, current_c = self.rk4_step(x_t, current_h, current_c, sub_dt) return current_h, current_c @tf.function def call(self, inputs): x, time_diffs = inputs batch_size = tf.shape(x)[0] seq_len = tf.shape(x)[1] # 시퀀스 방향 결정 if self.reverse: x = tf.reverse(x, axis=[1]) time_diffs = tf.reverse(time_diffs, axis=[1]) # 초기 상태 - 동적 배치 크기 처리 h = tf.zeros((batch_size, self.hidden_dim), dtype=tf.float32) c = tf.zeros((batch_size, self.hidden_dim), dtype=tf.float32) # TensorArray 초기화 outputs = tf.TensorArray( dtype=tf.float32, size=seq_len, dynamic_size=False, element_shape=tf.TensorShape([None, self.hidden_dim]) ) # tf.while_loop을 사용한 시퀀스 처리 def cond(t, h, c, outputs): return t < seq_len def body(t, h, c, outputs): x_t = x[:, t, :] dt_t = time_diffs[:, t] # ODE 적분 h_new, c_new = self.solve_ode(x_t, h, c, dt_t) # 형상 명시적 설정 h_new = tf.ensure_shape(h_new, [None, self.hidden_dim]) c_new = tf.ensure_shape(c_new, [None, self.hidden_dim]) # 출력 저장 outputs = outputs.write(t, h_new) return t + 1, h_new, c_new, outputs # while_loop 실행 _, final_h, final_c, outputs = tf.while_loop( cond, body, [0, h, c, outputs], shape_invariants=[ tf.TensorShape([]), # t: 스칼라 tf.TensorShape([None, self.hidden_dim]), # h: [batch, hidden] tf.TensorShape([None, self.hidden_dim]), # c: [batch, hidden] tf.TensorShape(None) # outputs: TensorArray ], parallel_iterations=1, # 순차 처리 보장 back_prop=True ) # 결과 스택 all_outputs = outputs.stack() # [seq_len, batch, hidden] all_outputs = tf.transpose(all_outputs, [1, 0, 2]) # [batch, seq_len, hidden] # 시퀀스 방향 복원 if self.reverse: all_outputs = tf.reverse(all_outputs, axis=[1]) if self.return_sequences: return all_outputs else: return all_outputs[:, -1, :] class BidirectionalContinuousLSTMLayer(tf.keras.layers.Layer): def __init__(self, hidden_dim=64, return_sequences=True, dt=0.1, ode_steps=5, merge_mode='concat', **kwargs): super(BidirectionalContinuousLSTMLayer, self).__init__(**kwargs) self.hidden_dim = hidden_dim self.return_sequences = return_sequences self.dt = dt self.ode_steps = ode_steps self.merge_mode = merge_mode def build(self, input_shape): self.forward_layer = ContinuousLSTMLayer( hidden_dim=self.hidden_dim, return_sequences=True, dt=self.dt, ode_steps=self.ode_steps, reverse=False ) self.backward_layer = ContinuousLSTMLayer( hidden_dim=self.hidden_dim, return_sequences=True, dt=self.dt, ode_steps=self.ode_steps, reverse=True ) self.forward_layer.build(input_shape) self.backward_layer.build(input_shape) super(BidirectionalContinuousLSTMLayer, self).build(input_shape) def compute_output_shape(self, input_shape): if self.merge_mode == 'concat': output_dim = self.hidden_dim * 2 else: # 'ave', 'sum', 'mul' output_dim = self.hidden_dim if self.return_sequences: return (input_shape[0][0], input_shape[0][1], output_dim) else: return (input_shape[0][0], output_dim) def get_config(self): config = super(BidirectionalContinuousLSTMLayer, self).get_config() config.update({ 'hidden_dim': self.hidden_dim, 'return_sequences': self.return_sequences, 'dt': self.dt, 'ode_steps': self.ode_steps, 'merge_mode': self.merge_mode }) return config @tf.function def call(self, inputs): # Forward와 Backward 처리 forward_output = self.forward_layer(inputs) backward_output = self.backward_layer(inputs) # 병합 방식에 따른 출력 if self.merge_mode == 'concat': merged = tf.concat([forward_output, backward_output], axis=-1) elif self.merge_mode == 'ave': merged = (forward_output + backward_output) / 2.0 elif self.merge_mode == 'sum': merged = forward_output + backward_output elif self.merge_mode == 'mul': merged = forward_output * backward_output else: merged = tf.concat([forward_output, backward_output], axis=-1) if not self.return_sequences: merged = merged[:, -1, :] return merged