yassinekolsi commited on
Commit
5770d80
·
1 Parent(s): b25a5e5

fix: PR review fixes - dockerfile, encoders, orchestrator, paths

Browse files

- dockerfile: use conda run for pip installs, remove hardcoded proxy, add COPY, fix ENV PATH syntax
- orchestrator: remove duplicate Optional import, fix retriever API (modality param), safe getattr
- encoders: batch_encode honors pooling setting in protein/text encoders
- qdrant_retriever: safe modality mapping for legacy values
- plugins/__init__: replace print with logging
- run_summary.json: convert absolute paths to relative (5 files)
- BIOFLOW_README: fix license badge (Apache->MIT)
- USE_POLICY: fix spelling (sperate->separate)
- model_customization.ipynb: fix config path
- enhanced_search: sort by score after MMR, qdrant API fallback
- UI: database filter state management, result display improvements

BIOFLOW_README.md CHANGED
@@ -1,7 +1,7 @@
1
  # BioFlow - AI-Powered Drug Discovery Platform
2
 
3
  [![Version](https://img.shields.io/badge/version-2.0.0-blue.svg)]()
4
- [![License](https://img.shields.io/badge/license-Apache%202.0-green.svg)](LICENSE)
5
 
6
  **BioFlow** is a unified AI platform for drug discovery, combining molecular encoding, protein analysis, and drug-target interaction prediction in a modern web interface.
7
 
 
1
  # BioFlow - AI-Powered Drug Discovery Platform
2
 
3
  [![Version](https://img.shields.io/badge/version-2.0.0-blue.svg)]()
4
+ [![License](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE)
5
 
6
  **BioFlow** is a unified AI platform for drug discovery, combining molecular encoding, protein analysis, and drug-target interaction prediction in a modern web interface.
7
 
USE_POLICY.md CHANGED
@@ -6,7 +6,7 @@ You have the right to use BioMedGPT pursuant to relevant agreements, but you can
6
 
7
  1. inciting to resist or undermine the implementation of the Constitution, laws and administrative regulations;
8
  2. inciting to subvert the state power and the overthrow of the political system;
9
- 3. inciting to sperate the state or undermine unity of the country;
10
  4. inciting national enmity or discrimination, undermine the unity of nations;
11
  5. content involving discrimination on the basis of race, sex, religion, geographical content, etc.;
12
  6. fabricating or distorting facts, spreading disinformation, or disturbing the public order;
 
6
 
7
  1. inciting to resist or undermine the implementation of the Constitution, laws and administrative regulations;
8
  2. inciting to subvert the state power and the overthrow of the political system;
9
+ 3. inciting to separate the state or undermine unity of the country;
10
  4. inciting national enmity or discrimination, undermine the unity of nations;
11
  5. content involving discrimination on the basis of race, sex, religion, geographical content, etc.;
12
  6. fabricating or distorting facts, spreading disinformation, or disturbing the public order;
bioflow/api/server.py CHANGED
@@ -539,7 +539,12 @@ async def enhanced_search(request: dict = None):
539
  collection = request.get("collection")
540
  use_mmr = request.get("use_mmr", True)
541
  lambda_param = request.get("lambda_param", 0.7)
542
- filters = request.get("filters")
 
 
 
 
 
543
 
544
  # Map old type names to new modality names
545
  type_to_modality = {
@@ -584,9 +589,9 @@ async def enhanced_search(request: dict = None):
584
  _log_event(
585
  "search",
586
  request_id,
587
- query=request.query[:200],
588
- top_k=request.top_k,
589
- use_mmr=request.use_mmr,
590
  returned=payload.get("returned"),
591
  total_found=payload.get("total_found"),
592
  duration_ms=round((time.perf_counter() - start) * 1000, 2),
 
539
  collection = request.get("collection")
540
  use_mmr = request.get("use_mmr", True)
541
  lambda_param = request.get("lambda_param", 0.7)
542
+ filters = request.get("filters") or {}
543
+ dataset = request.get("dataset") # Optional dataset filter (davis, kiba)
544
+
545
+ # Add dataset filter if specified
546
+ if dataset:
547
+ filters["source"] = dataset.lower()
548
 
549
  # Map old type names to new modality names
550
  type_to_modality = {
 
589
  _log_event(
590
  "search",
591
  request_id,
592
+ query=query[:200] if query else "",
593
+ top_k=top_k,
594
+ use_mmr=use_mmr,
595
  returned=payload.get("returned"),
596
  total_found=payload.get("total_found"),
597
  duration_ms=round((time.perf_counter() - start) * 1000, 2),
bioflow/core/orchestrator.py CHANGED
@@ -12,14 +12,10 @@ from dataclasses import dataclass, field
12
  from datetime import datetime
13
  from collections import defaultdict
14
 
15
- from typing import Optional as OptionalType
16
  from bioflow.core.base import BioEncoder, BioPredictor, BioGenerator, Modality
17
  from bioflow.core.config import NodeConfig, WorkflowConfig, NodeType
18
  from bioflow.core.registry import ToolRegistry
19
 
20
- # Re-import Optional with a different name to avoid conflicts
21
- from typing import Optional
22
-
23
  logging.basicConfig(level=logging.INFO)
24
  logger = logging.getLogger(__name__)
25
 
@@ -174,10 +170,11 @@ class BioFlowOrchestrator:
174
  if self._retriever is None:
175
  raise ValueError("No retriever configured. Call set_retriever() first.")
176
  limit = node.params.get("limit", 5)
177
- modality = node.params.get("modality", "text")
 
178
  return self._retriever.search(
179
  query=node_input,
180
- query_modality=modality,
181
  limit=limit
182
  )
183
 
@@ -191,7 +188,13 @@ class BioFlowOrchestrator:
191
  threshold = node.params.get("threshold", 0.5)
192
  key = node.params.get("key", "score")
193
  if isinstance(node_input, list):
194
- return [x for x in node_input if getattr(x, key, x.get(key, 0)) >= threshold]
 
 
 
 
 
 
195
  return node_input
196
 
197
  elif node.type == NodeType.CUSTOM:
 
12
  from datetime import datetime
13
  from collections import defaultdict
14
 
 
15
  from bioflow.core.base import BioEncoder, BioPredictor, BioGenerator, Modality
16
  from bioflow.core.config import NodeConfig, WorkflowConfig, NodeType
17
  from bioflow.core.registry import ToolRegistry
18
 
 
 
 
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
 
170
  if self._retriever is None:
171
  raise ValueError("No retriever configured. Call set_retriever() first.")
172
  limit = node.params.get("limit", 5)
173
+ modality_name = node.params.get("modality", "text")
174
+ modality = Modality(modality_name)
175
  return self._retriever.search(
176
  query=node_input,
177
+ modality=modality,
178
  limit=limit
179
  )
180
 
 
188
  threshold = node.params.get("threshold", 0.5)
189
  key = node.params.get("key", "score")
190
  if isinstance(node_input, list):
191
+ def _get_value(item: Any) -> float:
192
+ if hasattr(item, key):
193
+ return getattr(item, key)
194
+ if isinstance(item, dict):
195
+ return item.get(key, 0)
196
+ return 0
197
+ return [x for x in node_input if _get_value(x) >= threshold]
198
  return node_input
199
 
200
  elif node.type == NodeType.CUSTOM:
bioflow/plugins/__init__.py CHANGED
@@ -46,13 +46,21 @@ def register_all(registry=None):
46
 
47
  Args:
48
  registry: ToolRegistry instance (uses global if None)
 
 
 
49
  """
 
 
 
50
  from bioflow.core import ToolRegistry
51
  registry = registry or ToolRegistry
52
 
53
- # Note: Encoders are lazy-loaded, so we don't instantiate here
54
- # They will be registered when first used
55
- print("Plugins available for registration:")
56
- print(" Encoders: OBMEncoder, TextEncoder, MoleculeEncoder, ProteinEncoder")
57
- print(" Retrievers: QdrantRetriever")
58
- print(" Predictors: DeepPurposePredictor")
 
 
 
46
 
47
  Args:
48
  registry: ToolRegistry instance (uses global if None)
49
+
50
+ Returns:
51
+ dict: Available plugin classes by category
52
  """
53
+ import logging
54
+ logger = logging.getLogger(__name__)
55
+
56
  from bioflow.core import ToolRegistry
57
  registry = registry or ToolRegistry
58
 
59
+ available = {
60
+ "encoders": ["OBMEncoder", "TextEncoder", "MoleculeEncoder", "ProteinEncoder"],
61
+ "retrievers": ["QdrantRetriever"],
62
+ "predictors": ["DeepPurposePredictor"],
63
+ }
64
+
65
+ logger.info(f"Plugins available for registration: {available}")
66
+ return available
bioflow/plugins/encoders/protein_encoder.py CHANGED
@@ -173,8 +173,13 @@ class ProteinEncoder(BioEncoder):
173
  with torch.no_grad():
174
  outputs = self.model(**inputs)
175
  hidden_states = outputs.last_hidden_state
176
- attention_mask = inputs["attention_mask"].unsqueeze(-1)
177
- embeddings = (hidden_states * attention_mask).sum(1) / attention_mask.sum(1)
 
 
 
 
 
178
 
179
  results = []
180
  for i, emb in enumerate(embeddings):
 
173
  with torch.no_grad():
174
  outputs = self.model(**inputs)
175
  hidden_states = outputs.last_hidden_state
176
+
177
+ # Apply same pooling strategy as encode()
178
+ if self.pooling == "cls":
179
+ embeddings = hidden_states[:, 0, :]
180
+ else: # mean pooling
181
+ attention_mask = inputs["attention_mask"].unsqueeze(-1)
182
+ embeddings = (hidden_states * attention_mask).sum(1) / attention_mask.sum(1)
183
 
184
  results = []
185
  for i, emb in enumerate(embeddings):
bioflow/plugins/encoders/text_encoder.py CHANGED
@@ -159,11 +159,16 @@ class TextEncoder(BioEncoder):
159
  outputs = self.model(**inputs)
160
  hidden_states = outputs.last_hidden_state
161
 
162
- if self.pooling == "mean":
 
 
 
163
  attention_mask = inputs["attention_mask"].unsqueeze(-1)
164
  embeddings = (hidden_states * attention_mask).sum(1) / attention_mask.sum(1)
 
 
165
  else:
166
- embeddings = hidden_states[:, 0, :]
167
 
168
  results = []
169
  for i, emb in enumerate(embeddings):
 
159
  outputs = self.model(**inputs)
160
  hidden_states = outputs.last_hidden_state
161
 
162
+ # Apply same pooling strategy as encode()
163
+ if self.pooling == "cls":
164
+ embeddings = hidden_states[:, 0, :]
165
+ elif self.pooling == "mean":
166
  attention_mask = inputs["attention_mask"].unsqueeze(-1)
167
  embeddings = (hidden_states * attention_mask).sum(1) / attention_mask.sum(1)
168
+ elif self.pooling == "max":
169
+ embeddings = hidden_states.max(dim=1).values
170
  else:
171
+ raise ValueError(f"Unknown pooling: {self.pooling}")
172
 
173
  results = []
174
  for i, emb in enumerate(embeddings):
bioflow/plugins/qdrant_retriever.py CHANGED
@@ -193,13 +193,29 @@ class QdrantRetriever(BioRetriever):
193
  query_filter=qdrant_filter
194
  )
195
 
196
- # Convert to RetrievalResult
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  return [
198
  RetrievalResult(
199
  id=str(r.id),
200
  score=r.score,
201
  content=r.payload.get("content", ""),
202
- modality=Modality(r.payload.get("modality", "text")),
203
  payload=r.payload
204
  )
205
  for r in results
 
193
  query_filter=qdrant_filter
194
  )
195
 
196
+ # Convert to RetrievalResult with safe modality mapping
197
+ def _safe_modality(payload: dict) -> Modality:
198
+ raw = payload.get("modality")
199
+ if isinstance(raw, Modality):
200
+ return raw
201
+ if not isinstance(raw, str):
202
+ return Modality.TEXT
203
+ norm = raw.strip().lower()
204
+ # Map legacy/synonym values
205
+ synonym_map = {"molecule": "smiles", "drug": "smiles"}
206
+ if norm in synonym_map:
207
+ norm = synonym_map[norm]
208
+ try:
209
+ return Modality(norm)
210
+ except ValueError:
211
+ return Modality.TEXT
212
+
213
  return [
214
  RetrievalResult(
215
  id=str(r.id),
216
  score=r.score,
217
  content=r.payload.get("content", ""),
218
+ modality=_safe_modality(r.payload),
219
  payload=r.payload
220
  )
221
  for r in results
bioflow/runs/20260125_080409_BindingDB_Kd/run_summary.json CHANGED
@@ -20,12 +20,12 @@
20
  "ci_approx": 0.8053618329014657
21
  },
22
  "files": {
23
- "predictions_test_csv": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_080409_BindingDB_Kd\\predictions_test.csv",
24
- "scatter_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_080409_BindingDB_Kd\\scatter.png",
25
- "curves_sorted_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_080409_BindingDB_Kd\\curves_sorted.png",
26
- "residuals_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_080409_BindingDB_Kd\\residuals.png",
27
- "hist_true_pred_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_080409_BindingDB_Kd\\hist_true_pred.png",
28
- "ecdf_true_pred_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_080409_BindingDB_Kd\\ecdf_true_pred.png"
29
  },
30
  "timing": {
31
  "total_seconds": 6569.407529115677,
 
20
  "ci_approx": 0.8053618329014657
21
  },
22
  "files": {
23
+ "predictions_test_csv": "predictions_test.csv",
24
+ "scatter_png": "scatter.png",
25
+ "curves_sorted_png": "curves_sorted.png",
26
+ "residuals_png": "residuals.png",
27
+ "hist_true_pred_png": "hist_true_pred.png",
28
+ "ecdf_true_pred_png": "ecdf_true_pred.png"
29
  },
30
  "timing": {
31
  "total_seconds": 6569.407529115677,
bioflow/runs/20260125_104915_KIBA/run_summary.json CHANGED
@@ -20,12 +20,12 @@
20
  "ci_approx": 0.7031028951074637
21
  },
22
  "files": {
23
- "predictions_test_csv": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_104915_KIBA\\predictions_test.csv",
24
- "scatter_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_104915_KIBA\\scatter.png",
25
- "curves_sorted_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_104915_KIBA\\curves_sorted.png",
26
- "residuals_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_104915_KIBA\\residuals.png",
27
- "hist_true_pred_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_104915_KIBA\\hist_true_pred.png",
28
- "ecdf_true_pred_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_104915_KIBA\\ecdf_true_pred.png"
29
  },
30
  "timing": {
31
  "total_seconds": 26236.644289016724,
 
20
  "ci_approx": 0.7031028951074637
21
  },
22
  "files": {
23
+ "predictions_test_csv": "predictions_test.csv",
24
+ "scatter_png": "scatter.png",
25
+ "curves_sorted_png": "curves_sorted.png",
26
+ "residuals_png": "residuals.png",
27
+ "hist_true_pred_png": "hist_true_pred.png",
28
+ "ecdf_true_pred_png": "ecdf_true_pred.png"
29
  },
30
  "timing": {
31
  "total_seconds": 26236.644289016724,
bioflow/search/enhanced_search.py CHANGED
@@ -252,6 +252,12 @@ class EnhancedSearchService:
252
  diversity_score = None
253
  enhanced_results = self._raw_to_enhanced(raw_results[:top_k])
254
 
 
 
 
 
 
 
255
  # Apply evidence linking
256
  enhanced_results = self._add_evidence_links(enhanced_results)
257
 
@@ -413,14 +419,26 @@ class EnhancedSearchService:
413
 
414
  for coll in collections:
415
  try:
416
- results = client.query_points(
417
- collection_name=coll,
418
- query=query_embedding,
419
- limit=limit,
420
- query_filter=query_filter,
421
- with_payload=True,
422
- with_vectors=with_vectors,
423
- ).points
 
 
 
 
 
 
 
 
 
 
 
 
424
 
425
  for r in results:
426
  payload_modality = r.payload.get('modality', 'unknown')
 
252
  diversity_score = None
253
  enhanced_results = self._raw_to_enhanced(raw_results[:top_k])
254
 
255
+ # Sort by original score for display (MMR selection already done)
256
+ enhanced_results.sort(key=lambda x: x.score, reverse=True)
257
+ # Update ranks after sorting
258
+ for i, r in enumerate(enhanced_results):
259
+ r.rank = i + 1
260
+
261
  # Apply evidence linking
262
  enhanced_results = self._add_evidence_links(enhanced_results)
263
 
 
419
 
420
  for coll in collections:
421
  try:
422
+ # Use search() for qdrant-client < 1.10, query_points() for >= 1.10
423
+ try:
424
+ results = client.query_points(
425
+ collection_name=coll,
426
+ query=query_embedding,
427
+ limit=limit,
428
+ query_filter=query_filter,
429
+ with_payload=True,
430
+ with_vectors=with_vectors,
431
+ ).points
432
+ except AttributeError:
433
+ # Fallback to older API (qdrant-client < 1.10)
434
+ results = client.search(
435
+ collection_name=coll,
436
+ query_vector=query_embedding,
437
+ limit=limit,
438
+ query_filter=query_filter,
439
+ with_payload=True,
440
+ with_vectors=with_vectors,
441
+ )
442
 
443
  for r in results:
444
  payload_modality = r.payload.get('modality', 'unknown')
dockerfile CHANGED
@@ -28,33 +28,36 @@ RUN conda init bash \
28
  && conda activate OpenBioMed \
29
  && pip install --upgrade pip setuptools
30
 
31
- # Installing PyTorch and torchvision
32
- RUN pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117 \
33
- && pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-1.13.1+cu117.html \
34
- && pip install pytorch_lightning==2.0.8 peft==0.9.0 accelerate==1.3.0 --no-deps -i https://pypi.tuna.tsinghua.edu.cn/simple
 
 
 
 
35
 
36
  # Install additional packages from requirements.txt
37
- RUN pip install -r requirements.txt
38
 
39
  # Install visualization tools
40
- RUN conda install -c conda-forge pymol-open-source -y \
41
- && pip install imageio
42
 
43
- # Install AutoDockVina tools
44
- RUN git config --global http.proxy http://100.68.173.241:3128 \
45
- && git config --global https.proxy http://100.68.173.241:3128 \
46
- && pip install meeko==0.1.dev3 pdb2pqr vina==1.2.2 \
47
- && pip install git+https://github.com/Valdes-Tresanco-MS/AutoDockTools_py3
48
 
49
  # Install NLTK
50
- RUN pip install spacy rouge_score nltk \
51
- && python -c "import nltk; nltk.download('wordnet'); nltk.download('omw-1.4')"
52
-
53
- # Set working directory
54
- WORKDIR /app
55
 
56
  # Activate the OpenBioMed environment by default
57
  RUN echo "source activate OpenBioMed" >> ~/.bashrc
 
 
 
 
58
 
59
  # Set default command
60
  ENTRYPOINT ["./scripts/run_server.sh"]
 
28
  && conda activate OpenBioMed \
29
  && pip install --upgrade pip setuptools
30
 
31
+ # Installing PyTorch and torchvision (using conda run to install in OpenBioMed env)
32
+ RUN conda run -n OpenBioMed pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117 \
33
+ && conda run -n OpenBioMed pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-1.13.1+cu117.html \
34
+ && conda run -n OpenBioMed pip install pytorch_lightning==2.0.8 peft==0.9.0 accelerate==1.3.0 --no-deps
35
+
36
+ # Set working directory and copy application files
37
+ WORKDIR /app
38
+ COPY . /app/
39
 
40
  # Install additional packages from requirements.txt
41
+ RUN conda run -n OpenBioMed pip install -r requirements.txt
42
 
43
  # Install visualization tools
44
+ RUN conda install -n OpenBioMed -c conda-forge pymol-open-source -y \
45
+ && conda run -n OpenBioMed pip install imageio
46
 
47
+ # Install AutoDockVina tools (proxy removed - use Docker build args if needed)
48
+ RUN conda run -n OpenBioMed pip install meeko==0.1.dev3 pdb2pqr vina==1.2.2 \
49
+ && conda run -n OpenBioMed pip install git+https://github.com/Valdes-Tresanco-MS/AutoDockTools_py3
 
 
50
 
51
  # Install NLTK
52
+ RUN conda run -n OpenBioMed pip install spacy rouge_score nltk \
53
+ && conda run -n OpenBioMed python -c "import nltk; nltk.download('wordnet'); nltk.download('omw-1.4')"
 
 
 
54
 
55
  # Activate the OpenBioMed environment by default
56
  RUN echo "source activate OpenBioMed" >> ~/.bashrc
57
+ ENV PATH="/root/miniconda3/envs/OpenBioMed/bin:$PATH"
58
+
59
+ # Make entrypoint executable
60
+ RUN chmod +x ./scripts/run_server.sh || true
61
 
62
  # Set default command
63
  ENTRYPOINT ["./scripts/run_server.sh"]
examples/model_customization.ipynb CHANGED
@@ -9,7 +9,8 @@
9
  "from open_biomed.core.pipeline import InferencePipeline\n",
10
  "from open_biomed.data import Molecule, Text\n",
11
  "\n",
12
- "cfg_path = \"./configs/text_based_molecule_editing/molt5.json\"\n",
 
13
  "pipeline = InferencePipeline(cfg_path)\n",
14
  "mol = [Molecule.from_smiles(\"CCCCC\")]\n",
15
  "text = [Text.from_str(\"wow\")]\n",
 
9
  "from open_biomed.core.pipeline import InferencePipeline\n",
10
  "from open_biomed.data import Molecule, Text\n",
11
  "\n",
12
+ "# Path relative to repo root (run from examples/ directory)\n",
13
+ "cfg_path = \"../configs/text_based_molecule_editing/molt5.json\"\n",
14
  "pipeline = InferencePipeline(cfg_path)\n",
15
  "mol = [Molecule.from_smiles(\"CCCCC\")]\n",
16
  "text = [Text.from_str(\"wow\")]\n",
ingest_dti_data.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ingest KIBA/DAVIS Drug-Target Interaction datasets into Qdrant.
3
+
4
+ Uses OBMEncoder (768-dim) to create searchable vectors from real DTI data.
5
+ """
6
+ import sys
7
+ import os
8
+ import argparse
9
+
10
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
11
+ sys.path.insert(0, ROOT_DIR)
12
+
13
+ import pandas as pd
14
+ from tqdm import tqdm
15
+
16
+ from bioflow.api.qdrant_service import get_qdrant_service
17
+
18
+
19
+ def load_dataset(dataset_name: str, limit: int = None) -> pd.DataFrame:
20
+ """Load KIBA or DAVIS dataset from local .tab files."""
21
+ filepath = os.path.join(ROOT_DIR, "data", f"{dataset_name.lower()}.tab")
22
+
23
+ if not os.path.exists(filepath):
24
+ raise FileNotFoundError(f"Dataset not found: {filepath}")
25
+
26
+ print(f"Loading {dataset_name} from {filepath}...")
27
+ df = pd.read_csv(filepath, sep='\t')
28
+
29
+ # Rename columns for consistency
30
+ # Format: ID1, X1 (SMILES), ID2, X2 (sequence), Y (affinity)
31
+ df.columns = ['drug_id', 'smiles', 'target_id', 'target_seq', 'affinity']
32
+
33
+ # Remove duplicates (keep unique drug-target pairs)
34
+ df = df.drop_duplicates(subset=['smiles', 'target_id'])
35
+
36
+ if limit:
37
+ df = df.head(limit)
38
+
39
+ print(f" Loaded {len(df)} unique drug-target pairs")
40
+ return df
41
+
42
+
43
+ def get_affinity_class(affinity: float, dataset: str) -> str:
44
+ """Classify affinity into high/medium/low based on dataset thresholds."""
45
+ if dataset.upper() == "KIBA":
46
+ # KIBA: lower is better (inhibition constant)
47
+ if affinity < 6:
48
+ return "high"
49
+ elif affinity < 8:
50
+ return "medium"
51
+ else:
52
+ return "low"
53
+ else: # DAVIS
54
+ # DAVIS: Kd values, lower is better
55
+ if affinity < 6:
56
+ return "high"
57
+ elif affinity < 7:
58
+ return "medium"
59
+ else:
60
+ return "low"
61
+
62
+
63
+ def get_drug_name(drug_id, smiles: str) -> str:
64
+ """Generate a readable drug name from ID or SMILES."""
65
+ drug_id_str = str(drug_id)
66
+ # If drug_id is numeric (like PubChem ID), create a friendly name
67
+ if drug_id_str.isdigit():
68
+ # Use PubChem CID format for known numeric IDs
69
+ return f"CID-{drug_id_str}"
70
+ return drug_id_str
71
+
72
+
73
+ def ingest_molecules(qdrant, df: pd.DataFrame, dataset: str, batch_size: int = 50):
74
+ """Ingest unique molecules (drugs) from the dataset."""
75
+ print("\n[1/2] Ingesting molecules (drugs)...")
76
+
77
+ # Get unique SMILES with their best affinity
78
+ unique_drugs = df.groupby('smiles').agg({
79
+ 'drug_id': 'first',
80
+ 'affinity': 'min', # Best affinity
81
+ 'target_id': 'count' # Number of targets
82
+ }).reset_index()
83
+ unique_drugs.columns = ['smiles', 'drug_id', 'best_affinity', 'num_targets']
84
+
85
+ print(f" Found {len(unique_drugs)} unique molecules")
86
+
87
+ success_count = 0
88
+ for idx, row in tqdm(unique_drugs.iterrows(), total=len(unique_drugs), desc=" Molecules"):
89
+ try:
90
+ affinity_class = get_affinity_class(row['best_affinity'], dataset)
91
+ drug_name = get_drug_name(row['drug_id'], row['smiles'])
92
+
93
+ result = qdrant.ingest(
94
+ content=row['smiles'],
95
+ modality="molecule",
96
+ metadata={
97
+ "name": drug_name,
98
+ "drug_id": str(row['drug_id']), # Keep original ID
99
+ "smiles": row['smiles'],
100
+ "description": f"Drug from {dataset.upper()} dataset",
101
+ "source": dataset.lower(),
102
+ "dataset": dataset.lower(),
103
+ "best_affinity": float(row['best_affinity']),
104
+ "affinity_class": affinity_class,
105
+ "num_targets": int(row['num_targets']),
106
+ }
107
+ )
108
+ success_count += 1
109
+ except Exception as e:
110
+ if success_count == 0:
111
+ print(f"\n First error: {e}") # Show first error for debugging
112
+
113
+ print(f" ✓ Ingested {success_count}/{len(unique_drugs)} molecules")
114
+ return success_count
115
+
116
+
117
+ def ingest_proteins(qdrant, df: pd.DataFrame, dataset: str, batch_size: int = 50):
118
+ """Ingest unique proteins (targets) from the dataset."""
119
+ print("\n[2/2] Ingesting proteins (targets)...")
120
+
121
+ # Get unique proteins with their best affinity
122
+ unique_targets = df.groupby('target_id').agg({
123
+ 'target_seq': 'first',
124
+ 'affinity': 'min', # Best affinity
125
+ 'smiles': 'count' # Number of drugs
126
+ }).reset_index()
127
+ unique_targets.columns = ['target_id', 'target_seq', 'best_affinity', 'num_drugs']
128
+
129
+ print(f" Found {len(unique_targets)} unique proteins")
130
+
131
+ success_count = 0
132
+ for idx, row in tqdm(unique_targets.iterrows(), total=len(unique_targets), desc=" Proteins"):
133
+ try:
134
+ # Truncate very long sequences for embedding
135
+ sequence = str(row['target_seq'])[:1000]
136
+ affinity_class = get_affinity_class(row['best_affinity'], dataset)
137
+
138
+ result = qdrant.ingest(
139
+ content=sequence,
140
+ modality="protein",
141
+ metadata={
142
+ "name": row['target_id'],
143
+ "uniprot_id": row['target_id'],
144
+ "sequence": sequence,
145
+ "full_length": len(str(row['target_seq'])),
146
+ "description": f"Target from {dataset.upper()} dataset",
147
+ "source": dataset.lower(),
148
+ "dataset": dataset.lower(),
149
+ "best_affinity": float(row['best_affinity']),
150
+ "affinity_class": affinity_class,
151
+ "num_drugs": int(row['num_drugs']),
152
+ }
153
+ )
154
+ success_count += 1
155
+ except Exception as e:
156
+ if success_count == 0:
157
+ print(f"\n First error: {e}") # Show first error for debugging
158
+
159
+ print(f" ✓ Ingested {success_count}/{len(unique_targets)} proteins")
160
+ return success_count
161
+
162
+
163
+ def main():
164
+ parser = argparse.ArgumentParser(description="Ingest KIBA/DAVIS datasets into Qdrant")
165
+ parser.add_argument("--dataset", choices=["kiba", "davis", "both"], default="davis",
166
+ help="Dataset to ingest (default: davis)")
167
+ parser.add_argument("--limit", type=int, default=1000,
168
+ help="Limit number of records per dataset (default: 1000, 0 for all)")
169
+ parser.add_argument("--clear", action="store_true",
170
+ help="Clear existing collections before ingesting")
171
+ args = parser.parse_args()
172
+
173
+ print("=" * 60)
174
+ print(" KIBA/DAVIS -> QDRANT INGESTION")
175
+ print("=" * 60)
176
+
177
+ qdrant = get_qdrant_service()
178
+
179
+ if args.clear:
180
+ print("\nClearing existing collections...")
181
+ try:
182
+ client = qdrant._get_client()
183
+ for coll in qdrant.list_collections():
184
+ client.delete_collection(coll)
185
+ print(f" Deleted: {coll}")
186
+ # Clear the cache so collections will be recreated
187
+ qdrant._initialized_collections.clear()
188
+ except Exception as e:
189
+ print(f" Warning: {e}")
190
+
191
+ datasets = ["kiba", "davis"] if args.dataset == "both" else [args.dataset]
192
+ limit = args.limit if args.limit > 0 else None
193
+
194
+ total_molecules = 0
195
+ total_proteins = 0
196
+
197
+ for dataset in datasets:
198
+ print(f"\n{'='*60}")
199
+ print(f" Processing {dataset.upper()}")
200
+ print("=" * 60)
201
+
202
+ try:
203
+ df = load_dataset(dataset, limit=limit)
204
+ total_molecules += ingest_molecules(qdrant, df, dataset)
205
+ total_proteins += ingest_proteins(qdrant, df, dataset)
206
+ except FileNotFoundError as e:
207
+ print(f" ERROR: {e}")
208
+ continue
209
+
210
+ print("\n" + "=" * 60)
211
+ print(" INGESTION COMPLETE")
212
+ print("=" * 60)
213
+ print(f" Total molecules: {total_molecules}")
214
+ print(f" Total proteins: {total_proteins}")
215
+ print(f"\nSearch at: http://localhost:3000/dashboard/discovery")
216
+
217
+
218
+ if __name__ == "__main__":
219
+ main()
220
+
runs/20260125_080409_BindingDB_Kd/run_summary.json CHANGED
@@ -20,12 +20,12 @@
20
  "ci_approx": 0.8053618329014657
21
  },
22
  "files": {
23
- "predictions_test_csv": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_080409_BindingDB_Kd\\predictions_test.csv",
24
- "scatter_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_080409_BindingDB_Kd\\scatter.png",
25
- "curves_sorted_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_080409_BindingDB_Kd\\curves_sorted.png",
26
- "residuals_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_080409_BindingDB_Kd\\residuals.png",
27
- "hist_true_pred_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_080409_BindingDB_Kd\\hist_true_pred.png",
28
- "ecdf_true_pred_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_080409_BindingDB_Kd\\ecdf_true_pred.png"
29
  },
30
  "timing": {
31
  "total_seconds": 6569.407529115677,
 
20
  "ci_approx": 0.8053618329014657
21
  },
22
  "files": {
23
+ "predictions_test_csv": "predictions_test.csv",
24
+ "scatter_png": "scatter.png",
25
+ "curves_sorted_png": "curves_sorted.png",
26
+ "residuals_png": "residuals.png",
27
+ "hist_true_pred_png": "hist_true_pred.png",
28
+ "ecdf_true_pred_png": "ecdf_true_pred.png"
29
  },
30
  "timing": {
31
  "total_seconds": 6569.407529115677,
runs/20260125_104915_KIBA/run_summary.json CHANGED
@@ -20,12 +20,12 @@
20
  "ci_approx": 0.7031028951074637
21
  },
22
  "files": {
23
- "predictions_test_csv": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_104915_KIBA\\predictions_test.csv",
24
- "scatter_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_104915_KIBA\\scatter.png",
25
- "curves_sorted_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_104915_KIBA\\curves_sorted.png",
26
- "residuals_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_104915_KIBA\\residuals.png",
27
- "hist_true_pred_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_104915_KIBA\\hist_true_pred.png",
28
- "ecdf_true_pred_png": "c:\\Users\\hamza\\Downloads\\Free Spitfire 3D printed RC plane - 2118624\\files\\runs\\20260125_104915_KIBA\\ecdf_true_pred.png"
29
  },
30
  "timing": {
31
  "total_seconds": 26236.644289016724,
 
20
  "ci_approx": 0.7031028951074637
21
  },
22
  "files": {
23
+ "predictions_test_csv": "predictions_test.csv",
24
+ "scatter_png": "scatter.png",
25
+ "curves_sorted_png": "curves_sorted.png",
26
+ "residuals_png": "residuals.png",
27
+ "hist_true_pred_png": "hist_true_pred.png",
28
+ "ecdf_true_pred_png": "ecdf_true_pred.png"
29
  },
30
  "timing": {
31
  "total_seconds": 26236.644289016724,
ui/app/api/agents/workflow/route.ts CHANGED
@@ -1,4 +1,3 @@
1
- <<<<<<< HEAD
2
  import { NextResponse } from "next/server"
3
  import { API_CONFIG } from "@/config/api.config"
4
 
@@ -67,44 +66,3 @@ export async function POST(request: Request) {
67
  const mockResult = generateMockWorkflowResult(query, num_candidates)
68
  return NextResponse.json(mockResult)
69
  }
70
- =======
71
- import { NextResponse } from "next/server"
72
-
73
- import { API_CONFIG } from "@/config/api.config"
74
-
75
- export async function POST(request: Request) {
76
- const body = await request.json().catch(() => ({}))
77
-
78
- try {
79
- const response = await fetch(`${API_CONFIG.baseUrl}/api/agents/workflow`, {
80
- method: "POST",
81
- headers: { "Content-Type": "application/json" },
82
- body: JSON.stringify(body),
83
- cache: "no-store",
84
- })
85
-
86
- const data = await response.json().catch(() => null)
87
- if (!response.ok) {
88
- return NextResponse.json(
89
- { error: data?.detail || data?.error || `Backend returned ${response.status}` },
90
- { status: response.status }
91
- )
92
- }
93
-
94
- return NextResponse.json(data)
95
- } catch (error) {
96
- console.warn("Workflow API error, using mock response:", error)
97
- return NextResponse.json({
98
- success: true,
99
- status: "mock",
100
- steps_completed: 0,
101
- total_steps: 0,
102
- execution_time_ms: 0,
103
- top_candidates: [],
104
- all_outputs: {},
105
- errors: ["Backend unavailable"],
106
- })
107
- }
108
- }
109
-
110
- >>>>>>> Rami
 
 
1
  import { NextResponse } from "next/server"
2
  import { API_CONFIG } from "@/config/api.config"
3
 
 
66
  const mockResult = generateMockWorkflowResult(query, num_candidates)
67
  return NextResponse.json(mockResult)
68
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ui/app/dashboard/discovery/page.tsx CHANGED
@@ -16,15 +16,24 @@ const API_BASE = process.env.NEXT_PUBLIC_API_URL || "http://localhost:8000";
16
  interface SearchResult {
17
  id: string;
18
  score: number;
19
- smiles: string;
20
- target_seq: string;
21
- label: number;
22
- affinity_class: string;
 
 
 
 
 
 
 
 
23
  }
24
 
25
  export default function DiscoveryPage() {
26
  const [query, setQuery] = React.useState("")
27
  const [searchType, setSearchType] = React.useState("Similarity")
 
28
  const [isSearching, setIsSearching] = React.useState(false)
29
  const [step, setStep] = React.useState(0)
30
  const [results, setResults] = React.useState<SearchResult[]>([])
@@ -70,7 +79,8 @@ export default function DiscoveryPage() {
70
  body: JSON.stringify({
71
  query: query.trim(),
72
  type: apiType,
73
- limit: 10
 
74
  })
75
  });
76
 
@@ -146,13 +156,14 @@ export default function DiscoveryPage() {
146
  </div>
147
  <div className="space-y-2">
148
  <Label>Database</Label>
149
- <Select defaultValue="KIBA">
150
  <SelectTrigger>
151
  <SelectValue placeholder="Select database" />
152
  </SelectTrigger>
153
  <SelectContent>
154
- <SelectItem value="KIBA">KIBA (23.5K pairs)</SelectItem>
155
- <SelectItem value="DAVIS">DAVIS Kinase</SelectItem>
 
156
  </SelectContent>
157
  </Select>
158
  </div>
@@ -221,19 +232,40 @@ export default function DiscoveryPage() {
221
  <Card key={result.id}>
222
  <CardContent className="p-4 flex items-center justify-between">
223
  <div className="flex-1">
224
- <div className="font-mono text-sm font-medium">
225
- {result.smiles?.slice(0, 50)}{result.smiles?.length > 50 ? '...' : ''}
226
  </div>
227
- <div className="flex gap-4 text-sm text-muted-foreground mt-1">
228
- <span>Affinity: {result.affinity_class}</span>
229
- <span>Label: {result.label?.toFixed(2)}</span>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  </div>
231
  </div>
232
  <div className="text-right">
233
  <div className="text-sm text-muted-foreground">Similarity</div>
234
  <div className={`text-xl font-bold ${
235
  result.score >= 0.9 ? 'text-green-600' :
236
- result.score >= 0.7 ? 'text-green-500' : 'text-amber-500'
 
237
  }`}>
238
  {result.score.toFixed(3)}
239
  </div>
 
16
  interface SearchResult {
17
  id: string;
18
  score: number;
19
+ mmr_score?: number;
20
+ content: string;
21
+ modality: string;
22
+ metadata: {
23
+ name?: string;
24
+ smiles?: string;
25
+ description?: string;
26
+ source?: string;
27
+ label_true?: number;
28
+ affinity_class?: string;
29
+ [key: string]: unknown;
30
+ };
31
  }
32
 
33
  export default function DiscoveryPage() {
34
  const [query, setQuery] = React.useState("")
35
  const [searchType, setSearchType] = React.useState("Similarity")
36
+ const [database, setDatabase] = React.useState("both")
37
  const [isSearching, setIsSearching] = React.useState(false)
38
  const [step, setStep] = React.useState(0)
39
  const [results, setResults] = React.useState<SearchResult[]>([])
 
79
  body: JSON.stringify({
80
  query: query.trim(),
81
  type: apiType,
82
+ limit: 10,
83
+ dataset: database !== "both" ? database.toLowerCase() : undefined
84
  })
85
  });
86
 
 
156
  </div>
157
  <div className="space-y-2">
158
  <Label>Database</Label>
159
+ <Select value={database} onValueChange={setDatabase}>
160
  <SelectTrigger>
161
  <SelectValue placeholder="Select database" />
162
  </SelectTrigger>
163
  <SelectContent>
164
+ <SelectItem value="both">All Datasets</SelectItem>
165
+ <SelectItem value="kiba">KIBA (Kinase Inhibitors)</SelectItem>
166
+ <SelectItem value="davis">DAVIS (Kinase Targets)</SelectItem>
167
  </SelectContent>
168
  </Select>
169
  </div>
 
232
  <Card key={result.id}>
233
  <CardContent className="p-4 flex items-center justify-between">
234
  <div className="flex-1">
235
+ <div className="font-semibold text-base mb-1">
236
+ {result.metadata?.name || `Result ${i + 1}`}
237
  </div>
238
+ <div className="font-mono text-sm text-muted-foreground">
239
+ {(result.metadata?.smiles || result.content)?.slice(0, 60)}
240
+ {(result.metadata?.smiles || result.content)?.length > 60 ? '...' : ''}
241
+ </div>
242
+ {result.metadata?.description && (
243
+ <div className="text-sm text-muted-foreground mt-1">
244
+ {result.metadata.description}
245
+ </div>
246
+ )}
247
+ <div className="flex gap-4 text-xs text-muted-foreground mt-2">
248
+ {result.metadata?.affinity_class && (
249
+ <span className="bg-muted px-2 py-0.5 rounded">
250
+ Affinity: {result.metadata.affinity_class}
251
+ </span>
252
+ )}
253
+ {result.metadata?.label_true != null && (
254
+ <span className="bg-muted px-2 py-0.5 rounded">
255
+ Label: {result.metadata.label_true.toFixed(2)}
256
+ </span>
257
+ )}
258
+ <span className="bg-muted px-2 py-0.5 rounded">
259
+ {result.modality}
260
+ </span>
261
  </div>
262
  </div>
263
  <div className="text-right">
264
  <div className="text-sm text-muted-foreground">Similarity</div>
265
  <div className={`text-xl font-bold ${
266
  result.score >= 0.9 ? 'text-green-600' :
267
+ result.score >= 0.7 ? 'text-green-500' :
268
+ result.score >= 0.5 ? 'text-amber-500' : 'text-muted-foreground'
269
  }`}>
270
  {result.score.toFixed(3)}
271
  </div>