PeptiVerse / app.py
ynuozhang
update models
728610a
import gradio as gr
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import xgboost as xgb
from transformers import AutoTokenizer, AutoModel, AutoConfig, EsmModel, EsmTokenizer
import plotly.graph_objects as go
from pathlib import Path
import json
import time
from typing import List, Dict, Any, Tuple, Optional
import subprocess
from collections import defaultdict
from huggingface_hub import snapshot_download
from pathlib import Path
import os
from inference import (
PeptiVersePredictor,
read_best_manifest_csv,
BestRow,
canon_model,
)
try:
from Bio.SeqUtils.ProtParam import ProteinAnalysis
BIOPYTHON_AVAILABLE = True
except ImportError:
BIOPYTHON_AVAILABLE = False
print("BioPython not available. Using fallback for pI/charge calculations.")
def pick_assets_root() -> Path:
# HF Spaces container uses /home/user; detect via SPACE_ID or existence
spaces_root = Path("/home/user/assets")
if os.environ.get("SPACE_ID") or spaces_root.parent.exists():
try:
spaces_root.mkdir(parents=True, exist_ok=True)
return spaces_root
except Exception:
pass # fall through to local options
# Allow manual override
env = os.environ.get("HF_ASSETS_DIR")
if env:
p = Path(env); p.mkdir(parents=True, exist_ok=True)
return p
# Local fallbacks
for p in [Path.home() / "assets", Path.cwd() / "assets", Path("/tmp/assets")]:
try:
p.mkdir(parents=True, exist_ok=True)
return p
except Exception:
continue
raise RuntimeError("No writable assets directory found.")
ASSETS = pick_assets_root()
# Put all caches on the same writable disk
for k, v in {
"HF_HOME": str(ASSETS / "hf"),
"HUGGINGFACE_HUB_CACHE": str(ASSETS / "hf" / "cache"),
"TRANSFORMERS_CACHE": str(ASSETS / "transformers"),
"HF_DATASETS_CACHE": str(ASSETS / "hf" / "datasets"),
"XDG_CACHE_HOME": str(ASSETS / "xdg"),
"TMPDIR": str(ASSETS / "tmp"),
}.items():
os.environ.setdefault(k, v)
Path(v).mkdir(parents=True, exist_ok=True)
ASSETS_MODELS = ASSETS / "models"; ASSETS_MODELS.mkdir(parents=True, exist_ok=True)
ASSETS_DATA = ASSETS / "training_data_cleaned"; ASSETS_DATA.mkdir(parents=True, exist_ok=True)
MODEL_REPO = "ChatterjeeLab/PeptiVerse" # model repo
DATASET_REPO = "ChatterjeeLab/PeptiVerse" # dataset repo
def fetch_models_and_data():
snapshot_download(
repo_id=MODEL_REPO,
local_dir=str(ASSETS_MODELS),
local_dir_use_symlinks=False,
allow_patterns=[
# Model files
"training_classifiers/**/best_model*.json",
"training_classifiers/**/best_model*.pt",
"training_classifiers/**/best_model*.joblib",
# Tokenizer files
"tokenizer/new_vocab.txt",
"tokenizer/new_splits.txt",
# Training data for distributions
"training_data_cleaned/**/*.csv",
],
)
fetch_models_and_data()
BEST_TXT = Path("best_models.txt")
TRAINING_ROOT = ASSETS_MODELS / "training_classifiers"
TOKENIZER_DIR = ASSETS_MODELS / "tokenizer"
# Banned models that should fall back to XGB
BANNED_MODELS = {"svm", "enet", "svm_gpu", "enet_gpu"}
# "lower is better" exceptions for classification labeling
LOWER_BETTER = {"hemolysis", "toxicity"}
# Property display names and descriptions
PROPERTY_INFO = {
'solubility': {
'display': 'πŸ’§ Solubility',
'description': 'Aqueous solubility',
'direction': '↑',
'pass_label': 'Soluble',
'fail_label': 'Insoluble'
},
'permeability_penetrance': {
'display': 'πŸ”¬ Permeability (Penetrance)',
'description': 'Cell penetration capability',
'direction': '↑',
'pass_label': 'Permeable',
'fail_label': 'Non-permeable'
},
'hemolysis': {
'display': '🩸 Hemolysis',
'description': 'Red blood cell membrane disruption',
'direction': '↓',
'pass_label': 'Non-hemolytic',
'fail_label': 'Hemolytic'
},
'nf': {
'display': 'πŸ‘― Non-Fouling',
'description': 'Resistance to protein adsorption',
'direction': '↑',
'pass_label': 'Non-fouling',
'fail_label': 'Fouling'
},
'halflife': {
'display': '⏱️ Half-Life',
'description': 'Serum stability',
'direction': '↑',
'unit': 'hours'
},
'toxicity': {
'display': '☠️ Toxicity',
'description': 'Cytotoxicity',
'direction': '↓',
'pass_label': 'Non-toxic',
'fail_label': 'Toxic'
},
'permeability_pampa': {
'display': 'πŸͺ£ Permeability (PAMPA)',
'description': 'PAMPA assay permeability',
'direction': '',
'threshold': -6, # Values > -6 are permeable
'pass_label': 'Permeable',
'fail_label': 'Non-permeable'
},
'permeability_caco2': {
'display': 'πŸͺ£ Permeability (Caco-2)',
'description': 'Caco-2 cell permeability',
'direction': '',
'threshold': -6, # Values > -6 are permeable
'pass_label': 'Permeable',
'fail_label': 'Non-permeable'
},
'binding_affinity': {
'display': 'πŸ”— Binding Affinity',
'description': 'Protein-peptide binding strength',
'direction': '↑',
'thresholds': {'tight': 9, 'weak': 7}
}
}
PROP_ORDER = [
'solubility',
'permeability_penetrance',
'hemolysis',
'nf',
'halflife',
'toxicity',
'permeability_pampa',
'permeability_caco2',
'binding_affinity',
]
# Distribution-only keys
DIST_KEYS = {
"binding_affinity_wt": "πŸ”— Binding Affinity β€” WT (distribution)",
"binding_affinity_smiles": "πŸ”— Binding Affinity β€” SMILES (distribution)",
"binding_affinity_all": "πŸ”— Binding Affinity β€” WT+SMILES (distribution)",
"halflife_wt": "⏱️ Half-life β€” WT (distribution)",
"halflife_smiles": "⏱️ Half-life β€” SMILES (distribution)",
"halflife_all": "⏱️ Half-life β€” WT+SMILES (distribution)",
}
def create_filtered_manifest(manifest_path: Path) -> Dict[str, BestRow]:
"""Read manifest and replace banned models with XGB"""
original = read_best_manifest_csv(manifest_path)
filtered = {}
for prop_key, row in original.items():
# Normalize property key for half-life
normalized_key = prop_key
if prop_key in ['halflife', 'half_life']:
normalized_key = 'halflife'
# Check and potentially replace WT model
wt_model = canon_model(row.best_wt)
if wt_model in BANNED_MODELS:
wt_model = "XGB"
elif wt_model is None:
wt_model = row.best_wt
else:
wt_model = row.best_wt
# Check and potentially replace SMILES model
smiles_model = canon_model(row.best_smiles)
if smiles_model in BANNED_MODELS:
smiles_model = "XGB"
elif smiles_model is None:
smiles_model = row.best_smiles
else:
smiles_model = row.best_smiles
# Create modified row
filtered[normalized_key] = BestRow(
property_key=normalized_key,
best_wt=wt_model if wt_model != row.best_wt else row.best_wt,
best_smiles=smiles_model if smiles_model != row.best_smiles else row.best_smiles,
task_type=row.task_type,
thr_wt=row.thr_wt,
thr_smiles=row.thr_smiles,
)
return filtered
class AppContext:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.best = create_filtered_manifest(BEST_TXT)
self.predictor = PeptiVersePredictor(
manifest_path=BEST_TXT,
classifier_weight_root=ASSETS_MODELS,
esm_name="facebook/esm2_t33_650M_UR50D",
clm_name="aaronfeller/PeptideCLM-23M-all",
smiles_vocab=str(TOKENIZER_DIR / "new_vocab.txt"),
smiles_splits=str(TOKENIZER_DIR / "new_splits.txt"),
device=str(self.device),
)
# override manifest AND reload models so keys/folders match
self.predictor.manifest = self.best
self.predictor.models.clear()
self.predictor.meta.clear()
self.predictor._load_all_best_models()
CTX: AppContext | None = None
def initialize():
global CTX
if CTX is None:
CTX = AppContext()
return CTX
def get_available_properties(ctx, modality: str) -> Dict[str, bool]:
"""
Returns dict of property -> bool indicating if available for the modality
"""
available = {}
for prop_key in PROPERTY_INFO.keys():
if prop_key not in ctx.best:
available[prop_key] = False
continue
row = ctx.best[prop_key]
if modality == "Sequence":
model = row.best_wt
else:
model = row.best_smiles
# Check if model exists and is not empty/dash
if not model or model in {"-", "β€”", "NA", "N/A", None}:
available[prop_key] = False
else:
# Check if we actually have the model loaded
mode = "wt" if modality == "Sequence" else "smiles"
available[prop_key] = (prop_key, mode) in ctx.predictor.models
return available
def get_threshold(ctx: AppContext, prop: str, modality: str) -> float | None:
row = ctx.best.get(prop)
if row is None:
return None
return row.thr_wt if modality == "Sequence" else row.thr_smiles
def get_best_models_table(ctx: AppContext) -> pd.DataFrame:
"""Generate a table showing best models and thresholds"""
data = []
for prop_key, row in ctx.best.items():
prop_info = PROPERTY_INFO.get(prop_key, {})
display_name = prop_info.get('display', prop_key)
data.append({
'Property': display_name,
'Best Model (Sequence)': row.best_wt if row.best_wt else 'β€”',
'Threshold (Sequence)': f"{row.thr_wt:.4f}" if row.thr_wt is not None else 'β€”',
'Best Model (SMILES)': row.best_smiles if row.best_smiles else 'β€”',
'Threshold (SMILES)': f"{row.thr_smiles:.4f}" if row.thr_smiles is not None else 'β€”',
'Task Type': row.task_type
})
return pd.DataFrame(data)
try:
from rdkit import Chem
from rdkit.Chem import Descriptors, AllChem
RDKIT_AVAILABLE = True
except ImportError:
RDKIT_AVAILABLE = False
print("RDKit not available. SMILES input will be disabled.")
import re
AA_RE = re.compile(r'^[ACDEFGHIKLMNPQRSTVWYBXZJUO\-]+$', re.IGNORECASE)
def is_aa_sequence_like(s: str) -> bool:
s = s.strip().replace(" ", "")
if not s:
return False
# Very lenient: allow AA letters + optional '-' for readability
return bool(AA_RE.fullmatch(s)) and any(c.isalpha() for c in s)
def is_smiles_like(s: str) -> bool:
s = s.strip()
if not s:
return False
# Heuristic: SMILES often contains these symbols; also reject if it looks like pure AA
maybe_smiles_chars = set("=#()[]+\\/-@1234567890")
return (any(ch in maybe_smiles_chars for ch in s) or not is_aa_sequence_like(s)) and len(s) >= 2
# ==================== Sequence Analysis ====================
class SequenceAnalyzer:
"""Calculate physicochemical properties of peptide sequences
If biopython fail.
"""
# pKa values for amino acids
PKA_VALUES = {
'N_term': 9.6,
'C_term': 2.3,
'D': 3.9, # Aspartic acid
'E': 4.2, # Glutamic acid
'H': 6.0, # Histidine
'C': 8.3, # Cysteine
'Y': 10.1, # Tyrosine
'K': 10.5, # Lysine
'R': 12.5, # Arginine
}
@classmethod
def calculate_net_charge(cls, sequence: str, pH: float = 7.0) -> float:
"""Calculate net charge at given pH using Henderson-Hasselbalch equation"""
if BIOPYTHON_AVAILABLE:
try:
analyzer = ProteinAnalysis(sequence)
return analyzer.charge_at_pH(pH)
except:
pass
# Fallback calculation
charge = 0
# N-terminus
charge += 1 / (1 + 10**(pH - cls.PKA_VALUES['N_term']))
# C-terminus
charge -= 1 / (1 + 10**(cls.PKA_VALUES['C_term'] - pH))
# Count charged residues
for aa in sequence:
if aa in 'KR': # Positive
pKa = cls.PKA_VALUES.get(aa, cls.PKA_VALUES['K' if aa == 'K' else 'R'])
charge += 1 / (1 + 10**(pH - pKa))
elif aa in 'DE': # Negative
pKa = cls.PKA_VALUES.get(aa, cls.PKA_VALUES['D' if aa == 'D' else 'E'])
charge -= 1 / (1 + 10**(pKa - pH))
elif aa == 'H': # Histidine (positive when protonated)
charge += 1 / (1 + 10**(pH - cls.PKA_VALUES['H']))
elif aa == 'C': # Cysteine (negative when deprotonated)
charge -= 1 / (1 + 10**(cls.PKA_VALUES['C'] - pH))
elif aa == 'Y': # Tyrosine (negative when deprotonated)
charge -= 1 / (1 + 10**(cls.PKA_VALUES['Y'] - pH))
return round(charge, 2)
@classmethod
def calculate_isoelectric_point(cls, sequence: str) -> float:
"""Calculate theoretical pI using bisection method"""
if BIOPYTHON_AVAILABLE:
try:
analyzer = ProteinAnalysis(sequence)
return analyzer.isoelectric_point()
except:
pass
# Fallback: Bisection method
pH_min, pH_max = 0.0, 14.0
epsilon = 0.01
while (pH_max - pH_min) > epsilon:
pH_mid = (pH_min + pH_max) / 2
charge = cls.calculate_net_charge(sequence, pH_mid)
if abs(charge) < epsilon:
return round(pH_mid, 2)
if charge > 0:
pH_min = pH_mid
else:
pH_max = pH_mid
return round((pH_min + pH_max) / 2, 2)
@classmethod
def calculate_molecular_weight(cls, sequence: str) -> float:
"""Calculate molecular weight"""
if BIOPYTHON_AVAILABLE:
try:
analyzer = ProteinAnalysis(sequence)
return analyzer.molecular_weight()
except:
pass
# Fallback: approximate calculation
weights = {
'A': 89.1, 'C': 121.2, 'D': 133.1, 'E': 147.1, 'F': 165.2,
'G': 75.1, 'H': 155.2, 'I': 131.2, 'K': 146.2, 'L': 131.2,
'M': 149.2, 'N': 132.1, 'P': 115.1, 'Q': 146.2, 'R': 174.2,
'S': 105.1, 'T': 119.1, 'V': 117.1, 'W': 204.2, 'Y': 181.2
}
mw = sum(weights.get(aa, 0) for aa in sequence)
# Subtract water for peptide bonds
mw -= 18.0 * (len(sequence) - 1)
return round(mw, 1)
@classmethod
def calculate_hydrophobicity(cls, sequence: str) -> float:
"""Calculate GRAVY (grand average of hydropathy)"""
if BIOPYTHON_AVAILABLE:
try:
analyzer = ProteinAnalysis(sequence)
return analyzer.gravy()
except:
pass
# Kyte-Doolittle scale
hydrophobicity = {
'A': 1.8, 'C': 2.5, 'D': -3.5, 'E': -3.5, 'F': 2.8,
'G': -0.4, 'H': -3.2, 'I': 4.5, 'K': -3.9, 'L': 3.8,
'M': 1.9, 'N': -3.5, 'P': -1.6, 'Q': -3.5, 'R': -4.5,
'S': -0.8, 'T': -0.7, 'V': 4.2, 'W': -0.9, 'Y': -1.3
}
if len(sequence) == 0:
return 0
total = sum(hydrophobicity.get(aa, 0) for aa in sequence)
return round(total / len(sequence), 2)
# ==================== Data Management ====================
class TrainingDataManager:
def __init__(self, data_dir=None):
possible_dirs = [
ASSETS_MODELS / "training_data_cleaned", # In HF downloaded location
Path("training_data_cleaned"), # Local relative path
ASSETS_DATA, # Original location
]
self.data_dir = None
for d in possible_dirs:
if d.exists():
self.data_dir = d
print(f"Using data directory: {d}")
break
if self.data_dir is None:
print(f"WARNING: No data directory found. Tried: {possible_dirs}")
self.data_dir = ASSETS_DATA # Fallback
self.data_dir.mkdir(exist_ok=True)
self.statistics = self.load_statistics()
def load_csv_data(self, filepath: Path, value_column, is_binary: bool = False) -> Optional[Dict]:
"""Load data from a CSV file.
value_column can be a string OR a list/tuple of candidate column names.
"""
if not filepath.exists():
print(f"File not found: {filepath}")
return None
try:
df = pd.read_csv(filepath, encoding="utf-8", on_bad_lines="skip")
print(f"Columns in {filepath.name}: {df.columns.tolist()[:5]}...")
# Case-insensitive column map
col_lower = {col.lower(): col for col in df.columns}
# allow list/tuple of candidates
if isinstance(value_column, (list, tuple)):
chosen = None
for c in value_column:
if c is None:
continue
c_l = str(c).lower()
if c_l in col_lower:
chosen = col_lower[c_l]
break
if chosen is None:
print(f"None of candidate columns {value_column} found. Available: {list(df.columns)[:10]}")
return None
value_column = chosen
else:
# keep original behavior, but safe-cast to str
vc_l = str(value_column).lower()
if vc_l not in col_lower:
alternatives = {
'label': ['label', 'labels', 'y', 'target'],
'affinity': ['affinity', 'pkd', 'pki', 'binding_affinity'],
'pampa': ['pampa', 'pampa_value', 'permeability'],
'caco2': ['caco2', 'caco-2', 'caco_2'],
'log_hour': ['log_hour', 'loghour', 'log_hours', 'loghours'],
'half_life_hours': ['half_life_hours', 'halflife_hours', 'hours'],
'half_life_seconds': ['half_life_seconds', 'halflife_seconds', 'seconds'],
}
found = False
for alt in alternatives.get(vc_l, []):
if alt.lower() in col_lower:
value_column = col_lower[alt.lower()]
found = True
break
if not found:
print(f"Column {value_column} not found. Available: {list(df.columns)[:10]}")
return None
else:
value_column = col_lower[vc_l]
vals = pd.to_numeric(df[value_column], errors="coerce").dropna().to_numpy()
if len(vals) == 0:
print(f"No valid values found in column {value_column}")
return None
print(f"Loaded {len(vals)} values from {filepath.name}")
if is_binary:
unique_vals = np.unique(vals)
if not set(unique_vals).issubset({0, 1, 0.0, 1.0}):
vals = (vals > 0.5).astype(int)
return {"values": vals, "n_samples": len(vals)}
except Exception as e:
print(f"Error loading {filepath}: {e}")
import traceback
traceback.print_exc()
return None
def load_statistics(self):
"""Load pre-computed statistics for each property from actual data files"""
stats = {}
# Map properties to their data files and value columns
data_mappings = {
'hemolysis': {
'files': [
'hemolysis/hemo_meta_with_split.csv',
'hemolysis/hemolysis_meta_with_split.csv',
],
'column': 'label',
'is_binary': True
},
'solubility': {
'files': [
'solubility/sol_meta_with_split.csv',
'solubility/solubility_meta_with_split.csv',
],
'column': 'label',
'is_binary': True
},
"binding_affinity_wt": {
"files": ["binding_affinity/binding_affinity_wt_meta_with_split.csv"],
"column": "affinity",
"is_binary": False
},
"binding_affinity_smiles": {
"files": ["binding_affinity/binding_affinity_smiles_meta_with_split.csv"],
"column": "affinity",
"is_binary": False
},
"binding_affinity_all": {
"files": [
"binding_affinity/binding_affinity_wt_meta_with_split.csv",
"binding_affinity/binding_affinity_smiles_meta_with_split.csv",
],
"column": "affinity",
"is_binary": False
},
"halflife_wt": {
"files": [
"half_life/halflife_with_split.csv",
"half_life/halflife_meta_with_split.csv",
],
"column": ["half_life_hours", "log_hour", "log_hours"],
"is_binary": False
},
"halflife_smiles": {
"files": [
"half_life/halflife_smiles_with_split.csv",
"half_life/halflife_smiles_with_splits.csv",
"half_life/halflife_smiles_meta_with_split.csv",
],
"column": ["half_life_hours", "log_hour", "log_hours"],
"is_binary": False
},
"halflife_all": {
"files": [
"half_life/halflife_with_split.csv",
"half_life/halflife_meta_with_split.csv",
"half_life/halflife_smiles_with_split.csv",
"half_life/halflife_smiles_with_splits.csv",
"half_life/halflife_smiles_meta_with_split.csv",
],
"column": ["half_life_hours", "log_hour", "log_hours"],
"is_binary": False
},
'nf': {
'files': [
'nonfouling/nf_meta_with_split.csv',
'nf/nf_meta_with_split.csv',
],
'column': 'label',
'is_binary': True
},
'permeability_penetrance': {
'files': [
'permeability/perm_meta_with_split.csv',
'permeability_penetrance/permeability_meta_with_split.csv',
],
'column': 'label',
'is_binary': True
},
'permeability_pampa': {
'files': [
'permeability_pampa/pampa_meta_with_split.csv',
'pampa/pampa_meta_with_split.csv',
],
'column': 'PAMPA',
'is_binary': False
},
'permeability_caco2': {
'files': [
'permeability_caco2/caco2_meta_with_split.csv',
'caco2/caco2_meta_with_split.csv',
],
'column': 'Caco2',
'is_binary': False
},
'toxicity': {
'files': [
'toxicity/tox_meta_with_split.csv',
'toxicity/toxicity_meta_with_split.csv',
],
'column': 'label',
'is_binary': True
}
}
# Load actual data
for prop_key, mapping in data_mappings.items():
all_vals = []
loaded_from = []
for file_path in mapping['files']:
filepath = self.data_dir / file_path
if not filepath.exists():
continue
d = self.load_csv_data(
filepath,
mapping['column'],
mapping.get('is_binary', False)
)
if d:
all_vals.append(d["values"])
loaded_from.append(file_path)
if all_vals:
vals = np.concatenate(all_vals, axis=0)
prop_info = PROPERTY_INFO.get(prop_key, {})
stats[prop_key] = {
"values": vals,
"description": prop_info.get("description", ""),
"unit": "Probability" if mapping.get("is_binary") else prop_info.get("unit", "Score"),
"n_samples": int(vals.shape[0]),
"kind": "binary" if mapping.get("is_binary") else "continuous",
"loaded_from": loaded_from, # optional: good for debugging
}
# thresholds / unit tweaks
if prop_key == "binding_affinity":
stats[prop_key]["threshold"] = 9
stats[prop_key]["threshold_secondary"] = 7
stats[prop_key]["unit"] = "pKd/pKi"
elif prop_key in ["permeability_pampa", "permeability_caco2"]:
stats[prop_key]["threshold"] = -6
stats[prop_key]["unit"] = "log Peff" if prop_key == "permeability_pampa" else "log Papp"
elif prop_key == "halflife":
stats[prop_key]["unit"] = "hours"
# for distribution plotting
if prop_key.startswith("binding_affinity"):
stats[prop_key]["threshold"] = 9
stats[prop_key]["threshold_secondary"] = 7
stats[prop_key]["unit"] = "pKd/pKi"
elif prop_key.startswith("halflife"):
stats[prop_key]["unit"] = "hours"
print(f"βœ“ Loaded {prop_key} from {loaded_from} ({len(vals)} samples)")
continue
# fallback synthetic
print(f"⚠ Using synthetic data for {prop_key}")
return stats
def get_distribution_plot(self, property_name, current_value=None):
if property_name not in self.statistics:
return None
s = self.statistics[property_name]
vals = np.asarray(s["values"])
kind = s.get("kind", "continuous")
if kind == "binary":
n0 = int((vals == 0).sum())
n1 = int((vals == 1).sum())
total = max(n0 + n1, 1)
fig = go.Figure()
prop_info = PROPERTY_INFO.get(property_name, {})
labels = [
prop_info.get('fail_label', 'Negative (0)'),
prop_info.get('pass_label', 'Positive (1)')
]
fig.add_trace(go.Bar(x=labels, y=[n0, n1]))
fig.update_layout(
title=f"{prop_info.get('display', property_name)} β€” Class Distribution",
xaxis_title="Class",
yaxis_title="Count",
height=400,
showlegend=False,
annotations=[
dict(x=labels[0], y=n0, text=f"{n0} ({n0/total:.1%})", showarrow=False, yshift=8),
dict(x=labels[1], y=n1, text=f"{n1} ({n1/total:.1%})", showarrow=False, yshift=8),
],
)
return fig
# Continuous distribution
fig = go.Figure()
fig.add_trace(go.Histogram(x=vals, nbinsx=50, name="Training Data"))
# Primary threshold (if any)
if "threshold" in s and s["threshold"] is not None:
fig.add_vline(
x=float(s["threshold"]),
line_dash="dash",
line_color="purple" if property_name == "binding_affinity" else "red",
annotation_text=(
"Tight threshold: {:.3f}".format(float(s["threshold"]))
if property_name == "binding_affinity"
else "Threshold: {:.3f}".format(float(s["threshold"]))
),
)
# Secondary threshold for binding (weak)
if property_name == "binding_affinity" and "threshold_secondary" in s and s["threshold_secondary"] is not None:
fig.add_vline(
x=float(s["threshold_secondary"]),
line_dash="dash",
line_color="orange",
annotation_text="Weak threshold: {:.3f}".format(float(s["threshold_secondary"])),
)
# Current value
if current_value is not None:
fig.add_vline(
x=float(current_value),
line_dash="solid",
line_color="green",
line_width=3,
annotation_text=f"Your Result: {float(current_value):.3f}",
)
prop_info = PROPERTY_INFO.get(property_name, {})
fig.update_layout(
title=f"{prop_info.get('display', property_name)} Distribution",
xaxis_title=s.get("unit", ""),
yaxis_title="Count",
height=400,
showlegend=False,
)
return fig
def get_property_info(self, property_name):
if property_name not in self.statistics:
return None
s = self.statistics[property_name]
vals = np.asarray(s["values"])
kind = s.get("kind", "continuous")
info = {
"description": s.get("description", ""),
"unit": s.get("unit", ""),
"n_samples": int(len(vals)),
"mean": float(np.mean(vals)),
"std": float(np.std(vals)),
"min": float(np.min(vals)),
"max": float(np.max(vals)),
"percentiles": {},
}
if kind == "binary":
info["n_neg"] = int((vals == 0).sum())
info["n_pos"] = int((vals == 1).sum())
else:
pct = np.percentile(vals, [10, 25, 50, 75, 90])
info["percentiles"] = {
"10%": float(pct[0]),
"25%": float(pct[1]),
"50% (median)": float(pct[2]),
"75%": float(pct[3]),
"90%": float(pct[4]),
}
return info
# ==================== Gradio Interface ====================
def predict_properties(
input_text: str,
input_type: str, # "Sequence" or "SMILES"
protein_text: str, # For binding affinity
selected_props: list[str], # from individual checkboxes
include_physicochemical: bool,
pH_value: float,
progress=gr.Progress()
):
if not input_text or not input_text.strip():
return None, "⚠️ Please provide input."
lines = [s.strip() for s in input_text.split("\n") if s.strip()]
if input_type == "Sequence":
bad = [s for s in lines if not is_aa_sequence_like(s)]
if bad:
return None, f"⚠️ Input Type=Sequence but {len(bad)} line(s) don't look like AA sequences. Example: {bad[0][:60]}"
else:
bad = [s for s in lines if not is_smiles_like(s)]
if bad:
return None, f"⚠️ Input Type=SMILES but {len(bad)} line(s) don't look like SMILES. Example: {bad[0][:60]}"
ctx = initialize()
print("keys in ctx.best:", sorted(ctx.best.keys()))
print("loaded model keys:", sorted(ctx.predictor.models.keys()))
print("halflife wt loaded?", ("halflife","wt") in ctx.predictor.models)
print("halflife smiles loaded?", ("halflife","smiles") in ctx.predictor.models)
if not selected_props:
return None, "⚠️ Please select at least one property."
results = []
analyzer = SequenceAnalyzer()
# Check availability
available = get_available_properties(ctx, input_type)
unavailable = [p for p in selected_props if not available.get(p, False)]
if unavailable:
unavailable_names = [PROPERTY_INFO.get(p, {}).get('display', p) for p in unavailable]
return None, f"⚠️ These properties are not supported for {input_type}: {', '.join(unavailable_names)}"
for i, s in enumerate(lines):
progress((i + 1) / len(lines), f"Processing {i+1}/{len(lines)}")
# Regular property predictions
for prop in selected_props:
if prop == "binding_affinity":
# Handle binding affinity separately
if not protein_text or not protein_text.strip():
results.append({
"Input": s[:30] + "..." if len(s) > 30 else s,
"Property": PROPERTY_INFO[prop]['display'],
"Prediction": "N/A",
"Value": "Requires protein",
"Unit": "",
})
continue
mode = "wt" if input_type == "Sequence" else "smiles"
try:
result = ctx.predictor.predict_binding_affinity(mode, protein_text.strip(), s)
affinity = result["affinity"]
# Determine binding class based on thresholds
if affinity >= 9:
class_label = "Tight binding"
elif affinity >= 7:
class_label = "Medium binding"
else:
class_label = "Weak binding"
results.append({
"Input": s[:30] + "..." if len(s) > 30 else s,
"Property": PROPERTY_INFO[prop]['display'],
"Prediction": class_label,
"Value": f"{affinity:.3f}",
"Unit": "pKd/pKi",
})
except Exception as e:
print(f"Error predicting binding affinity: {e}")
results.append({
"Input": s[:30] + "..." if len(s) > 30 else s,
"Property": PROPERTY_INFO[prop]['display'],
"Prediction": "Error",
"Value": "Failed",
"Unit": "",
})
continue
# Regular properties
mode = "wt" if input_type == "Sequence" else "smiles"
try:
result = ctx.predictor.predict_property(prop, mode, s)
score = result["score"]
prop_info = PROPERTY_INFO.get(prop, {})
# Determine label based on property type
if prop in ['permeability_pampa', 'permeability_caco2']:
# Special handling for permeability assays
label = prop_info['pass_label'] if score > -6 else prop_info['fail_label']
unit = "log Peff" if prop == 'permeability_pampa' else "log Papp"
elif prop == 'halflife':
# Regression task, no pass/fail
label = "β€”"
unit = prop_info.get('unit', 'hours')
else:
# Classification tasks
thr = get_threshold(ctx, prop, input_type)
if thr is not None:
if prop in LOWER_BETTER:
label = prop_info.get('pass_label', 'Pass') if score < thr else prop_info.get('fail_label', 'Fail')
else:
label = prop_info.get('pass_label', 'Pass') if score >= thr else prop_info.get('fail_label', 'Fail')
else:
label = "β€”"
unit = "Probability"
results.append({
"Input": s[:30] + "..." if len(s) > 30 else s,
"Property": prop_info.get('display', prop),
"Prediction": label,
"Value": f"{score:.3f}",
"Unit": unit,
})
except Exception as e:
print(f"Error predicting {prop} for {s[:30]}: {e}")
continue
# physicochemical only for AA sequence modality
if input_type == "Sequence" and include_physicochemical:
analysis = {
"length": len(s),
"molecular_weight": analyzer.calculate_molecular_weight(s),
"net_charge": analyzer.calculate_net_charge(s, pH_value),
"isoelectric_point": analyzer.calculate_isoelectric_point(s),
"hydrophobicity": analyzer.calculate_hydrophobicity(s),
}
short = s[:30] + "..." if len(s) > 30 else s
results += [
{"Input": short, "Property": "πŸ“ Length", "Prediction": "", "Value": str(analysis["length"]), "Unit": "aa"},
{"Input": short, "Property": "βš–οΈ Molecular Weight", "Prediction": "", "Value": f"{analysis['molecular_weight']:.1f}", "Unit": "Da"},
{"Input": short, "Property": f"⚑ Net Charge (pH {pH_value})", "Prediction": "", "Value": f"{analysis['net_charge']:.2f}", "Unit": ""},
{"Input": short, "Property": "🎯 Isoelectric Point", "Prediction": "", "Value": f"{analysis['isoelectric_point']:.2f}", "Unit": "pH"},
{"Input": short, "Property": "πŸ’¦ Hydrophobicity (GRAVY)", "Prediction": "", "Value": f"{analysis['hydrophobicity']:.2f}", "Unit": "GRAVY"},
]
df = pd.DataFrame(results)
status = f"βœ… Completed {len(df)} rows ({len(lines)} input(s), {len(selected_props)} selected properties)."
return df, status
def show_distribution(property_name, predicted_value=None):
"""Show distribution plot + info for selected property."""
data_manager = TrainingDataManager()
if not property_name:
return None, "Select a property to view its distribution."
# Get the first property if a list was passed
prop = property_name[0] if isinstance(property_name, list) else property_name
# Generate the plot
fig = data_manager.get_distribution_plot(prop, predicted_value)
# Build info panel
info = data_manager.get_property_info(prop)
if not info:
return fig, "No information available for this property."
prop_info = PROPERTY_INFO.get(prop, {})
title = DIST_KEYS.get(prop, PROPERTY_INFO.get(prop, {}).get("display", prop))
kind = data_manager.statistics.get(prop, {}).get("kind", "continuous")
if kind == "binary":
n_pos = info.get("n_pos", 0)
n_neg = info.get("n_neg", 0)
total = max(n_pos + n_neg, 1)
info_text = f"""
#### {title} Information
**Description:** {info.get('description','')}
**Statistics (Binary):**
- Samples: {info['n_samples']:,}
- {prop_info.get('pass_label', 'Positive')} (1): {n_pos:,} ({n_pos/total:.1%})
- {prop_info.get('fail_label', 'Negative')} (0): {n_neg:,} ({n_neg/total:.1%})
"""
else:
p = info.get("percentiles", {})
info_text = f"""
#### {title} Information
**Description:** {info.get('description','')}
**Statistics:**
- Samples: {info['n_samples']:,}
- Mean: {info['mean']:.3f} {info['unit']}
- Std Dev: {info['std']:.3f}
- Range: [{info['min']:.3f}, {info['max']:.3f}]
**Percentiles:**
- 10%: {p.get('10%', float('nan')):.3f}
- 25%: {p.get('25%', float('nan')):.3f}
- 50% (median): {p.get('50% (median)', float('nan')):.3f}
- 75%: {p.get('75%', float('nan')):.3f}
- 90%: {p.get('90%', float('nan')):.3f}
"""
return fig, info_text
def load_example(example_name):
"""Load example sequences"""
examples = {
"T7 Peptide": ("HAIYPRH", ""),
"Protein-Peptide": (
"GIVEQCCTSICSLYQLENYCN",
"MVHLTPEEKSAVTALWGKVNVDEVGGEALGRLLVVYPWTQRFFESFGDLST"
),
"Cyclic Peptide (SMILES)": (
"CC(C)C[C@@H]1NC(=O)[C@@H](CC(C)C)N(C)C(=O)[C@@H](C)N(C)C(=O)[C@H](Cc2ccccc2)NC(=O)[C@H](CC(C)C)N(C)C(=O)[C@H]2CCCN2C1=O",
""
),
"Protein-Cyclic Peptide (SMILES)": (
"CC(C)C[C@@H]1NC(=O)[C@@H](CC(C)C)N(C)C(=O)[C@@H](C)N(C)C(=O)[C@H](Cc2ccccc2)NC(=O)[C@H](CC(C)C)N(C)C(=O)[C@H]2CCCN2C1=O",
"MVHLTPEEKSAVTALWGKVNVDEVGGEALGRLLVVYPWTQRFFESFGDLST"
),
"None": ("", ""),
}
return examples.get(example_name, ("", ""))
def on_example_change(name: str):
if not name:
return gr.update(), gr.update()
binder, protein = load_example(name)
show_protein = name in ["Protein-Peptide", "Protein-Cyclic Peptide (SMILES)"]
return (
gr.update(value=binder),
gr.update(value=protein, visible=show_protein),
)
def on_modality_change(modality, *checkbox_values):
ctx = initialize()
available = get_available_properties(ctx, modality)
updates = []
for i, prop_key in enumerate(PROP_ORDER):
is_available = available.get(prop_key, False)
prop_info = PROPERTY_INFO[prop_key]
label_text = f"{prop_info['display']} {prop_info.get('direction','')}".rstrip()
if not is_available:
label_text += " (Not supported)"
if prop_key == "binding_affinity" and is_available:
label_text += " *"
current_value = checkbox_values[i] if i < len(checkbox_values) else False
updates.append(gr.update(
label=label_text,
interactive=is_available,
value=False if not is_available else current_value
))
return updates
def collect_selected_properties(*checkbox_values):
selected = []
for i, prop_key in enumerate(PROP_ORDER):
if i < len(checkbox_values) and checkbox_values[i]:
selected.append(prop_key)
return selected
# ==================== Gradio App ====================
def load_custom_css():
"""Load CSS styling document"""
css_file = "peptiverse_styles.css"
try:
with open(css_file, 'r', encoding='utf-8') as f:
return f.read()
except FileNotFoundError:
print(f"Warning: CSS file '{css_file}' not found. Using default styles.")
# Minimal fallback CSS
return """
.gradio-container {
font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
font-size: 16px !important;
}
"""
except Exception as e:
print(f"Error loading CSS: {e}")
return ""
custom_css = load_custom_css()
def get_title_html():
"""Load light/dark SVG title and swap via prefers-color-scheme"""
import base64, os
def load_svg_b64(path):
if not os.path.exists(path):
return None
with open(path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
light_b64 = load_svg_b64("peptiverse-light-withlogo.svg")
dark_b64 = load_svg_b64("peptiverse-dark-withlogo.svg")
if light_b64 or dark_b64:
imgs = []
if light_b64:
imgs.append(f'''
<img class="logo logo-light"
src="data:image/svg+xml;base64,{light_b64}"
alt="PeptiVerse"
style="max-height: 200px;" />
''')
if dark_b64:
imgs.append(f'''
<img class="logo logo-dark"
src="data:image/svg+xml;base64,{dark_b64}"
alt="PeptiVerse"
style="max-height: 200px;" />
''')
return f'''
<div class="svg-title-container">
{''.join(imgs)}
</div>
'''
# ---------- Fallback ----------
return '''
<div class="svg-title-container">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 700 140"
style="width: 100%; max-width: 700px; height: auto;">
<defs>
<linearGradient id="titleGradient" x1="0%" y1="0%" x2="100%" y2="100%">
<stop offset="0%" style="stop-color:#667eea"/>
<stop offset="100%" style="stop-color:#764ba2"/>
</linearGradient>
<filter id="shadow">
<feDropShadow dx="0" dy="3" stdDeviation="4" flood-opacity="0.15"/>
</filter>
</defs>
<text x="50%" y="50%"
text-anchor="middle"
dominant-baseline="middle"
style="font-size:72px;font-weight:bold;
fill:url(#titleGradient);filter:url(#shadow);">
🌐 PeptiVerse
</text>
</svg>
</div>
'''
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="indigo")) as demo:
ctx = initialize()
# Header with SVG title support
title_html = get_title_html()
gr.HTML(title_html)
gr.Markdown(
"""
# 🌐 PeptiVerse
""",
visible=False
)
with gr.Tabs():
# Main Prediction Tab
with gr.TabItem("πŸ”¬ Predict", elem_classes="predict-tab"):
with gr.Row():
# Input Section
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### πŸ“ Input")
input_type = gr.Radio(
["Sequence", "SMILES"],
label="Input Type",
value="Sequence"
)
# Load T7 peptide by default
input_text = gr.Textbox(
label="Peptide Sequence(s) / SMILES",
placeholder="Enter amino acid sequence(s) or SMILES, one per line",
lines=6,
value="HAIYPRH"
)
protein_seq = gr.Textbox(
label="Protein Sequence (for binding prediction)",
placeholder="Enter target protein sequence",
lines=3,
visible=False
)
gr.Markdown("**Examples:**")
example_dropdown = gr.Dropdown(
choices=["None", "T7 Peptide", "Protein-Peptide", "Cyclic Peptide (SMILES)", "Protein-Cyclic Peptide (SMILES)"],
label="Load Example",
value="T7 Peptide", # Set T7 as default
interactive=True,
allow_custom_value=False
)
# Property Selection
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### βš™οΈ Select Properties")
with gr.Accordion("Physicochemical Properties", open=True, elem_id="acc_phys"):
include_physicochemical = gr.Checkbox(
label="πŸ§ͺ Calculate Basic Properties",
value=True,
info="MW, net charge, pI, hydrophobicity (Sequence only)"
)
pH_value = gr.Slider(
minimum=0,
maximum=14,
value=7.0,
step=0.1,
label="pH for Net Charge",
info="Physiological pH is ~7.4"
)
# Create individual checkboxes in fixed order
with gr.Accordion("Prediction Properties", open=True, elem_id="acc_pred"):
property_checkboxes = []
available = get_available_properties(ctx, "Sequence")
for prop_key in PROP_ORDER:
prop_info = PROPERTY_INFO[prop_key]
is_available = available.get(prop_key, False)
label_text = f"{prop_info['display']} {prop_info.get('direction','')}".rstrip()
if not is_available:
label_text += " (Not supported)"
if prop_key == "binding_affinity" and is_available:
label_text += " *"
default_on = (prop_key in ["solubility", "hemolysis"]) # optional defaults
cb = gr.Checkbox(
label=label_text,
value=is_available and default_on,
interactive=is_available,
elem_id=f"checkbox_{prop_key}",
)
property_checkboxes.append(cb)
gr.Markdown("*Requires protein sequence input above", elem_classes="text-sm text-gray-500")
# Best Models Tab
with gr.TabItem("πŸ“‹ Best Models", elem_classes="best-models-tab"):
gr.Markdown("### Current Best Models Configuration")
gr.Markdown("This table shows the models and thresholds currently being used for predictions:")
best_models_df = gr.Dataframe(
value=get_best_models_table(ctx),
headers=["Property", "Best Model (Sequence)", "Threshold (Sequence)",
"Best Model (SMILES)", "Threshold (SMILES)", "Task Type"],
interactive=False,
elem_id="best_models_df"
)
gr.Markdown("""
**Note:** Models marked as SVM, SVR, or ENET are automatically replaced with XGB
as these models are not currently supported in the deployment environment.
""")
# Distribution Analysis Tab
with gr.TabItem("πŸ“Š Distributions", elem_classes="distributions-tab"):
with gr.Row():
with gr.Column(scale=1):
base_props = [
k for k in PROPERTY_INFO.keys()
if k not in {"halflife", "binding_affinity"}
]
dist_choices = base_props + list(DIST_KEYS.keys())
property_selector = gr.Dropdown(
choices=dist_choices,
label="Select Property",
value="binding_affinity_all"
)
test_value = gr.Number(label="Test Value among Distribution", value=None)
show_dist_btn = gr.Button("Show Distribution")
with gr.Column(scale=2):
dist_plot_tab = gr.Plot(label="Score Distribution")
dist_info_tab = gr.Markdown()
# Data Documentation Tab
with gr.TabItem("πŸ“š Documentation", elem_classes="documentation-tab"):
# Load documentation
doc_file_path = "description.md"
try:
with open(doc_file_path, "r", encoding="utf-8") as f:
markdown_content = f.read()
except FileNotFoundError:
print(f"Warning: Documentation file '{doc_file_path}' not found.")
markdown_content = """
# Documentation
Documentation file not found. Please ensure `description.md` is in the same directory as the app.
"""
except Exception as e:
print(f"Error loading documentation: {e}")
markdown_content = "# Error loading documentation"
gr.Markdown(markdown_content)
# Action Buttons
with gr.Row():
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
predict_btn = gr.Button("πŸš€ Predict Properties", variant="primary", scale=2)
# Status
status_output = gr.Markdown("")
# Results Section
with gr.Group():
gr.Markdown("### πŸ“Š Results")
results_df = gr.Dataframe(
headers=["Input", "Property", "Prediction", "Value", "Unit"],
datatype=["str", "str", "str", "str", "str"],
interactive=False,
elem_id="results_df"
)
# Footer
gr.Markdown(
"""
---
<div style='text-align: center; color: #6b7280;'>
<p>PeptiVerse - A Unified Platform for peptide therapeutic property prediction.</p>
<p>Please cite our work if you use this tool in your research.</p>
</div>
"""
)
# Event Handlers
def update_visibility(binding_checked):
return gr.update(visible=binding_checked)
# Update checkbox states when modality changes
input_type.change(
on_modality_change,
inputs=[input_type] + property_checkboxes,
outputs=property_checkboxes
)
# Show protein sequence input when binding affinity is selected
BINDING_IDX = PROP_ORDER.index("binding_affinity")
property_checkboxes[BINDING_IDX].change(
update_visibility,
inputs=[property_checkboxes[BINDING_IDX]],
outputs=[protein_seq],
)
example_dropdown.change(
on_example_change,
inputs=[example_dropdown],
outputs=[input_text, protein_seq]
)
predict_btn.click(
lambda input_text, input_type, protein_text, include_physicochemical, pH_value, *checkbox_values:
predict_properties(
input_text, input_type, protein_text,
collect_selected_properties(*checkbox_values),
include_physicochemical, pH_value
),
inputs=[input_text, input_type, protein_seq, include_physicochemical, pH_value] + property_checkboxes,
outputs=[results_df, status_output]
)
clear_btn.click(
lambda: ["", "", "None", None, ""] + [False] * len(property_checkboxes),
outputs=[input_text, protein_seq, example_dropdown, results_df, status_output] + property_checkboxes
)
show_dist_btn.click(
show_distribution,
inputs=[property_selector, test_value],
outputs=[dist_plot_tab, dist_info_tab]
)
if __name__ == "__main__":
print("Initializing models...")
initialize()
print("Ready!")
demo.launch(share=True)