import numpy as np import torch from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity from datasets import load_from_disk, Dataset import os import pandas as pd class ClinicalCaseRetriever: """Retrieves relevant clinical cases based on user input using sentence-transformers embeddings.""" def __init__(self, dataset_path='./processed_clinical_cases', model_name="all-MiniLM-L6-v2"): print(f"Initializing ClinicalCaseRetriever with model: {model_name}") if isinstance(dataset_path, Dataset): self.dataset = dataset_path print("Using provided Hugging Face Dataset object.") elif isinstance(dataset_path, str) and os.path.isdir(dataset_path): try: self.dataset = load_from_disk(dataset_path) print(f"Dataset loaded successfully from disk: {dataset_path}") except Exception as e: print(f"Error loading dataset from disk {dataset_path}: {e}") raise ValueError("Failed to load dataset.") from e else: raise ValueError(f"Invalid dataset_path: Must be a Dataset object or a valid directory path. Got: {dataset_path}") if 'embeddings' not in self.dataset.column_names: raise ValueError("Dataset must contain an 'embeddings' column.") print(f"Dataset features: {self.dataset.features}") print(f"Number of cases in dataset: {len(self.dataset)}") self.model = SentenceTransformer(model_name) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") self.model.to(self.device) try: # Ensure embeddings are loaded as a NumPy array self.case_embeddings = np.array(self.dataset['embeddings']) if self.case_embeddings.ndim != 2: raise ValueError(f"Embeddings array must be 2-dimensional. Got shape: {self.case_embeddings.shape}") print(f"Loaded {len(self.dataset)} cases with embeddings of shape {self.case_embeddings.shape}") except Exception as e: print(f"Error processing embeddings from dataset: {e}") raise ValueError("Failed to load or process embeddings.") from e def get_available_cases(self, n=5): """Return a sample of available cases for user selection.""" num_cases = len(self.dataset) if num_cases == 0: return [] sample_size = min(n, num_cases) indices = np.random.choice(num_cases, sample_size, replace=False) # Ensure indices are int for slicing dataset return [(int(i), self.dataset[int(i)]['clinical_presentation']) for i in indices] def encode_query(self, query): """Generate embedding for a query string.""" # Create a better search query structure search_query = f"Clinical case about {query}" print(f"Encoding query: '{search_query}'") # Generate embedding using sentence-transformers try: query_embedding = self.model.encode([search_query], convert_to_numpy=True, device=self.device, show_progress_bar=False) return query_embedding except Exception as e: print(f"Error encoding query '{query}': {e}") return None # Or raise an error def retrieve_relevant_case(self, query, top_k=1, return_scores=False): """Find the most relevant clinical case(s) given a query.""" if not isinstance(top_k, int) or top_k < 1: print("Warning: top_k must be a positive integer. Defaulting to 1.") top_k = 1 # Get query embedding query_embedding = self.encode_query(query) if query_embedding is None: return [] if not return_scores else ([], []) # Calculate similarity scores try: similarities = cosine_similarity(query_embedding, self.case_embeddings)[0] # Get the single row of similarities except Exception as e: print(f"Error calculating cosine similarity: {e}") return [] if not return_scores else ([], []) # Get indices of top-k most similar cases # Ensure we don't request more indices than available cases k_actual = min(top_k, len(similarities)) if k_actual == 0: # Should not happen if dataset loaded, but safe check return [] if not return_scores else ([], []) # Use partitioning for efficiency if k is much smaller than N, or argsort otherwise # Using argsort is generally simpler and fine for moderate N top_indices = np.argsort(similarities)[-k_actual:][::-1].astype(int) # Get top k indices, sorted descending top_scores = similarities[top_indices].tolist() # Get scores for these indices # Return the most relevant case(s) try: # Retrieve cases safely using integer indices retrieved_cases = [self.dataset[int(idx)] for idx in top_indices] except IndexError as e: print(f"Error retrieving cases using indices {top_indices}: {e}") return [] if not return_scores else ([], []) except Exception as e: print(f"Unexpected error retrieving cases: {e}") return [] if not return_scores else ([], []) results_with_scores = list(zip(retrieved_cases, top_scores)) print(f"Retrieved {len(results_with_scores)} cases with similarity scores:") for case, score in results_with_scores: # Safely access presentation, provide default if missing presentation = case.get('clinical_presentation', 'Unknown Presentation') print(f"- {presentation}: {score:.4f}") if return_scores: return retrieved_cases, top_scores else: # Return list of tuples (case_dict, score) return results_with_scores class DummyRetriever: """A simple retriever that bypasses RAG, taking a pre-formatted DataFrame.""" def __init__(self, df): self.dataset = [] if not isinstance(df, pd.DataFrame) or df.empty: print("Warning: DummyRetriever initialized with empty or invalid DataFrame.") return # Expects df to be pre-processed with columns: # 'clinical_presentation', 'turn_id', 'question', 'answer' required_cols = ['clinical_presentation', 'turn_id', 'question', 'answer'] if not all(col in df.columns for col in required_cols): print(f"Warning: DummyRetriever DataFrame missing required columns. Need: {required_cols}") return grouped = df.groupby('clinical_presentation') print(f"DummyRetriever processing {len(grouped)} unique presentations.") for i, (scenario, group) in enumerate(grouped): group_sorted = group.sort_values('turn_id') case_dict = { "case_id": group_sorted['case_id'].iloc[0] if 'case_id' in group_sorted.columns else f"dummy_{i}", "clinical_presentation": scenario, "questions": group_sorted["question"].tolist(), "answers": group_sorted["answer"].tolist() } self.dataset.append(case_dict) print(f"DummyRetriever initialized with {len(self.dataset)} cases.") def retrieve_relevant_case(self, scenario_query, top_k=1): """ Finds the case matching the query string exactly. Ignores 'top_k' but mimics the return structure [(case_dict, score)]. """ print(f"DummyRetriever searching for exact match: '{scenario_query}'") for case_dict in self.dataset: if case_dict["clinical_presentation"] == scenario_query: print(f"DummyRetriever found match: {case_dict['clinical_presentation']}") return [(case_dict, 1.0)] print(f"DummyRetriever: No exact match found for '{scenario_query}'") return []