""" ProtTale Gradio demo — runs on a Hugging Face Space (CPU). On startup the app pulls the checkpoint from `Mulah/ProtTale` (HF model hub) and keeps it in the Space's local cache. Inference is then done with the standard `Blip2Stage2` pipeline exactly like `predict_single.py`. """ import os os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" from argparse import Namespace import gradio as gr import torch from huggingface_hub import hf_hub_download from model.blip2_stage2 import Blip2Stage2 from data_provider.stage2_dm import InferenceCollater CKPT_REPO = "Mulah/ProtTale" CKPT_FILE = "checkpoint.ckpt" def build_args(ckpt_path: str) -> Namespace: return Namespace( plm_model="esmc_300m", encoder_type="auto", num_query_token=4, plm_tune="lora", plm_lora_r=4, plm_lora_alpha=8, plm_lora_dropout=0.1, bert_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", cross_attention_freq=2, llm_name="facebook/galactica-1.3b", llm_tune="lora", lora_r=16, lora_alpha=32, lora_dropout=0.1, peft_dir="", peft_config="", num_beams=3, do_sample=False, max_inference_len=128, min_inference_len=1, text_max_len=128, prot_max_len=1024, caption_eval_epoch=1, inference_on_training_data=False, train_reliability_head_only=False, report_go_wang_on_test=False, report_go_wang_on_val=False, save_predictions=False, ia_path="evals/tools/IA.txt", go_files_tsv_path="evals/tools/go_files.tsv", test_set_path="", valid_set_path="", root="data/SwissProtV3", enbale_gradient_checkpointing=False, init_checkpoint=ckpt_path, stage1_path="", stage2_path="", init_lr=1e-4, min_lr=1e-5, warmup_lr=1e-6, warmup_steps=0, lr_decay_rate=0.9, scheduler="None", weight_decay=0.05, reliability_lr=1e-4, max_epochs=1, filename="predict_single", seed=42, reliability_binary=True, ) print("Downloading checkpoint from HF hub (first run only) ...") CKPT_PATH = hf_hub_download(repo_id=CKPT_REPO, filename=CKPT_FILE) print(f"Checkpoint ready at: {CKPT_PATH}") print("Building model ...") MODEL_ARGS = build_args(CKPT_PATH) MODEL = Blip2Stage2(MODEL_ARGS) print("Loading weights ...") _ckpt = torch.load(CKPT_PATH, map_location="cpu") _state_dict = _ckpt.get("state_dict", _ckpt) _model_sd = MODEL.state_dict() _filtered = {k: v for k, v in _state_dict.items() if k in _model_sd and _model_sd[k].shape == v.shape} MODEL.load_state_dict(_filtered, strict=False) MODEL.eval() DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL = MODEL.to(DEVICE) BLIP2 = MODEL.blip2 COLLATER = InferenceCollater( tokenizer=BLIP2.llm_tokenizer, prot_tokenizer=BLIP2.plm_tokenizer, text_max_len=MODEL_ARGS.text_max_len, prot_max_len=MODEL_ARGS.prot_max_len, ) PROMPT = "Swiss-Prot description: " print(f"[debug] MODEL_ARGS.reliability_binary = {getattr(MODEL_ARGS, 'reliability_binary', '')}") print(f"[debug] BLIP2.reliability_binary = {getattr(BLIP2, 'reliability_binary', '')}") _rh_w = BLIP2.reliability_head[1].weight print(f"[debug] reliability_head output dim = {_rh_w.shape[0]} (binary=2, 4-class=4)") print(f"[debug] ckpt mtime / size = {os.path.getmtime(CKPT_PATH):.0f}, {os.path.getsize(CKPT_PATH):,} bytes") _ckpt_keys_head = [k for k in _state_dict if "reliability_head" in k] for _k in _ckpt_keys_head: print(f"[debug] ckpt has {_k}: shape={tuple(_state_dict[_k].shape)}") print("Ready.") @torch.no_grad() def predict(sequence: str): sequence = (sequence or "").strip().upper().replace(" ", "").replace("\n", "") if not sequence: return "Please paste a protein amino-acid sequence.", "", "" if len(sequence) > MODEL_ARGS.prot_max_len: sequence = sequence[: MODEL_ARGS.prot_max_len] batch = [(sequence, PROMPT, "", 0.0, [], 0)] prot_tokens, prompt_tokens, r_tensor, _ = COLLATER(batch) prot_tokens = {k: (v.to(DEVICE) if torch.is_tensor(v) else v) for k, v in prot_tokens.items()} prompt_tokens = prompt_tokens.to(DEVICE) r_tensor = r_tensor.to(DEVICE) samples = {"prot_batch": prot_tokens, "prompt_batch": prompt_tokens, "reliability": r_tensor} pred_texts, r_pred, _, _, r_probs = BLIP2.generate( samples, do_sample=MODEL_ARGS.do_sample, num_beams=MODEL_ARGS.num_beams, max_length=MODEL_ARGS.max_inference_len, min_length=MODEL_ARGS.min_inference_len, ) pred = pred_texts[0] r = float(r_pred.cpu().tolist()[0] if torch.is_tensor(r_pred) else r_pred[0]) if torch.is_tensor(r_probs): flat = r_probs.flatten().cpu().tolist() else: flat = [float(x) for sub in r_probs for x in (sub if isinstance(sub, (list, tuple)) else [sub])] print(f"[debug] r_probs raw flat = {flat}") # remove after verifying p_pos = float(flat[-1]) return pred, format_reliability(r), f"{p_pos:.4f}" EXAMPLE_SEQ = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN" RELIABILITY_LABELS = { 1.0: "Reliable", 0.0: "Unreliable", } def format_reliability(r: float) -> str: label = RELIABILITY_LABELS.get(r, "Unknown") return f"{label} (class={r:g})" with gr.Blocks(title="ProtTale") as demo: gr.Image( value="prottale_logo.png", show_label=False, show_download_button=False, container=False, height=120, ) gr.Markdown( "Protein amino-acid sequence → Swiss-Prot-style function description, " "with a reliability score for the generated text.\n\n" "**Note:** this Space runs on CPU; a single beam-3 generation typically takes ~30–120 s." ) with gr.Row(): seq_in = gr.Textbox(label="Protein sequence", lines=6, value=EXAMPLE_SEQ) run_btn = gr.Button("Predict", variant="primary") with gr.Column(): pred_out = gr.Textbox(label="Predicted function") r_out = gr.Textbox(label="Reliability") p_out = gr.Textbox(label="P(positive)") run_btn.click(predict, inputs=[seq_in], outputs=[pred_out, r_out, p_out]) if __name__ == "__main__": demo.queue(max_size=8).launch(server_name="0.0.0.0", server_port=7860)