GrimSqueaker commited on
Commit
5d5d5e8
·
verified ·
1 Parent(s): 41629fd

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. .tmp.driveupload/43308 +282 -0
  2. app.py +1 -1
.tmp.driveupload/43308 ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from tensorflow import keras
6
+ # from keras.layers import ...
7
+ from huggingface_hub import hf_hub_download
8
+ import gradio as gr
9
+ import h5py
10
+
11
+ from dl_model_def import make_fs, TwoTowerDual, build_two_tower_model
12
+
13
+ # ============================================
14
+ # CONFIG
15
+ # ============================================
16
+
17
+ DATA_DIR = "./data/proc"
18
+
19
+ # Download the model weights from your specific HF Repo
20
+ print("Downloading model weights from Hugging Face Hub...")
21
+ WEIGHTS_FILE = hf_hub_download(
22
+ repo_id="GrimSqueaker/OTRec",
23
+ filename="model.weights.h5"
24
+ )
25
+ print(f"Weights downloaded to: {WEIGHTS_FILE}")
26
+
27
+ # ============================================
28
+ # LOAD TRAINING DATA
29
+ # ============================================
30
+
31
+ df_learn = pd.read_parquet(f"{DATA_DIR}/df_learn_sub.parquet")
32
+ disease_df = pd.read_parquet(f"{DATA_DIR}/disease_df.parquet")
33
+ target_df = pd.read_parquet(f"{DATA_DIR}/target_df.parquet")
34
+
35
+ # Ensure column names match training
36
+ df_learn = df_learn.rename(columns={
37
+ "disease_text_embed": "disease_text",
38
+ "target_text_embed": "target_text"
39
+ }, errors="ignore")
40
+
41
+ disease_df.rename(columns={"disease_text_embed": "disease_text"}, errors="ignore",inplace=True)
42
+
43
+ target_df.rename(columns={"target_text_embed":"target_text"}, errors="ignore",inplace=True)
44
+
45
+ # ============================================
46
+ # BUILD MODEL + LOAD WEIGHTS
47
+ # ============================================
48
+
49
+ print("Building TwoTowerDual...")
50
+
51
+ # 1. Reset Keras Session to ensure layer names start at index 0 (matches clean training)
52
+ tf.keras.backend.clear_session()
53
+
54
+ # 2. Rebuild architecture
55
+ model = build_two_tower_model(df_learn)
56
+
57
+ print("Loading weights...")
58
+ try:
59
+ # Try standard load
60
+ model.load_weights(WEIGHTS_FILE)
61
+ except ValueError as e:
62
+ print(f"Standard load failed ({e}). Attempting name-mismatch fix...")
63
+
64
+ # FALLBACK: The training notebook likely generated layer names like 'dise_emb_1'
65
+ # due to multiple runs. We inspect the .h5 file and map the names.
66
+ with h5py.File(WEIGHTS_FILE, 'r') as f:
67
+ h5_keys = list(f.keys())
68
+ print(f"Weights file contains layers: {h5_keys}")
69
+
70
+ # Helper to find the matching key in h5 file for a given prefix
71
+ def match_layer_name(target_attr, prefix):
72
+ # Find key in h5 that starts with prefix (e.g. 'dise_emb')
73
+ match = next((k for k in h5_keys if k.startswith(prefix)), None)
74
+ if match and hasattr(model, target_attr):
75
+ layer = getattr(model, target_attr)
76
+ print(f"Renaming model layer '{layer.name}' to '{match}' to match file.")
77
+ layer._name = match
78
+
79
+ # Apply renames for known components
80
+ match_layer_name('dise_emb', 'dise_emb')
81
+ match_layer_name('q_tower', 'tower') # Attempt to catch tower/tower_1
82
+ # k_tower might share the name 'tower' prefix in H5, which is tricky in subclasses
83
+ # usually save_weights on subclass saves attributes directly.
84
+
85
+ # Retry load after renaming
86
+ model.load_weights(WEIGHTS_FILE)
87
+
88
+ print("Weights loaded successfully.")
89
+
90
+ # ============================================
91
+ # PRECOMPUTE CANDIDATE EMBEDDINGS
92
+ # ============================================
93
+
94
+
95
+ # # Note: In TF 2.16+, Ensure inputs are tf.constant or numpy compatible
96
+ # cand_embs = model.encode_k(target_texts, target_ids)
97
+ # cand_embs = tf.nn.l2_normalize(cand_embs, axis=1).numpy()
98
+
99
+ # print("Candidate embeddings ready.")
100
+
101
+
102
+ print("Precomputing candidate embeddings (batched)...")
103
+
104
+ target_texts = target_df["target_text"].astype(str).to_numpy()
105
+ target_ids = target_df["targetId"].astype(str).to_numpy()
106
+
107
+ # FIX: Process in batches to avoid OOM
108
+ BATCH_SIZE = 1024 # Conservative batch size for wide inputs
109
+ cand_embs_list = []
110
+
111
+ total = len(target_texts)
112
+ for i in range(0, total, BATCH_SIZE):
113
+ # Slice the batch
114
+ end = min(i + BATCH_SIZE, total)
115
+ batch_txt = target_texts[i:end]
116
+ batch_ids = target_ids[i:end]
117
+
118
+ # Run inference on the batch (keeps memory usage low)
119
+ # Using tf.device conversion is optional but good for safety if GPU is fragmented
120
+ emb_batch = model.encode_k(batch_txt, batch_ids)
121
+ cand_embs_list.append(emb_batch)
122
+
123
+ if i % 5000 == 0:
124
+ print(f" Processed {i}/{total} candidates...")
125
+
126
+ # Concatenate all batches back into one tensor
127
+ cand_embs = tf.concat(cand_embs_list, axis=0)
128
+
129
+ # Normalize the final result
130
+ cand_embs = tf.nn.l2_normalize(cand_embs, axis=1).numpy()
131
+
132
+ print(f"Candidate embeddings ready. Shape: {cand_embs.shape}")
133
+
134
+ # ============================================
135
+ # RECOMMENDATION FUNCTION
136
+ # ============================================
137
+
138
+ def recommend_targets(disease_id, top_k=10):
139
+ # 1. Validate Input
140
+ if not disease_id:
141
+ return pd.DataFrame(), None
142
+
143
+ row = disease_df.loc[disease_df["diseaseId"] == disease_id]
144
+ if row.empty:
145
+ return pd.DataFrame(), None
146
+
147
+ # 2. Encode Query
148
+ disease_text = row["disease_text"].iloc[0]
149
+ q_emb = model.encode_q(
150
+ tf.constant([disease_text]),
151
+ tf.constant([disease_id])
152
+ )
153
+ q_emb = tf.nn.l2_normalize(q_emb, axis=1).numpy()[0]
154
+
155
+ # 3. Calculate Raw Cosine Similarity
156
+ # Shape: (N_targets,)
157
+ raw_sim = cand_embs @ q_emb
158
+
159
+ # 4. Convert to Probability (Fixes negative scores)
160
+ # The model has a trained 'cls_head' (Sigmoid) that maps Similarity -> Probability
161
+ # We reshape to (N, 1) because the Keras Dense layer expects a matrix
162
+ scores = model.cls_head(raw_sim.reshape(-1, 1)).numpy().flatten()
163
+
164
+ # 5. Get Top K
165
+ k = int(top_k)
166
+ idx = np.argsort(scores)[::-1][:k]
167
+
168
+ # 6. Build Result DataFrame
169
+ results = target_df.iloc[idx].copy()
170
+
171
+ # Force standard python float for clean rounding
172
+ raw_scores = scores[idx]
173
+ results["score"] = [round(float(x), 4) for x in raw_scores]
174
+
175
+ # 7. Select Columns
176
+ desc_col = "functionDescription" if "functionDescription" in results.columns else "functionDescriptions"
177
+
178
+ desired_cols = [
179
+ "targetId",
180
+ "approvedSymbol",
181
+ "approvedName",
182
+ desc_col,
183
+ "score"
184
+ ]
185
+
186
+ final_cols = [c for c in desired_cols if c in results.columns]
187
+ results = results[final_cols]
188
+
189
+ # 8. Save to CSV for download
190
+ csv_path = "recommendations.csv"
191
+ results.to_csv(csv_path, index=False)
192
+
193
+ return results, csv_path
194
+
195
+ # ============================================
196
+ # GRADIO APP
197
+ # ============================================
198
+
199
+ def search_diseases(query):
200
+ if not query or len(query) < 2:
201
+ return gr.update(choices=[], value=None)
202
+
203
+ mask = (
204
+ disease_df["name"].str.contains(query, case=False, na=False) |
205
+ disease_df["diseaseId"].str.contains(query, case=False, na=False)
206
+ )
207
+
208
+ matches = disease_df.loc[mask].head(30)
209
+
210
+ choices = [
211
+ (f"{row['name']} ({row['diseaseId']})", row['diseaseId'])
212
+ for _, row in matches.iterrows()
213
+ ]
214
+
215
+ first_val = choices[0][1] if choices else None
216
+ return gr.update(choices=choices, value=first_val)
217
+
218
+ def launch():
219
+ examples = ["synuclein", "diabetes", "doid_0050890"]
220
+
221
+ with gr.Blocks() as demo:
222
+ gr.Markdown("# Disease → Target Recommender")
223
+ gr.Markdown("Search for a disease by **Name** or **ID** to get target recommendations.")
224
+
225
+ with gr.Row():
226
+ search_box = gr.Textbox(
227
+ label="1. Search Disease",
228
+ placeholder="Type name (e.g., 'Parkinson') or ID...",
229
+ lines=1
230
+ )
231
+
232
+ did_dropdown = gr.Dropdown(
233
+ label="2. Select Disease",
234
+ choices=[],
235
+ interactive=True
236
+ )
237
+
238
+ topk = gr.Slider(1, 400, value=10, step=5, label="Top K Targets")
239
+
240
+ # Search Logic (Updates dropdown options and default value)
241
+ search_box.change(fn=search_diseases, inputs=search_box, outputs=did_dropdown)
242
+
243
+ # Output Components (Stacked vertically for full width)
244
+ out_df = gr.Dataframe(
245
+ label="Predictions",
246
+ interactive=False,
247
+ wrap=True,
248
+ show_search="filter",
249
+ )
250
+
251
+ out_file = gr.File(label="Download CSV")
252
+
253
+ # === TRIGGER LOGIC ===
254
+ # 1. Manual Trigger (Keep the button just in case)
255
+ btn = gr.Button("Recommend Targets", variant="primary")
256
+ btn.click(
257
+ fn=recommend_targets,
258
+ inputs=[did_dropdown, topk],
259
+ outputs=[out_df, out_file]
260
+ )
261
+
262
+ # 2. Auto-Trigger on Change
263
+ # This handles the Examples too: Example -> Search -> Dropdown Update -> Trigger
264
+ did_dropdown.change(
265
+ fn=recommend_targets,
266
+ inputs=[did_dropdown, topk],
267
+ outputs=[out_df, out_file]
268
+ )
269
+
270
+ # Also update when slider moves
271
+ topk.change(
272
+ fn=recommend_targets,
273
+ inputs=[did_dropdown, topk],
274
+ outputs=[out_df, out_file]
275
+ )
276
+
277
+ gr.Examples(examples=examples, inputs=search_box)
278
+
279
+ demo.launch()
280
+
281
+ if __name__ == "__main__":
282
+ launch()
app.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
  import tensorflow as tf
5
  from tensorflow import keras
6
  # from keras.layers import ...
7
-
8
  import gradio as gr
9
  import h5py
10
 
 
4
  import tensorflow as tf
5
  from tensorflow import keras
6
  # from keras.layers import ...
7
+ from huggingface_hub import hf_hub_download
8
  import gradio as gr
9
  import h5py
10