kmlyyll commited on
Commit
9ceca05
·
verified ·
1 Parent(s): 45d205e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -4
app.py CHANGED
@@ -1,7 +1,76 @@
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ """
2
+ ZeroGPU Structure Prediction API
3
+ """
4
+
5
+ import spaces
6
  import gradio as gr
7
+ import torch
8
+ from transformers import EsmForProteinFolding, AutoTokenizer
9
+
10
+ print("Loading ESMFold model...")
11
+ MODEL_NAME = "facebook/esmfold_v1"
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
+ model = EsmForProteinFolding.from_pretrained(MODEL_NAME)
15
+
16
+ if torch.cuda.is_available():
17
+ model = model.cuda()
18
+ model.esm = model.esm.half()
19
+ print(f"Model loaded on GPU: {torch.cuda.get_device_name(0)}")
20
+ else:
21
+ print("Model loaded on CPU")
22
+
23
+
24
+ @spaces.GPU(duration=120)
25
+ def predict_structure(sequence: str) -> str:
26
+ sequence = sequence.strip().upper()
27
+ valid_aa = set("ACDEFGHIKLMNPQRSTVWY")
28
+
29
+ if not sequence:
30
+ return "Error: Empty sequence provided"
31
+
32
+ invalid_chars = set(sequence) - valid_aa
33
+ if invalid_chars:
34
+ return f"Error: Invalid amino acids found: {invalid_chars}"
35
+
36
+ if len(sequence) > 500:
37
+ return "Error: Sequence too long (max 500 residues)"
38
+
39
+ try:
40
+ inputs = tokenizer(sequence, return_tensors="pt", add_special_tokens=False)
41
+
42
+ if torch.cuda.is_available():
43
+ inputs = {k: v.cuda() for k, v in inputs.items()}
44
+
45
+ with torch.no_grad():
46
+ outputs = model(**inputs)
47
+
48
+ pdb_string = model.output_to_pdb(outputs)[0]
49
+ return pdb_string
50
+
51
+ except Exception as e:
52
+ return f"Error: {str(e)}"
53
+
54
 
55
+ with gr.Blocks(title="🧬 Antibody Structure API") as demo:
56
+ gr.Markdown("""
57
+ # 🧬 Antibody Structure Prediction API (ZeroGPU)
58
+
59
+ GPU-accelerated ESMFold structure prediction.
60
+
61
+ **API Usage:**
62
+ ```python
63
+ from gradio_client import Client
64
+ client = Client("kmlyyll/antibody-structure-api")
65
+ pdb = client.predict(sequence, api_name="/predict")
66
+ ```
67
+ """)
68
+
69
+ seq_input = gr.Textbox(label="Amino Acid Sequence", placeholder="Enter sequence...", lines=3)
70
+ predict_btn = gr.Button("Predict Structure", variant="primary")
71
+ pdb_output = gr.Textbox(label="PDB Output", lines=20)
72
+
73
+ predict_btn.click(fn=predict_structure, inputs=seq_input, outputs=pdb_output, api_name="predict")
74
 
75
+ if __name__ == "__main__":
76
+ demo.launch()