File size: 8,030 Bytes
129641e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from datasets import load_from_disk, Dataset
import os
import pandas as pd
class ClinicalCaseRetriever:
"""Retrieves relevant clinical cases based on user input using sentence-transformers embeddings."""
def __init__(self, dataset_path='./processed_clinical_cases', model_name="all-MiniLM-L6-v2"):
print(f"Initializing ClinicalCaseRetriever with model: {model_name}")
if isinstance(dataset_path, Dataset):
self.dataset = dataset_path
print("Using provided Hugging Face Dataset object.")
elif isinstance(dataset_path, str) and os.path.isdir(dataset_path):
try:
self.dataset = load_from_disk(dataset_path)
print(f"Dataset loaded successfully from disk: {dataset_path}")
except Exception as e:
print(f"Error loading dataset from disk {dataset_path}: {e}")
raise ValueError("Failed to load dataset.") from e
else:
raise ValueError(f"Invalid dataset_path: Must be a Dataset object or a valid directory path. Got: {dataset_path}")
if 'embeddings' not in self.dataset.column_names:
raise ValueError("Dataset must contain an 'embeddings' column.")
print(f"Dataset features: {self.dataset.features}")
print(f"Number of cases in dataset: {len(self.dataset)}")
self.model = SentenceTransformer(model_name)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
self.model.to(self.device)
try:
# Ensure embeddings are loaded as a NumPy array
self.case_embeddings = np.array(self.dataset['embeddings'])
if self.case_embeddings.ndim != 2:
raise ValueError(f"Embeddings array must be 2-dimensional. Got shape: {self.case_embeddings.shape}")
print(f"Loaded {len(self.dataset)} cases with embeddings of shape {self.case_embeddings.shape}")
except Exception as e:
print(f"Error processing embeddings from dataset: {e}")
raise ValueError("Failed to load or process embeddings.") from e
def get_available_cases(self, n=5):
"""Return a sample of available cases for user selection."""
num_cases = len(self.dataset)
if num_cases == 0:
return []
sample_size = min(n, num_cases)
indices = np.random.choice(num_cases, sample_size, replace=False)
# Ensure indices are int for slicing dataset
return [(int(i), self.dataset[int(i)]['clinical_presentation']) for i in indices]
def encode_query(self, query):
"""Generate embedding for a query string."""
# Create a better search query structure
search_query = f"Clinical case about {query}"
print(f"Encoding query: '{search_query}'")
# Generate embedding using sentence-transformers
try:
query_embedding = self.model.encode([search_query], convert_to_numpy=True, device=self.device, show_progress_bar=False)
return query_embedding
except Exception as e:
print(f"Error encoding query '{query}': {e}")
return None # Or raise an error
def retrieve_relevant_case(self, query, top_k=1, return_scores=False):
"""Find the most relevant clinical case(s) given a query."""
if not isinstance(top_k, int) or top_k < 1:
print("Warning: top_k must be a positive integer. Defaulting to 1.")
top_k = 1
# Get query embedding
query_embedding = self.encode_query(query)
if query_embedding is None:
return [] if not return_scores else ([], [])
# Calculate similarity scores
try:
similarities = cosine_similarity(query_embedding, self.case_embeddings)[0] # Get the single row of similarities
except Exception as e:
print(f"Error calculating cosine similarity: {e}")
return [] if not return_scores else ([], [])
# Get indices of top-k most similar cases
# Ensure we don't request more indices than available cases
k_actual = min(top_k, len(similarities))
if k_actual == 0: # Should not happen if dataset loaded, but safe check
return [] if not return_scores else ([], [])
# Use partitioning for efficiency if k is much smaller than N, or argsort otherwise
# Using argsort is generally simpler and fine for moderate N
top_indices = np.argsort(similarities)[-k_actual:][::-1].astype(int) # Get top k indices, sorted descending
top_scores = similarities[top_indices].tolist() # Get scores for these indices
# Return the most relevant case(s)
try:
# Retrieve cases safely using integer indices
retrieved_cases = [self.dataset[int(idx)] for idx in top_indices]
except IndexError as e:
print(f"Error retrieving cases using indices {top_indices}: {e}")
return [] if not return_scores else ([], [])
except Exception as e:
print(f"Unexpected error retrieving cases: {e}")
return [] if not return_scores else ([], [])
results_with_scores = list(zip(retrieved_cases, top_scores))
print(f"Retrieved {len(results_with_scores)} cases with similarity scores:")
for case, score in results_with_scores:
# Safely access presentation, provide default if missing
presentation = case.get('clinical_presentation', 'Unknown Presentation')
print(f"- {presentation}: {score:.4f}")
if return_scores:
return retrieved_cases, top_scores
else:
# Return list of tuples (case_dict, score)
return results_with_scores
class DummyRetriever:
"""A simple retriever that bypasses RAG, taking a pre-formatted DataFrame."""
def __init__(self, df):
self.dataset = []
if not isinstance(df, pd.DataFrame) or df.empty:
print("Warning: DummyRetriever initialized with empty or invalid DataFrame.")
return
# Expects df to be pre-processed with columns:
# 'clinical_presentation', 'turn_id', 'question', 'answer'
required_cols = ['clinical_presentation', 'turn_id', 'question', 'answer']
if not all(col in df.columns for col in required_cols):
print(f"Warning: DummyRetriever DataFrame missing required columns. Need: {required_cols}")
return
grouped = df.groupby('clinical_presentation')
print(f"DummyRetriever processing {len(grouped)} unique presentations.")
for i, (scenario, group) in enumerate(grouped):
group_sorted = group.sort_values('turn_id')
case_dict = {
"case_id": group_sorted['case_id'].iloc[0] if 'case_id' in group_sorted.columns else f"dummy_{i}",
"clinical_presentation": scenario,
"questions": group_sorted["question"].tolist(),
"answers": group_sorted["answer"].tolist()
}
self.dataset.append(case_dict)
print(f"DummyRetriever initialized with {len(self.dataset)} cases.")
def retrieve_relevant_case(self, scenario_query, top_k=1):
"""
Finds the case matching the query string exactly.
Ignores 'top_k' but mimics the return structure [(case_dict, score)].
"""
print(f"DummyRetriever searching for exact match: '{scenario_query}'")
for case_dict in self.dataset:
if case_dict["clinical_presentation"] == scenario_query:
print(f"DummyRetriever found match: {case_dict['clinical_presentation']}")
return [(case_dict, 1.0)]
print(f"DummyRetriever: No exact match found for '{scenario_query}'")
return [] |