GrimSqueaker commited on
Commit
3d9bb2a
·
verified ·
1 Parent(s): 27eb6dc

Upload folder using huggingface_hub

Browse files
.github/workflows/update_space.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run Python script
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout
14
+ uses: actions/checkout@v2
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v2
18
+ with:
19
+ python-version: '3.9'
20
+
21
+ - name: Install Gradio
22
+ run: python -m pip install gradio
23
+
24
+ - name: Log in to Hugging Face
25
+ run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
26
+
27
+ - name: Deploy to Spaces
28
+ run: gradio deploy
README.md CHANGED
@@ -1,12 +1,36 @@
1
- ---
2
- title: OTRec
3
- emoji: 🔥
4
- colorFrom: yellow
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 6.1.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: OTRec
3
+ app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 6.0.1
6
+ ---
7
+ # Disease–Target Recommender (Open Targets)
8
+
9
+ This Space exposes a two-tower recommender model trained on Open Targets–derived
10
+ disease–target data. Given a **disease ID** (matching the `diseaseId` column from
11
+ the preprocessed data), it returns a ranked list of predicted **target IDs**.
12
+
13
+ The backend is a TensorFlow / Keras model with:
14
+ - A **query tower** for diseases (disease text + disease ID embedding)
15
+ - A **key tower** for targets (target text only)
16
+ - Cosine similarity between disease and target embeddings
17
+
18
+ All candidate target embeddings are currently precomputed at startup for fast inference. (can drop)
19
+
20
+ ---
21
+
22
+ ## Files and structure
23
+
24
+ Expected repo layout:
25
+
26
+ ```text
27
+ .
28
+ ├── app.py
29
+ ├── requirements.txt
30
+ ├── model.weights.h5
31
+ └── data/
32
+ └── proc/
33
+ ├── disease_df.parquet
34
+ └── target_df.parquet
35
+ └── df_learn.parquet
36
+
__pycache__/dl_model_def.cpython-310.pyc ADDED
Binary file (3.82 kB). View file
 
app.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from tensorflow import keras
6
+ # from keras.layers import ...
7
+
8
+ import gradio as gr
9
+ import h5py
10
+
11
+ from dl_model_def import make_fs, TwoTowerDual, build_two_tower_model
12
+
13
+ # ============================================
14
+ # CONFIG
15
+ # ============================================
16
+
17
+ DATA_DIR = "./data/proc"
18
+
19
+ # Download the model weights from your specific HF Repo
20
+ print("Downloading model weights from Hugging Face Hub...")
21
+ WEIGHTS_FILE = hf_hub_download(
22
+ repo_id="GrimSqueaker/OTRec",
23
+ filename="model.weights.h5"
24
+ )
25
+ print(f"Weights downloaded to: {WEIGHTS_FILE}")
26
+
27
+ # ============================================
28
+ # LOAD TRAINING DATA
29
+ # ============================================
30
+
31
+ df_learn = pd.read_parquet(f"{DATA_DIR}/df_learn_sub.parquet")
32
+ disease_df = pd.read_parquet(f"{DATA_DIR}/disease_df.parquet")
33
+ target_df = pd.read_parquet(f"{DATA_DIR}/target_df.parquet")
34
+
35
+ # Ensure column names match training
36
+ df_learn = df_learn.rename(columns={
37
+ "disease_text_embed": "disease_text",
38
+ "target_text_embed": "target_text"
39
+ }, errors="ignore")
40
+
41
+ disease_df.rename(columns={"disease_text_embed": "disease_text"}, errors="ignore",inplace=True)
42
+
43
+ target_df.rename(columns={"target_text_embed":"target_text"}, errors="ignore",inplace=True)
44
+
45
+ # ============================================
46
+ # BUILD MODEL + LOAD WEIGHTS
47
+ # ============================================
48
+
49
+ print("Building TwoTowerDual...")
50
+
51
+ # 1. Reset Keras Session to ensure layer names start at index 0 (matches clean training)
52
+ tf.keras.backend.clear_session()
53
+
54
+ # 2. Rebuild architecture
55
+ model = build_two_tower_model(df_learn)
56
+
57
+ print("Loading weights...")
58
+ try:
59
+ # Try standard load
60
+ model.load_weights(WEIGHTS_FILE)
61
+ except ValueError as e:
62
+ print(f"Standard load failed ({e}). Attempting name-mismatch fix...")
63
+
64
+ # FALLBACK: The training notebook likely generated layer names like 'dise_emb_1'
65
+ # due to multiple runs. We inspect the .h5 file and map the names.
66
+ with h5py.File(WEIGHTS_FILE, 'r') as f:
67
+ h5_keys = list(f.keys())
68
+ print(f"Weights file contains layers: {h5_keys}")
69
+
70
+ # Helper to find the matching key in h5 file for a given prefix
71
+ def match_layer_name(target_attr, prefix):
72
+ # Find key in h5 that starts with prefix (e.g. 'dise_emb')
73
+ match = next((k for k in h5_keys if k.startswith(prefix)), None)
74
+ if match and hasattr(model, target_attr):
75
+ layer = getattr(model, target_attr)
76
+ print(f"Renaming model layer '{layer.name}' to '{match}' to match file.")
77
+ layer._name = match
78
+
79
+ # Apply renames for known components
80
+ match_layer_name('dise_emb', 'dise_emb')
81
+ match_layer_name('q_tower', 'tower') # Attempt to catch tower/tower_1
82
+ # k_tower might share the name 'tower' prefix in H5, which is tricky in subclasses
83
+ # usually save_weights on subclass saves attributes directly.
84
+
85
+ # Retry load after renaming
86
+ model.load_weights(WEIGHTS_FILE)
87
+
88
+ print("Weights loaded successfully.")
89
+
90
+ # ============================================
91
+ # PRECOMPUTE CANDIDATE EMBEDDINGS
92
+ # ============================================
93
+
94
+
95
+ # # Note: In TF 2.16+, Ensure inputs are tf.constant or numpy compatible
96
+ # cand_embs = model.encode_k(target_texts, target_ids)
97
+ # cand_embs = tf.nn.l2_normalize(cand_embs, axis=1).numpy()
98
+
99
+ # print("Candidate embeddings ready.")
100
+
101
+
102
+ print("Precomputing candidate embeddings (batched)...")
103
+
104
+ target_texts = target_df["target_text"].astype(str).to_numpy()
105
+ target_ids = target_df["targetId"].astype(str).to_numpy()
106
+
107
+ # FIX: Process in batches to avoid OOM
108
+ BATCH_SIZE = 1024 # Conservative batch size for wide inputs
109
+ cand_embs_list = []
110
+
111
+ total = len(target_texts)
112
+ for i in range(0, total, BATCH_SIZE):
113
+ # Slice the batch
114
+ end = min(i + BATCH_SIZE, total)
115
+ batch_txt = target_texts[i:end]
116
+ batch_ids = target_ids[i:end]
117
+
118
+ # Run inference on the batch (keeps memory usage low)
119
+ # Using tf.device conversion is optional but good for safety if GPU is fragmented
120
+ emb_batch = model.encode_k(batch_txt, batch_ids)
121
+ cand_embs_list.append(emb_batch)
122
+
123
+ if i % 5000 == 0:
124
+ print(f" Processed {i}/{total} candidates...")
125
+
126
+ # Concatenate all batches back into one tensor
127
+ cand_embs = tf.concat(cand_embs_list, axis=0)
128
+
129
+ # Normalize the final result
130
+ cand_embs = tf.nn.l2_normalize(cand_embs, axis=1).numpy()
131
+
132
+ print(f"Candidate embeddings ready. Shape: {cand_embs.shape}")
133
+
134
+ # ============================================
135
+ # RECOMMENDATION FUNCTION
136
+ # ============================================
137
+
138
+ def recommend_targets(disease_id, top_k=10):
139
+ # 1. Validate Input
140
+ if not disease_id:
141
+ return pd.DataFrame(), None
142
+
143
+ row = disease_df.loc[disease_df["diseaseId"] == disease_id]
144
+ if row.empty:
145
+ return pd.DataFrame(), None
146
+
147
+ # 2. Encode Query
148
+ disease_text = row["disease_text"].iloc[0]
149
+ q_emb = model.encode_q(
150
+ tf.constant([disease_text]),
151
+ tf.constant([disease_id])
152
+ )
153
+ q_emb = tf.nn.l2_normalize(q_emb, axis=1).numpy()[0]
154
+
155
+ # 3. Calculate Raw Cosine Similarity
156
+ # Shape: (N_targets,)
157
+ raw_sim = cand_embs @ q_emb
158
+
159
+ # 4. Convert to Probability (Fixes negative scores)
160
+ # The model has a trained 'cls_head' (Sigmoid) that maps Similarity -> Probability
161
+ # We reshape to (N, 1) because the Keras Dense layer expects a matrix
162
+ scores = model.cls_head(raw_sim.reshape(-1, 1)).numpy().flatten()
163
+
164
+ # 5. Get Top K
165
+ k = int(top_k)
166
+ idx = np.argsort(scores)[::-1][:k]
167
+
168
+ # 6. Build Result DataFrame
169
+ results = target_df.iloc[idx].copy()
170
+
171
+ # Force standard python float for clean rounding
172
+ raw_scores = scores[idx]
173
+ results["score"] = [round(float(x), 4) for x in raw_scores]
174
+
175
+ # 7. Select Columns
176
+ desc_col = "functionDescription" if "functionDescription" in results.columns else "functionDescriptions"
177
+
178
+ desired_cols = [
179
+ "targetId",
180
+ "approvedSymbol",
181
+ "approvedName",
182
+ desc_col,
183
+ "score"
184
+ ]
185
+
186
+ final_cols = [c for c in desired_cols if c in results.columns]
187
+ results = results[final_cols]
188
+
189
+ # 8. Save to CSV for download
190
+ csv_path = "recommendations.csv"
191
+ results.to_csv(csv_path, index=False)
192
+
193
+ return results, csv_path
194
+
195
+ # ============================================
196
+ # GRADIO APP
197
+ # ============================================
198
+
199
+ def search_diseases(query):
200
+ if not query or len(query) < 2:
201
+ return gr.update(choices=[], value=None)
202
+
203
+ mask = (
204
+ disease_df["name"].str.contains(query, case=False, na=False) |
205
+ disease_df["diseaseId"].str.contains(query, case=False, na=False)
206
+ )
207
+
208
+ matches = disease_df.loc[mask].head(30)
209
+
210
+ choices = [
211
+ (f"{row['name']} ({row['diseaseId']})", row['diseaseId'])
212
+ for _, row in matches.iterrows()
213
+ ]
214
+
215
+ first_val = choices[0][1] if choices else None
216
+ return gr.update(choices=choices, value=first_val)
217
+
218
+ def launch():
219
+ examples = ["synuclein", "diabetes", "doid_0050890"]
220
+
221
+ with gr.Blocks() as demo:
222
+ gr.Markdown("# Disease → Target Recommender")
223
+ gr.Markdown("Search for a disease by **Name** or **ID** to get target recommendations.")
224
+
225
+ with gr.Row():
226
+ search_box = gr.Textbox(
227
+ label="1. Search Disease",
228
+ placeholder="Type name (e.g., 'Parkinson') or ID...",
229
+ lines=1
230
+ )
231
+
232
+ did_dropdown = gr.Dropdown(
233
+ label="2. Select Disease",
234
+ choices=[],
235
+ interactive=True
236
+ )
237
+
238
+ topk = gr.Slider(1, 400, value=10, step=5, label="Top K Targets")
239
+
240
+ # Search Logic (Updates dropdown options and default value)
241
+ search_box.change(fn=search_diseases, inputs=search_box, outputs=did_dropdown)
242
+
243
+ # Output Components (Stacked vertically for full width)
244
+ out_df = gr.Dataframe(
245
+ label="Predictions",
246
+ interactive=False,
247
+ wrap=True,
248
+ show_search="filter",
249
+ )
250
+
251
+ out_file = gr.File(label="Download CSV")
252
+
253
+ # === TRIGGER LOGIC ===
254
+ # 1. Manual Trigger (Keep the button just in case)
255
+ btn = gr.Button("Recommend Targets", variant="primary")
256
+ btn.click(
257
+ fn=recommend_targets,
258
+ inputs=[did_dropdown, topk],
259
+ outputs=[out_df, out_file]
260
+ )
261
+
262
+ # 2. Auto-Trigger on Change
263
+ # This handles the Examples too: Example -> Search -> Dropdown Update -> Trigger
264
+ did_dropdown.change(
265
+ fn=recommend_targets,
266
+ inputs=[did_dropdown, topk],
267
+ outputs=[out_df, out_file]
268
+ )
269
+
270
+ # Also update when slider moves
271
+ topk.change(
272
+ fn=recommend_targets,
273
+ inputs=[did_dropdown, topk],
274
+ outputs=[out_df, out_file]
275
+ )
276
+
277
+ gr.Examples(examples=examples, inputs=search_box)
278
+
279
+ demo.launch()
280
+
281
+ if __name__ == "__main__":
282
+ launch()
data/proc/df_learn_sub.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20f1834245178fbf27f385aaef6a757921ced1c6a37fce1fe29d86b1d11a4854
3
+ size 25162164
data/proc/disease_df.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:def4b7f42efca118bdbc84745249ff37fa8c9dc1ac8740feff17ffb99ac3c316
3
+ size 13255480
data/proc/target_df.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3182b52698a4513eeb88cf678189d62897815263ff170fd23d4978e3a869f823
3
+ size 27290365
dl_model_def.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dl_model_def.py
2
+ import tensorflow as tf
3
+ from tensorflow import keras
4
+ from tensorflow.keras.utils import FeatureSpace
5
+ # REMOVED: from keras.layers import ...
6
+
7
+ MAX_TOK = 160_000
8
+ EMB_ID = 64
9
+
10
+ @keras.utils.register_keras_serializable(package="OTRec")
11
+ def make_fs():
12
+ return FeatureSpace(
13
+ {
14
+ "text": FeatureSpace.feature(
15
+ preprocessor=keras.layers.TextVectorization(
16
+ max_tokens=MAX_TOK,
17
+ output_mode="count",
18
+ ),
19
+ dtype="string",
20
+ output_mode="float",
21
+ )
22
+ },
23
+ output_mode="concat",
24
+ )
25
+
26
+
27
+ # @keras.utils.register_keras_serializable() # added to here instead of inside the model
28
+ # def build_tower(input_dim: int,EMB_ID:int=64) -> keras.Model:
29
+ # inp = keras.Input(shape=(input_dim + EMB_ID,))
30
+ # x = keras.layers.LayerNormalization()(inp)
31
+ # # x = keras.layers.BatchNormalization()(inp)
32
+ # ## BatchNormalization
33
+ # x = keras.layers.Dropout(0.2)(x)
34
+ # # x = keras.layers.Dense(768, activation="gelu")(x)
35
+ # # out = keras.layers.Dense(256, activation="tanh")(x)
36
+ # # out = keras.layers.Dense(256, activation="gelu")(inp)
37
+
38
+ # # out = keras.layers.Dense(256, activation="linear")(x) # orig, 95.9 auc
39
+ # # out = keras.layers.Dense(256, activation="gelu")(x) #
40
+ # out = keras.layers.Dense(512, activation="elu")(x)
41
+ # return keras.Model(inp, out, name="tower")
42
+
43
+ @keras.utils.register_keras_serializable()
44
+ def build_tower(input_dim: int, EMB_ID: int = 64) -> keras.Model:
45
+ inp = keras.Input(shape=(input_dim + EMB_ID,))
46
+ norm_x = keras.layers.LayerNormalization()(inp)
47
+
48
+ # Path 1: The Linear Projection (Wide)
49
+ linear_out = keras.layers.Dense(384, activation="linear")(norm_x)
50
+
51
+ # Path 2: Non-linear capture (Optional complex interactions)
52
+ deep = keras.layers.Dense(384, activation="elu")(norm_x)
53
+ deep = keras.layers.LayerNormalization()(deep) # Norm inside deep block is fine
54
+ deep = keras.layers.Dropout(0.35)(deep)
55
+
56
+ deep = keras.layers.Dense(64, activation="elu")(deep)
57
+ deep = keras.layers.Dropout(0.15)(deep)
58
+ # # Remove the LN here if you are putting it at the end,
59
+ # # OR keep it if you want the deep branch specifically standardized.
60
+ # # (Keeping it is fine/standard for a block).
61
+ # deep = keras.layers.LayerNormalization()(deep)
62
+ deep = keras.layers.Dense(384, activation="linear")(deep)
63
+
64
+ # Add them (Residual style)
65
+ out = keras.layers.Add()([linear_out, deep])
66
+ # out = keras.layers.LayerNormalization(name="final_norm")(out)
67
+
68
+ return keras.Model(inp, out, name="tower")
69
+
70
+
71
+ @keras.utils.register_keras_serializable(package="OTRec")
72
+ class TwoTowerDual(keras.Model):
73
+ def __init__(self,
74
+ dise_lookup,
75
+ dise_emb,
76
+ q_fs,
77
+ k_fs,
78
+ q_tower,
79
+ k_tower,
80
+ concat_layer,
81
+ **kwargs):
82
+ super().__init__(**kwargs)
83
+ self.dise_lookup = dise_lookup
84
+ self.dise_emb = dise_emb
85
+ self.q_fs = q_fs
86
+ self.k_fs = k_fs
87
+ self.q_tower = q_tower
88
+ self.k_tower = k_tower
89
+ self.concat = concat_layer
90
+ self.dot = keras.layers.Dot(axes=-1, normalize=True, name="cosine")
91
+ self.cls_head = keras.layers.Dense(1, activation="sigmoid",
92
+ name="cls",
93
+ # 1. Start with a high scaling factor so Sigmoid isn't trapped in the middle.
94
+ # (This is trainable, so the model can lower it if 20 is too high).
95
+ # kernel_initializer=tf.keras.initializers.Constant(5.0),
96
+ # bias_initializer=tf.keras.initializers.Constant(-2.2)
97
+ )
98
+ self.score_head = keras.layers.Dense(
99
+ 1,
100
+ activation=None,
101
+ name="score",
102
+ bias_initializer=tf.keras.initializers.Constant(0.049),
103
+ )
104
+ self.build_tower = build_tower # added new!
105
+
106
+ def encode_q(self, txt, did):
107
+ return self.q_tower(
108
+ self.concat([
109
+ self.q_fs({"text": txt}),
110
+ self.dise_emb(self.dise_lookup(did)),
111
+ ])
112
+ )
113
+
114
+ def encode_k(self, txt, tid):
115
+ txt_vec = self.k_fs({"text": txt})
116
+ return self.k_tower(txt_vec)
117
+
118
+ def call(self, feats):
119
+ q = self.encode_q(
120
+ feats["query"]["disease_text"],
121
+ feats["query"]["diseaseId"],
122
+ )
123
+ k = self.encode_k(
124
+ feats["candidate"]["target_text"],
125
+ feats["candidate"]["targetId"],
126
+ )
127
+ sim = self.dot([q, k])
128
+ prob = self.cls_head(sim)
129
+ reg = self.score_head(sim)
130
+ return {"cls": prob, "score": reg}
131
+
132
+ @keras.utils.register_keras_serializable() # added
133
+ def build_two_tower_model(df_learn) -> TwoTowerDual:
134
+ # 1) Feature spaces
135
+ q_fs = make_fs()
136
+ k_fs = make_fs()
137
+
138
+ q_fs.adapt(
139
+ tf.data.Dataset.from_tensor_slices({"text": df_learn["disease_text"]})
140
+ .batch(4096)
141
+ .prefetch(tf.data.AUTOTUNE)
142
+ )
143
+ k_fs.adapt(
144
+ tf.data.Dataset.from_tensor_slices({"text": df_learn["target_text"]})
145
+ .batch(4096)
146
+ .prefetch(tf.data.AUTOTUNE)
147
+ )
148
+
149
+ # 2) Lookup + embedding
150
+ dise_lookup = keras.layers.StringLookup(name="disease_lookup")
151
+ dise_lookup.adapt(df_learn["diseaseId"])
152
+ dise_emb = keras.layers.Embedding(
153
+ input_dim=dise_lookup.vocabulary_size(),
154
+ output_dim=EMB_ID,
155
+ name="dise_emb",
156
+ )
157
+
158
+ # # 3) Towers
159
+ # # def build_tower(input_dim: int) -> keras.Model:
160
+ # # inp = keras.Input(shape=(input_dim + EMB_ID,))
161
+ # # # out = keras.layers.Dense(128)(inp)
162
+
163
+ # # out = keras.layers.Dense(128)(inp)
164
+ # # return keras.Model(inp, out, name="tower")
165
+ # @keras.utils.register_keras_serializable() # added
166
+ # def build_tower(input_dim: int,EMB_ID:int=64) -> keras.Model:
167
+ # inp = keras.Input(shape=(input_dim + EMB_ID,))
168
+ # x = keras.layers.LayerNormalization()(inp)
169
+ # # x = keras.layers.BatchNormalization()(inp)
170
+ # ## BatchNormalization
171
+ # # x = keras.layers.Dropout(0.1)(x)
172
+ # # x = keras.layers.Dense(768, activation="gelu")(x)
173
+ # # out = keras.layers.Dense(256, activation="tanh")(x)
174
+ # # out = keras.layers.Dense(256, activation="gelu")(inp)
175
+ # out = keras.layers.Dense(256, activation="linear")(x)
176
+ # return keras.Model(inp, out, name="tower")
177
+
178
+ q_tower = build_tower(q_fs.get_encoded_features().shape[-1])
179
+ k_tower = build_tower(k_fs.get_encoded_features().shape[-1] - EMB_ID)
180
+
181
+ concat = keras.layers.Concatenate(name="concat")
182
+
183
+ # 4) Build model
184
+ model = TwoTowerDual(
185
+ dise_lookup=dise_lookup,
186
+ dise_emb=dise_emb,
187
+ q_fs=q_fs,
188
+ k_fs=k_fs,
189
+ q_tower=q_tower,
190
+ k_tower=k_tower,
191
+ concat_layer=concat,
192
+ name="two_tower_dual",
193
+ )
194
+
195
+ # Dummy build
196
+ dummy = {
197
+ "query": {
198
+ "disease_text": tf.constant(["dummy"]),
199
+ "diseaseId": tf.constant([df_learn["diseaseId"].iloc[0]]),
200
+ },
201
+ "candidate": {
202
+ "target_text": tf.constant(["dummy target"]),
203
+ "targetId": tf.constant([df_learn["targetId"].iloc[0]]),
204
+ },
205
+ }
206
+ _ = model(dummy)
207
+
208
+ return model
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ tensorflow==2.16
2
+ numpy
3
+ pandas
4
+ pyarrow
5
+ gradio
6
+ huggingface-hub