Spaces:
Running
Running
| """Biomedical and chemical embedding extraction utilities. | |
| Production goals: | |
| - deterministic embedding cache for large corpora | |
| - safe CPU/GPU fallback | |
| - batch extraction for both text and SMILES | |
| - simple integration with training and inference pipelines | |
| """ | |
| from __future__ import annotations | |
| import hashlib | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from typing import Any, Dict, Iterable, List, Optional | |
| import numpy as np | |
| try: | |
| from transformers import AutoModel, AutoTokenizer | |
| import torch | |
| except Exception: # pragma: no cover - informative fallback | |
| AutoModel = None # type: ignore | |
| AutoTokenizer = None # type: ignore | |
| torch = None # type: ignore | |
| logger = logging.getLogger("medcare_ddi.embeddings") | |
| DEFAULT_TEXT_MODELS: Dict[str, str] = { | |
| 'biobert': 'dmis-lab/biobert-base-cased-v1.1', | |
| 'pubmedbert': 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext', | |
| 'sapbert': 'cambridgeltl/SapBERT-from-PubMedBERT-fulltext', | |
| } | |
| DEFAULT_CHEM_MODEL = 'seyonec/ChemBERTa-zinc-base-v1' | |
| def _hash_inputs(inputs: List[str]) -> str: | |
| m = hashlib.sha256() | |
| for s in inputs: | |
| m.update(s.encode('utf-8')) | |
| return m.hexdigest()[:32] | |
| def _stable_model_name(model_name: str) -> str: | |
| return model_name.replace('/', '_').replace(':', '_') | |
| def _normalize_texts(items: Iterable[str]) -> List[str]: | |
| return [' '.join(str(x).strip().split()) for x in items] | |
| def _safe_text(value: str | None) -> str: | |
| text = str(value or '').strip() | |
| return text if text and text.lower() != 'nan' else '' | |
| class EmbeddingService: | |
| def __init__(self, cache_dir: Optional[Path] = None, device: Optional[str] = None): | |
| self.cache_dir = Path(cache_dir or Path.cwd() / "embeddings_cache") | |
| self.cache_dir.mkdir(parents=True, exist_ok=True) | |
| self._models: Dict[str, tuple[Any, Any]] = {} | |
| if device: | |
| self.device = device | |
| else: | |
| self.device = "cuda" if torch and torch.cuda.is_available() else "cpu" | |
| def resolve_model(self, model_name: str) -> str: | |
| key = model_name.strip().lower() | |
| return DEFAULT_TEXT_MODELS.get(key, model_name) | |
| def _load_model(self, model_name: str): | |
| if AutoModel is None or AutoTokenizer is None: | |
| raise RuntimeError("transformers not available; install 'transformers' and 'torch'") | |
| resolved = self.resolve_model(model_name) | |
| if resolved in self._models: | |
| return self._models[resolved] | |
| tokenizer = AutoTokenizer.from_pretrained(resolved, use_fast=True) | |
| model = AutoModel.from_pretrained(resolved) | |
| model.to(self.device) | |
| model.eval() | |
| self._models[resolved] = (tokenizer, model) | |
| return tokenizer, model | |
| def _batch_tokenize(self, tokenizer, texts: List[str], max_length: int = 128): | |
| return tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors='pt') | |
| def _cache_file(self, modality: str, model_name: str, values: List[str]) -> Path: | |
| key_meta = { | |
| 'modality': modality, | |
| 'model': self.resolve_model(model_name), | |
| 'n': len(values), | |
| 'hash': _hash_inputs(values), | |
| } | |
| key = hashlib.sha256(json.dumps(key_meta, sort_keys=True).encode('utf-8')).hexdigest()[:40] | |
| return self.cache_dir / f'{modality}_{_stable_model_name(self.resolve_model(model_name))}_{key}.npy' | |
| def _mean_pool(self, output, attention_mask): | |
| last = output.last_hidden_state | |
| mask = attention_mask.unsqueeze(-1) | |
| summed = (last * mask).sum(1) | |
| counts = mask.sum(1).clamp(min=1) | |
| return summed / counts | |
| def get_text_embeddings(self, texts: List[str], model_name: str, batch_size: int = 32) -> np.ndarray: | |
| """Return embeddings for list of texts using model_name. | |
| Embeddings are the mean pooled token embeddings (last hidden state). | |
| Results are cached under cache_dir/model_name/<input_hash>.npy | |
| """ | |
| values = _normalize_texts(texts) | |
| if not values: | |
| return np.zeros((0, 0), dtype=np.float32) | |
| cache_path = self._cache_file('text', model_name, values) | |
| if cache_path.exists(): | |
| logger.info(f"Loading cached text embeddings: {cache_path}") | |
| return np.load(cache_path) | |
| tokenizer, model = self._load_model(model_name) | |
| all_embs = [] | |
| for i in range(0, len(values), batch_size): | |
| batch = values[i:i+batch_size] | |
| toks = self._batch_tokenize(tokenizer, batch) | |
| input_ids = toks['input_ids'].to(self.device) | |
| attention_mask = toks['attention_mask'].to(self.device) | |
| with torch.no_grad(): | |
| out = model(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled = self._mean_pool(out, attention_mask).cpu().numpy() | |
| all_embs.append(pooled) | |
| embs = np.vstack(all_embs).astype(np.float32) | |
| np.save(cache_path, embs) | |
| logger.info(f"Saved embeddings to cache: {cache_path}") | |
| return embs | |
| def get_smiles_embeddings(self, smiles: List[str], model_name: str = 'seyonec/ChemBERTa-zinc-base-v1', batch_size: int = 32) -> np.ndarray: | |
| """Get SMILES embeddings using a chemistry transformer (ChemBERTa). | |
| Model must be compatible with HuggingFace AutoModel. | |
| """ | |
| values = _normalize_texts(smiles) | |
| if not values: | |
| return np.zeros((0, 0), dtype=np.float32) | |
| resolved = model_name if model_name else DEFAULT_CHEM_MODEL | |
| cache_path = self._cache_file('smiles', resolved, values) | |
| if cache_path.exists(): | |
| logger.info(f"Loading cached SMILES embeddings: {cache_path}") | |
| return np.load(cache_path) | |
| tokenizer, model = self._load_model(resolved) | |
| all_embs = [] | |
| for i in range(0, len(values), batch_size): | |
| batch = values[i:i+batch_size] | |
| toks = tokenizer(batch, padding=True, truncation=True, return_tensors='pt') | |
| input_ids = toks['input_ids'].to(self.device) | |
| attention_mask = toks['attention_mask'].to(self.device) | |
| with torch.no_grad(): | |
| out = model(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled = self._mean_pool(out, attention_mask).cpu().numpy() | |
| all_embs.append(pooled) | |
| embs = np.vstack(all_embs).astype(np.float32) | |
| np.save(cache_path, embs) | |
| logger.info(f"Saved SMILES embeddings to cache: {cache_path}") | |
| return embs | |
| def get_drug_profile_embeddings( | |
| self, | |
| names: List[str], | |
| active_ingredients: Optional[List[str]] = None, | |
| descriptions: Optional[List[str]] = None, | |
| smiles: Optional[List[str]] = None, | |
| text_model: str = 'pubmedbert', | |
| smiles_model: str = DEFAULT_CHEM_MODEL, | |
| batch_size: int = 32, | |
| ) -> Dict[str, np.ndarray]: | |
| """Build multimodal per-drug embeddings from name/ingredient/description/SMILES fields.""" | |
| n = len(names) | |
| active_ingredients = active_ingredients or [''] * n | |
| descriptions = descriptions or [''] * n | |
| smiles = smiles or [''] * n | |
| if not (len(active_ingredients) == len(descriptions) == len(smiles) == n): | |
| raise ValueError('All profile fields must have equal length') | |
| names_clean = [_safe_text(x) for x in names] | |
| ingredients_clean = [_safe_text(x) for x in active_ingredients] | |
| descriptions_clean = [_safe_text(x) for x in descriptions] | |
| smiles_clean = [_safe_text(x) for x in smiles] | |
| name_emb = self.get_text_embeddings(names_clean, model_name=text_model, batch_size=batch_size) | |
| ingredient_emb = self.get_text_embeddings(ingredients_clean, model_name=text_model, batch_size=batch_size) | |
| description_emb = self.get_text_embeddings(descriptions_clean, model_name=text_model, batch_size=batch_size) | |
| any_smiles = any(bool(s) for s in smiles_clean) | |
| if any_smiles: | |
| smiles_emb = self.get_smiles_embeddings(smiles_clean, model_name=smiles_model, batch_size=batch_size) | |
| else: | |
| smiles_emb = np.zeros((n, 0), dtype=np.float32) | |
| text_concat = np.hstack([name_emb, ingredient_emb, description_emb]).astype(np.float32) | |
| fused = np.hstack([text_concat, smiles_emb]).astype(np.float32) if smiles_emb.size else text_concat | |
| return { | |
| 'name': name_emb, | |
| 'active_ingredient': ingredient_emb, | |
| 'description': description_emb, | |
| 'smiles': smiles_emb, | |
| 'text_concat': text_concat, | |
| 'fused': fused, | |
| } | |
| def benchmark_embedding_models( | |
| service: EmbeddingService, | |
| texts: List[str], | |
| labels: List[int], | |
| model_names: List[str], | |
| batch_size: int = 64, | |
| ) -> Dict[str, Any]: | |
| """Quick embedding quality benchmark using 1-NN leave-one-out accuracy.""" | |
| if len(texts) != len(labels): | |
| raise ValueError('texts and labels length mismatch') | |
| if len(texts) < 3: | |
| raise ValueError('Need at least 3 samples for benchmark') | |
| y = np.array(labels, dtype=np.int64) | |
| report: Dict[str, Any] = {'num_samples': len(texts), 'models': {}} | |
| for model_name in model_names: | |
| emb = service.get_text_embeddings(texts, model_name=model_name, batch_size=batch_size) | |
| sim = emb @ emb.T | |
| norms = np.linalg.norm(emb, axis=1, keepdims=True) | |
| denom = np.clip(norms @ norms.T, 1e-9, None) | |
| cosine = sim / denom | |
| np.fill_diagonal(cosine, -1.0) | |
| nn_idx = np.argmax(cosine, axis=1) | |
| pred = y[nn_idx] | |
| score = float((pred == y).mean()) | |
| report['models'][model_name] = { | |
| 'embedding_dim': int(emb.shape[1]), | |
| 'nn_accuracy': score, | |
| } | |
| return report | |
| # Convenience initializer | |
| _default_service: Optional[EmbeddingService] = None | |
| def init_embedding_service(cache_dir: Optional[str] = None, device: Optional[str] = None) -> EmbeddingService: | |
| global _default_service | |
| if _default_service is None: | |
| _default_service = EmbeddingService(cache_dir=Path(cache_dir) if cache_dir else None, device=device) | |
| return _default_service | |
| if __name__ == '__main__': | |
| svc = init_embedding_service() | |
| try: | |
| em = svc.get_text_embeddings(['Aspirin', 'Warfarin'], 'biobert', batch_size=2) | |
| print('Embeddings shape:', em.shape) | |
| except Exception as e: | |
| print('Embedding test failed:', e) | |