|
|
import numpy as np |
|
|
import torch |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
from datasets import load_from_disk, Dataset |
|
|
import os |
|
|
import pandas as pd |
|
|
|
|
|
class ClinicalCaseRetriever: |
|
|
"""Retrieves relevant clinical cases based on user input using sentence-transformers embeddings.""" |
|
|
|
|
|
def __init__(self, dataset_path='./processed_clinical_cases', model_name="all-MiniLM-L6-v2"): |
|
|
print(f"Initializing ClinicalCaseRetriever with model: {model_name}") |
|
|
if isinstance(dataset_path, Dataset): |
|
|
self.dataset = dataset_path |
|
|
print("Using provided Hugging Face Dataset object.") |
|
|
elif isinstance(dataset_path, str) and os.path.isdir(dataset_path): |
|
|
try: |
|
|
self.dataset = load_from_disk(dataset_path) |
|
|
print(f"Dataset loaded successfully from disk: {dataset_path}") |
|
|
except Exception as e: |
|
|
print(f"Error loading dataset from disk {dataset_path}: {e}") |
|
|
raise ValueError("Failed to load dataset.") from e |
|
|
else: |
|
|
raise ValueError(f"Invalid dataset_path: Must be a Dataset object or a valid directory path. Got: {dataset_path}") |
|
|
|
|
|
if 'embeddings' not in self.dataset.column_names: |
|
|
raise ValueError("Dataset must contain an 'embeddings' column.") |
|
|
|
|
|
print(f"Dataset features: {self.dataset.features}") |
|
|
print(f"Number of cases in dataset: {len(self.dataset)}") |
|
|
|
|
|
|
|
|
self.model = SentenceTransformer(model_name) |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {self.device}") |
|
|
self.model.to(self.device) |
|
|
|
|
|
try: |
|
|
|
|
|
self.case_embeddings = np.array(self.dataset['embeddings']) |
|
|
if self.case_embeddings.ndim != 2: |
|
|
raise ValueError(f"Embeddings array must be 2-dimensional. Got shape: {self.case_embeddings.shape}") |
|
|
print(f"Loaded {len(self.dataset)} cases with embeddings of shape {self.case_embeddings.shape}") |
|
|
except Exception as e: |
|
|
print(f"Error processing embeddings from dataset: {e}") |
|
|
raise ValueError("Failed to load or process embeddings.") from e |
|
|
|
|
|
|
|
|
def get_available_cases(self, n=5): |
|
|
"""Return a sample of available cases for user selection.""" |
|
|
num_cases = len(self.dataset) |
|
|
if num_cases == 0: |
|
|
return [] |
|
|
sample_size = min(n, num_cases) |
|
|
indices = np.random.choice(num_cases, sample_size, replace=False) |
|
|
|
|
|
return [(int(i), self.dataset[int(i)]['clinical_presentation']) for i in indices] |
|
|
|
|
|
def encode_query(self, query): |
|
|
"""Generate embedding for a query string.""" |
|
|
|
|
|
search_query = f"Clinical case about {query}" |
|
|
print(f"Encoding query: '{search_query}'") |
|
|
|
|
|
|
|
|
try: |
|
|
query_embedding = self.model.encode([search_query], convert_to_numpy=True, device=self.device, show_progress_bar=False) |
|
|
return query_embedding |
|
|
except Exception as e: |
|
|
print(f"Error encoding query '{query}': {e}") |
|
|
return None |
|
|
|
|
|
def retrieve_relevant_case(self, query, top_k=1, return_scores=False): |
|
|
"""Find the most relevant clinical case(s) given a query.""" |
|
|
if not isinstance(top_k, int) or top_k < 1: |
|
|
print("Warning: top_k must be a positive integer. Defaulting to 1.") |
|
|
top_k = 1 |
|
|
|
|
|
|
|
|
query_embedding = self.encode_query(query) |
|
|
if query_embedding is None: |
|
|
return [] if not return_scores else ([], []) |
|
|
|
|
|
|
|
|
try: |
|
|
similarities = cosine_similarity(query_embedding, self.case_embeddings)[0] |
|
|
except Exception as e: |
|
|
print(f"Error calculating cosine similarity: {e}") |
|
|
return [] if not return_scores else ([], []) |
|
|
|
|
|
|
|
|
|
|
|
k_actual = min(top_k, len(similarities)) |
|
|
if k_actual == 0: |
|
|
return [] if not return_scores else ([], []) |
|
|
|
|
|
|
|
|
|
|
|
top_indices = np.argsort(similarities)[-k_actual:][::-1].astype(int) |
|
|
|
|
|
top_scores = similarities[top_indices].tolist() |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
retrieved_cases = [self.dataset[int(idx)] for idx in top_indices] |
|
|
except IndexError as e: |
|
|
print(f"Error retrieving cases using indices {top_indices}: {e}") |
|
|
return [] if not return_scores else ([], []) |
|
|
except Exception as e: |
|
|
print(f"Unexpected error retrieving cases: {e}") |
|
|
return [] if not return_scores else ([], []) |
|
|
|
|
|
|
|
|
results_with_scores = list(zip(retrieved_cases, top_scores)) |
|
|
print(f"Retrieved {len(results_with_scores)} cases with similarity scores:") |
|
|
for case, score in results_with_scores: |
|
|
|
|
|
presentation = case.get('clinical_presentation', 'Unknown Presentation') |
|
|
print(f"- {presentation}: {score:.4f}") |
|
|
|
|
|
if return_scores: |
|
|
return retrieved_cases, top_scores |
|
|
else: |
|
|
|
|
|
return results_with_scores |
|
|
|
|
|
|
|
|
class DummyRetriever: |
|
|
"""A simple retriever that bypasses RAG, taking a pre-formatted DataFrame.""" |
|
|
|
|
|
def __init__(self, df): |
|
|
self.dataset = [] |
|
|
if not isinstance(df, pd.DataFrame) or df.empty: |
|
|
print("Warning: DummyRetriever initialized with empty or invalid DataFrame.") |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
required_cols = ['clinical_presentation', 'turn_id', 'question', 'answer'] |
|
|
if not all(col in df.columns for col in required_cols): |
|
|
print(f"Warning: DummyRetriever DataFrame missing required columns. Need: {required_cols}") |
|
|
return |
|
|
|
|
|
grouped = df.groupby('clinical_presentation') |
|
|
print(f"DummyRetriever processing {len(grouped)} unique presentations.") |
|
|
for i, (scenario, group) in enumerate(grouped): |
|
|
group_sorted = group.sort_values('turn_id') |
|
|
|
|
|
case_dict = { |
|
|
"case_id": group_sorted['case_id'].iloc[0] if 'case_id' in group_sorted.columns else f"dummy_{i}", |
|
|
"clinical_presentation": scenario, |
|
|
"questions": group_sorted["question"].tolist(), |
|
|
"answers": group_sorted["answer"].tolist() |
|
|
} |
|
|
self.dataset.append(case_dict) |
|
|
print(f"DummyRetriever initialized with {len(self.dataset)} cases.") |
|
|
|
|
|
def retrieve_relevant_case(self, scenario_query, top_k=1): |
|
|
""" |
|
|
Finds the case matching the query string exactly. |
|
|
Ignores 'top_k' but mimics the return structure [(case_dict, score)]. |
|
|
""" |
|
|
print(f"DummyRetriever searching for exact match: '{scenario_query}'") |
|
|
for case_dict in self.dataset: |
|
|
if case_dict["clinical_presentation"] == scenario_query: |
|
|
print(f"DummyRetriever found match: {case_dict['clinical_presentation']}") |
|
|
return [(case_dict, 1.0)] |
|
|
|
|
|
print(f"DummyRetriever: No exact match found for '{scenario_query}'") |
|
|
return [] |