Spaces:
Sleeping
Sleeping
| """ | |
| ZeroGPU Structure Prediction API | |
| """ | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| from transformers import EsmForProteinFolding, AutoTokenizer | |
| print("Loading ESMFold model...") | |
| MODEL_NAME = "facebook/esmfold_v1" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = EsmForProteinFolding.from_pretrained(MODEL_NAME) | |
| if torch.cuda.is_available(): | |
| model = model.cuda() | |
| model.esm = model.esm.half() | |
| print(f"Model loaded on GPU: {torch.cuda.get_device_name(0)}") | |
| else: | |
| print("Model loaded on CPU") | |
| def predict_structure(sequence: str) -> str: | |
| sequence = sequence.strip().upper() | |
| valid_aa = set("ACDEFGHIKLMNPQRSTVWY") | |
| if not sequence: | |
| return "Error: Empty sequence provided" | |
| invalid_chars = set(sequence) - valid_aa | |
| if invalid_chars: | |
| return f"Error: Invalid amino acids found: {invalid_chars}" | |
| if len(sequence) > 500: | |
| return "Error: Sequence too long (max 500 residues)" | |
| try: | |
| inputs = tokenizer(sequence, return_tensors="pt", add_special_tokens=False) | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.cuda() for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| pdb_string = model.output_to_pdb(outputs)[0] | |
| return pdb_string | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| with gr.Blocks(title="🧬 Antibody Structure API") as demo: | |
| gr.Markdown(""" | |
| # 🧬 Antibody Structure Prediction API (ZeroGPU) | |
| GPU-accelerated ESMFold structure prediction. | |
| **API Usage:** | |
| ```python | |
| from gradio_client import Client | |
| client = Client("kmlyyll/antibody-structure-api") | |
| pdb = client.predict(sequence, api_name="/predict") | |
| ``` | |
| """) | |
| seq_input = gr.Textbox(label="Amino Acid Sequence", placeholder="Enter sequence...", lines=3) | |
| predict_btn = gr.Button("Predict Structure", variant="primary") | |
| pdb_output = gr.Textbox(label="PDB Output", lines=20) | |
| predict_btn.click(fn=predict_structure, inputs=seq_input, outputs=pdb_output, api_name="predict") | |
| if __name__ == "__main__": | |
| demo.launch() |