Spaces:
Running
Running
File size: 6,531 Bytes
7c15d15 81e9b11 7c15d15 50c4964 7c15d15 7bdb15b 7c15d15 50c4964 7c15d15 23f8086 50c4964 7c15d15 4107d9e 7c15d15 6f6cbbd 50c4964 6f6cbbd 7c15d15 6f6cbbd 7c15d15 6f6cbbd 50c4964 7c15d15 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | """
ProtTale Gradio demo — runs on a Hugging Face Space (CPU).
On startup the app pulls the checkpoint from `Mulah/ProtTale` (HF model hub)
and keeps it in the Space's local cache. Inference is then done with the
standard `Blip2Stage2` pipeline exactly like `predict_single.py`.
"""
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
from argparse import Namespace
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from model.blip2_stage2 import Blip2Stage2
from data_provider.stage2_dm import InferenceCollater
CKPT_REPO = "Mulah/ProtTale"
CKPT_FILE = "checkpoint.ckpt"
def build_args(ckpt_path: str) -> Namespace:
return Namespace(
plm_model="esmc_300m",
encoder_type="auto",
num_query_token=4,
plm_tune="lora",
plm_lora_r=4,
plm_lora_alpha=8,
plm_lora_dropout=0.1,
bert_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
cross_attention_freq=2,
llm_name="facebook/galactica-1.3b",
llm_tune="lora",
lora_r=16,
lora_alpha=32,
lora_dropout=0.1,
peft_dir="",
peft_config="",
num_beams=3,
do_sample=False,
max_inference_len=128,
min_inference_len=1,
text_max_len=128,
prot_max_len=1024,
caption_eval_epoch=1,
inference_on_training_data=False,
train_reliability_head_only=False,
report_go_wang_on_test=False,
report_go_wang_on_val=False,
save_predictions=False,
ia_path="evals/tools/IA.txt",
go_files_tsv_path="evals/tools/go_files.tsv",
test_set_path="",
valid_set_path="",
root="data/SwissProtV3",
enbale_gradient_checkpointing=False,
init_checkpoint=ckpt_path,
stage1_path="",
stage2_path="",
init_lr=1e-4,
min_lr=1e-5,
warmup_lr=1e-6,
warmup_steps=0,
lr_decay_rate=0.9,
scheduler="None",
weight_decay=0.05,
reliability_lr=1e-4,
max_epochs=1,
filename="predict_single",
seed=42,
reliability_binary=True,
)
print("Downloading checkpoint from HF hub (first run only) ...")
CKPT_PATH = hf_hub_download(repo_id=CKPT_REPO, filename=CKPT_FILE)
print(f"Checkpoint ready at: {CKPT_PATH}")
print("Building model ...")
MODEL_ARGS = build_args(CKPT_PATH)
MODEL = Blip2Stage2(MODEL_ARGS)
print("Loading weights ...")
_ckpt = torch.load(CKPT_PATH, map_location="cpu")
_state_dict = _ckpt.get("state_dict", _ckpt)
_model_sd = MODEL.state_dict()
_filtered = {k: v for k, v in _state_dict.items() if k in _model_sd and _model_sd[k].shape == v.shape}
MODEL.load_state_dict(_filtered, strict=False)
MODEL.eval()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL = MODEL.to(DEVICE)
BLIP2 = MODEL.blip2
COLLATER = InferenceCollater(
tokenizer=BLIP2.llm_tokenizer,
prot_tokenizer=BLIP2.plm_tokenizer,
text_max_len=MODEL_ARGS.text_max_len,
prot_max_len=MODEL_ARGS.prot_max_len,
)
PROMPT = "Swiss-Prot description: "
print(f"[debug] MODEL_ARGS.reliability_binary = {getattr(MODEL_ARGS, 'reliability_binary', '<missing>')}")
print(f"[debug] BLIP2.reliability_binary = {getattr(BLIP2, 'reliability_binary', '<missing>')}")
_rh_w = BLIP2.reliability_head[1].weight
print(f"[debug] reliability_head output dim = {_rh_w.shape[0]} (binary=2, 4-class=4)")
print(f"[debug] ckpt mtime / size = {os.path.getmtime(CKPT_PATH):.0f}, {os.path.getsize(CKPT_PATH):,} bytes")
_ckpt_keys_head = [k for k in _state_dict if "reliability_head" in k]
for _k in _ckpt_keys_head:
print(f"[debug] ckpt has {_k}: shape={tuple(_state_dict[_k].shape)}")
print("Ready.")
@torch.no_grad()
def predict(sequence: str):
sequence = (sequence or "").strip().upper().replace(" ", "").replace("\n", "")
if not sequence:
return "Please paste a protein amino-acid sequence.", "", ""
if len(sequence) > MODEL_ARGS.prot_max_len:
sequence = sequence[: MODEL_ARGS.prot_max_len]
batch = [(sequence, PROMPT, "", 0.0, [], 0)]
prot_tokens, prompt_tokens, r_tensor, _ = COLLATER(batch)
prot_tokens = {k: (v.to(DEVICE) if torch.is_tensor(v) else v) for k, v in prot_tokens.items()}
prompt_tokens = prompt_tokens.to(DEVICE)
r_tensor = r_tensor.to(DEVICE)
samples = {"prot_batch": prot_tokens, "prompt_batch": prompt_tokens, "reliability": r_tensor}
pred_texts, r_pred, _, _, r_probs = BLIP2.generate(
samples,
do_sample=MODEL_ARGS.do_sample,
num_beams=MODEL_ARGS.num_beams,
max_length=MODEL_ARGS.max_inference_len,
min_length=MODEL_ARGS.min_inference_len,
)
pred = pred_texts[0]
r = float(r_pred.cpu().tolist()[0] if torch.is_tensor(r_pred) else r_pred[0])
if torch.is_tensor(r_probs):
flat = r_probs.flatten().cpu().tolist()
else:
flat = [float(x) for sub in r_probs for x in (sub if isinstance(sub, (list, tuple)) else [sub])]
print(f"[debug] r_probs raw flat = {flat}") # remove after verifying
p_pos = float(flat[-1])
return pred, format_reliability(r), f"{p_pos:.4f}"
EXAMPLE_SEQ = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
RELIABILITY_LABELS = {
1.0: "Reliable",
0.0: "Unreliable",
}
def format_reliability(r: float) -> str:
label = RELIABILITY_LABELS.get(r, "Unknown")
return f"{label} (class={r:g})"
with gr.Blocks(title="ProtTale") as demo:
gr.Image(
value="prottale_logo.png",
show_label=False,
show_download_button=False,
container=False,
height=120,
)
gr.Markdown(
"Protein amino-acid sequence → Swiss-Prot-style function description, "
"with a reliability score for the generated text.\n\n"
"**Note:** this Space runs on CPU; a single beam-3 generation typically takes ~30–120 s."
)
with gr.Row():
seq_in = gr.Textbox(label="Protein sequence", lines=6, value=EXAMPLE_SEQ)
run_btn = gr.Button("Predict", variant="primary")
with gr.Column():
pred_out = gr.Textbox(label="Predicted function")
r_out = gr.Textbox(label="Reliability")
p_out = gr.Textbox(label="P(positive)")
run_btn.click(predict, inputs=[seq_in], outputs=[pred_out, r_out, p_out])
if __name__ == "__main__":
demo.queue(max_size=8).launch(server_name="0.0.0.0", server_port=7860)
|