Spaces:
Running
Running
| """ | |
| ProtTale Gradio demo — runs on a Hugging Face Space (CPU). | |
| On startup the app pulls the checkpoint from `Mulah/ProtTale` (HF model hub) | |
| and keeps it in the Space's local cache. Inference is then done with the | |
| standard `Blip2Stage2` pipeline exactly like `predict_single.py`. | |
| """ | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" | |
| from argparse import Namespace | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from model.blip2_stage2 import Blip2Stage2 | |
| from data_provider.stage2_dm import InferenceCollater | |
| CKPT_REPO = "Mulah/ProtTale" | |
| CKPT_FILE = "checkpoint.ckpt" | |
| def build_args(ckpt_path: str) -> Namespace: | |
| return Namespace( | |
| plm_model="esmc_300m", | |
| encoder_type="auto", | |
| num_query_token=4, | |
| plm_tune="lora", | |
| plm_lora_r=4, | |
| plm_lora_alpha=8, | |
| plm_lora_dropout=0.1, | |
| bert_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", | |
| cross_attention_freq=2, | |
| llm_name="facebook/galactica-1.3b", | |
| llm_tune="lora", | |
| lora_r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.1, | |
| peft_dir="", | |
| peft_config="", | |
| num_beams=3, | |
| do_sample=False, | |
| max_inference_len=128, | |
| min_inference_len=1, | |
| text_max_len=128, | |
| prot_max_len=1024, | |
| caption_eval_epoch=1, | |
| inference_on_training_data=False, | |
| train_reliability_head_only=False, | |
| report_go_wang_on_test=False, | |
| report_go_wang_on_val=False, | |
| save_predictions=False, | |
| ia_path="evals/tools/IA.txt", | |
| go_files_tsv_path="evals/tools/go_files.tsv", | |
| test_set_path="", | |
| valid_set_path="", | |
| root="data/SwissProtV3", | |
| enbale_gradient_checkpointing=False, | |
| init_checkpoint=ckpt_path, | |
| stage1_path="", | |
| stage2_path="", | |
| init_lr=1e-4, | |
| min_lr=1e-5, | |
| warmup_lr=1e-6, | |
| warmup_steps=0, | |
| lr_decay_rate=0.9, | |
| scheduler="None", | |
| weight_decay=0.05, | |
| reliability_lr=1e-4, | |
| max_epochs=1, | |
| filename="predict_single", | |
| seed=42, | |
| reliability_binary=True, | |
| ) | |
| print("Downloading checkpoint from HF hub (first run only) ...") | |
| CKPT_PATH = hf_hub_download(repo_id=CKPT_REPO, filename=CKPT_FILE) | |
| print(f"Checkpoint ready at: {CKPT_PATH}") | |
| print("Building model ...") | |
| MODEL_ARGS = build_args(CKPT_PATH) | |
| MODEL = Blip2Stage2(MODEL_ARGS) | |
| print("Loading weights ...") | |
| _ckpt = torch.load(CKPT_PATH, map_location="cpu") | |
| _state_dict = _ckpt.get("state_dict", _ckpt) | |
| _model_sd = MODEL.state_dict() | |
| _filtered = {k: v for k, v in _state_dict.items() if k in _model_sd and _model_sd[k].shape == v.shape} | |
| MODEL.load_state_dict(_filtered, strict=False) | |
| MODEL.eval() | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL = MODEL.to(DEVICE) | |
| BLIP2 = MODEL.blip2 | |
| COLLATER = InferenceCollater( | |
| tokenizer=BLIP2.llm_tokenizer, | |
| prot_tokenizer=BLIP2.plm_tokenizer, | |
| text_max_len=MODEL_ARGS.text_max_len, | |
| prot_max_len=MODEL_ARGS.prot_max_len, | |
| ) | |
| PROMPT = "Swiss-Prot description: " | |
| print(f"[debug] MODEL_ARGS.reliability_binary = {getattr(MODEL_ARGS, 'reliability_binary', '<missing>')}") | |
| print(f"[debug] BLIP2.reliability_binary = {getattr(BLIP2, 'reliability_binary', '<missing>')}") | |
| _rh_w = BLIP2.reliability_head[1].weight | |
| print(f"[debug] reliability_head output dim = {_rh_w.shape[0]} (binary=2, 4-class=4)") | |
| print(f"[debug] ckpt mtime / size = {os.path.getmtime(CKPT_PATH):.0f}, {os.path.getsize(CKPT_PATH):,} bytes") | |
| _ckpt_keys_head = [k for k in _state_dict if "reliability_head" in k] | |
| for _k in _ckpt_keys_head: | |
| print(f"[debug] ckpt has {_k}: shape={tuple(_state_dict[_k].shape)}") | |
| print("Ready.") | |
| def predict(sequence: str): | |
| sequence = (sequence or "").strip().upper().replace(" ", "").replace("\n", "") | |
| if not sequence: | |
| return "Please paste a protein amino-acid sequence.", "", "" | |
| if len(sequence) > MODEL_ARGS.prot_max_len: | |
| sequence = sequence[: MODEL_ARGS.prot_max_len] | |
| batch = [(sequence, PROMPT, "", 0.0, [], 0)] | |
| prot_tokens, prompt_tokens, r_tensor, _ = COLLATER(batch) | |
| prot_tokens = {k: (v.to(DEVICE) if torch.is_tensor(v) else v) for k, v in prot_tokens.items()} | |
| prompt_tokens = prompt_tokens.to(DEVICE) | |
| r_tensor = r_tensor.to(DEVICE) | |
| samples = {"prot_batch": prot_tokens, "prompt_batch": prompt_tokens, "reliability": r_tensor} | |
| pred_texts, r_pred, _, _, r_probs = BLIP2.generate( | |
| samples, | |
| do_sample=MODEL_ARGS.do_sample, | |
| num_beams=MODEL_ARGS.num_beams, | |
| max_length=MODEL_ARGS.max_inference_len, | |
| min_length=MODEL_ARGS.min_inference_len, | |
| ) | |
| pred = pred_texts[0] | |
| r = float(r_pred.cpu().tolist()[0] if torch.is_tensor(r_pred) else r_pred[0]) | |
| if torch.is_tensor(r_probs): | |
| flat = r_probs.flatten().cpu().tolist() | |
| else: | |
| flat = [float(x) for sub in r_probs for x in (sub if isinstance(sub, (list, tuple)) else [sub])] | |
| print(f"[debug] r_probs raw flat = {flat}") # remove after verifying | |
| p_pos = float(flat[-1]) | |
| return pred, format_reliability(r), f"{p_pos:.4f}" | |
| EXAMPLE_SEQ = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN" | |
| RELIABILITY_LABELS = { | |
| 1.0: "Reliable", | |
| 0.0: "Unreliable", | |
| } | |
| def format_reliability(r: float) -> str: | |
| label = RELIABILITY_LABELS.get(r, "Unknown") | |
| return f"{label} (class={r:g})" | |
| with gr.Blocks(title="ProtTale") as demo: | |
| gr.Image( | |
| value="prottale_logo.png", | |
| show_label=False, | |
| show_download_button=False, | |
| container=False, | |
| height=120, | |
| ) | |
| gr.Markdown( | |
| "Protein amino-acid sequence → Swiss-Prot-style function description, " | |
| "with a reliability score for the generated text.\n\n" | |
| "**Note:** this Space runs on CPU; a single beam-3 generation typically takes ~30–120 s." | |
| ) | |
| with gr.Row(): | |
| seq_in = gr.Textbox(label="Protein sequence", lines=6, value=EXAMPLE_SEQ) | |
| run_btn = gr.Button("Predict", variant="primary") | |
| with gr.Column(): | |
| pred_out = gr.Textbox(label="Predicted function") | |
| r_out = gr.Textbox(label="Reliability") | |
| p_out = gr.Textbox(label="P(positive)") | |
| run_btn.click(predict, inputs=[seq_in], outputs=[pred_out, r_out, p_out]) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=8).launch(server_name="0.0.0.0", server_port=7860) | |