Spaces:
Sleeping
Sleeping
| """ | |
| Ingest KIBA/DAVIS Drug-Target Interaction datasets into Qdrant. | |
| Uses OBMEncoder (768-dim) to create searchable vectors from real DTI data. | |
| """ | |
| import sys | |
| import os | |
| import argparse | |
| ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.insert(0, ROOT_DIR) | |
| import pandas as pd | |
| from tqdm import tqdm | |
| from bioflow.api.qdrant_service import get_qdrant_service | |
| def load_dataset(dataset_name: str, limit: int = None) -> pd.DataFrame: | |
| """Load KIBA or DAVIS dataset from local .tab files.""" | |
| filepath = os.path.join(ROOT_DIR, "data", f"{dataset_name.lower()}.tab") | |
| if not os.path.exists(filepath): | |
| raise FileNotFoundError(f"Dataset not found: {filepath}") | |
| print(f"Loading {dataset_name} from {filepath}...") | |
| df = pd.read_csv(filepath, sep='\t') | |
| # Rename columns for consistency | |
| # Format: ID1, X1 (SMILES), ID2, X2 (sequence), Y (affinity) | |
| df.columns = ['drug_id', 'smiles', 'target_id', 'target_seq', 'affinity'] | |
| # Remove duplicates (keep unique drug-target pairs) | |
| df = df.drop_duplicates(subset=['smiles', 'target_id']) | |
| if limit: | |
| df = df.head(limit) | |
| print(f" Loaded {len(df)} unique drug-target pairs") | |
| return df | |
| def get_affinity_class(affinity: float, dataset: str) -> str: | |
| """Classify affinity into high/medium/low based on dataset thresholds.""" | |
| if dataset.upper() == "KIBA": | |
| # KIBA: lower is better (inhibition constant) | |
| if affinity < 6: | |
| return "high" | |
| elif affinity < 8: | |
| return "medium" | |
| else: | |
| return "low" | |
| else: # DAVIS | |
| # DAVIS: Kd values, lower is better | |
| if affinity < 6: | |
| return "high" | |
| elif affinity < 7: | |
| return "medium" | |
| else: | |
| return "low" | |
| def get_drug_name(drug_id, smiles: str) -> str: | |
| """Generate a readable drug name from ID or SMILES.""" | |
| drug_id_str = str(drug_id) | |
| # If drug_id is numeric (like PubChem ID), create a friendly name | |
| if drug_id_str.isdigit(): | |
| # Use PubChem CID format for known numeric IDs | |
| return f"CID-{drug_id_str}" | |
| return drug_id_str | |
| def ingest_molecules(qdrant, df: pd.DataFrame, dataset: str, batch_size: int = 50): | |
| """Ingest unique molecules (drugs) from the dataset.""" | |
| print("\n[1/2] Ingesting molecules (drugs)...") | |
| # Get unique SMILES with their best affinity | |
| unique_drugs = df.groupby('smiles').agg({ | |
| 'drug_id': 'first', | |
| 'affinity': 'min', # Best affinity | |
| 'target_id': 'count' # Number of targets | |
| }).reset_index() | |
| unique_drugs.columns = ['smiles', 'drug_id', 'best_affinity', 'num_targets'] | |
| print(f" Found {len(unique_drugs)} unique molecules") | |
| success_count = 0 | |
| for idx, row in tqdm(unique_drugs.iterrows(), total=len(unique_drugs), desc=" Molecules"): | |
| try: | |
| affinity_class = get_affinity_class(row['best_affinity'], dataset) | |
| drug_name = get_drug_name(row['drug_id'], row['smiles']) | |
| result = qdrant.ingest( | |
| content=row['smiles'], | |
| modality="molecule", | |
| metadata={ | |
| "name": drug_name, | |
| "drug_id": str(row['drug_id']), # Keep original ID | |
| "smiles": row['smiles'], | |
| "description": f"Drug from {dataset.upper()} dataset", | |
| "source": dataset.lower(), | |
| "dataset": dataset.lower(), | |
| "best_affinity": float(row['best_affinity']), | |
| "affinity_class": affinity_class, | |
| "num_targets": int(row['num_targets']), | |
| } | |
| ) | |
| success_count += 1 | |
| except Exception as e: | |
| if success_count == 0: | |
| print(f"\n First error: {e}") # Show first error for debugging | |
| print(f" ✓ Ingested {success_count}/{len(unique_drugs)} molecules") | |
| return success_count | |
| def ingest_proteins(qdrant, df: pd.DataFrame, dataset: str, batch_size: int = 50): | |
| """Ingest unique proteins (targets) from the dataset.""" | |
| print("\n[2/2] Ingesting proteins (targets)...") | |
| # Get unique proteins with their best affinity | |
| unique_targets = df.groupby('target_id').agg({ | |
| 'target_seq': 'first', | |
| 'affinity': 'min', # Best affinity | |
| 'smiles': 'count' # Number of drugs | |
| }).reset_index() | |
| unique_targets.columns = ['target_id', 'target_seq', 'best_affinity', 'num_drugs'] | |
| print(f" Found {len(unique_targets)} unique proteins") | |
| success_count = 0 | |
| for idx, row in tqdm(unique_targets.iterrows(), total=len(unique_targets), desc=" Proteins"): | |
| try: | |
| # Truncate very long sequences for embedding | |
| sequence = str(row['target_seq'])[:1000] | |
| affinity_class = get_affinity_class(row['best_affinity'], dataset) | |
| result = qdrant.ingest( | |
| content=sequence, | |
| modality="protein", | |
| metadata={ | |
| "name": row['target_id'], | |
| "uniprot_id": row['target_id'], | |
| "sequence": sequence, | |
| "full_length": len(str(row['target_seq'])), | |
| "description": f"Target from {dataset.upper()} dataset", | |
| "source": dataset.lower(), | |
| "dataset": dataset.lower(), | |
| "best_affinity": float(row['best_affinity']), | |
| "affinity_class": affinity_class, | |
| "num_drugs": int(row['num_drugs']), | |
| } | |
| ) | |
| success_count += 1 | |
| except Exception as e: | |
| if success_count == 0: | |
| print(f"\n First error: {e}") # Show first error for debugging | |
| print(f" ✓ Ingested {success_count}/{len(unique_targets)} proteins") | |
| return success_count | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Ingest KIBA/DAVIS datasets into Qdrant") | |
| parser.add_argument("--dataset", choices=["kiba", "davis", "both"], default="davis", | |
| help="Dataset to ingest (default: davis)") | |
| parser.add_argument("--limit", type=int, default=1000, | |
| help="Limit number of records per dataset (default: 1000, 0 for all)") | |
| parser.add_argument("--clear", action="store_true", | |
| help="Clear existing collections before ingesting") | |
| args = parser.parse_args() | |
| print("=" * 60) | |
| print(" KIBA/DAVIS -> QDRANT INGESTION") | |
| print("=" * 60) | |
| qdrant = get_qdrant_service() | |
| if args.clear: | |
| print("\nClearing existing collections...") | |
| try: | |
| client = qdrant._get_client() | |
| for coll in qdrant.list_collections(): | |
| client.delete_collection(coll) | |
| print(f" Deleted: {coll}") | |
| # Clear the cache so collections will be recreated | |
| qdrant._initialized_collections.clear() | |
| except Exception as e: | |
| print(f" Warning: {e}") | |
| datasets = ["kiba", "davis"] if args.dataset == "both" else [args.dataset] | |
| limit = args.limit if args.limit > 0 else None | |
| total_molecules = 0 | |
| total_proteins = 0 | |
| for dataset in datasets: | |
| print(f"\n{'='*60}") | |
| print(f" Processing {dataset.upper()}") | |
| print("=" * 60) | |
| try: | |
| df = load_dataset(dataset, limit=limit) | |
| total_molecules += ingest_molecules(qdrant, df, dataset) | |
| total_proteins += ingest_proteins(qdrant, df, dataset) | |
| except FileNotFoundError as e: | |
| print(f" ERROR: {e}") | |
| continue | |
| print("\n" + "=" * 60) | |
| print(" INGESTION COMPLETE") | |
| print("=" * 60) | |
| print(f" Total molecules: {total_molecules}") | |
| print(f" Total proteins: {total_proteins}") | |
| print(f"\nSearch at: http://localhost:3000/dashboard/discovery") | |
| if __name__ == "__main__": | |
| main() | |