File size: 6,531 Bytes
7c15d15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81e9b11
7c15d15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50c4964
7c15d15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bdb15b
 
 
 
 
 
 
 
 
7c15d15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50c4964
7c15d15
 
 
 
 
 
 
 
 
23f8086
 
 
 
 
 
50c4964
7c15d15
 
4107d9e
7c15d15
6f6cbbd
50c4964
 
6f6cbbd
 
 
 
 
 
 
 
7c15d15
6f6cbbd
 
 
 
 
 
 
7c15d15
 
 
 
 
 
 
 
 
 
6f6cbbd
50c4964
7c15d15
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""
ProtTale Gradio demo — runs on a Hugging Face Space (CPU).

On startup the app pulls the checkpoint from `Mulah/ProtTale` (HF model hub)
and keeps it in the Space's local cache. Inference is then done with the
standard `Blip2Stage2` pipeline exactly like `predict_single.py`.
"""
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"

from argparse import Namespace

import gradio as gr
import torch
from huggingface_hub import hf_hub_download

from model.blip2_stage2 import Blip2Stage2
from data_provider.stage2_dm import InferenceCollater


CKPT_REPO = "Mulah/ProtTale"
CKPT_FILE = "checkpoint.ckpt"


def build_args(ckpt_path: str) -> Namespace:
    return Namespace(
        plm_model="esmc_300m",
        encoder_type="auto",
        num_query_token=4,
        plm_tune="lora",
        plm_lora_r=4,
        plm_lora_alpha=8,
        plm_lora_dropout=0.1,
        bert_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
        cross_attention_freq=2,
        llm_name="facebook/galactica-1.3b",
        llm_tune="lora",
        lora_r=16,
        lora_alpha=32,
        lora_dropout=0.1,
        peft_dir="",
        peft_config="",
        num_beams=3,
        do_sample=False,
        max_inference_len=128,
        min_inference_len=1,
        text_max_len=128,
        prot_max_len=1024,
        caption_eval_epoch=1,
        inference_on_training_data=False,
        train_reliability_head_only=False,
        report_go_wang_on_test=False,
        report_go_wang_on_val=False,
        save_predictions=False,
        ia_path="evals/tools/IA.txt",
        go_files_tsv_path="evals/tools/go_files.tsv",
        test_set_path="",
        valid_set_path="",
        root="data/SwissProtV3",
        enbale_gradient_checkpointing=False,
        init_checkpoint=ckpt_path,
        stage1_path="",
        stage2_path="",
        init_lr=1e-4,
        min_lr=1e-5,
        warmup_lr=1e-6,
        warmup_steps=0,
        lr_decay_rate=0.9,
        scheduler="None",
        weight_decay=0.05,
        reliability_lr=1e-4,
        max_epochs=1,
        filename="predict_single",
        seed=42,
        reliability_binary=True,
    )


print("Downloading checkpoint from HF hub (first run only) ...")
CKPT_PATH = hf_hub_download(repo_id=CKPT_REPO, filename=CKPT_FILE)
print(f"Checkpoint ready at: {CKPT_PATH}")

print("Building model ...")
MODEL_ARGS = build_args(CKPT_PATH)
MODEL = Blip2Stage2(MODEL_ARGS)

print("Loading weights ...")
_ckpt = torch.load(CKPT_PATH, map_location="cpu")
_state_dict = _ckpt.get("state_dict", _ckpt)
_model_sd = MODEL.state_dict()
_filtered = {k: v for k, v in _state_dict.items() if k in _model_sd and _model_sd[k].shape == v.shape}
MODEL.load_state_dict(_filtered, strict=False)
MODEL.eval()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL = MODEL.to(DEVICE)

BLIP2 = MODEL.blip2
COLLATER = InferenceCollater(
    tokenizer=BLIP2.llm_tokenizer,
    prot_tokenizer=BLIP2.plm_tokenizer,
    text_max_len=MODEL_ARGS.text_max_len,
    prot_max_len=MODEL_ARGS.prot_max_len,
)
PROMPT = "Swiss-Prot description: "

print(f"[debug] MODEL_ARGS.reliability_binary = {getattr(MODEL_ARGS, 'reliability_binary', '<missing>')}")
print(f"[debug] BLIP2.reliability_binary     = {getattr(BLIP2, 'reliability_binary', '<missing>')}")
_rh_w = BLIP2.reliability_head[1].weight
print(f"[debug] reliability_head output dim  = {_rh_w.shape[0]} (binary=2, 4-class=4)")
print(f"[debug] ckpt mtime / size            = {os.path.getmtime(CKPT_PATH):.0f}, {os.path.getsize(CKPT_PATH):,} bytes")
_ckpt_keys_head = [k for k in _state_dict if "reliability_head" in k]
for _k in _ckpt_keys_head:
    print(f"[debug] ckpt has {_k}: shape={tuple(_state_dict[_k].shape)}")
print("Ready.")


@torch.no_grad()
def predict(sequence: str):
    sequence = (sequence or "").strip().upper().replace(" ", "").replace("\n", "")
    if not sequence:
        return "Please paste a protein amino-acid sequence.", "", ""
    if len(sequence) > MODEL_ARGS.prot_max_len:
        sequence = sequence[: MODEL_ARGS.prot_max_len]

    batch = [(sequence, PROMPT, "", 0.0, [], 0)]
    prot_tokens, prompt_tokens, r_tensor, _ = COLLATER(batch)
    prot_tokens = {k: (v.to(DEVICE) if torch.is_tensor(v) else v) for k, v in prot_tokens.items()}
    prompt_tokens = prompt_tokens.to(DEVICE)
    r_tensor = r_tensor.to(DEVICE)

    samples = {"prot_batch": prot_tokens, "prompt_batch": prompt_tokens, "reliability": r_tensor}
    pred_texts, r_pred, _, _, r_probs = BLIP2.generate(
        samples,
        do_sample=MODEL_ARGS.do_sample,
        num_beams=MODEL_ARGS.num_beams,
        max_length=MODEL_ARGS.max_inference_len,
        min_length=MODEL_ARGS.min_inference_len,
    )

    pred = pred_texts[0]
    r = float(r_pred.cpu().tolist()[0] if torch.is_tensor(r_pred) else r_pred[0])
    if torch.is_tensor(r_probs):
        flat = r_probs.flatten().cpu().tolist()
    else:
        flat = [float(x) for sub in r_probs for x in (sub if isinstance(sub, (list, tuple)) else [sub])]
    print(f"[debug] r_probs raw flat = {flat}")  # remove after verifying
    p_pos = float(flat[-1])
    return pred, format_reliability(r), f"{p_pos:.4f}"


EXAMPLE_SEQ = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"

RELIABILITY_LABELS = {
    1.0: "Reliable",
    0.0: "Unreliable",
}


def format_reliability(r: float) -> str:
    label = RELIABILITY_LABELS.get(r, "Unknown")
    return f"{label} (class={r:g})"


with gr.Blocks(title="ProtTale") as demo:
    gr.Image(
        value="prottale_logo.png",
        show_label=False,
        show_download_button=False,
        container=False,
        height=120,
    )
    gr.Markdown(
        "Protein amino-acid sequence → Swiss-Prot-style function description, "
        "with a reliability score for the generated text.\n\n"
        "**Note:** this Space runs on CPU; a single beam-3 generation typically takes ~30–120 s."
    )
    with gr.Row():
        seq_in = gr.Textbox(label="Protein sequence", lines=6, value=EXAMPLE_SEQ)
    run_btn = gr.Button("Predict", variant="primary")
    with gr.Column():
        pred_out = gr.Textbox(label="Predicted function")
        r_out = gr.Textbox(label="Reliability")
        p_out = gr.Textbox(label="P(positive)")

    run_btn.click(predict, inputs=[seq_in], outputs=[pred_out, r_out, p_out])

if __name__ == "__main__":
    demo.queue(max_size=8).launch(server_name="0.0.0.0", server_port=7860)