OTRec / dl_model_def.py
GrimSqueaker's picture
Upload folder using huggingface_hub
3d9bb2a verified
# dl_model_def.py
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.utils import FeatureSpace
# REMOVED: from keras.layers import ...
MAX_TOK = 160_000
EMB_ID = 64
@keras.utils.register_keras_serializable(package="OTRec")
def make_fs():
return FeatureSpace(
{
"text": FeatureSpace.feature(
preprocessor=keras.layers.TextVectorization(
max_tokens=MAX_TOK,
output_mode="count",
),
dtype="string",
output_mode="float",
)
},
output_mode="concat",
)
# @keras.utils.register_keras_serializable() # added to here instead of inside the model
# def build_tower(input_dim: int,EMB_ID:int=64) -> keras.Model:
# inp = keras.Input(shape=(input_dim + EMB_ID,))
# x = keras.layers.LayerNormalization()(inp)
# # x = keras.layers.BatchNormalization()(inp)
# ## BatchNormalization
# x = keras.layers.Dropout(0.2)(x)
# # x = keras.layers.Dense(768, activation="gelu")(x)
# # out = keras.layers.Dense(256, activation="tanh")(x)
# # out = keras.layers.Dense(256, activation="gelu")(inp)
# # out = keras.layers.Dense(256, activation="linear")(x) # orig, 95.9 auc
# # out = keras.layers.Dense(256, activation="gelu")(x) #
# out = keras.layers.Dense(512, activation="elu")(x)
# return keras.Model(inp, out, name="tower")
@keras.utils.register_keras_serializable()
def build_tower(input_dim: int, EMB_ID: int = 64) -> keras.Model:
inp = keras.Input(shape=(input_dim + EMB_ID,))
norm_x = keras.layers.LayerNormalization()(inp)
# Path 1: The Linear Projection (Wide)
linear_out = keras.layers.Dense(384, activation="linear")(norm_x)
# Path 2: Non-linear capture (Optional complex interactions)
deep = keras.layers.Dense(384, activation="elu")(norm_x)
deep = keras.layers.LayerNormalization()(deep) # Norm inside deep block is fine
deep = keras.layers.Dropout(0.35)(deep)
deep = keras.layers.Dense(64, activation="elu")(deep)
deep = keras.layers.Dropout(0.15)(deep)
# # Remove the LN here if you are putting it at the end,
# # OR keep it if you want the deep branch specifically standardized.
# # (Keeping it is fine/standard for a block).
# deep = keras.layers.LayerNormalization()(deep)
deep = keras.layers.Dense(384, activation="linear")(deep)
# Add them (Residual style)
out = keras.layers.Add()([linear_out, deep])
# out = keras.layers.LayerNormalization(name="final_norm")(out)
return keras.Model(inp, out, name="tower")
@keras.utils.register_keras_serializable(package="OTRec")
class TwoTowerDual(keras.Model):
def __init__(self,
dise_lookup,
dise_emb,
q_fs,
k_fs,
q_tower,
k_tower,
concat_layer,
**kwargs):
super().__init__(**kwargs)
self.dise_lookup = dise_lookup
self.dise_emb = dise_emb
self.q_fs = q_fs
self.k_fs = k_fs
self.q_tower = q_tower
self.k_tower = k_tower
self.concat = concat_layer
self.dot = keras.layers.Dot(axes=-1, normalize=True, name="cosine")
self.cls_head = keras.layers.Dense(1, activation="sigmoid",
name="cls",
# 1. Start with a high scaling factor so Sigmoid isn't trapped in the middle.
# (This is trainable, so the model can lower it if 20 is too high).
# kernel_initializer=tf.keras.initializers.Constant(5.0),
# bias_initializer=tf.keras.initializers.Constant(-2.2)
)
self.score_head = keras.layers.Dense(
1,
activation=None,
name="score",
bias_initializer=tf.keras.initializers.Constant(0.049),
)
self.build_tower = build_tower # added new!
def encode_q(self, txt, did):
return self.q_tower(
self.concat([
self.q_fs({"text": txt}),
self.dise_emb(self.dise_lookup(did)),
])
)
def encode_k(self, txt, tid):
txt_vec = self.k_fs({"text": txt})
return self.k_tower(txt_vec)
def call(self, feats):
q = self.encode_q(
feats["query"]["disease_text"],
feats["query"]["diseaseId"],
)
k = self.encode_k(
feats["candidate"]["target_text"],
feats["candidate"]["targetId"],
)
sim = self.dot([q, k])
prob = self.cls_head(sim)
reg = self.score_head(sim)
return {"cls": prob, "score": reg}
@keras.utils.register_keras_serializable() # added
def build_two_tower_model(df_learn) -> TwoTowerDual:
# 1) Feature spaces
q_fs = make_fs()
k_fs = make_fs()
q_fs.adapt(
tf.data.Dataset.from_tensor_slices({"text": df_learn["disease_text"]})
.batch(4096)
.prefetch(tf.data.AUTOTUNE)
)
k_fs.adapt(
tf.data.Dataset.from_tensor_slices({"text": df_learn["target_text"]})
.batch(4096)
.prefetch(tf.data.AUTOTUNE)
)
# 2) Lookup + embedding
dise_lookup = keras.layers.StringLookup(name="disease_lookup")
dise_lookup.adapt(df_learn["diseaseId"])
dise_emb = keras.layers.Embedding(
input_dim=dise_lookup.vocabulary_size(),
output_dim=EMB_ID,
name="dise_emb",
)
# # 3) Towers
# # def build_tower(input_dim: int) -> keras.Model:
# # inp = keras.Input(shape=(input_dim + EMB_ID,))
# # # out = keras.layers.Dense(128)(inp)
# # out = keras.layers.Dense(128)(inp)
# # return keras.Model(inp, out, name="tower")
# @keras.utils.register_keras_serializable() # added
# def build_tower(input_dim: int,EMB_ID:int=64) -> keras.Model:
# inp = keras.Input(shape=(input_dim + EMB_ID,))
# x = keras.layers.LayerNormalization()(inp)
# # x = keras.layers.BatchNormalization()(inp)
# ## BatchNormalization
# # x = keras.layers.Dropout(0.1)(x)
# # x = keras.layers.Dense(768, activation="gelu")(x)
# # out = keras.layers.Dense(256, activation="tanh")(x)
# # out = keras.layers.Dense(256, activation="gelu")(inp)
# out = keras.layers.Dense(256, activation="linear")(x)
# return keras.Model(inp, out, name="tower")
q_tower = build_tower(q_fs.get_encoded_features().shape[-1])
k_tower = build_tower(k_fs.get_encoded_features().shape[-1] - EMB_ID)
concat = keras.layers.Concatenate(name="concat")
# 4) Build model
model = TwoTowerDual(
dise_lookup=dise_lookup,
dise_emb=dise_emb,
q_fs=q_fs,
k_fs=k_fs,
q_tower=q_tower,
k_tower=k_tower,
concat_layer=concat,
name="two_tower_dual",
)
# Dummy build
dummy = {
"query": {
"disease_text": tf.constant(["dummy"]),
"diseaseId": tf.constant([df_learn["diseaseId"].iloc[0]]),
},
"candidate": {
"target_text": tf.constant(["dummy target"]),
"targetId": tf.constant([df_learn["targetId"].iloc[0]]),
},
}
_ = model(dummy)
return model