diff --git "a/app.py" "b/app.py" --- "a/app.py" +++ "b/app.py" @@ -16,6 +16,13 @@ from collections import defaultdict from huggingface_hub import snapshot_download from pathlib import Path import os +from inference import ( + PeptiVersePredictor, + read_best_manifest_csv, + BestRow, + canon_model, +) + try: from Bio.SeqUtils.ProtParam import ProteinAnalysis BIOPYTHON_AVAILABLE = True @@ -63,7 +70,7 @@ for k, v in { Path(v).mkdir(parents=True, exist_ok=True) ASSETS_MODELS = ASSETS / "models"; ASSETS_MODELS.mkdir(parents=True, exist_ok=True) -ASSETS_DATA = ASSETS / "training_data"; ASSETS_DATA.mkdir(parents=True, exist_ok=True) +ASSETS_DATA = ASSETS / "training_data_cleaned"; ASSETS_DATA.mkdir(parents=True, exist_ok=True) MODEL_REPO = "ChatterjeeLab/Classifier_Weight" # model repo DATASET_REPO = "ChatterjeeLab/Classifier_Weight" # dataset repo (create this) @@ -72,21 +79,243 @@ def fetch_models_and_data(): snapshot_download( repo_id=MODEL_REPO, local_dir=str(ASSETS_MODELS), - local_dir_use_symlinks=True, + local_dir_use_symlinks=False, allow_patterns=[ - "models/*.pt","models/*.pth","models/*.ckpt","models/*.safetensors", - "models/*.json","models/*.yaml","models/*.yml", + # Model files + "training_classifiers/**/best_model*.json", + "training_classifiers/**/best_model*.pt", + "training_classifiers/**/best_model*.joblib", + # Tokenizer files + "tokenizer/new_vocab.txt", + "tokenizer/new_splits.txt", + # Training data for distributions + "training_data_cleaned/**/*.csv", ], ) - snapshot_download( - repo_id=DATASET_REPO, # <-- no repo_type here - local_dir=str(ASSETS_DATA), - local_dir_use_symlinks=True, - allow_patterns=["training_data/*.csv","training_data/*.npz","training_data/*.md"], - ) fetch_models_and_data() +BEST_TXT = Path("best_models.txt") +TRAINING_ROOT = ASSETS_MODELS / "training_classifiers" +TOKENIZER_DIR = ASSETS_MODELS / "tokenizer" + +# Banned models that should fall back to XGB +BANNED_MODELS = {"svm", "svr", "enet", "svm_gpu", "enet_gpu"} + +# "lower is better" exceptions for classification labeling +LOWER_BETTER = {"hemolysis", "toxicity"} + +# Property display names and descriptions +PROPERTY_INFO = { + 'solubility': { + 'display': 'πŸ’§ Solubility', + 'description': 'Aqueous solubility', + 'direction': '↑', + 'pass_label': 'Soluble', + 'fail_label': 'Insoluble' + }, + 'permeability_penetrance': { + 'display': 'πŸ”¬ Permeability (Penetrance)', + 'description': 'Cell penetration capability', + 'direction': '↑', + 'pass_label': 'Permeable', + 'fail_label': 'Non-permeable' + }, + 'hemolysis': { + 'display': '🩸 Hemolysis', + 'description': 'Red blood cell membrane disruption', + 'direction': '↓', + 'pass_label': 'Non-hemolytic', + 'fail_label': 'Hemolytic' + }, + 'nf': { + 'display': 'πŸ‘― Non-Fouling', + 'description': 'Resistance to protein adsorption', + 'direction': '↑', + 'pass_label': 'Non-fouling', + 'fail_label': 'Fouling' + }, + 'halflife': { + 'display': '⏱️ Half-Life', + 'description': 'Serum stability', + 'direction': '↑', + 'unit': 'hours' + }, + 'toxicity': { + 'display': '☠️ Toxicity', + 'description': 'Cytotoxicity', + 'direction': '↓', + 'pass_label': 'Non-toxic', + 'fail_label': 'Toxic' + }, + 'permeability_pampa': { + 'display': 'πŸͺ£ Permeability (PAMPA)', + 'description': 'PAMPA assay permeability', + 'direction': '', + 'threshold': -6, # Values > -6 are permeable + 'pass_label': 'Permeable', + 'fail_label': 'Non-permeable' + }, + 'permeability_caco2': { + 'display': 'πŸͺ£ Permeability (Caco-2)', + 'description': 'Caco-2 cell permeability', + 'direction': '', + 'threshold': -6, # Values > -6 are permeable + 'pass_label': 'Permeable', + 'fail_label': 'Non-permeable' + }, + 'binding_affinity': { + 'display': 'πŸ”— Binding Affinity', + 'description': 'Protein-peptide binding strength', + 'direction': '↑', + 'thresholds': {'tight': 9, 'weak': 7} + } +} +PROP_ORDER = [ + 'solubility', + 'permeability_penetrance', + 'hemolysis', + 'nf', + 'halflife', + 'toxicity', + 'permeability_pampa', + 'permeability_caco2', + 'binding_affinity', +] + + +# Distribution-only keys +DIST_KEYS = { + "binding_affinity_wt": "πŸ”— Binding Affinity β€” WT (distribution)", + "binding_affinity_smiles": "πŸ”— Binding Affinity β€” SMILES (distribution)", + "binding_affinity_all": "πŸ”— Binding Affinity β€” WT+SMILES (distribution)", + "halflife_wt": "⏱️ Half-life β€” WT (distribution)", + "halflife_smiles": "⏱️ Half-life β€” SMILES (distribution)", + "halflife_all": "⏱️ Half-life β€” WT+SMILES (distribution)", +} + +def create_filtered_manifest(manifest_path: Path) -> Dict[str, BestRow]: + """Read manifest and replace banned models with XGB""" + original = read_best_manifest_csv(manifest_path) + filtered = {} + + for prop_key, row in original.items(): + # Normalize property key for half-life + normalized_key = prop_key + if prop_key in ['halflife', 'half_life']: + normalized_key = 'halflife' # Use consistent key + + # Check and potentially replace WT model + wt_model = canon_model(row.best_wt) + if wt_model in BANNED_MODELS: + wt_model = "XGB" + elif wt_model is None: + wt_model = row.best_wt + else: + wt_model = row.best_wt + + # Check and potentially replace SMILES model + smiles_model = canon_model(row.best_smiles) + if smiles_model in BANNED_MODELS: + smiles_model = "XGB" + elif smiles_model is None: + smiles_model = row.best_smiles + else: + smiles_model = row.best_smiles + + # Create modified row + filtered[normalized_key] = BestRow( + property_key=normalized_key, + best_wt=wt_model if wt_model != row.best_wt else row.best_wt, + best_smiles=smiles_model if smiles_model != row.best_smiles else row.best_smiles, + task_type=row.task_type, + thr_wt=row.thr_wt, + thr_smiles=row.thr_smiles, + ) + + return filtered + +class AppContext: + def __init__(self): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + self.best = create_filtered_manifest(BEST_TXT) + + self.predictor = PeptiVersePredictor( + manifest_path=BEST_TXT, + classifier_weight_root=ASSETS_MODELS, + esm_name="facebook/esm2_t33_650M_UR50D", + clm_name="aaronfeller/PeptideCLM-23M-all", + smiles_vocab=str(TOKENIZER_DIR / "new_vocab.txt"), + smiles_splits=str(TOKENIZER_DIR / "new_splits.txt"), + device=str(self.device), + ) + + # βœ… override manifest AND reload models so keys/folders match + self.predictor.manifest = self.best + self.predictor.models.clear() + self.predictor.meta.clear() + self.predictor._load_all_best_models() + + +CTX: AppContext | None = None + +def initialize(): + global CTX + if CTX is None: + CTX = AppContext() + return CTX + +def get_available_properties(ctx, modality: str) -> Dict[str, bool]: + """ + Returns dict of property -> bool indicating if available for the modality + """ + available = {} + for prop_key in PROPERTY_INFO.keys(): + if prop_key not in ctx.best: + available[prop_key] = False + continue + + row = ctx.best[prop_key] + if modality == "Sequence": + model = row.best_wt + else: + model = row.best_smiles + + # Check if model exists and is not empty/dash + if not model or model in {"-", "β€”", "NA", "N/A", None}: + available[prop_key] = False + else: + # Check if we actually have the model loaded + mode = "wt" if modality == "Sequence" else "smiles" + available[prop_key] = (prop_key, mode) in ctx.predictor.models + + return available + +def get_threshold(ctx: AppContext, prop: str, modality: str) -> float | None: + row = ctx.best.get(prop) + if row is None: + return None + return row.thr_wt if modality == "Sequence" else row.thr_smiles + +def get_best_models_table(ctx: AppContext) -> pd.DataFrame: + """Generate a table showing best models and thresholds""" + data = [] + for prop_key, row in ctx.best.items(): + prop_info = PROPERTY_INFO.get(prop_key, {}) + display_name = prop_info.get('display', prop_key) + + data.append({ + 'Property': display_name, + 'Best Model (Sequence)': row.best_wt if row.best_wt else 'β€”', + 'Threshold (Sequence)': f"{row.thr_wt:.4f}" if row.thr_wt is not None else 'β€”', + 'Best Model (SMILES)': row.best_smiles if row.best_smiles else 'β€”', + 'Threshold (SMILES)': f"{row.thr_smiles:.4f}" if row.thr_smiles is not None else 'β€”', + 'Task Type': row.task_type + }) + + return pd.DataFrame(data) + try: from rdkit import Chem from rdkit.Chem import Descriptors, AllChem @@ -242,525 +471,275 @@ class SequenceAnalyzer: total = sum(hydrophobicity.get(aa, 0) for aa in sequence) return round(total / len(sequence), 2) -# ==================== Model Classes ==================== - -# --- add this utility somewhere above UnifiedPeptidePredictor --- -def load_cnn_weights_safely(model: nn.Module, ckpt_path: Path, device: torch.device): - """ - Load a CNN checkpoint that might include old ESM weights, DDP prefixes, or different wrappers. - Strips unknown prefixes and ignores non-matching keys gracefully. - """ - ckpt = torch.load(ckpt_path, map_location=device) - - # 1) Extract a state dict from various formats - if isinstance(ckpt, dict) and any(k in ckpt for k in ["state_dict", "model_state_dict", "weights"]): - sd = ckpt.get("state_dict") or ckpt.get("model_state_dict") or ckpt.get("weights") - elif isinstance(ckpt, dict): - # Probably already a state_dict - sd = ckpt - else: - # Possibly a full pickled model; try to read its state_dict - try: - sd = ckpt.state_dict() - except Exception as e: - raise RuntimeError(f"Unsupported checkpoint format at {ckpt_path}: {type(ckpt)}") from e - - # 2) Normalize keys: strip DDP 'module.' and drop old ESM-containing parameters - cleaned = {} - for k, v in sd.items(): - k2 = k - if k2.startswith("module."): - k2 = k2[len("module."):] - # drop anything from the embedded ESM or other now-missing submodules - if k2.startswith("esm_model.") or k2.startswith("esm.") or k2.startswith("backbone.esm."): - continue - cleaned[k2] = v - - # 3) Load non-strictly so extra/missing heads don't crash - missing, unexpected = model.load_state_dict(cleaned, strict=False) - - # Optional: log what happened so you can verify - if unexpected: - print(f"[load_cnn_weights_safely] Unexpected keys ignored: {sorted(unexpected)[:6]}{'...' if len(unexpected)>6 else ''}") - if missing: - print(f"[load_cnn_weights_safely] Missing keys not found in checkpoint: {sorted(missing)[:6]}{'...' if len(missing)>6 else ''}") - - -# ====== PeptideCLM SMILES featurizer ====== -from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer -from transformers import AutoModelForMaskedLM +# ==================== Data Management ==================== -class PeptideCLMFeaturizer: - """ - Mean-pool hidden states from PeptideCLM-23M-all for SMILES tokens produced by SMILES_SPE_Tokenizer. - Use the SAME tokenizer files, max_length, and pooling you used in training XGB models. - """ - def __init__(self, vocab_path: str, splits_path: str, device: torch.device, max_length: int = 256): - self.device = device - self.max_length = max_length - self.tok = SMILES_SPE_Tokenizer(vocab_path, splits_path) - self.model = AutoModelForMaskedLM.from_pretrained("aaronfeller/PeptideCLM-23M-all").roformer.to(device).eval() - - @torch.no_grad() - def embed_list(self, smiles_list: list[str]) -> np.ndarray: - feats = [] - for s in smiles_list: - toks = self.tok(s, return_tensors="pt", truncation=True, padding=True) - toks = {k: v.to(self.device) for k, v in toks.items()} - out = self.model(**toks).last_hidden_state # [1, L, H] - mask = toks["attention_mask"].unsqueeze(-1) # [1, L, 1] - pooled = (out * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) - feats.append(pooled.squeeze(0).float().cpu().numpy()) - return np.stack(feats, axis=0) # [N, H] - - -class UnpooledBindingPredictor(nn.Module): - """Binding affinity predictor with cross-attention mechanism""" - def __init__(self, - esm_model_name="facebook/esm2_t33_650M_UR50D", - hidden_dim=512, - kernel_sizes=[3, 5, 7], - n_heads=8, - n_layers=3, - dropout=0.1, - freeze_esm=True): - super().__init__() - - # Use these everywhere for consistency - self.tight_threshold = 7.5 - self.weak_threshold = 6.0 - - self.esm_model = AutoModel.from_pretrained(esm_model_name) - self.config = AutoConfig.from_pretrained(esm_model_name) - if freeze_esm: - for p in self.esm_model.parameters(): - p.requires_grad = False +class TrainingDataManager: + def __init__(self, data_dir=None): + # Try multiple possible locations for data + possible_dirs = [ + ASSETS_MODELS / "training_data_cleaned", # In HF downloaded location + Path("training_data_cleaned"), # Local relative path + ASSETS_DATA, # Original location + ] - esm_dim = self.config.hidden_size - out_ch = 64 - self.protein_conv_layers = nn.ModuleList([ - nn.Conv1d(esm_dim, out_ch, k, padding='same') for k in kernel_sizes - ]) - self.binder_conv_layers = nn.ModuleList([ - nn.Conv1d(esm_dim, out_ch, k, padding='same') for k in kernel_sizes - ]) - total = out_ch * len(kernel_sizes) * 2 + self.data_dir = None + for d in possible_dirs: + if d.exists(): + self.data_dir = d + print(f"Using data directory: {d}") + break - self.protein_projection = nn.Linear(total, hidden_dim) - self.binder_projection = nn.Linear(total, hidden_dim) - self.protein_norm = nn.LayerNorm(hidden_dim) - self.binder_norm = nn.LayerNorm(hidden_dim) + if self.data_dir is None: + print(f"WARNING: No data directory found. Tried: {possible_dirs}") + self.data_dir = ASSETS_DATA # Fallback + self.data_dir.mkdir(exist_ok=True) - self.cross_attention_layers = nn.ModuleList([ - nn.ModuleDict({ - 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout), - 'norm1': nn.LayerNorm(hidden_dim), - 'ffn': nn.Sequential( - nn.Linear(hidden_dim, hidden_dim * 4), - nn.ReLU(), - nn.Dropout(dropout), - nn.Linear(hidden_dim * 4, hidden_dim), - ), - 'norm2': nn.LayerNorm(hidden_dim), - }) for _ in range(n_layers) - ]) - - self.shared_head = nn.Sequential( - nn.Linear(hidden_dim * 2, hidden_dim), - nn.ReLU(), - nn.Dropout(dropout), - ) - self.regression_head = nn.Linear(hidden_dim, 1) - self.classification_head = nn.Linear(hidden_dim, 3) - - def get_binding_class(self, affinity: torch.Tensor | float) -> torch.LongTensor | int: - """ - 0: tight (>= tight_threshold) - 1: medium [weak_threshold, tight_threshold) - 2: weak (< weak_threshold) - """ - if isinstance(affinity, torch.Tensor): - tight = affinity >= self.tight_threshold - weak = affinity < self.weak_threshold - medium = ~(tight | weak) - classes = torch.zeros_like(affinity, dtype=torch.long) - classes[medium] = 1 - classes[weak] = 2 - return classes - else: - if affinity >= self.tight_threshold: - return 0 - elif affinity < self.weak_threshold: - return 2 - else: - return 1 - - def compute_embeddings(self, input_ids, attention_mask=None): - out = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) - return out.last_hidden_state - - def process_sequence(self, unpooled_emb, conv_layers, attention_mask=None): - x = unpooled_emb.transpose(1, 2) # [B, C_in=E, L] - conv_outputs = [F.relu(conv(x)) for conv in conv_layers] # list of [B, C_out, L] - conv_output = torch.cat(conv_outputs, dim=1) # [B, sumC, L] - if attention_mask is not None: - mask = attention_mask.unsqueeze(1).expand(-1, conv_output.size(1), -1) - masked = conv_output.masked_fill(mask == 0, float('-inf')) - max_pooled = masked.max(dim=2)[0] - sum_pooled = (conv_output * mask).sum(dim=2) - denom = mask.sum(dim=2).clamp(min=1.0) - avg_pooled = sum_pooled / denom - else: - max_pooled = conv_output.max(dim=2)[0] - avg_pooled = conv_output.mean(dim=2) - return torch.cat([max_pooled, avg_pooled], dim=1) # [B, 2*sumC] - - def forward(self, protein_input_ids, binder_input_ids, protein_mask=None, binder_mask=None): - protein_unpooled = self.compute_embeddings(protein_input_ids, protein_mask) - binder_unpooled = self.compute_embeddings(binder_input_ids, binder_mask) - protein_features = self.process_sequence(protein_unpooled, self.protein_conv_layers, protein_mask) - binder_features = self.process_sequence(binder_unpooled, self.binder_conv_layers, binder_mask) - protein = self.protein_norm(self.protein_projection(protein_features)) - binder = self.binder_norm(self.binder_projection(binder_features)) - - # make them "sequence length 1" for MHA (L,B,D) - protein = protein.unsqueeze(0).transpose(0,1) - binder = binder.unsqueeze(0).transpose(0,1) - for layer in self.cross_attention_layers: - p_attn = layer['attention'](protein, binder, binder)[0] - protein = layer['norm1'](protein + p_attn) - protein = layer['norm2'](protein + layer['ffn'](protein)) - b_attn = layer['attention'](binder, protein, protein)[0] - binder = layer['norm1'](binder + b_attn) - binder = layer['norm2'](binder + layer['ffn'](binder)) - - protein_pool = protein.mean(dim=0).squeeze(0) - binder_pool = binder.mean(dim=0).squeeze(0) - shared = self.shared_head(torch.cat([protein_pool, binder_pool], dim=-1)) - reg = self.regression_head(shared) # [1] - logits= self.classification_head(shared) # [3] - return reg, logits - - -# ------- SMILES + Protein binding model (reg + 3-class) ------- -class ImprovedBindingPredictor(nn.Module): - def __init__(self, esm_dim=1280, smiles_dim=768, hidden_dim=512, n_heads=8, n_layers=3, dropout=0.1): - super().__init__() - self.tight_threshold = 7.5 - self.weak_threshold = 6.0 - - self.smiles_projection = nn.Linear(smiles_dim, hidden_dim) - self.protein_projection = nn.Linear(esm_dim, hidden_dim) - self.protein_norm = nn.LayerNorm(hidden_dim) - self.smiles_norm = nn.LayerNorm(hidden_dim) - - self.cross_attention_layers = nn.ModuleList([ - nn.ModuleDict({ - 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout), - 'norm1': nn.LayerNorm(hidden_dim), - 'ffn': nn.Sequential( - nn.Linear(hidden_dim, hidden_dim * 4), - nn.ReLU(), - nn.Dropout(dropout), - nn.Linear(hidden_dim * 4, hidden_dim), - ), - 'norm2': nn.LayerNorm(hidden_dim), - }) for _ in range(n_layers) - ]) - - self.shared_head = nn.Sequential( - nn.Linear(hidden_dim * 2, hidden_dim), - nn.ReLU(), - nn.Dropout(dropout), - ) - self.regression_head = nn.Linear(hidden_dim, 1) - self.classification_head = nn.Linear(hidden_dim, 3) - - def get_binding_class(self, affinity): - """Convert affinity values to class indices - 0: tight binding (>= 7.5) - 1: medium binding (6.0-7.5) - 2: weak binding (< 6.0) - """ - if isinstance(affinity, torch.Tensor): - tight_mask = affinity >= self.tight_threshold - weak_mask = affinity < self.weak_threshold - medium_mask = ~(tight_mask | weak_mask) - - classes = torch.zeros_like(affinity, dtype=torch.long) - classes[medium_mask] = 1 - classes[weak_mask] = 2 - return classes - else: - if affinity >= self.tight_threshold: - return 0 # tight binding - elif affinity < self.weak_threshold: - return 2 # weak binding - else: - return 1 # medium binding - - def forward(self, protein_emb: torch.Tensor, smiles_emb: torch.Tensor): - # protein_emb: [1, E], smiles_emb: [1, H] - protein = self.protein_norm(self.protein_projection(protein_emb)) # [1, D] - smiles = self.smiles_norm(self.smiles_projection(smiles_emb)) # [1, D] - - # Treat as "sequence length"=1 tokens; mha still works (QKV dims match) - protein = protein.unsqueeze(0) # [1, 1, D] -> (L, B, D) expected, we’ll keep batch in 2nd dim: - smiles = smiles.unsqueeze(0) # [1, 1, D] - protein = protein.transpose(0, 1) # [B=1, L=1, D] -> MHA wants [L, B, D] - smiles = smiles.transpose(0, 1) - - for layer in self.cross_attention_layers: - attn_p = layer['attention'](protein, smiles, smiles)[0] - protein = layer['norm1'](protein + attn_p) - protein = layer['norm2'](protein + layer['ffn'](protein)) - - attn_s = layer['attention'](smiles, protein, protein)[0] - smiles = layer['norm1'](smiles + attn_s) - smiles = layer['norm2'](smiles + layer['ffn'](smiles)) - - # pool over L (it's 1, so mean==squeeze) - protein_pool = protein.mean(dim=0).squeeze(0) # [D] - smiles_pool = smiles.mean(dim=0).squeeze(0) # [D] - - combined = torch.cat([protein_pool, smiles_pool], dim=-1) # [2D] - shared = self.shared_head(combined) - reg = self.regression_head(shared) # scalar pKd/pKi - logits = self.classification_head(shared) # 3-class - return reg, logits - - -class PeptideCNN(nn.Module): - """CNN model for single peptide property prediction""" - def __init__(self, input_dim=1280, hidden_dims=None, output_dim=160, dropout_rate=0.3): - super().__init__() - if hidden_dims is None: - hidden_dims = [input_dim // 2, input_dim // 4] - - self.conv1 = nn.Conv1d(input_dim, hidden_dims[0], kernel_size=3, padding=1) - self.conv2 = nn.Conv1d(hidden_dims[0], hidden_dims[1], kernel_size=5, padding=1) - self.fc = nn.Linear(hidden_dims[1], output_dim) - self.dropout = nn.Dropout(dropout_rate) - self.predictor = nn.Linear(output_dim, 1) - - def forward(self, esm_embeddings, return_features=False): - x = esm_embeddings.permute(0, 2, 1) - x = F.relu(self.conv1(x)) - x = self.dropout(x) - x = F.relu(self.conv2(x)) - x = self.dropout(x) - x = x.permute(0, 2, 1) - x = x.mean(dim=1) - features = self.fc(x) - if return_features: - return features - return self.predictor(features) - - -# ==================== Data Management ==================== - -class TrainingDataManager: - def __init__(self, data_dir=ASSETS_DATA): - self.data_dir = Path(data_dir) - self.data_dir.mkdir(exist_ok=True) self.statistics = self.load_statistics() - def _load_half_life_csv(self): - csv_path = self.data_dir / "training_data/half_life_smiles.csv" - if not csv_path.exists(): - return None - try: - df = pd.read_csv(csv_path) - if "log_hour" in df.columns: - vals = pd.to_numeric(df["log_hour"], errors="coerce").dropna().to_numpy() - else: - if "half_life_hours" not in df.columns: - if "half_life" in df.columns: - df["half_life_hours"] = pd.to_numeric(df["half_life"], errors="coerce") / 3600.0 - else: - raise ValueError("CSV must contain 'log_hour' or 'half_life_hours' (or 'half_life').") - hh = pd.to_numeric(df["half_life_hours"], errors="coerce") - vals = np.log10(hh.replace(0, np.nan)).dropna().to_numpy() - if len(vals) == 0: - return None - return { - "values": vals, - "unit": "log10(hours)", - "threshold": float(np.median(vals)), # median on log scale - "kind": "continuous", - } - except Exception as e: - print(f"[TrainingDataManager] half-life load error: {e}") - return None - - def _load_binary_pair(self, prefix: str): - """ - Load binary labels from -positive.npz and -negative.npz - Returns: {'values': y, 'unit': 'Class (0=neg, 1=pos)', 'kind': 'binary', 'n_pos': int, 'n_neg': int} - or None if missing. - """ - pos_path = self.data_dir / f"training_data/{prefix}-positive.npz" - neg_path = self.data_dir / f"training_data/{prefix}-negative.npz" - if not pos_path.exists() or not neg_path.exists(): - return None - try: - with np.load(pos_path) as pos: - pos_data = pos["arr_0"] - with np.load(neg_path) as neg: - neg_data = neg["arr_0"] - y = np.concatenate( - [np.ones(len(pos_data), dtype=int), np.zeros(len(neg_data), dtype=int)], - axis=0 - ) - return { - "values": y, - "unit": "Class (0=neg, 1=pos)", - "kind": "binary", - "n_pos": int(len(pos_data)), - "n_neg": int(len(neg_data)), - } - except Exception as e: - print(f"[TrainingDataManager] binary load error for '{prefix}': {e}") - return None - - def _load_binding_affinity_csv(self): - """ - Read c-binding.csv and return the raw affinity values (pKd/pKi-like, i.e., -log scale). - No filtering/clipping β€” only numeric conversion with NaNs dropped so plotting works. + def load_csv_data(self, filepath: Path, value_column, is_binary: bool = False) -> Optional[Dict]: + """Load data from a CSV file. + value_column can be a string OR a list/tuple of candidate column names. """ - csv_path = self.data_dir / "training_data/c-binding.csv" - if not csv_path.exists(): + if not filepath.exists(): + print(f"File not found: {filepath}") return None try: - df = pd.read_csv(csv_path) - if "affinity" not in df.columns: - raise ValueError("CSV must contain an 'affinity' column.") - - vals = pd.to_numeric(df["affinity"], errors="coerce").dropna().to_numpy() + df = pd.read_csv(filepath, encoding="utf-8", on_bad_lines="skip") + print(f"Columns in {filepath.name}: {df.columns.tolist()[:5]}...") + + # Case-insensitive column map + col_lower = {col.lower(): col for col in df.columns} + + # βœ… NEW: allow list/tuple of candidates + if isinstance(value_column, (list, tuple)): + chosen = None + for c in value_column: + if c is None: + continue + c_l = str(c).lower() + if c_l in col_lower: + chosen = col_lower[c_l] + break + if chosen is None: + print(f"None of candidate columns {value_column} found. Available: {list(df.columns)[:10]}") + return None + value_column = chosen + else: + # keep original behavior, but safe-cast to str + vc_l = str(value_column).lower() + if vc_l not in col_lower: + alternatives = { + 'label': ['label', 'labels', 'y', 'target'], + 'affinity': ['affinity', 'pkd', 'pki', 'binding_affinity'], + 'pampa': ['pampa', 'pampa_value', 'permeability'], + 'caco2': ['caco2', 'caco-2', 'caco_2'], + 'log_hour': ['log_hour', 'loghour', 'log_hours', 'loghours'], + 'half_life_hours': ['half_life_hours', 'halflife_hours', 'hours'], + 'half_life_seconds': ['half_life_seconds', 'halflife_seconds', 'seconds'], + } + found = False + for alt in alternatives.get(vc_l, []): + if alt.lower() in col_lower: + value_column = col_lower[alt.lower()] + found = True + break + if not found: + print(f"Column {value_column} not found. Available: {list(df.columns)[:10]}") + return None + else: + value_column = col_lower[vc_l] + + vals = pd.to_numeric(df[value_column], errors="coerce").dropna().to_numpy() if len(vals) == 0: + print(f"No valid values found in column {value_column}") return None - return { - "values": vals, - "description": "Protein–ligand binding affinity normalized", - "unit": "score", - "threshold": 7.5, # main threshold (tight) - "threshold_secondary": 6.0, # weak threshold - "kind": "continuous", - "download_link": str(csv_path), - } - except Exception as e: - print(f"[TrainingDataManager] binding-affinity load error: {e}") - return None + print(f"Loaded {len(vals)} values from {filepath.name}") - def _load_permeability_pampa_csv(self): - """ - Load PAMPA permeability values from training_data/nc-CPP-processed.csv. - Expects columns: 'SMILES','PAMPA'. We only parse PAMPA as float; NaNs are dropped. - No filtering/clipping. - """ - csv_path = self.data_dir / "training_data/nc-CPP-processed.csv" - if not csv_path.exists(): - return None - try: - df = pd.read_csv(csv_path) - if "PAMPA" not in df.columns: - raise ValueError("CSV must contain a 'PAMPA' column.") + if is_binary: + unique_vals = np.unique(vals) + if not set(unique_vals).issubset({0, 1, 0.0, 1.0}): + vals = (vals > 0.5).astype(int) - vals = pd.to_numeric(df["PAMPA"], errors="coerce").dropna().to_numpy() - if len(vals) == 0: - return None + return {"values": vals, "n_samples": len(vals)} - threshold_default = float(np.median(vals)) - return { - "values": vals, - "description": "Cell membrane permeability measurements", - "unit": "log Peff", - "threshold": threshold_default, - "kind": "continuous", - "download_link": str(csv_path), - } except Exception as e: - print(f"[TrainingDataManager] permeability PAMPA load error: {e}") + print(f"Error loading {filepath}: {e}") + import traceback + traceback.print_exc() return None + def load_statistics(self): - """Load pre-computed statistics for each property""" - stats = { + """Load pre-computed statistics for each property from actual data files""" + stats = {} + + # Map properties to their data files and value columns + data_mappings = { 'hemolysis': { - 'values': np.random.beta(2, 5, 1000), - 'description': 'Probability of peptide disrupting red blood cell membranes.', - 'unit': 'Probability', - 'threshold': 0.5, - 'download_link': '#' + 'files': [ + 'hemolysis/hemo_meta_with_split.csv', + 'hemolysis/hemolysis_meta_with_split.csv', + ], + 'column': 'label', + 'is_binary': True }, 'solubility': { - 'values': np.random.normal(5, 2, 1000), - 'description': 'Probability of peptide remaining dissolved in aqueous conditions.', - 'unit': 'Probability', - 'threshold': 0.5, - 'download_link': '#' + 'files': [ + 'solubility/sol_meta_with_split.csv', + 'solubility/solubility_meta_with_split.csv', + ], + 'column': 'label', + 'is_binary': True }, - 'binding_affinity': { - 'values': np.random.normal(7, 1.5, 1000), - 'description': 'Protein-peptide binding affinity', - 'unit': 'Probability', - 'threshold': 7.5, - 'download_link': '#' + "binding_affinity_wt": { + "files": ["binding_affinity/binding_affinity_wt_meta_with_split.csv"], + "column": "affinity", + "is_binary": False }, - 'half_life (smiles)': { - # will be overwritten below if CSV exists - 'values': np.random.lognormal(2, 1, 1000), - 'description': 'Serum half-life from clinical and preclinical studies', - 'unit': 'Hours', - 'threshold': 2.0, # hours (default fallback) - 'download_link': '#' + "binding_affinity_smiles": { + "files": ["binding_affinity/binding_affinity_smiles_meta_with_split.csv"], + "column": "affinity", + "is_binary": False }, - 'nonfouling': { - 'values': np.random.lognormal(4, 1, 1000), - 'description': 'A nonfouling peptide resists nonspecific interactions and protein adsorption.', - 'unit': 'Probability', - 'threshold': 0.5, - 'download_link': '#' + "binding_affinity_all": { + "files": [ + "binding_affinity/binding_affinity_wt_meta_with_split.csv", + "binding_affinity/binding_affinity_smiles_meta_with_split.csv", + ], + "column": "affinity", + "is_binary": False }, - 'permeability': { - 'values': np.random.normal(-4, 1, 1000), - 'description': 'Cell membrane permeability measurements', - 'unit': 'Probability of peptide penetrating the cell membrane.', - 'threshold': 0.5, - 'download_link': '#' + + "halflife_wt": { + "files": [ + "half_life/halflife_with_split.csv", + "half_life/halflife_meta_with_split.csv", + ], + "column": ["half_life_hours", "log_hour", "log_hours"], + "is_binary": False + }, + "halflife_smiles": { + "files": [ + "half_life/halflife_smiles_with_split.csv", + "half_life/halflife_smiles_with_splits.csv", + "half_life/halflife_smiles_meta_with_split.csv", + ], + "column": ["half_life_hours", "log_hour", "log_hours"], + "is_binary": False + }, + "halflife_all": { + "files": [ + "half_life/halflife_with_split.csv", + "half_life/halflife_meta_with_split.csv", + "half_life/halflife_smiles_with_split.csv", + "half_life/halflife_smiles_with_splits.csv", + "half_life/halflife_smiles_meta_with_split.csv", + ], + "column": ["half_life_hours", "log_hour", "log_hours"], + "is_binary": False + }, + 'nf': { + 'files': [ + 'nonfouling/nf_meta_with_split.csv', + 'nf/nf_meta_with_split.csv', + ], + 'column': 'label', + 'is_binary': True + }, + 'permeability_penetrance': { + 'files': [ + 'permeability/perm_meta_with_split.csv', + 'permeability_penetrance/permeability_meta_with_split.csv', + ], + 'column': 'label', + 'is_binary': True + }, + 'permeability_pampa': { + 'files': [ + 'permeability_pampa/pampa_meta_with_split.csv', + 'pampa/pampa_meta_with_split.csv', + ], + 'column': 'PAMPA', + 'is_binary': False + }, + 'permeability_caco2': { + 'files': [ + 'permeability_caco2/caco2_meta_with_split.csv', + 'caco2/caco2_meta_with_split.csv', + ], + 'column': 'Caco2', + 'is_binary': False + }, + 'toxicity': { + 'files': [ + 'toxicity/tox_meta_with_split.csv', + 'toxicity/toxicity_meta_with_split.csv', + ], + 'column': 'label', + 'is_binary': True } } - - # Overlay real half-life - hl = self._load_half_life_csv() - if hl is not None: - stats["half_life (smiles)"].update(hl) - - # Overlay real solubility from sol-* (binary) - sol = self._load_binary_pair("sol") - if sol is not None: - stats["solubility"].update(sol) - - # Overlay real non-fouling from nf-* (binary) - nf = self._load_binary_pair("nf") - if nf is not None: - stats["nonfouling"].update(nf) - - hemo = self._load_binary_pair("hemo") - if hemo is not None: - stats["hemolysis"].update(hemo) - - ba = self._load_binding_affinity_csv() - if ba is not None: - stats["binding_affinity"].update(ba) - pampa = self._load_permeability_pampa_csv() - if pampa is not None: - stats["permeability"].update(pampa) + # Load actual data + for prop_key, mapping in data_mappings.items(): + all_vals = [] + loaded_from = [] + + for file_path in mapping['files']: + filepath = self.data_dir / file_path + if not filepath.exists(): + continue + + d = self.load_csv_data( + filepath, + mapping['column'], + mapping.get('is_binary', False) + ) + if d: + all_vals.append(d["values"]) + loaded_from.append(file_path) + + if all_vals: + vals = np.concatenate(all_vals, axis=0) + + prop_info = PROPERTY_INFO.get(prop_key, {}) + stats[prop_key] = { + "values": vals, + "description": prop_info.get("description", ""), + "unit": "Probability" if mapping.get("is_binary") else prop_info.get("unit", "Score"), + "n_samples": int(vals.shape[0]), + "kind": "binary" if mapping.get("is_binary") else "continuous", + "loaded_from": loaded_from, # optional: good for debugging + } + + # thresholds / unit tweaks + if prop_key == "binding_affinity": + stats[prop_key]["threshold"] = 9 + stats[prop_key]["threshold_secondary"] = 7 + stats[prop_key]["unit"] = "pKd/pKi" + + elif prop_key in ["permeability_pampa", "permeability_caco2"]: + stats[prop_key]["threshold"] = -6 + stats[prop_key]["unit"] = "log Peff" if prop_key == "permeability_pampa" else "log Papp" + + elif prop_key == "halflife": + stats[prop_key]["unit"] = "hours" + # for distribution plotting + if prop_key.startswith("binding_affinity"): + stats[prop_key]["threshold"] = 9 + stats[prop_key]["threshold_secondary"] = 7 + stats[prop_key]["unit"] = "pKd/pKi" + + elif prop_key.startswith("halflife"): + stats[prop_key]["unit"] = "hours" + print(f"βœ“ Loaded {prop_key} from {loaded_from} ({len(vals)} samples)") + continue - return stats + # fallback synthetic (unchanged) + print(f"⚠ Using synthetic data for {prop_key}") + return stats def get_distribution_plot(self, property_name, current_value=None): if property_name not in self.statistics: @@ -774,21 +753,28 @@ class TrainingDataManager: n1 = int((vals == 1).sum()) total = max(n0 + n1, 1) fig = go.Figure() - fig.add_trace(go.Bar(x=["Negative (0)", "Positive (1)"], y=[n0, n1])) + + prop_info = PROPERTY_INFO.get(property_name, {}) + labels = [ + prop_info.get('fail_label', 'Negative (0)'), + prop_info.get('pass_label', 'Positive (1)') + ] + + fig.add_trace(go.Bar(x=labels, y=[n0, n1])) fig.update_layout( - title=f"{property_name.replace('_',' ').title()} β€” Class Balance", + title=f"{prop_info.get('display', property_name)} β€” Class Distribution", xaxis_title="Class", yaxis_title="Count", height=400, showlegend=False, annotations=[ - dict(x="Negative (0)", y=n0, text=f"{n0} ({n0/total:.1%})", showarrow=False, yshift=8), - dict(x="Positive (1)", y=n1, text=f"{n1} ({n1/total:.1%})", showarrow=False, yshift=8), + dict(x=labels[0], y=n0, text=f"{n0} ({n0/total:.1%})", showarrow=False, yshift=8), + dict(x=labels[1], y=n1, text=f"{n1} ({n1/total:.1%})", showarrow=False, yshift=8), ], ) return fig - # continuous + # Continuous distribution fig = go.Figure() fig.add_trace(go.Histogram(x=vals, nbinsx=50, name="Training Data")) @@ -824,8 +810,9 @@ class TrainingDataManager: annotation_text=f"Your Result: {float(current_value):.3f}", ) + prop_info = PROPERTY_INFO.get(property_name, {}) fig.update_layout( - title=f"{property_name.replace('_', ' ').title()} Distribution", + title=f"{prop_info.get('display', property_name)} Distribution", xaxis_title=s.get("unit", ""), yaxis_title="Count", height=400, @@ -833,7 +820,6 @@ class TrainingDataManager: ) return fig - def get_property_info(self, property_name): if property_name not in self.statistics: return None @@ -864,789 +850,189 @@ class TrainingDataManager: "75%": float(pct[3]), "90%": float(pct[4]), } - return info - - - -def _base_stat_key(model_key: str) -> str: - # strip modality suffixes to find stats in TrainingDataManager - for suf in ("_seq", "_smiles"): - if model_key.endswith(suf): - return model_key[:-len(suf)] - return model_key -# ==================== Unified Predictor ==================== - -class UnifiedPeptidePredictor: - """Main predictor handling all model types""" - - def __init__(self, model_dir="models"): - self.model_dir = Path(model_dir) - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Initialize tokenizer and ESM model - print("Loading ESM model...") - self.tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") - self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") - self.esm_model.to(self.device) - self.esm_model.eval() - - self.tokenizer_dir = Path("tokenizer") - self.smiles_featurizer = PeptideCLMFeaturizer( - vocab_path=f"{self.tokenizer_dir}/new_vocab.txt", - splits_path=f"{self.tokenizer_dir}/new_splits.txt", - device=self.device, - ) - # Model registry - self.models = {} - self.model_configs = self.get_model_configs() - self.sequence_analyzer = SequenceAnalyzer() - # Data manager - self.data_manager = TrainingDataManager(data_dir=ASSETS_DATA) - self._protein_cache = {} - # Load models - self.load_all_models() - - def get_model_configs(self): - """Define model configurations""" - return { - 'hemolysis_seq': { - 'type': 'xgboost', - 'input': 'sequence', - 'path': 'models/best_model_hemolysis.json', - 'inverse_score': False, - 'unit': 'Probability', - 'display_name': '🩸 Hemolysis', - 'positive_label': 'Non-hemolytic', - 'negative_label': 'Hemolytic' - }, - 'hemolysis_smiles': { - 'type': 'xgboost', - 'input': 'smiles', - 'path': 'models/hemolysis-xgboost_smiles.json', - 'inverse_score': False, - 'unit': 'Probability', - 'display_name': '🩸 Hemolysis', - 'positive_label': 'Non-hemolytic', - 'negative_label': 'Hemolytic' - }, - 'solubility_seq': { - 'type': 'xgboost', - 'input': 'sequence', - 'path': 'models/best_model_solubility.json', - 'unit': 'Probability', - 'display_name': 'πŸ’§ Solubility', - 'positive_label': 'Soluble', - 'negative_label': 'Insoluble' - }, - 'solubility_smiles': { - 'type': 'xgboost', - 'input': 'smiles', - 'path': 'models/solubility-xgboost_smiles.json', - 'unit': 'Probability', - 'display_name': 'πŸ’§ Solubility', - 'positive_label': 'Soluble', - 'negative_label': 'Insoluble' - }, - 'permeability_smiles': { - 'type': 'xgboost', - 'input': 'smiles', - 'path': 'models/permeability-xgboost_smiles.json', - 'unit': 'Probability', - 'display_name': 'πŸͺ£ Permeability', - 'positive_label': 'Permeable', - 'negative_label': 'Impermeable' - }, - 'half_life_seq': { - 'type': 'pytorch_cnn', - 'input': 'sequence', - 'path': 'models/best_model_half_life.pth', - 'transform': lambda x: 10**x, - 'unit': 'hours', - 'display_name': '⏱️ Half-life', - 'positive_label': 'Stable', - 'negative_label': 'Unstable' - }, - 'nonfouling_seq': { - 'type': 'xgboost', - 'input': 'sequence', - 'path': 'models/best_model_nonfouling.json', - 'unit': 'Probability', - 'display_name': 'πŸ‘― Non-Fouling', - 'positive_label': 'Non-toxic', - 'negative_label': 'Toxic' - }, - 'nonfouling_smiles': { - 'type': 'xgboost', - 'input': 'smiles', - 'path': 'models/nonfouling-xgboost_smiles.json', - 'unit': 'Probability', - 'display_name': 'πŸ‘― Non-Fouling', - 'positive_label': 'Stable', - 'negative_label': 'Unstable' - }, - 'binding_affinity': { - 'type': 'binding', - 'input': 'dual_sequence', - 'path': 'models/binding_affinity_unpooled.pt', - 'unit': 'Probability', - 'display_name': 'πŸ”— Binding Affinity' - }, - 'binding_affinity_smiles': { - 'type': 'binding_smiles', - 'input': 'sequence+smiles', - 'path': 'models/binding-affinity_smiles.pt', - 'unit': 'Probability', - 'display_name': 'πŸ”— Binding Affinity (SMILES)' - }, - } - def analyze_sequence(self, sequence: str, pH: float = 7.0) -> Dict[str, Any]: - """Comprehensive sequence analysis including charge, pI, and aggregation""" - results = {} - # Basic properties - results['length'] = len(sequence) - results['molecular_weight'] = self.sequence_analyzer.calculate_molecular_weight(sequence) - results['net_charge'] = self.sequence_analyzer.calculate_net_charge(sequence, pH) - results['isoelectric_point'] = self.sequence_analyzer.calculate_isoelectric_point(sequence) - results['hydrophobicity'] = self.sequence_analyzer.calculate_hydrophobicity(sequence) - return results - def load_all_models(self): - """Load all available models""" - for name, config in self.model_configs.items(): - model_path = self.model_dir / config['path'] - - if not model_path.exists(): - print(f"Warning: Model {name} not found at {model_path}") - continue - - try: - if config['type'] == 'xgboost': - self.models[name] = xgb.Booster(model_file=str(model_path)) - - elif config['type'] == 'pytorch_cnn': - model = PeptideCNN().to(self.device) - ckpt_path = model_path # Path from config - load_cnn_weights_safely(model, ckpt_path, self.device) - model.eval() - self.models[name] = model - - elif config['type'] == 'binding': - checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) - model = UnpooledBindingPredictor( - hidden_dim=384, - kernel_sizes=[3, 5, 7], - n_heads=8, - n_layers=4, - dropout=0.14561457009902096, - freeze_esm=True - ).to(self.device) - model.load_state_dict(checkpoint['model_state_dict']) - model.eval() - self.models[name] = model - elif config['type'] == 'binding_smiles': - ckpt = torch.load(model_path, map_location=self.device, weights_only=False) - model = ImprovedBindingPredictor( - esm_dim=1280, smiles_dim=768, hidden_dim=512, n_heads=8, n_layers=3, dropout=0.1 - ).to(self.device) - model.load_state_dict(ckpt['model_state_dict']) - model.eval() - self.models[name] = model - - print(f"βœ“ Loaded {name}") - - except Exception as e: - print(f"Error loading {name}: {e}") - - def _protein_embed_mean(self, protein_seq: str) -> torch.Tensor: - """Mean-pool ESM last_hidden_state -> [1, 1280]""" - toks = self.tokenizer(protein_seq, return_tensors="pt", padding=True, truncation=True, max_length=1024) - toks = {k: v.to(self.device) for k, v in toks.items()} - with torch.no_grad(): - out = self.esm_model(**toks).last_hidden_state # [1, L, E] - mask = toks['attention_mask'].unsqueeze(-1) # [1, L, 1] - pooled = (out * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) # [1, E] - return pooled - - def _get_protein_vec(self, protein_seq: str) -> torch.Tensor: - key = protein_seq.strip() - if key in self._protein_cache: - return self._protein_cache[key] - vec = self._protein_embed_mean(key) - self._protein_cache[key] = vec - return vec - - def _smiles_embed_mean(self, smiles: str) -> torch.Tensor: - vec = self.smiles_featurizer.embed_list([smiles])[0] # np [H] - return torch.from_numpy(vec).to(self.device).unsqueeze(0) # [1, H] - - def predict_property(self, model, config, value: str, input_type: str): - """ - value: either AA sequence (Sequence mode) or SMILES (SMILES mode) - """ - if config['type'] == 'xgboost': - if input_type == 'SMILES': - if config.get('input') != 'smiles': - raise RuntimeError(f"Model {config['display_name']} expects sequence, not SMILES.") - feats = self._features_from_smiles_peptclm(value)[None, ...] # [1, D] - else: - if config.get('input') == 'smiles': - raise RuntimeError(f"Model {config['display_name']} expects SMILES, not sequence.") - # ESM mean-pooled features - toks = self.tokenizer(value, return_tensors="pt", padding=True, truncation=True, max_length=512) - toks = {k: v.to(self.device) for k, v in toks.items()} - with torch.no_grad(): - out = self.esm_model(**toks).last_hidden_state - mask = toks["attention_mask"].unsqueeze(-1) - pooled = (out * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) - feats = pooled.float().cpu().numpy() # [1, 1280] - # Optional safety check - expected = model.num_features() - if feats.shape[1] != expected: - raise RuntimeError(f"Feature dim mismatch: got {feats.shape[1]}, booster expects {expected}") - dmat = xgb.DMatrix(feats) - pred = model.predict(dmat)[0] - if config.get('inverse_score', False): - pred = 1 - pred - return float(pred) - - elif config['type'] == 'pytorch_cnn': - if input_type == 'SMILES': - raise RuntimeError(f"{config['display_name']} (CNN) expects AA sequence, not SMILES.") - toks = self.tokenizer(value, return_tensors="pt", padding=True, truncation=True, max_length=512) - toks = {k: v.to(self.device) for k, v in toks.items()} - with torch.no_grad(): - out = self.esm_model(**toks).last_hidden_state - y = model(out).squeeze().item() - if 'transform' in config: - y = config['transform'](y) - return float(y) - - else: - raise NotImplementedError(config['type']) - - def predict_sequence_property(self, model, config, sequence): - """Predict property from sequence""" - inputs = self.tokenizer( - sequence, - return_tensors="pt", - padding=True, - truncation=True, - max_length=512 - ) - inputs = {k: v.to(self.device) for k, v in inputs.items()} - - with torch.no_grad(): - outputs = self.esm_model(**inputs) - embeddings = outputs.last_hidden_state - - if config['type'] == 'xgboost': - attention_mask = inputs['attention_mask'] - masked_embeddings = embeddings * attention_mask.unsqueeze(-1) - sum_embeddings = masked_embeddings.sum(dim=1) - seq_lengths = attention_mask.sum(dim=1, keepdim=True) - mean_embeddings = sum_embeddings / seq_lengths - features = mean_embeddings.cpu().numpy() - - dmatrix = xgb.DMatrix(features) - prediction = model.predict(dmatrix)[0] - - if config.get('inverse_score', False): - prediction = 1 - prediction - - elif config['type'] == 'pytorch_cnn': - prediction = model(embeddings).squeeze().item() - - if 'transform' in config: - prediction = config['transform'](prediction) - - return prediction - - def predict_binding(self, model, protein_seq, binder_seq, prefer_thresholds: bool = True): - """Predict (affinity, class_label). If prefer_thresholds=True, label is derived from model.tight/weak thresholds.""" - protein_tokens = self.tokenizer( - protein_seq, return_tensors="pt", - padding="max_length", max_length=1024, truncation=True - ) - binder_tokens = self.tokenizer( - binder_seq, return_tensors="pt", - padding="max_length", max_length=1024, truncation=True - ) - protein_ids = protein_tokens['input_ids'].to(self.device) - protein_mask= protein_tokens['attention_mask'].to(self.device) - binder_ids = binder_tokens['input_ids'].to(self.device) - binder_mask = binder_tokens['attention_mask'].to(self.device) - - with torch.no_grad(): - reg, logits = model(protein_ids, binder_ids, protein_mask, binder_mask) - affinity = float(reg.squeeze().item()) - # 1) threshold-based class: - cls_by_thr = int(model.get_binding_class(affinity)) - # 2) logits-based class: - cls_by_logit = int(torch.argmax(logits, dim=-1).item()) - - class_names = ['Tight', 'Medium', 'Weak'] - # choose which one you want to show - cls_idx = cls_by_thr if prefer_thresholds else cls_by_logit - - # decorate with explicit cutoffs for UI clarity - if cls_idx == 0: - label = f"Tight (β‰₯ {model.tight_threshold:.1f})" - elif cls_idx == 1: - label = f"Medium ({model.weak_threshold:.1f}–{model.tight_threshold:.1f})" - else: - label = f"Weak (< {model.weak_threshold:.1f})" - - return affinity, label - - - def predict_binding_smiles(self, model, protein_seq: str, smiles_str: str, prefer_thresholds: bool = True) -> tuple[float, str]: - prot_vec = self._get_protein_vec(protein_seq) # [1, 1280] - smiles_vec = self._smiles_embed_mean(smiles_str) # [1, 768] - with torch.no_grad(): - reg, logits = model(prot_vec, smiles_vec) - affinity = float(reg.squeeze().item()) - cls_by_thr = int(model.get_binding_class(affinity)) - cls_by_logit = int(torch.argmax(logits, dim=-1).item()) - - cls_idx = cls_by_thr if prefer_thresholds else cls_by_logit - - if cls_idx == 0: - label = f"Tight (β‰₯ {model.tight_threshold:.1f})" - elif cls_idx == 1: - label = f"Medium ({model.weak_threshold:.1f}–{model.tight_threshold:.1f})" - else: - label = f"Weak (< {model.weak_threshold:.1f})" - return affinity, label - - - def _features_from_smiles_peptclm(self, s: str) -> np.ndarray: - return self.smiles_featurizer.embed_list([s])[0] - - @staticmethod - def affinity_to_nM(affinity: float) -> float: - """ - Convert model affinity score (pKd / pKi / pIC50 style: -log10(K [M])) - to an approximate concentration in nM. - """ - # K [M] = 10^(-affinity); then convert M -> nM (1e9 factor) - return 10.0 ** (-float(affinity)) * 1e9 - - -# ==================== TANGO INTEGRATION ==================== - -# TANGO executable: same folder as this script -try: - HERE = Path(__file__).resolve().parent -except NameError: - HERE = Path(".").resolve() - -TANGO_EXE = str(HERE / "tango_x86_64_release") - -# Default params (adjust if you like) -DEFAULT_TANGO_PARAMS = { - "nt": "N", - "ct": "N", - "ph": "7.0", - "te": "310", # Kelvin (~37 Β°C) - "io": "0.05", - "tf": "0", - "stab": "-10", - "conc": "0.0001", -} - -def _parse_tango_keyvals(text: str) -> dict: - """ - Parse lines like: - 'AGG 0 AMYLO 6.41e-13 TURN 7.06 HELIX 0 HELAGG 0 BETA 19.67' - into {'AMYLO': [...], 'BETA': [...], ...} - """ - buckets = defaultdict(list) - for line in text.splitlines(): - pairs = re.findall( - r'\b(AGG|AMYLO|TURN|HELIX|HELAGG|BETA)\s+([+-]?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?)\b', - line - ) - for k, v in pairs: - try: - buckets[k].append(float(v)) - except ValueError: - pass - return dict(buckets) - -def _agg(vals, how="sum"): - if not vals: - return None - if how == "sum": - return float(sum(vals)) - if how == "max": - return float(max(vals)) - if how == "mean": - return float(sum(vals) / len(vals)) - return None - -def run_tango_for_sequence( - seq: str, - pH_value: str, - ident: str = "seq", - params: dict | None = None, - exe: str = TANGO_EXE, -) -> dict: - """ - Run TANGO on a single sequence and return: - - amyloid aggregation (AMYLO sum/max) - - Ξ²-sheet aggregation (BETA sum/max) - """ - params = {**DEFAULT_TANGO_PARAMS, **(params or {})} - params["ph"] = pH_value - cmd = [exe, ident] + [f'{k}="{v}"' for k, v in params.items()] + [f'seq="{seq}"'] - - # TANGO likes a single shell command - p = subprocess.run(" ".join(cmd), shell=True, capture_output=True, text=True) - out = (p.stdout or "") + (("\n[STDERR]\n" + p.stderr) if p.stderr else "") - - buckets = _parse_tango_keyvals(out) - - amylo_vals = buckets.get("AMYLO", []) - beta_vals = buckets.get("BETA", []) - agg_vals = buckets.get("AGG", []) - - tango_amylo_max = _agg(amylo_vals, "max") - tango_amylo_sum = _agg(amylo_vals, "sum") - tango_beta_max = _agg(beta_vals, "max") - tango_beta_sum = _agg(beta_vals, "sum") - tango_agg_sum = _agg(agg_vals, "sum") - - return { - "tango_amylo_max": tango_amylo_max, - "tango_amylo_sum": tango_amylo_sum, - "tango_beta_max": tango_beta_max, - "tango_beta_sum": tango_beta_sum, - "tango_agg_sum": tango_agg_sum, - "raw_output": out.strip(), - } + return info # ==================== Gradio Interface ==================== -# Global predictor -predictor = None - -def initialize(): - """Initialize the predictor""" - global predictor - if predictor is None: - predictor = UnifiedPeptidePredictor(model_dir=ASSETS_MODELS) - return predictor - - def predict_properties( input_text: str, - input_type: str, - protein_seq: str, - # Individual property checkboxes - hemolysis: bool, - solubility: bool, - permeability: bool, - half_life: bool, - nonfouling: bool, - binding_affinity: bool, - tango_amyloid: bool, - tango_beta: bool, + input_type: str, # "Sequence" or "SMILES" + protein_text: str, # For binding affinity + selected_props: list[str], # from individual checkboxes include_physicochemical: bool, pH_value: float, progress=gr.Progress() ): - """Main prediction function""" - if not input_text or not input_text.strip(): - return None, "⚠️ Please provide an input sequence" + return None, "⚠️ Please provide input." lines = [s.strip() for s in input_text.split("\n") if s.strip()] - - if input_type == "SMILES": - bad = [s for s in lines if not is_smiles_like(s)] + if input_type == "Sequence": + bad = [s for s in lines if not is_aa_sequence_like(s)] if bad: - return None, f"⚠️ You selected SMILES but {len(bad)} input line(s) don't look like SMILES. Example bad line: {bad[0][:60]}" - if binding_affinity and not protein_seq: - return None, "⚠️ For SMILES binding, please provide a protein sequence in the 'Protein Sequence' box." + return None, f"⚠️ Input Type=Sequence but {len(bad)} line(s) don't look like AA sequences. Example: {bad[0][:60]}" else: - bad = [s for s in lines if not is_aa_sequence_like(s)] + bad = [s for s in lines if not is_smiles_like(s)] if bad: - return None, f"⚠️ You selected Sequence but {len(bad)} input line(s) don't look like amino-acid sequences. Example bad line: {bad[0][:60]}" - pred = initialize() + return None, f"⚠️ Input Type=SMILES but {len(bad)} line(s) don't look like SMILES. Example: {bad[0][:60]}" + + ctx = initialize() + print("keys in ctx.best:", sorted(ctx.best.keys())) + print("loaded model keys:", sorted(ctx.predictor.models.keys())) + print("halflife wt loaded?", ("halflife","wt") in ctx.predictor.models) + print("halflife smiles loaded?", ("halflife","smiles") in ctx.predictor.models) + if not selected_props: + return None, "⚠️ Please select at least one property." + results = [] + analyzer = SequenceAnalyzer() - # Collect selected properties - selected_properties = [] - - checkbox_to_keys = { - 'hemolysis': ['hemolysis_seq', 'hemolysis_smiles'], - 'solubility': ['solubility_seq', 'solubility_smiles'], - 'permeability': ['permeability_smiles'], - 'half_life': ['half_life_seq', 'binding_affinity_smiles'], - 'nonfouling': ['nonfouling_seq', 'nonfouling_smiles'], # adjust if you have a real cytotox model - } - selected_properties = [] - for ui_name, is_selected in { - 'hemolysis': hemolysis, - 'solubility': solubility, - 'permeability': permeability, - 'half_life': half_life, - 'nonfouling': nonfouling, - }.items(): - if not is_selected: - continue - # choose the variant that matches the current input type - keys = checkbox_to_keys.get(ui_name, []) - for key in keys: - if key in pred.model_configs: - expected_input = pred.model_configs[key].get('input', 'sequence') - if (input_type == 'SMILES' and expected_input == 'smiles') or \ - (input_type == 'Sequence' and expected_input == 'sequence'): - if key in pred.models: - selected_properties.append(key) - - # Process sequences for regular properties - if selected_properties: - sequences = [s.strip() for s in input_text.split('\n') if s.strip()] - - for seq_idx, seq in enumerate(sequences): - progress((seq_idx + 1) / len(sequences), f"Processing sequence {seq_idx + 1}/{len(sequences)}") - - for prop in selected_properties: - config = pred.model_configs[prop] - model = pred.models[prop] - + # Check availability + available = get_available_properties(ctx, input_type) + unavailable = [p for p in selected_props if not available.get(p, False)] + if unavailable: + unavailable_names = [PROPERTY_INFO.get(p, {}).get('display', p) for p in unavailable] + return None, f"⚠️ These properties are not supported for {input_type}: {', '.join(unavailable_names)}" + + for i, s in enumerate(lines): + progress((i + 1) / len(lines), f"Processing {i+1}/{len(lines)}") + + # Regular property predictions + for prop in selected_props: + if prop == "binding_affinity": + # Handle binding affinity separately + if not protein_text or not protein_text.strip(): + results.append({ + "Input": s[:30] + "..." if len(s) > 30 else s, + "Property": PROPERTY_INFO[prop]['display'], + "Prediction": "N/A", + "Value": "Requires protein", + "Unit": "", + }) + continue + + mode = "wt" if input_type == "Sequence" else "smiles" try: - value = pred.predict_property(model, config, seq, input_type) + result = ctx.predictor.predict_binding_affinity(mode, protein_text.strip(), s) + affinity = result["affinity"] - stat_key = _base_stat_key(prop) - threshold = pred.data_manager.statistics.get(stat_key, {}).get('threshold') - if threshold is not None: - # which direction? - if stat_key in ['hemolysis']: # lower is better - label = config['positive_label'] if value < threshold else config['negative_label'] - else: # higher is better by default for these examples - label = config['positive_label'] if value > threshold else config['negative_label'] + # Determine binding class based on thresholds + if affinity >= 9: + class_label = "Tight binding" + elif affinity >= 7: + class_label = "Medium binding" else: - label = "" - - # Create clickable property name - prop_display = f'{config["display_name"]}' + class_label = "Weak binding" results.append({ - 'Sequence': seq[:30] + '...' if len(seq) > 30 else seq, - 'Property': config["display_name"], - 'Prediction': label, - 'Value': f"{value:.3f}", - 'Unit': config['unit'] + "Input": s[:30] + "..." if len(s) > 30 else s, + "Property": PROPERTY_INFO[prop]['display'], + "Prediction": class_label, + "Value": f"{affinity:.3f}", + "Unit": "pKd/pKi", }) except Exception as e: - print(f"Error predicting {prop}: {e}") - if input_type == "Sequence": - if include_physicochemical: - seq_display = seq[:30] + '...' if len(seq) > 30 else seq - progress((seq_idx + 0.3) / len(lines), f"Calculating physicochemical properties...") - analysis = pred.analyze_sequence(seq, pH_value) - - results.append({ - 'Sequence': seq_display, - 'Property': 'πŸ“ Length', - 'Prediction': '', - 'Value': str(analysis['length']), - 'Unit': 'aa' - }) - results.append({ - 'Sequence': seq_display, - 'Property': 'βš–οΈ Molecular Weight', - 'Prediction': '', - 'Value': f"{analysis['molecular_weight']:.1f}", - 'Unit': 'Da' - }) + print(f"Error predicting binding affinity: {e}") results.append({ - 'Sequence': seq_display, - 'Property': f'⚑ Net Charge (pH {pH_value})', - 'Prediction': '', - 'Value': f"{analysis['net_charge']:.2f}", - 'Unit': '' + "Input": s[:30] + "..." if len(s) > 30 else s, + "Property": PROPERTY_INFO[prop]['display'], + "Prediction": "Error", + "Value": "Failed", + "Unit": "", }) - results.append({ - 'Sequence': seq_display, - 'Property': '🎯 Isoelectric Point', - 'Prediction': '', - 'Value': f"{analysis['isoelectric_point']:.2f}", - 'Unit': 'pH' - }) - hydro = analysis['hydrophobicity'] - if hydro <= -0.5: - hydro_label = "Hydrophilic" - elif hydro >= 0.5: - hydro_label = "Hydrophobic" - else: - hydro_label = "Intermediate" - - results.append({ - 'Sequence': seq_display, - 'Property': 'πŸ’¦ Hydrophobicity (GRAVY)', - 'Prediction': hydro_label, - 'Value': f"{hydro:.2f}", - 'Unit': 'GRAVY (Kyte-Doolittle)', - }) - if input_type == "Sequence" and (tango_amyloid or tango_beta): - try: - # Run once per sequence - tango_res = run_tango_for_sequence( - seq, - pH_value=pH_value, - ident=f"seq{seq_idx+1}", - params=None # override pH/te here if you want - ) - - short_seq = seq[:30] + '...' if len(seq) > 30 else seq - - if tango_amyloid and tango_res["tango_amylo_sum"] is not None: - results.append({ - 'Sequence': short_seq, - 'Property': "🧱 TANGO Amyloid Aggregation", - 'Prediction': "", - 'Value': f"{tango_res['tango_amylo_sum']:.3f}", - 'Unit': "TANGO (sum)" - }) - - if tango_beta and tango_res["tango_beta_sum"] is not None: - results.append({ - 'Sequence': short_seq, - 'Property': "🧬 TANGO Ξ²-sheet Aggregation", - 'Prediction': "", - 'Value': f"{tango_res['tango_beta_sum']:.3f}", - 'Unit': "TANGO (sum)" - }) - - except Exception as e: - print(f"Error running TANGO for sequence {seq_idx+1}: {e}") - - # Handle binding affinity separately - if binding_affinity and input_text: - # Sequence–Sequence binding - if input_type == "Sequence": - if 'binding_affinity' in pred.models: - progress(0.9, "Predicting binding affinity (sequence) ...") - if not protein_seq: - return None, "⚠️ Please provide a protein sequence for binding prediction." - try: - binder_seqs = [s.strip() for s in input_text.split('\n') if s.strip()] - for binder_seq in binder_seqs: - affinity, binding_class = pred.predict_binding( - pred.models['binding_affinity'], - protein_seq, - binder_seq - ) - kd_nM = pred.affinity_to_nM(affinity) - - seq_label = f"Protein–{binder_seq[:20]}..." - prop_base = pred.model_configs['binding_affinity']['display_name'] - - # Row 1: affinity score (pKd-like) - results.append({ - 'Sequence': seq_label, - 'Property': f"{prop_base} (score)", - 'Prediction': binding_class, - 'Value': f"{affinity:.3f}", - 'Unit': "Affinity score (pKd-like)", - }) - - # Row 2: converted Kd in nM - results.append({ - 'Sequence': seq_label, - 'Property': f"{prop_base} (Kd est.)", - 'Prediction': binding_class, - 'Value': f"{kd_nM:.3g}", - 'Unit': "nM (Kd/Ki/IC50)", - }) - except Exception as e: - print(f"Error in sequence binding prediction: {e}") - - # Sequence + SMILES binding - else: # input_type == "SMILES" - if 'binding_affinity_smiles' not in pred.models: - return None, "⚠️ SMILES binding model not loaded. Please add the checkpoint to models/ and restart." - if not protein_seq: - return None, "⚠️ For SMILES binding, please provide a protein sequence." - # quick AA check for protein_seq - if not is_aa_sequence_like(protein_seq): - return None, "⚠️ The provided protein sequence does not look like an amino-acid sequence." - progress(0.9, "Predicting binding affinity (SMILES) ...") + continue + + # Regular properties + mode = "wt" if input_type == "Sequence" else "smiles" + try: - smiles_list = [s.strip() for s in input_text.split('\n') if s.strip()] - for smi in smiles_list: - affinity, label = pred.predict_binding_smiles( - pred.models['binding_affinity_smiles'], - protein_seq, - smi - ) - kd_nM = pred.affinity_to_nM(affinity) - - seq_label = f"Protein–{smi[:20]}..." - prop_base = pred.model_configs['binding_affinity_smiles']['display_name'] - - # Row 1: affinity score (pKd-like) - results.append({ - 'Sequence': seq_label, - 'Property': f"{prop_base} (score)", - 'Prediction': label, # Tight / Medium / Weak - 'Value': f"{affinity:.3f}", - 'Unit': "Affinity score (pKd-like)", - }) - - # Row 2: converted Kd in nM - results.append({ - 'Sequence': seq_label, - 'Property': f"{prop_base} (Kd est.)", - 'Prediction': label, - 'Value': f"{kd_nM:.3g}", - 'Unit': "nM (Kd/Ki/IC50)", - }) + result = ctx.predictor.predict_property(prop, mode, s) + score = result["score"] + + prop_info = PROPERTY_INFO.get(prop, {}) + + # Determine label based on property type + if prop in ['permeability_pampa', 'permeability_caco2']: + # Special handling for permeability assays + label = prop_info['pass_label'] if score > -6 else prop_info['fail_label'] + unit = "log Peff" if prop == 'permeability_pampa' else "log Papp" + elif prop == 'halflife': + # Regression task, no pass/fail + label = "β€”" + unit = prop_info.get('unit', 'hours') + else: + # Classification tasks + thr = get_threshold(ctx, prop, input_type) + if thr is not None: + if prop in LOWER_BETTER: + label = prop_info.get('pass_label', 'Pass') if score < thr else prop_info.get('fail_label', 'Fail') + else: + label = prop_info.get('pass_label', 'Pass') if score >= thr else prop_info.get('fail_label', 'Fail') + else: + label = "β€”" + unit = "Probability" + + results.append({ + "Input": s[:30] + "..." if len(s) > 30 else s, + "Property": prop_info.get('display', prop), + "Prediction": label, + "Value": f"{score:.3f}", + "Unit": unit, + }) except Exception as e: - print(f"Error in SMILES binding prediction: {e}") - - if not results: - return None, "⚠️ Please select at least one property to predict" - - # Create summary - n_sequences = len(set(r['Sequence'] for r in results)) - n_properties = len(set(r['Property'] for r in results)) - - status = f"βœ… Completed {len(results)} predictions ({n_sequences} sequence(s), {n_properties} properties)" - if binding_affinity: - status += " \n**Binding class cutoffs:** Tight β‰₯ 7.5, Medium 6.0–7.5, Weak < 6.0" - - return pd.DataFrame(results), status + print(f"Error predicting {prop} for {s[:30]}: {e}") + continue + # optional physicochemical only for AA sequence modality + if input_type == "Sequence" and include_physicochemical: + analysis = { + "length": len(s), + "molecular_weight": analyzer.calculate_molecular_weight(s), + "net_charge": analyzer.calculate_net_charge(s, pH_value), + "isoelectric_point": analyzer.calculate_isoelectric_point(s), + "hydrophobicity": analyzer.calculate_hydrophobicity(s), + } + short = s[:30] + "..." if len(s) > 30 else s + results += [ + {"Input": short, "Property": "πŸ“ Length", "Prediction": "", "Value": str(analysis["length"]), "Unit": "aa"}, + {"Input": short, "Property": "βš–οΈ Molecular Weight", "Prediction": "", "Value": f"{analysis['molecular_weight']:.1f}", "Unit": "Da"}, + {"Input": short, "Property": f"⚑ Net Charge (pH {pH_value})", "Prediction": "", "Value": f"{analysis['net_charge']:.2f}", "Unit": ""}, + {"Input": short, "Property": "🎯 Isoelectric Point", "Prediction": "", "Value": f"{analysis['isoelectric_point']:.2f}", "Unit": "pH"}, + {"Input": short, "Property": "πŸ’¦ Hydrophobicity (GRAVY)", "Prediction": "", "Value": f"{analysis['hydrophobicity']:.2f}", "Unit": "GRAVY"}, + ] + + df = pd.DataFrame(results) + status = f"βœ… Completed {len(df)} rows ({len(lines)} input(s), {len(selected_props)} selected properties)." + return df, status def show_distribution(property_name, predicted_value=None): """Show distribution plot + info for selected property.""" - pred = initialize() + data_manager = TrainingDataManager() + if not property_name: return None, "Select a property to view its distribution." # Get the first property if a list was passed prop = property_name[0] if isinstance(property_name, list) else property_name - # Generate the plot (works for both binary & continuous) - fig = pred.data_manager.get_distribution_plot(prop, predicted_value) + # Generate the plot + fig = data_manager.get_distribution_plot(prop, predicted_value) - # Build info panel with correct fields per kind - stats = pred.data_manager.statistics.get(prop, {}) - kind = stats.get("kind", "continuous") - info = pred.data_manager.get_property_info(prop) + # Build info panel + info = data_manager.get_property_info(prop) if not info: return fig, "No information available for this property." - title = prop.replace('_', ' ').title() + prop_info = PROPERTY_INFO.get(prop, {}) + title = DIST_KEYS.get(prop, PROPERTY_INFO.get(prop, {}).get("display", prop)) + kind = data_manager.statistics.get(prop, {}).get("kind", "continuous") + if kind == "binary": - n_pos = info.get("n_pos", int((stats.get("values") == 1).sum() if "values" in stats else 0)) - n_neg = info.get("n_neg", int((stats.get("values") == 0).sum() if "values" in stats else 0)) + n_pos = info.get("n_pos", 0) + n_neg = info.get("n_neg", 0) total = max(n_pos + n_neg, 1) info_text = f""" ### {title} Information @@ -1655,8 +1041,8 @@ def show_distribution(property_name, predicted_value=None): **Statistics (Binary):** - Samples: {info['n_samples']:,} -- Positives (1): {n_pos:,} ({n_pos/total:.1%}) -- Negatives (0): {n_neg:,} ({n_neg/total:.1%}) +- {prop_info.get('pass_label', 'Positive')} (1): {n_pos:,} ({n_pos/total:.1%}) +- {prop_info.get('fail_label', 'Negative')} (0): {n_neg:,} ({n_neg/total:.1%}) """ else: p = info.get("percentiles", {}) @@ -1681,19 +1067,18 @@ def show_distribution(property_name, predicted_value=None): return fig, info_text - def load_example(example_name): """Load example sequences""" examples = { "T7": ("HAIYPRH", ""), - "Protein-Peptide": ("MVHLTPEEKSAVTALWGKVNVDEVGGEALGRLLVVYPWTQRFFESFGDLST", "GIVEQCCTSICSLYQLENYCN") + "Protein-Peptide": ("GIVEQCCTSICSLYQLENYCN", "MVHLTPEEKSAVTALWGKVNVDEVGGEALGRLLVVYPWTQRFFESFGDLST"), + "Cyclic Peptide (SMILES)": ("CC(C)C[C@@H]1NC(=O)[C@@H](CC(C)C)N(C)C(=O)[C@@H](C)N(C)C(=O)[C@H](Cc2ccccc2)NC(=O)[C@H](CC(C)C)N(C)C(=O)[C@H]2CCCN2C1=O", ""), + "Protein-Cyclic Peptide (SMILES)": ("CC(C)C[C@@H]1NC(=O)[C@@H](CC(C)C)N(C)C(=O)[C@@H](C)N(C)C(=O)[C@H](Cc2ccccc2)NC(=O)[C@H](CC(C)C)N(C)C(=O)[C@H]2CCCN2C1=O", "MVHLTPEEKSAVTALWGKVNVDEVGGEALGRLLVVYPWTQRFFESFGDLST") + } if example_name in examples: - if example_name == "Protein-Peptide": - return examples[example_name][1], examples[example_name][0] # Binder, Protein - else: - return examples[example_name][0], "" + return examples[example_name] return "", "" def on_example_change(name: str): @@ -1704,13 +1089,37 @@ def on_example_change(name: str): gr.update(value=protein, visible=show_protein) # protein_seq (and toggle visibility) ) -def on_example_load(name: str): - binder, protein = load_example(name) - show_protein = (name == "Protein-Peptide") - return ( - gr.update(value=binder), # input_text - gr.update(value=protein, visible=show_protein) # protein_seq + visibility - ) +def on_modality_change(modality, *checkbox_values): + ctx = initialize() + available = get_available_properties(ctx, modality) + + updates = [] + for i, prop_key in enumerate(PROP_ORDER): + is_available = available.get(prop_key, False) + prop_info = PROPERTY_INFO[prop_key] + label_text = f"{prop_info['display']} {prop_info.get('direction','')}".rstrip() + if not is_available: + label_text += " (Not supported)" + if prop_key == "binding_affinity" and is_available: + label_text += " *" + + current_value = checkbox_values[i] if i < len(checkbox_values) else False + updates.append(gr.update( + label=label_text, + interactive=is_available, + value=False if not is_available else current_value + )) + return updates + + +def collect_selected_properties(*checkbox_values): + selected = [] + for i, prop_key in enumerate(PROP_ORDER): + if i < len(checkbox_values) and checkbox_values[i]: + selected.append(prop_key) + return selected + + # ==================== Gradio App ==================== custom_css = """ @@ -1759,6 +1168,7 @@ table { """ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="indigo")) as demo: + ctx = initialize() # Header gr.Markdown( @@ -1784,40 +1194,35 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="indigo")) as de ) input_text = gr.Textbox( - label="Peptide Sequence(s) / Binder", - placeholder="Enter amino acid sequence(s), one per line", + label="Peptide Sequence(s) / SMILES", + placeholder="Enter amino acid sequence(s) or SMILES, one per line", lines=6 ) protein_seq = gr.Textbox( label="Protein Sequence (for binding prediction)", - placeholder="Enter protein sequence for binding affinity prediction", + placeholder="Enter target protein sequence", lines=3, visible=False ) gr.Markdown("**Examples:**") example_dropdown = gr.Dropdown( - choices=["T7","Protein-Peptide"], + choices=["T7 Peptide", "Protein-Peptide", "Cyclic Peptide (SMILES)", "Protein-Cyclic Peptide (SMILES)"], label="Load Example", interactive=True ) - - file_input = gr.File( - label="Or Upload File", - file_types=[".txt", ".fasta", ".fa"], - visible=False - ) - # Property Selection + # Property Selection - Fixed order to prevent checkbox mapping issues with gr.Column(scale=1): with gr.Group(): gr.Markdown("### βš™οΈ Select Properties") + with gr.Accordion("Physicochemical Properties", open=True): include_physicochemical = gr.Checkbox( label="πŸ§ͺ Calculate Basic Properties", value=True, - info="MW, net charge, pI, hydrophobicity" + info="MW, net charge, pI, hydrophobicity (Sequence only)" ) pH_value = gr.Slider( @@ -1828,34 +1233,67 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="indigo")) as de label="pH for Net Charge", info="Physiological pH is ~7.4" ) - with gr.Accordion("Sequence Properties", open=True): - hemolysis = gr.Checkbox(label="🩸 Hemolysis ↓", value=True) - solubility = gr.Checkbox(label="πŸ’§ Solubility ↑", value=True) - permeability = gr.Checkbox(label="πŸͺ£ Permeability ↑", value=False) - half_life = gr.Checkbox(label="⏱️ Half-life ↑", value=False) - nonfouling = gr.Checkbox(label="πŸ‘― Non-Fouling ↑", value=False) - tango_amyloid = gr.Checkbox(label="🧱 TANGO Amyloid Aggregation ↓", value=False) - tango_beta = gr.Checkbox(label="🧬 TANGO Ξ²-sheet Aggregation ↓", value=False) - with gr.Accordion("Binding Prediction", open=False): - binding_affinity = gr.Checkbox(label="πŸ”— Binding Affinity ↑", value=False) - gr.Markdown("*Requires protein sequence input*") + + # Create individual checkboxes in fixed order + with gr.Accordion("Prediction Properties", open=True): + property_checkboxes = [] + available = get_available_properties(ctx, "Sequence") + + for prop_key in PROP_ORDER: + prop_info = PROPERTY_INFO[prop_key] + is_available = available.get(prop_key, False) + + label_text = f"{prop_info['display']} {prop_info.get('direction','')}".rstrip() + if not is_available: + label_text += " (Not supported)" + if prop_key == "binding_affinity" and is_available: + label_text += " *" + + default_on = (prop_key in ["solubility", "hemolysis"]) # optional defaults + cb = gr.Checkbox( + label=label_text, + value=is_available and default_on, + interactive=is_available, + elem_id=f"checkbox_{prop_key}", + ) + property_checkboxes.append(cb) + + gr.Markdown("*Requires protein sequence input above", elem_classes="text-sm text-gray-500") + + + # Best Models Tab + with gr.TabItem("πŸ“‹ Best Models"): + gr.Markdown("### Current Best Models Configuration") + gr.Markdown("This table shows the models and thresholds currently being used for predictions:") + best_models_df = gr.Dataframe( + value=get_best_models_table(ctx), + headers=["Property", "Best Model (Sequence)", "Threshold (Sequence)", + "Best Model (SMILES)", "Threshold (SMILES)", "Task Type"], + interactive=False + ) + gr.Markdown(""" + **Note:** Models marked as SVM, SVR, or ENET are automatically replaced with XGB + as these models are not currently supported in the deployment environment. + """) + # Distribution Analysis Tab with gr.TabItem("πŸ“Š Distributions"): with gr.Row(): with gr.Column(scale=1): + dist_choices = list(PROPERTY_INFO.keys()) + list(DIST_KEYS.keys()) + property_selector = gr.Dropdown( - choices=["hemolysis", "solubility", "permeability", "half_life (smiles)", - "nonfouling", "binding_affinity", "tango_amyloid", "tango_beta"], + choices=dist_choices, label="Select Property", - value="hemolysis" + value="binding_affinity_all" ) test_value = gr.Number(label="Test Value among Distribution", value=None) show_dist_btn = gr.Button("Show Distribution") with gr.Column(scale=2): dist_plot_tab = gr.Plot(label="Score Distribution") - dist_info_tab = gr.Markdown() - + dist_info_tab = gr.Markdown() + # Data Documentation Tab with gr.TabItem("πŸ“š Documentation"): file_path = "description.md" @@ -1869,7 +1307,6 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="indigo")) as de gr.Markdown( markdown_content ) - # Action Buttons with gr.Row(): clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary") @@ -1883,25 +1320,18 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="indigo")) as de gr.Markdown("### πŸ“Š Results") results_df = gr.Dataframe( - headers=["Sequence", "Property", "Prediction", "Value", "Unit"], + headers=["Input", "Property", "Prediction", "Value", "Unit"], datatype=["str", "str", "str", "str", "str"], interactive=False ) - # Hidden components for distribution modal - with gr.Row(visible=False) as distribution_row: - with gr.Column(): - selected_property = gr.Textbox(visible=False) - dist_plot_modal = gr.Plot() # <-- renamed - dist_info_modal = gr.Markdown() # <-- renamed - close_btn = gr.Button("Close") - # Footer gr.Markdown( """ ---
-

Please Cite Us.

+

PeptiVerse - Advanced Peptide Property Predictions

+

Please cite our work if you use this tool in your research.

""" ) @@ -1910,10 +1340,20 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="indigo")) as de def update_visibility(binding_checked): return gr.update(visible=binding_checked) - binding_affinity.change( + # Update checkbox states when modality changes + input_type.change( + on_modality_change, + inputs=[input_type] + property_checkboxes, + outputs=property_checkboxes + ) + + # Show protein sequence input when binding affinity is selected + BINDING_IDX = PROP_ORDER.index("binding_affinity") + + property_checkboxes[BINDING_IDX].change( update_visibility, - inputs=[binding_affinity], - outputs=[protein_seq] + inputs=[property_checkboxes[BINDING_IDX]], + outputs=[protein_seq], ) example_dropdown.change( @@ -1921,29 +1361,23 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="indigo")) as de inputs=[example_dropdown], outputs=[input_text, protein_seq] ) + predict_btn.click( - predict_properties, - inputs=[ - input_text, input_type, protein_seq, - hemolysis, solubility, permeability, - half_life, nonfouling, - binding_affinity, tango_amyloid, tango_beta, include_physicochemical, pH_value, - ], + lambda input_text, input_type, protein_text, include_physicochemical, pH_value, *checkbox_values: + predict_properties( + input_text, input_type, protein_text, + collect_selected_properties(*checkbox_values), + include_physicochemical, pH_value + ), + inputs=[input_text, input_type, protein_seq, include_physicochemical, pH_value] + property_checkboxes, outputs=[results_df, status_output] ) - + clear_btn.click( - lambda: ("", "", None, ""), - outputs=[input_text, protein_seq, results_df, status_output] + lambda: ["", "", None, ""] + [False] * len(property_checkboxes), + outputs=[input_text, protein_seq, results_df, status_output] + property_checkboxes ) - # Add JavaScript for clickable property names - demo.load(js=""" - function show_distribution(property, value) { - // This would open a modal with the distribution - console.log('Show distribution for', property, 'with value', value); - } - """) show_dist_btn.click( show_distribution, inputs=[property_selector, test_value], @@ -1954,4 +1388,4 @@ if __name__ == "__main__": print("Initializing models...") initialize() print("Ready!") - demo.launch(share=True) + demo.launch(share=True) \ No newline at end of file