boardllm / src /retriever.py
melmoheb's picture
Upload folder using huggingface_hub
2247e66 verified
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from datasets import load_from_disk, Dataset
import os
import pandas as pd
class ClinicalCaseRetriever:
"""Retrieves relevant clinical cases based on user input using sentence-transformers embeddings."""
def __init__(self, dataset_path='./processed_clinical_cases', model_name="all-MiniLM-L6-v2"):
print(f"Initializing ClinicalCaseRetriever with model: {model_name}")
if isinstance(dataset_path, Dataset):
self.dataset = dataset_path
print("Using provided Hugging Face Dataset object.")
elif isinstance(dataset_path, str) and os.path.isdir(dataset_path):
try:
self.dataset = load_from_disk(dataset_path)
print(f"Dataset loaded successfully from disk: {dataset_path}")
except Exception as e:
print(f"Error loading dataset from disk {dataset_path}: {e}")
raise ValueError("Failed to load dataset.") from e
else:
raise ValueError(f"Invalid dataset_path: Must be a Dataset object or a valid directory path. Got: {dataset_path}")
if 'embeddings' not in self.dataset.column_names:
raise ValueError("Dataset must contain an 'embeddings' column.")
print(f"Dataset features: {self.dataset.features}")
print(f"Number of cases in dataset: {len(self.dataset)}")
self.model = SentenceTransformer(model_name)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
self.model.to(self.device)
try:
# Ensure embeddings are loaded as a NumPy array
self.case_embeddings = np.array(self.dataset['embeddings'])
if self.case_embeddings.ndim != 2:
raise ValueError(f"Embeddings array must be 2-dimensional. Got shape: {self.case_embeddings.shape}")
print(f"Loaded {len(self.dataset)} cases with embeddings of shape {self.case_embeddings.shape}")
except Exception as e:
print(f"Error processing embeddings from dataset: {e}")
raise ValueError("Failed to load or process embeddings.") from e
def get_available_cases(self, n=5):
"""Return a sample of available cases for user selection."""
num_cases = len(self.dataset)
if num_cases == 0:
return []
sample_size = min(n, num_cases)
indices = np.random.choice(num_cases, sample_size, replace=False)
# Ensure indices are int for slicing dataset
return [(int(i), self.dataset[int(i)]['clinical_presentation']) for i in indices]
def encode_query(self, query):
"""Generate embedding for a query string."""
# Create a better search query structure
search_query = f"Clinical case about {query}"
print(f"Encoding query: '{search_query}'")
# Generate embedding using sentence-transformers
try:
query_embedding = self.model.encode([search_query], convert_to_numpy=True, device=self.device, show_progress_bar=False)
return query_embedding
except Exception as e:
print(f"Error encoding query '{query}': {e}")
return None # Or raise an error
def retrieve_relevant_case(self, query, top_k=1, return_scores=False):
"""Find the most relevant clinical case(s) given a query."""
if not isinstance(top_k, int) or top_k < 1:
print("Warning: top_k must be a positive integer. Defaulting to 1.")
top_k = 1
# Get query embedding
query_embedding = self.encode_query(query)
if query_embedding is None:
return [] if not return_scores else ([], [])
# Calculate similarity scores
try:
similarities = cosine_similarity(query_embedding, self.case_embeddings)[0] # Get the single row of similarities
except Exception as e:
print(f"Error calculating cosine similarity: {e}")
return [] if not return_scores else ([], [])
# Get indices of top-k most similar cases
# Ensure we don't request more indices than available cases
k_actual = min(top_k, len(similarities))
if k_actual == 0: # Should not happen if dataset loaded, but safe check
return [] if not return_scores else ([], [])
# Use partitioning for efficiency if k is much smaller than N, or argsort otherwise
# Using argsort is generally simpler and fine for moderate N
top_indices = np.argsort(similarities)[-k_actual:][::-1].astype(int) # Get top k indices, sorted descending
top_scores = similarities[top_indices].tolist() # Get scores for these indices
# Return the most relevant case(s)
try:
# Retrieve cases safely using integer indices
retrieved_cases = [self.dataset[int(idx)] for idx in top_indices]
except IndexError as e:
print(f"Error retrieving cases using indices {top_indices}: {e}")
return [] if not return_scores else ([], [])
except Exception as e:
print(f"Unexpected error retrieving cases: {e}")
return [] if not return_scores else ([], [])
results_with_scores = list(zip(retrieved_cases, top_scores))
print(f"Retrieved {len(results_with_scores)} cases with similarity scores:")
for case, score in results_with_scores:
# Safely access presentation, provide default if missing
presentation = case.get('clinical_presentation', 'Unknown Presentation')
print(f"- {presentation}: {score:.4f}")
if return_scores:
return retrieved_cases, top_scores
else:
# Return list of tuples (case_dict, score)
return results_with_scores
class DummyRetriever:
"""A simple retriever that bypasses RAG, taking a pre-formatted DataFrame."""
def __init__(self, df):
self.dataset = []
if not isinstance(df, pd.DataFrame) or df.empty:
print("Warning: DummyRetriever initialized with empty or invalid DataFrame.")
return
# Expects df to be pre-processed with columns:
# 'clinical_presentation', 'turn_id', 'question', 'answer'
required_cols = ['clinical_presentation', 'turn_id', 'question', 'answer']
if not all(col in df.columns for col in required_cols):
print(f"Warning: DummyRetriever DataFrame missing required columns. Need: {required_cols}")
return
grouped = df.groupby('clinical_presentation')
print(f"DummyRetriever processing {len(grouped)} unique presentations.")
for i, (scenario, group) in enumerate(grouped):
group_sorted = group.sort_values('turn_id')
case_dict = {
"case_id": group_sorted['case_id'].iloc[0] if 'case_id' in group_sorted.columns else f"dummy_{i}",
"clinical_presentation": scenario,
"questions": group_sorted["question"].tolist(),
"answers": group_sorted["answer"].tolist()
}
self.dataset.append(case_dict)
print(f"DummyRetriever initialized with {len(self.dataset)} cases.")
def retrieve_relevant_case(self, scenario_query, top_k=1):
"""
Finds the case matching the query string exactly.
Ignores 'top_k' but mimics the return structure [(case_dict, score)].
"""
print(f"DummyRetriever searching for exact match: '{scenario_query}'")
for case_dict in self.dataset:
if case_dict["clinical_presentation"] == scenario_query:
print(f"DummyRetriever found match: {case_dict['clinical_presentation']}")
return [(case_dict, 1.0)]
print(f"DummyRetriever: No exact match found for '{scenario_query}'")
return []