File size: 9,201 Bytes
3d9bb2a
 
 
 
 
 
5d5d5e8
3d9bb2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import os
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
# from keras.layers import ... 
from huggingface_hub import hf_hub_download
import gradio as gr
import h5py

from dl_model_def import make_fs, TwoTowerDual, build_two_tower_model

# ============================================
#  CONFIG
# ============================================

DATA_DIR = "./data/proc"

# Download the model weights from your specific HF Repo
print("Downloading model weights from Hugging Face Hub...")
WEIGHTS_FILE = hf_hub_download(
    repo_id="GrimSqueaker/OTRec",
    filename="model.weights.h5"
)
print(f"Weights downloaded to: {WEIGHTS_FILE}")

# ============================================
#  LOAD TRAINING DATA
# ============================================

df_learn   = pd.read_parquet(f"{DATA_DIR}/df_learn_sub.parquet")
disease_df = pd.read_parquet(f"{DATA_DIR}/disease_df.parquet")
target_df  = pd.read_parquet(f"{DATA_DIR}/target_df.parquet")

# Ensure column names match training
df_learn = df_learn.rename(columns={
    "disease_text_embed": "disease_text",
    "target_text_embed": "target_text"
}, errors="ignore")

disease_df.rename(columns={"disease_text_embed": "disease_text"}, errors="ignore",inplace=True)

target_df.rename(columns={"target_text_embed":"target_text"}, errors="ignore",inplace=True)

# ============================================
#  BUILD MODEL + LOAD WEIGHTS
# ============================================

print("Building TwoTowerDual...")

# 1. Reset Keras Session to ensure layer names start at index 0 (matches clean training)
tf.keras.backend.clear_session()

# 2. Rebuild architecture
model = build_two_tower_model(df_learn)

print("Loading weights...")
try:
    # Try standard load
    model.load_weights(WEIGHTS_FILE)
except ValueError as e:
    print(f"Standard load failed ({e}). Attempting name-mismatch fix...")
    
    # FALLBACK: The training notebook likely generated layer names like 'dise_emb_1' 
    # due to multiple runs. We inspect the .h5 file and map the names.
    with h5py.File(WEIGHTS_FILE, 'r') as f:
        h5_keys = list(f.keys())
        print(f"Weights file contains layers: {h5_keys}")
        
        # Helper to find the matching key in h5 file for a given prefix
        def match_layer_name(target_attr, prefix):
            # Find key in h5 that starts with prefix (e.g. 'dise_emb')
            match = next((k for k in h5_keys if k.startswith(prefix)), None)
            if match and hasattr(model, target_attr):
                layer = getattr(model, target_attr)
                print(f"Renaming model layer '{layer.name}' to '{match}' to match file.")
                layer._name = match

        # Apply renames for known components
        match_layer_name('dise_emb', 'dise_emb')
        match_layer_name('q_tower', 'tower') # Attempt to catch tower/tower_1
        # k_tower might share the name 'tower' prefix in H5, which is tricky in subclasses
        # usually save_weights on subclass saves attributes directly.
        
    # Retry load after renaming
    model.load_weights(WEIGHTS_FILE)

print("Weights loaded successfully.")

# ============================================
#  PRECOMPUTE CANDIDATE EMBEDDINGS
# ============================================


# # Note: In TF 2.16+, Ensure inputs are tf.constant or numpy compatible
# cand_embs = model.encode_k(target_texts, target_ids)
# cand_embs = tf.nn.l2_normalize(cand_embs, axis=1).numpy()

# print("Candidate embeddings ready.")


print("Precomputing candidate embeddings (batched)...")

target_texts = target_df["target_text"].astype(str).to_numpy()
target_ids   = target_df["targetId"].astype(str).to_numpy()

# FIX: Process in batches to avoid OOM
BATCH_SIZE = 1024 # Conservative batch size for wide inputs
cand_embs_list = []

total = len(target_texts)
for i in range(0, total, BATCH_SIZE):
    # Slice the batch
    end = min(i + BATCH_SIZE, total)
    batch_txt = target_texts[i:end]
    batch_ids = target_ids[i:end]
    
    # Run inference on the batch (keeps memory usage low)
    # Using tf.device conversion is optional but good for safety if GPU is fragmented
    emb_batch = model.encode_k(batch_txt, batch_ids)
    cand_embs_list.append(emb_batch)
    
    if i % 5000 == 0:
        print(f"  Processed {i}/{total} candidates...")

# Concatenate all batches back into one tensor
cand_embs = tf.concat(cand_embs_list, axis=0)

# Normalize the final result
cand_embs = tf.nn.l2_normalize(cand_embs, axis=1).numpy()

print(f"Candidate embeddings ready. Shape: {cand_embs.shape}")

# ============================================
#  RECOMMENDATION FUNCTION
# ============================================

def recommend_targets(disease_id, top_k=10):
    # 1. Validate Input
    if not disease_id:
        return pd.DataFrame(), None
        
    row = disease_df.loc[disease_df["diseaseId"] == disease_id]
    if row.empty:
        return pd.DataFrame(), None

    # 2. Encode Query
    disease_text = row["disease_text"].iloc[0]
    q_emb = model.encode_q(
        tf.constant([disease_text]),
        tf.constant([disease_id])
    )
    q_emb = tf.nn.l2_normalize(q_emb, axis=1).numpy()[0]

    # 3. Calculate Raw Cosine Similarity
    # Shape: (N_targets,)
    raw_sim = cand_embs @ q_emb
    
    # 4. Convert to Probability (Fixes negative scores)
    # The model has a trained 'cls_head' (Sigmoid) that maps Similarity -> Probability
    # We reshape to (N, 1) because the Keras Dense layer expects a matrix
    scores = model.cls_head(raw_sim.reshape(-1, 1)).numpy().flatten()
    
    # 5. Get Top K
    k = int(top_k)
    idx = np.argsort(scores)[::-1][:k]
    
    # 6. Build Result DataFrame
    results = target_df.iloc[idx].copy()
    
    # Force standard python float for clean rounding
    raw_scores = scores[idx]
    results["score"] = [round(float(x), 4) for x in raw_scores]
    
    # 7. Select Columns
    desc_col = "functionDescription" if "functionDescription" in results.columns else "functionDescriptions"
    
    desired_cols = [
        "targetId", 
        "approvedSymbol", 
        "approvedName", 
        desc_col, 
        "score"
    ]
    
    final_cols = [c for c in desired_cols if c in results.columns]
    results = results[final_cols]
    
    # 8. Save to CSV for download
    csv_path = "recommendations.csv"
    results.to_csv(csv_path, index=False)
    
    return results, csv_path

# ============================================
#  GRADIO APP
# ============================================

def search_diseases(query):
    if not query or len(query) < 2:
        return gr.update(choices=[], value=None)
    
    mask = (
        disease_df["name"].str.contains(query, case=False, na=False) | 
        disease_df["diseaseId"].str.contains(query, case=False, na=False)
    )
    
    matches = disease_df.loc[mask].head(30)
    
    choices = [
        (f"{row['name']} ({row['diseaseId']})", row['diseaseId']) 
        for _, row in matches.iterrows()
    ]
    
    first_val = choices[0][1] if choices else None
    return gr.update(choices=choices, value=first_val)

def launch():
    examples = ["synuclein", "diabetes", "doid_0050890"]
    
    with gr.Blocks() as demo:
        gr.Markdown("# Disease → Target Recommender")
        gr.Markdown("Search for a disease by **Name** or **ID** to get target recommendations.")

        with gr.Row():
            search_box = gr.Textbox(
                label="1. Search Disease", 
                placeholder="Type name (e.g., 'Parkinson') or ID...",
                lines=1
            )
            
            did_dropdown = gr.Dropdown(
                label="2. Select Disease",
                choices=[], 
                interactive=True
            )
            
            topk = gr.Slider(1, 400, value=10, step=5, label="Top K Targets")

        # Search Logic (Updates dropdown options and default value)
        search_box.change(fn=search_diseases, inputs=search_box, outputs=did_dropdown)

        # Output Components (Stacked vertically for full width)
        out_df = gr.Dataframe(
            label="Predictions", 
            interactive=False,
            wrap=True,
            show_search="filter",
        )
        
        out_file = gr.File(label="Download CSV")

        # === TRIGGER LOGIC ===
        # 1. Manual Trigger (Keep the button just in case)
        btn = gr.Button("Recommend Targets", variant="primary")
        btn.click(
            fn=recommend_targets, 
            inputs=[did_dropdown, topk], 
            outputs=[out_df, out_file]
        )
        
        # 2. Auto-Trigger on Change
        # This handles the Examples too: Example -> Search -> Dropdown Update -> Trigger
        did_dropdown.change(
            fn=recommend_targets,
            inputs=[did_dropdown, topk],
            outputs=[out_df, out_file]
        )
        
        # Also update when slider moves
        topk.change(
            fn=recommend_targets,
            inputs=[did_dropdown, topk],
            outputs=[out_df, out_file]
        )
        
        gr.Examples(examples=examples, inputs=search_box)

    demo.launch()

if __name__ == "__main__":
    launch()