File size: 4,926 Bytes
6491864
 
 
 
 
 
 
 
 
 
 
268c49d
 
 
 
 
6491864
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
Hugging Face Spaces Gradio App for Antibody Non-Specificity Prediction

Simplified deployment version (no Hydra, no complex dependencies).
Works on HF Spaces free CPU tier.

Local app (src/antibody_training_esm/cli/app.py) remains unchanged.
"""

import logging
import os
import sys
from pathlib import Path

# Add src to Python path for local imports (HF Spaces doesn't install package)
sys.path.insert(0, str(Path(__file__).parent / "src"))

import gradio as gr
import torch
from pydantic import ValidationError

from antibody_training_esm.core.prediction import Predictor
from antibody_training_esm.models.prediction import PredictionRequest

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# HF Spaces environment detection
IS_HF_SPACE = os.getenv("SPACE_ID") is not None

# Model path (either local or downloaded from HF Hub)
MODEL_PATH = os.getenv(
    "MODEL_PATH", "experiments/checkpoints/esm1v/logreg/boughter_vh_esm1v_logreg.pkl"
)

# ESM model name
MODEL_NAME = "facebook/esm1v_t33_650M_UR90S_1"

# Force CPU for HF Spaces free tier
DEVICE = "cpu"

# Load model globally (HF Spaces best practice)
logger.info(f"Loading model from {MODEL_PATH}...")
predictor = Predictor(
    model_name=MODEL_NAME, classifier_path=MODEL_PATH, device=DEVICE, config_path=None
)

# Warm up model
try:
    logger.info("Warming up model...")
    predictor.predict_single("QVQL")
    logger.info("Model ready!")
except Exception as e:
    logger.warning(f"Warmup failed (non-fatal): {e}")


def predict_sequence(sequence: str) -> tuple[str, str]:
    """
    Prediction function for Gradio interface.

    Args:
        sequence: Antibody amino acid sequence

    Returns:
        Tuple of (prediction, probability)
    """
    try:
        # Validate with Pydantic
        request = PredictionRequest(sequence=sequence)

        # Log request
        logger.info(f"Processing sequence: length={len(request.sequence)}")

        # Predict
        result = predictor.predict_single(request)

        # Format probability
        prob_percent = f"{result.probability:.1%}"

        return result.prediction, prob_percent

    except ValidationError as e:
        # User-friendly error message
        error_msg = e.errors()[0]["msg"]
        raise gr.Error(error_msg) from e
    except torch.cuda.OutOfMemoryError as e:
        logger.error("GPU OOM during inference")
        raise gr.Error(
            "Server overloaded (GPU OOM). Please try again in a moment."
        ) from e
    except Exception as e:
        logger.exception("Unexpected prediction failure")
        raise gr.Error(f"Prediction failed: {str(e)}") from e


# Example sequences
examples = [
    [
        "QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYNMHWVRQAPGQGLEWMGGIYPGDSDTRYSPSFQGQVTISADKSISTAYLQWSSLKASDTAMYYCARSTYYGGDWYFNVWGQGTLVTVSS"
    ],  # Standard VH
    [
        "DIQMTQSPSSLSASVGDRVTITCRASQSISSYLNWYQQKPGKAPKLLIYAASSLQSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYSTPLTFGGGTKVEIK"
    ],  # Standard VL
    [
        "EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLEWVARIYPTNGYTRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCARSWGQGTLVTVSS"
    ],  # Short VH
]

# Create Gradio interface
iface = gr.Interface(
    fn=predict_sequence,
    inputs=gr.TextArea(
        lines=7,
        max_lines=20,
        max_length=2000,
        label="Antibody Sequence (VH or VL)",
        placeholder="Paste amino acid sequence here (e.g., QVQL...)",
        info="Supported characters: Standard amino acids (ACDEFGHIKLMNPQRSTVWY).",
        show_copy_button=True,
    ),
    outputs=[
        gr.Textbox(label="Prediction", show_copy_button=True),
        gr.Textbox(label="Probability of Non-Specificity", show_copy_button=True),
    ],
    title="🧬 Antibody Non-Specificity Predictor",
    description=(
        "Predict antibody polyreactivity (non-specificity) from Variable Heavy (VH) "
        "or Variable Light (VL) sequences using ESM-1v protein language models.\n\n"
        "**Model:** ESM-1v (650M parameters) + Logistic Regression\n"
        "**Training:** Boughter dataset (914 antibodies, ELISA polyreactivity)\n"
        "**Citation:** Sakhnini et al. (2025) - Prediction of Antibody Non-Specificity using PLMs"
    ),
    article=(
        f"**Model:** {MODEL_NAME}\n"
        f"**Device:** {DEVICE}\n"
        f"**Environment:** {'Hugging Face Spaces' if IS_HF_SPACE else 'Local'}"
    ),
    examples=examples,
    cache_examples=False,  # Don't cache on HF Spaces (saves disk)
    flagging_mode="never",
    analytics_enabled=False,
    submit_btn="🔬 Predict Non-Specificity",
    clear_btn="🗑️ Clear",
)

# Enable queue for concurrency
iface.queue(default_concurrency_limit=2, max_size=10)

# Launch app
if __name__ == "__main__":
    iface.launch(
        server_name="0.0.0.0",  # Required for HF Spaces
        server_port=7860,
        share=False,
        show_api=False,  # No public REST API
    )