buddyzhu commited on
Commit
71932a8
·
verified ·
1 Parent(s): 3606075

Create demo

Browse files
Files changed (1) hide show
  1. demo +48 -0
demo ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from modelgenerator.tasks import Embed
5
+
6
+ # 加载模型(第一次会自动下载大权重,比较慢)
7
+ print("Loading model...")
8
+ model = Embed.from_config({
9
+ "model.backbone": "aido_protein_rag_16b"
10
+ }).eval()
11
+
12
+ # 支持超长序列
13
+ model.backbone.max_length = 12800
14
+
15
+ def predict_protein(sequence: str):
16
+ if not sequence or len(sequence) < 5:
17
+ return "请输入有效的蛋白质序列(至少5个氨基酸)"
18
+
19
+ # 简单输入(仅序列,MSA和结构可选)
20
+ data = {
21
+ 'sequences': [sequence],
22
+ # 'msa': [...], # 可选:多序列比对
23
+ # 'str_emb': np.random.randn(1, 50, 384) # 可选:结构嵌入
24
+ }
25
+
26
+ transformed_batch = model.transform(data)
27
+
28
+ with torch.no_grad():
29
+ embedding = model(transformed_batch)
30
+
31
+ # 返回 embedding 的形状和前几个值作为示例
32
+ emb = embedding.cpu().numpy()
33
+ return f"Embedding shape: {emb.shape}\n\n前10个值示例: {emb.flatten()[:10].tolist()}"
34
+
35
+ # Gradio 界面
36
+ iface = gr.Interface(
37
+ fn=predict_protein,
38
+ inputs=gr.Textbox(label="输入蛋白质序列 (e.g. ACDEFGHIKLMNPQRSTVWY)", lines=5, placeholder="请输入氨基酸序列..."),
39
+ outputs=gr.Textbox(label="模型输出 (Embedding)"),
40
+ title="AIDO.Protein-RAG-16B Demo",
41
+ description="输入蛋白序列,获取模型的嵌入表示。注意:16B 模型较大,首次加载需要时间。",
42
+ examples=[["ACDEFGHIKLMNPQRSTVWY"], ["MTEITAAMVKELRESTGAGA"]],
43
+ allow_flagging="never"
44
+ )
45
+
46
+ if __name__ == "__main__":
47
+ iface.launch()
48
+