Spaces:
Sleeping
Sleeping
| # dl_model_def.py | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from tensorflow.keras.utils import FeatureSpace | |
| # REMOVED: from keras.layers import ... | |
| MAX_TOK = 160_000 | |
| EMB_ID = 64 | |
| def make_fs(): | |
| return FeatureSpace( | |
| { | |
| "text": FeatureSpace.feature( | |
| preprocessor=keras.layers.TextVectorization( | |
| max_tokens=MAX_TOK, | |
| output_mode="count", | |
| ), | |
| dtype="string", | |
| output_mode="float", | |
| ) | |
| }, | |
| output_mode="concat", | |
| ) | |
| # @keras.utils.register_keras_serializable() # added to here instead of inside the model | |
| # def build_tower(input_dim: int,EMB_ID:int=64) -> keras.Model: | |
| # inp = keras.Input(shape=(input_dim + EMB_ID,)) | |
| # x = keras.layers.LayerNormalization()(inp) | |
| # # x = keras.layers.BatchNormalization()(inp) | |
| # ## BatchNormalization | |
| # x = keras.layers.Dropout(0.2)(x) | |
| # # x = keras.layers.Dense(768, activation="gelu")(x) | |
| # # out = keras.layers.Dense(256, activation="tanh")(x) | |
| # # out = keras.layers.Dense(256, activation="gelu")(inp) | |
| # # out = keras.layers.Dense(256, activation="linear")(x) # orig, 95.9 auc | |
| # # out = keras.layers.Dense(256, activation="gelu")(x) # | |
| # out = keras.layers.Dense(512, activation="elu")(x) | |
| # return keras.Model(inp, out, name="tower") | |
| def build_tower(input_dim: int, EMB_ID: int = 64) -> keras.Model: | |
| inp = keras.Input(shape=(input_dim + EMB_ID,)) | |
| norm_x = keras.layers.LayerNormalization()(inp) | |
| # Path 1: The Linear Projection (Wide) | |
| linear_out = keras.layers.Dense(384, activation="linear")(norm_x) | |
| # Path 2: Non-linear capture (Optional complex interactions) | |
| deep = keras.layers.Dense(384, activation="elu")(norm_x) | |
| deep = keras.layers.LayerNormalization()(deep) # Norm inside deep block is fine | |
| deep = keras.layers.Dropout(0.35)(deep) | |
| deep = keras.layers.Dense(64, activation="elu")(deep) | |
| deep = keras.layers.Dropout(0.15)(deep) | |
| # # Remove the LN here if you are putting it at the end, | |
| # # OR keep it if you want the deep branch specifically standardized. | |
| # # (Keeping it is fine/standard for a block). | |
| # deep = keras.layers.LayerNormalization()(deep) | |
| deep = keras.layers.Dense(384, activation="linear")(deep) | |
| # Add them (Residual style) | |
| out = keras.layers.Add()([linear_out, deep]) | |
| # out = keras.layers.LayerNormalization(name="final_norm")(out) | |
| return keras.Model(inp, out, name="tower") | |
| class TwoTowerDual(keras.Model): | |
| def __init__(self, | |
| dise_lookup, | |
| dise_emb, | |
| q_fs, | |
| k_fs, | |
| q_tower, | |
| k_tower, | |
| concat_layer, | |
| **kwargs): | |
| super().__init__(**kwargs) | |
| self.dise_lookup = dise_lookup | |
| self.dise_emb = dise_emb | |
| self.q_fs = q_fs | |
| self.k_fs = k_fs | |
| self.q_tower = q_tower | |
| self.k_tower = k_tower | |
| self.concat = concat_layer | |
| self.dot = keras.layers.Dot(axes=-1, normalize=True, name="cosine") | |
| self.cls_head = keras.layers.Dense(1, activation="sigmoid", | |
| name="cls", | |
| # 1. Start with a high scaling factor so Sigmoid isn't trapped in the middle. | |
| # (This is trainable, so the model can lower it if 20 is too high). | |
| # kernel_initializer=tf.keras.initializers.Constant(5.0), | |
| # bias_initializer=tf.keras.initializers.Constant(-2.2) | |
| ) | |
| self.score_head = keras.layers.Dense( | |
| 1, | |
| activation=None, | |
| name="score", | |
| bias_initializer=tf.keras.initializers.Constant(0.049), | |
| ) | |
| self.build_tower = build_tower # added new! | |
| def encode_q(self, txt, did): | |
| return self.q_tower( | |
| self.concat([ | |
| self.q_fs({"text": txt}), | |
| self.dise_emb(self.dise_lookup(did)), | |
| ]) | |
| ) | |
| def encode_k(self, txt, tid): | |
| txt_vec = self.k_fs({"text": txt}) | |
| return self.k_tower(txt_vec) | |
| def call(self, feats): | |
| q = self.encode_q( | |
| feats["query"]["disease_text"], | |
| feats["query"]["diseaseId"], | |
| ) | |
| k = self.encode_k( | |
| feats["candidate"]["target_text"], | |
| feats["candidate"]["targetId"], | |
| ) | |
| sim = self.dot([q, k]) | |
| prob = self.cls_head(sim) | |
| reg = self.score_head(sim) | |
| return {"cls": prob, "score": reg} | |
| # added | |
| def build_two_tower_model(df_learn) -> TwoTowerDual: | |
| # 1) Feature spaces | |
| q_fs = make_fs() | |
| k_fs = make_fs() | |
| q_fs.adapt( | |
| tf.data.Dataset.from_tensor_slices({"text": df_learn["disease_text"]}) | |
| .batch(4096) | |
| .prefetch(tf.data.AUTOTUNE) | |
| ) | |
| k_fs.adapt( | |
| tf.data.Dataset.from_tensor_slices({"text": df_learn["target_text"]}) | |
| .batch(4096) | |
| .prefetch(tf.data.AUTOTUNE) | |
| ) | |
| # 2) Lookup + embedding | |
| dise_lookup = keras.layers.StringLookup(name="disease_lookup") | |
| dise_lookup.adapt(df_learn["diseaseId"]) | |
| dise_emb = keras.layers.Embedding( | |
| input_dim=dise_lookup.vocabulary_size(), | |
| output_dim=EMB_ID, | |
| name="dise_emb", | |
| ) | |
| # # 3) Towers | |
| # # def build_tower(input_dim: int) -> keras.Model: | |
| # # inp = keras.Input(shape=(input_dim + EMB_ID,)) | |
| # # # out = keras.layers.Dense(128)(inp) | |
| # # out = keras.layers.Dense(128)(inp) | |
| # # return keras.Model(inp, out, name="tower") | |
| # @keras.utils.register_keras_serializable() # added | |
| # def build_tower(input_dim: int,EMB_ID:int=64) -> keras.Model: | |
| # inp = keras.Input(shape=(input_dim + EMB_ID,)) | |
| # x = keras.layers.LayerNormalization()(inp) | |
| # # x = keras.layers.BatchNormalization()(inp) | |
| # ## BatchNormalization | |
| # # x = keras.layers.Dropout(0.1)(x) | |
| # # x = keras.layers.Dense(768, activation="gelu")(x) | |
| # # out = keras.layers.Dense(256, activation="tanh")(x) | |
| # # out = keras.layers.Dense(256, activation="gelu")(inp) | |
| # out = keras.layers.Dense(256, activation="linear")(x) | |
| # return keras.Model(inp, out, name="tower") | |
| q_tower = build_tower(q_fs.get_encoded_features().shape[-1]) | |
| k_tower = build_tower(k_fs.get_encoded_features().shape[-1] - EMB_ID) | |
| concat = keras.layers.Concatenate(name="concat") | |
| # 4) Build model | |
| model = TwoTowerDual( | |
| dise_lookup=dise_lookup, | |
| dise_emb=dise_emb, | |
| q_fs=q_fs, | |
| k_fs=k_fs, | |
| q_tower=q_tower, | |
| k_tower=k_tower, | |
| concat_layer=concat, | |
| name="two_tower_dual", | |
| ) | |
| # Dummy build | |
| dummy = { | |
| "query": { | |
| "disease_text": tf.constant(["dummy"]), | |
| "diseaseId": tf.constant([df_learn["diseaseId"].iloc[0]]), | |
| }, | |
| "candidate": { | |
| "target_text": tf.constant(["dummy target"]), | |
| "targetId": tf.constant([df_learn["targetId"].iloc[0]]), | |
| }, | |
| } | |
| _ = model(dummy) | |
| return model |