ProtTale-demo / app.py
Mulah's picture
Debug: log reliability_binary flag + reliability_head shape + ckpt info
7bdb15b
"""
ProtTale Gradio demo — runs on a Hugging Face Space (CPU).
On startup the app pulls the checkpoint from `Mulah/ProtTale` (HF model hub)
and keeps it in the Space's local cache. Inference is then done with the
standard `Blip2Stage2` pipeline exactly like `predict_single.py`.
"""
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
from argparse import Namespace
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from model.blip2_stage2 import Blip2Stage2
from data_provider.stage2_dm import InferenceCollater
CKPT_REPO = "Mulah/ProtTale"
CKPT_FILE = "checkpoint.ckpt"
def build_args(ckpt_path: str) -> Namespace:
return Namespace(
plm_model="esmc_300m",
encoder_type="auto",
num_query_token=4,
plm_tune="lora",
plm_lora_r=4,
plm_lora_alpha=8,
plm_lora_dropout=0.1,
bert_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
cross_attention_freq=2,
llm_name="facebook/galactica-1.3b",
llm_tune="lora",
lora_r=16,
lora_alpha=32,
lora_dropout=0.1,
peft_dir="",
peft_config="",
num_beams=3,
do_sample=False,
max_inference_len=128,
min_inference_len=1,
text_max_len=128,
prot_max_len=1024,
caption_eval_epoch=1,
inference_on_training_data=False,
train_reliability_head_only=False,
report_go_wang_on_test=False,
report_go_wang_on_val=False,
save_predictions=False,
ia_path="evals/tools/IA.txt",
go_files_tsv_path="evals/tools/go_files.tsv",
test_set_path="",
valid_set_path="",
root="data/SwissProtV3",
enbale_gradient_checkpointing=False,
init_checkpoint=ckpt_path,
stage1_path="",
stage2_path="",
init_lr=1e-4,
min_lr=1e-5,
warmup_lr=1e-6,
warmup_steps=0,
lr_decay_rate=0.9,
scheduler="None",
weight_decay=0.05,
reliability_lr=1e-4,
max_epochs=1,
filename="predict_single",
seed=42,
reliability_binary=True,
)
print("Downloading checkpoint from HF hub (first run only) ...")
CKPT_PATH = hf_hub_download(repo_id=CKPT_REPO, filename=CKPT_FILE)
print(f"Checkpoint ready at: {CKPT_PATH}")
print("Building model ...")
MODEL_ARGS = build_args(CKPT_PATH)
MODEL = Blip2Stage2(MODEL_ARGS)
print("Loading weights ...")
_ckpt = torch.load(CKPT_PATH, map_location="cpu")
_state_dict = _ckpt.get("state_dict", _ckpt)
_model_sd = MODEL.state_dict()
_filtered = {k: v for k, v in _state_dict.items() if k in _model_sd and _model_sd[k].shape == v.shape}
MODEL.load_state_dict(_filtered, strict=False)
MODEL.eval()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL = MODEL.to(DEVICE)
BLIP2 = MODEL.blip2
COLLATER = InferenceCollater(
tokenizer=BLIP2.llm_tokenizer,
prot_tokenizer=BLIP2.plm_tokenizer,
text_max_len=MODEL_ARGS.text_max_len,
prot_max_len=MODEL_ARGS.prot_max_len,
)
PROMPT = "Swiss-Prot description: "
print(f"[debug] MODEL_ARGS.reliability_binary = {getattr(MODEL_ARGS, 'reliability_binary', '<missing>')}")
print(f"[debug] BLIP2.reliability_binary = {getattr(BLIP2, 'reliability_binary', '<missing>')}")
_rh_w = BLIP2.reliability_head[1].weight
print(f"[debug] reliability_head output dim = {_rh_w.shape[0]} (binary=2, 4-class=4)")
print(f"[debug] ckpt mtime / size = {os.path.getmtime(CKPT_PATH):.0f}, {os.path.getsize(CKPT_PATH):,} bytes")
_ckpt_keys_head = [k for k in _state_dict if "reliability_head" in k]
for _k in _ckpt_keys_head:
print(f"[debug] ckpt has {_k}: shape={tuple(_state_dict[_k].shape)}")
print("Ready.")
@torch.no_grad()
def predict(sequence: str):
sequence = (sequence or "").strip().upper().replace(" ", "").replace("\n", "")
if not sequence:
return "Please paste a protein amino-acid sequence.", "", ""
if len(sequence) > MODEL_ARGS.prot_max_len:
sequence = sequence[: MODEL_ARGS.prot_max_len]
batch = [(sequence, PROMPT, "", 0.0, [], 0)]
prot_tokens, prompt_tokens, r_tensor, _ = COLLATER(batch)
prot_tokens = {k: (v.to(DEVICE) if torch.is_tensor(v) else v) for k, v in prot_tokens.items()}
prompt_tokens = prompt_tokens.to(DEVICE)
r_tensor = r_tensor.to(DEVICE)
samples = {"prot_batch": prot_tokens, "prompt_batch": prompt_tokens, "reliability": r_tensor}
pred_texts, r_pred, _, _, r_probs = BLIP2.generate(
samples,
do_sample=MODEL_ARGS.do_sample,
num_beams=MODEL_ARGS.num_beams,
max_length=MODEL_ARGS.max_inference_len,
min_length=MODEL_ARGS.min_inference_len,
)
pred = pred_texts[0]
r = float(r_pred.cpu().tolist()[0] if torch.is_tensor(r_pred) else r_pred[0])
if torch.is_tensor(r_probs):
flat = r_probs.flatten().cpu().tolist()
else:
flat = [float(x) for sub in r_probs for x in (sub if isinstance(sub, (list, tuple)) else [sub])]
print(f"[debug] r_probs raw flat = {flat}") # remove after verifying
p_pos = float(flat[-1])
return pred, format_reliability(r), f"{p_pos:.4f}"
EXAMPLE_SEQ = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
RELIABILITY_LABELS = {
1.0: "Reliable",
0.0: "Unreliable",
}
def format_reliability(r: float) -> str:
label = RELIABILITY_LABELS.get(r, "Unknown")
return f"{label} (class={r:g})"
with gr.Blocks(title="ProtTale") as demo:
gr.Image(
value="prottale_logo.png",
show_label=False,
show_download_button=False,
container=False,
height=120,
)
gr.Markdown(
"Protein amino-acid sequence → Swiss-Prot-style function description, "
"with a reliability score for the generated text.\n\n"
"**Note:** this Space runs on CPU; a single beam-3 generation typically takes ~30–120 s."
)
with gr.Row():
seq_in = gr.Textbox(label="Protein sequence", lines=6, value=EXAMPLE_SEQ)
run_btn = gr.Button("Predict", variant="primary")
with gr.Column():
pred_out = gr.Textbox(label="Predicted function")
r_out = gr.Textbox(label="Reliability")
p_out = gr.Textbox(label="P(positive)")
run_btn.click(predict, inputs=[seq_in], outputs=[pred_out, r_out, p_out])
if __name__ == "__main__":
demo.queue(max_size=8).launch(server_name="0.0.0.0", server_port=7860)