| import torch |
| import torch.nn as nn |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from peft import PeftModel |
| import os |
| from pathlib import Path |
| import logging |
| import numpy as np |
|
|
| |
| |
| |
|
|
| BASE_DIR = Path(__file__).resolve().parent |
| MODEL_DIR = BASE_DIR / "models" / "biogpt-lora-finetuned" |
| QML_MODEL_PATH = BASE_DIR / "models" / "qml_model.pth" |
|
|
| |
| |
| |
|
|
| from app.model_arch import HybridQMLModel |
| from app.utils import smiles_to_features, repair_smiles |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
| FALLBACK_SMILES = { |
| |
| "cancer": [ |
| "Cc1ccc(cc1Nc2nccc(n2)c3cccnc3)NC(=O)c4ccc(cc4)CN5CCN(CC5)C", |
| "COc1cc2ncnc(Nc3cccc(Cl)c3F)c2cc1OCCCN4CCOCC4", |
| "CC(C)(C)OC(=O)Nc1ccc(cc1)C(=O)Nc2ccc(cc2)N3CCOCC3", |
| ], |
| "tumor": [ |
| "Cc1ccc(cc1Nc2nccc(n2)c3cccnc3)NC(=O)c4ccc(cc4)CN5CCN(CC5)C", |
| "O=C(Nc1ccc(F)cc1)c1cc(Cl)cc(Cl)c1", |
| "CC1=C(C(=O)Nc2ccccc2)c2cc(Cl)sc2N1C(=O)c1ccccc1", |
| ], |
| "leukemia": [ |
| "Cc1ccc(cc1Nc2nccc(n2)c3cccnc3)NC(=O)c4ccc(cc4)CN5CCN(CC5)C", |
| "CN1CCN(CC1)c1ccc(cc1)C(=O)Nc2cc(cnc2N)c3ccccc3", |
| "COc1ccc(cc1OC)c1cc2c(NC(=O)c3ccc(F)cc3)ncnc2[nH]1", |
| ], |
|
|
| |
| "alzheimer": [ |
| "CCc1nn(C)c2ccc(nc12)C(=O)Nc3ccc(cc3)N3CCN(CC3)C", |
| "O=C(O)c1ccccc1Oc1ccccc1", |
| "CN(C)CCc1c[nH]c2ccc(CS(=O)(=O)N3CCCC3)cc12", |
| ], |
| "parkinson": [ |
| "NCCc1ccc(O)c(O)c1", |
| "CC(N)Cc1ccc(O)c(O)c1", |
| "O=C(O)C(N)Cc1ccc(O)c(O)c1", |
| ], |
| "depression": [ |
| "CNCCC(Oc1ccc(C(F)(F)F)cc1)c1ccccc1", |
| "CN(C)CCCN1c2ccccc2CCc2ccccc21", |
| "OC(=O)c1ccc(Cl)cc1", |
| ], |
| "epilepsy": [ |
| "NC(=O)c1cccnc1", |
| "CC(C)CC1(CC(=O)N)C(=O)NC(=O)N1", |
| "NCC(O)c1ccc(O)c(O)c1", |
| ], |
|
|
| |
| "diabetes": [ |
| "CN(C)C(=N)NC(=N)N", |
| "CCOC(=O)c1cnc2cc(S(N)(=O)=O)ccc2c1NCc1ccc(F)cc1", |
| "CC1=CC(=C(C=C1)OCC(O)=O)CCOC2=CC=C(C=C2)CC3C(=O)NC(=O)S3", |
| ], |
| "obesity": [ |
| "O=C(O)CCCCCCC(=O)O", |
| "CC(C)(C)NCC(O)c1ccc(O)c(O)c1", |
| "COc1cc2c(cc1OC)CC(N)CC2", |
| ], |
| "thyroid": [ |
| "Ic1cc(CC(N)C(=O)O)cc(I)c1Oc1cc(I)c(O)c(I)c1", |
| "CC(N)Cc1ccc(O)c(O)c1", |
| "O=C(O)C(N)Cc1ccc(O)c(O)c1", |
| ], |
|
|
| |
| "hypertension": [ |
| "CCOC(=O)C1=C(COCCN)NC(C)=C(C1c1cccc(Cl)c1)C(=O)OCC", |
| "CC(C)(C)NCC(O)c1ccc(O)c(O)c1", |
| "OC(=O)c1ccccc1Oc1ccccc1", |
| ], |
| "heart disease": [ |
| "CC12CCC3C(C1CCC2O)CCC4=CC(=O)CCC34C", |
| "O=C(O)c1ccccc1OC(=O)C", |
| "CCCC1=NN(C2=CC(=C(C=C12)S(=O)(=O)N1CCN(CC1)C)OCC)C", |
| ], |
| "stroke": [ |
| "OC(=O)c1ccccc1Oc1ccccc1", |
| "CC(=O)Oc1ccccc1C(=O)O", |
| "O=C(O)CCCc1ccc(O)cc1", |
| ], |
|
|
| |
| "infection": [ |
| "CC1(C)SC2C(NC(=O)Cc3ccccc3)C(=O)N2C1C(=O)O", |
| "OC(=O)c1ccc(N)cc1", |
| "Cc1onc(c1C(=O)Nc2ccc(O)cc2)c1ccccc1", |
| ], |
| "influenza": [ |
| "CCOC(=O)C1=C[C@@H](OC(CC)CC)[C@@H](NC(C)=O)[C@@H](N)C1", |
| "OC(=O)[C@@H](N)Cc1ccc(O)cc1", |
| "CC(=O)Nc1ccc(O)cc1", |
| ], |
| "hiv": [ |
| "Cc1ccc(cc1)S(=O)(=O)Nc1ccc(N2CCOCC2)nn1", |
| "O=C(Nc1ccc(F)cc1)c1ccncc1", |
| "CC(C)(C)OC(=O)N1CCC(CC1)n1cnc2ccccc21", |
| ], |
| "malaria": [ |
| "CCN(CC)CCCC(C)Nc1ccnc2cc(Cl)ccc12", |
| "OCC1OC(O)C(O)C(O)C1O", |
| "Clc1ccc(cc1)C(c1ccccc1)(c1ccccc1)O", |
| ], |
| "tuberculosis": [ |
| "OC(=O)c1ccncc1", |
| "CC1=CN=C(C=C1)C(=O)NN", |
| "O=C(O)c1ccc(N)cc1", |
| ], |
|
|
| |
| "arthritis": [ |
| "CC(=O)Oc1ccccc1C(=O)O", |
| "CC(C)Cc1ccc(cc1)C(C)C(=O)O", |
| "COc1ccc(cc1OC)C(=O)c1ccc(O)cc1", |
| ], |
| "inflammation": [ |
| "CC(C)Cc1ccc(cc1)C(C)C(=O)O", |
| "CC(=O)Oc1ccccc1C(=O)O", |
| "O=C(O)c1ccc(Cl)cc1", |
| ], |
| "lupus": [ |
| "CCN(CC)CCCC(C)Nc1ccnc2cc(Cl)ccc12", |
| "COc1ccc(cc1)S(=O)(=O)Nc1ccc(N)cc1", |
| "O=C(Nc1ccccc1)c1cccnc1", |
| ], |
|
|
| |
| "asthma": [ |
| "CC(C)(C)NCC(O)c1ccc(O)c(O)c1", |
| "COc1ccc(CCNCc2ccc(O)c(CO)c2)cc1", |
| "O=C(O)c1ccc(N)cc1S(=O)(=O)N", |
| ], |
| "copd": [ |
| "CC(C)(C)NCC(O)c1ccc(O)c(O)c1", |
| "O=C(O)CCc1ccc(O)cc1", |
| "Cn1ccnc1SCC1=C(N2C(=O)[C@@H](NC(=O)Cc3ccccc3)[C@H]2SC)CS1", |
| ], |
|
|
| |
| "default": [ |
| "CC(=O)Oc1ccccc1C(=O)O", |
| "CC(C)Cc1ccc(cc1)C(C)C(=O)O", |
| "CC(=O)Nc1ccc(O)cc1", |
| "Cc1ccc(cc1Nc2nccc(n2)c3cccnc3)NC(=O)c4ccc(cc4)CN5CCN(CC5)C", |
| "CCOC(=O)C1=C[C@@H](OC(CC)CC)[C@@H](NC(C)=O)[C@@H](N)C1", |
| ], |
| } |
|
|
| |
| MAX_GENERATION_ATTEMPTS = 3 |
|
|
|
|
| def _get_fallback(disease: str, n: int) -> list[str]: |
| """ |
| Return up to n fallback SMILES for a disease keyword. |
| Matches on lowercase substring so 'breast cancer' β 'cancer' key. |
| """ |
| disease_lower = disease.lower().strip() |
| for key in FALLBACK_SMILES: |
| if key in disease_lower or disease_lower in key: |
| return FALLBACK_SMILES[key][:n] |
| return FALLBACK_SMILES["default"][:n] |
|
|
|
|
| class ModelInference: |
| def __init__(self): |
| """Initialize all models on startup (CPU-only, HF Spaces safe)""" |
|
|
| self.llm_device = torch.device("cpu") |
| self.qml_device = torch.device("cpu") |
|
|
| logger.info("π₯οΈ Running on CPU (Hugging Face Spaces)") |
|
|
| self._load_llm() |
| self._load_qml() |
|
|
| |
| |
| |
| def _load_llm(self): |
| try: |
| model_path = MODEL_DIR.resolve() |
| base_model_name = "microsoft/biogpt" |
|
|
| if not model_path.exists(): |
| raise FileNotFoundError(f"LLM path not found: {model_path}") |
|
|
| logger.info(f"Loading BioGPT LoRA from {model_path}") |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained( |
| model_path, |
| local_files_only=True, |
| trust_remote_code=True |
| ) |
|
|
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
| base_model = AutoModelForCausalLM.from_pretrained( |
| base_model_name, |
| torch_dtype=torch.float32, |
| low_cpu_mem_usage=True, |
| trust_remote_code=True |
| ).to(self.llm_device) |
|
|
| self.llm_model = PeftModel.from_pretrained( |
| base_model, |
| model_path, |
| local_files_only=True |
| ).to(self.llm_device) |
|
|
| self.llm_model.eval() |
| logger.info("β
BioGPT + LoRA loaded") |
|
|
| except Exception: |
| logger.error("β Failed to load LLM", exc_info=True) |
| raise |
|
|
| |
| |
| |
| def _load_qml(self): |
| try: |
| if not QML_MODEL_PATH.exists(): |
| raise FileNotFoundError(f"QML model not found: {QML_MODEL_PATH}") |
|
|
| checkpoint = torch.load(QML_MODEL_PATH, map_location="cpu", weights_only=False) |
|
|
| n_qubits = checkpoint.get("n_qubits", 8) |
| n_layers = checkpoint.get("n_layers", 3) |
| feature_dim = checkpoint.get("feature_dim", 64) |
|
|
| self.pca_components = checkpoint.get("pca_components", None) |
| self.pca_mean = checkpoint.get("pca_mean", None) |
| self.scaler_mean = checkpoint.get("scaler_mean", None) |
| self.scaler_scale = checkpoint.get("scaler_scale", None) |
| self.fingerprint_bits = checkpoint.get("fingerprint_bits", 2048) |
|
|
| logger.info( |
| f"QML config β n_qubits={n_qubits}, n_layers={n_layers}, " |
| f"feature_dim={feature_dim}, fingerprint_bits={self.fingerprint_bits}" |
| ) |
|
|
| self.qml_model = HybridQMLModel( |
| n_qubits=n_qubits, |
| n_layers=n_layers, |
| feature_dim=feature_dim, |
| quantum_layer=None |
| ) |
|
|
| self.qml_model.load_state_dict(checkpoint["model_state_dict"]) |
| self.qml_model.to(self.qml_device) |
| self.qml_model.eval() |
|
|
| logger.info("β
QML model loaded") |
| logger.info(f"Scaler mean shape: {self.scaler_mean.shape}") |
| logger.info(f"PCA components shape: {self.pca_components.shape}") |
|
|
| except Exception: |
| logger.error("β Failed to load QML model", exc_info=True) |
| raise |
|
|
| |
| |
| |
| def _prepare_features(self, smiles: str) -> np.ndarray: |
| """ |
| Full preprocessing pipeline matching training: |
| 1. smiles_to_features β raw features (Morgan fingerprint + descriptors) |
| 2. StandardScaler (scaler_mean / scaler_scale) |
| 3. PCA (pca_mean / pca_components) β feature_dim floats |
| Returns a float32 array of shape (feature_dim,). |
| """ |
| raw = smiles_to_features(smiles, n_bits=self.fingerprint_bits) |
|
|
| if raw is None: |
| logger.warning(f"smiles_to_features returned None for: {smiles} β using zeros") |
| feature_dim = ( |
| self.pca_components.shape[0] |
| if self.pca_components is not None |
| else 64 |
| ) |
| return np.zeros(feature_dim, dtype=np.float32) |
|
|
| raw = np.array(raw, dtype=np.float32) |
|
|
| if self.scaler_mean is not None and self.scaler_scale is not None: |
| scale = np.where(self.scaler_scale == 0, 1.0, self.scaler_scale) |
| raw = (raw - self.scaler_mean) / scale |
|
|
| if self.pca_components is not None and self.pca_mean is not None: |
| raw = raw - self.pca_mean |
| raw = raw @ self.pca_components.T |
|
|
| return raw.astype(np.float32) |
|
|
| |
| |
| |
| def _generate_once( |
| self, |
| disease: str, |
| max_new_tokens: int, |
| temperature: float, |
| num_return_sequences: int, |
| ) -> list[str]: |
| """ |
| Run one LLM generation pass and return raw decoded SMILES strings. |
| No validation β caller handles that. |
| """ |
| prompt = f"Disease: {disease} -> SMILES:" |
| inputs = self.tokenizer( |
| prompt, |
| return_tensors="pt", |
| truncation=True, |
| max_length=256 |
| ) |
| inputs = {k: v.to(self.llm_device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = self.llm_model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=True, |
| temperature=temperature, |
| top_k=50, |
| top_p=0.95, |
| num_return_sequences=num_return_sequences, |
| pad_token_id=self.tokenizer.pad_token_id, |
| eos_token_id=self.tokenizer.eos_token_id, |
| ) |
|
|
| results = [] |
| for output in outputs: |
| text = self.tokenizer.decode(output, skip_special_tokens=True) |
| smiles = text.split("SMILES:", 1)[-1].strip().replace(" ", "") |
| if smiles: |
| results.append(smiles) |
| return results |
|
|
| |
| |
| |
| def generate_molecules(self, disease: str, num_candidates: int = 3) -> list: |
| """ |
| Generate valid, repaired SMILES strings for a disease. |
| |
| Option 3 improvements vs original: |
| 1. MAX_NEW_TOKENS increased from 30 β 80 (covers full drug-like SMILES) |
| 2. Validity filter: only repaired, parseable SMILES are kept |
| 3. Multiple generation attempts before fallback |
| 4. Fallback to curated SMILES when LLM produces nothing valid |
| 5. Returns list of dicts with {smiles, source} so frontend can show |
| whether the molecule was generated or is a curated fallback |
| |
| Returns: |
| list of dicts: [{"smiles": str, "source": "generated"|"fallback"}, ...] |
| """ |
| if not disease or not disease.strip(): |
| raise ValueError("Disease name cannot be empty") |
|
|
| |
| |
| max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "80")) |
| temperature = float(os.getenv("TEMPERATURE", "0.85")) |
|
|
| valid_smiles = [] |
| seen = set() |
|
|
| for attempt in range(1, MAX_GENERATION_ATTEMPTS + 1): |
| logger.info( |
| f"Generation attempt {attempt}/{MAX_GENERATION_ATTEMPTS} " |
| f"for disease='{disease}'" |
| ) |
|
|
| |
| raw_batch = self._generate_once( |
| disease, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| num_return_sequences=num_candidates * 2, |
| ) |
|
|
| for raw in raw_batch: |
| if len(valid_smiles) >= num_candidates: |
| break |
|
|
| repaired = repair_smiles(raw) |
| if repaired is None: |
| logger.debug(f"Irreparable SMILES (attempt {attempt}): {raw[:60]}") |
| continue |
|
|
| if repaired in seen: |
| continue |
|
|
| seen.add(repaired) |
| valid_smiles.append({ |
| "smiles": repaired, |
| "source": "generated", |
| }) |
| logger.info(f"β Valid generated SMILES: {repaired[:60]}") |
|
|
| if len(valid_smiles) >= num_candidates: |
| break |
|
|
| |
| temperature = min(temperature + 0.05, 1.2) |
|
|
| |
| if len(valid_smiles) < num_candidates: |
| needed = num_candidates - len(valid_smiles) |
| fallbacks = _get_fallback(disease, needed) |
| for fb_smiles in fallbacks: |
| if fb_smiles not in seen: |
| valid_smiles.append({ |
| "smiles": fb_smiles, |
| "source": "fallback", |
| }) |
| seen.add(fb_smiles) |
| logger.warning( |
| f"LLM generated {num_candidates - needed}/{num_candidates} valid SMILES " |
| f"for '{disease}' β used {needed} curated fallback(s)" |
| ) |
|
|
| result = valid_smiles[:num_candidates] |
| logger.info( |
| f"Returning {len(result)} molecules for '{disease}': " |
| f"{[r['source'] for r in result]}" |
| ) |
| return result |
|
|
| |
| |
| |
| def predict_from_features(self, features: np.ndarray) -> float: |
| """ |
| Run the QML model directly on a pre-computed feature vector. |
| Used by SHAP KernelExplainer. |
| Accepts either raw features or post-PCA features β detects automatically. |
| Returns: float drug-likeness score in [0, 1] |
| """ |
| try: |
| features = np.array(features, dtype=np.float32).flatten() |
| feature_dim = ( |
| self.pca_components.shape[0] |
| if self.pca_components is not None |
| else 64 |
| ) |
|
|
| if features.shape[0] != feature_dim: |
| if self.scaler_mean is not None and self.scaler_scale is not None: |
| scale = np.where(self.scaler_scale == 0, 1.0, self.scaler_scale) |
| features = (features - self.scaler_mean) / scale |
| if self.pca_components is not None and self.pca_mean is not None: |
| features = features - self.pca_mean |
| features = features @ self.pca_components.T |
|
|
| features = features.astype(np.float32) |
| x = torch.tensor(features, dtype=torch.float32).unsqueeze(0) |
| with torch.no_grad(): |
| logit = self.qml_model(x) |
| probability = torch.sigmoid(logit).squeeze().item() |
| return float(probability) |
|
|
| except Exception as e: |
| logger.warning(f"predict_from_features failed: {e}") |
| return 0.0 |
|
|
| |
| |
| |
| def predict_drug_potential(self, smiles: str) -> dict: |
| """ |
| Predict drug-likeness for a SMILES string. |
| Validates, repairs, extracts features, then runs QML inference. |
| """ |
| try: |
| fixed_smiles = repair_smiles(smiles) |
|
|
| if fixed_smiles is None: |
| return { |
| "prediction": "invalid", |
| "probability": 0.0, |
| "score": 0.0, |
| "is_promising": False, |
| "confidence": "low", |
| "error": "Invalid SMILES β could not be repaired", |
| "original_smiles": smiles, |
| "repaired_smiles": None, |
| } |
|
|
| features = self._prepare_features(fixed_smiles) |
| features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0) |
|
|
| with torch.no_grad(): |
| logit = self.qml_model(features_tensor) |
| probability = torch.sigmoid(logit).squeeze().item() |
|
|
| is_promising = probability >= 0.5 |
| confidence_score = abs(probability - 0.5) |
| confidence = ( |
| "high" if confidence_score > 0.30 else |
| "medium" if confidence_score > 0.15 else |
| "low" |
| ) |
|
|
| return { |
| "prediction": "drug" if is_promising else "not drug", |
| "probability": round(probability, 4), |
| "score": round(probability, 4), |
| "is_promising": is_promising, |
| "confidence": confidence, |
| "original_smiles": smiles, |
| "repaired_smiles": fixed_smiles if fixed_smiles != smiles else None, |
| } |
|
|
| except Exception as e: |
| logger.error("β Prediction failure", exc_info=True) |
| return { |
| "prediction": "unknown", |
| "probability": 0.5, |
| "score": 0.5, |
| "is_promising": False, |
| "confidence": "low", |
| "error": str(e), |
| "original_smiles": smiles, |
| "repaired_smiles": None, |
| } |