Prescript_1 / demo
buddyzhu's picture
Create demo
71932a8 verified
import gradio as gr
import torch
import numpy as np
from modelgenerator.tasks import Embed
# 加载模型(第一次会自动下载大权重,比较慢)
print("Loading model...")
model = Embed.from_config({
"model.backbone": "aido_protein_rag_16b"
}).eval()
# 支持超长序列
model.backbone.max_length = 12800
def predict_protein(sequence: str):
if not sequence or len(sequence) < 5:
return "请输入有效的蛋白质序列(至少5个氨基酸)"
# 简单输入(仅序列,MSA和结构可选)
data = {
'sequences': [sequence],
# 'msa': [...], # 可选:多序列比对
# 'str_emb': np.random.randn(1, 50, 384) # 可选:结构嵌入
}
transformed_batch = model.transform(data)
with torch.no_grad():
embedding = model(transformed_batch)
# 返回 embedding 的形状和前几个值作为示例
emb = embedding.cpu().numpy()
return f"Embedding shape: {emb.shape}\n\n前10个值示例: {emb.flatten()[:10].tolist()}"
# Gradio 界面
iface = gr.Interface(
fn=predict_protein,
inputs=gr.Textbox(label="输入蛋白质序列 (e.g. ACDEFGHIKLMNPQRSTVWY)", lines=5, placeholder="请输入氨基酸序列..."),
outputs=gr.Textbox(label="模型输出 (Embedding)"),
title="AIDO.Protein-RAG-16B Demo",
description="输入蛋白序列,获取模型的嵌入表示。注意:16B 模型较大,首次加载需要时间。",
examples=[["ACDEFGHIKLMNPQRSTVWY"], ["MTEITAAMVKELRESTGAGA"]],
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch()