File size: 8,030 Bytes
129641e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from datasets import load_from_disk, Dataset
import os
import pandas as pd

class ClinicalCaseRetriever:
    """Retrieves relevant clinical cases based on user input using sentence-transformers embeddings."""

    def __init__(self, dataset_path='./processed_clinical_cases', model_name="all-MiniLM-L6-v2"):
        print(f"Initializing ClinicalCaseRetriever with model: {model_name}")
        if isinstance(dataset_path, Dataset):
            self.dataset = dataset_path
            print("Using provided Hugging Face Dataset object.")
        elif isinstance(dataset_path, str) and os.path.isdir(dataset_path):
             try:
                self.dataset = load_from_disk(dataset_path)
                print(f"Dataset loaded successfully from disk: {dataset_path}")
             except Exception as e:
                 print(f"Error loading dataset from disk {dataset_path}: {e}")
                 raise ValueError("Failed to load dataset.") from e
        else:
            raise ValueError(f"Invalid dataset_path: Must be a Dataset object or a valid directory path. Got: {dataset_path}")

        if 'embeddings' not in self.dataset.column_names:
             raise ValueError("Dataset must contain an 'embeddings' column.")

        print(f"Dataset features: {self.dataset.features}")
        print(f"Number of cases in dataset: {len(self.dataset)}")


        self.model = SentenceTransformer(model_name)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        self.model.to(self.device)

        try:
            # Ensure embeddings are loaded as a NumPy array
            self.case_embeddings = np.array(self.dataset['embeddings'])
            if self.case_embeddings.ndim != 2:
                 raise ValueError(f"Embeddings array must be 2-dimensional. Got shape: {self.case_embeddings.shape}")
            print(f"Loaded {len(self.dataset)} cases with embeddings of shape {self.case_embeddings.shape}")
        except Exception as e:
            print(f"Error processing embeddings from dataset: {e}")
            raise ValueError("Failed to load or process embeddings.") from e


    def get_available_cases(self, n=5):
        """Return a sample of available cases for user selection."""
        num_cases = len(self.dataset)
        if num_cases == 0:
            return []
        sample_size = min(n, num_cases)
        indices = np.random.choice(num_cases, sample_size, replace=False)
        # Ensure indices are int for slicing dataset
        return [(int(i), self.dataset[int(i)]['clinical_presentation']) for i in indices]

    def encode_query(self, query):
        """Generate embedding for a query string."""
        # Create a better search query structure
        search_query = f"Clinical case about {query}"
        print(f"Encoding query: '{search_query}'")

        # Generate embedding using sentence-transformers
        try:
            query_embedding = self.model.encode([search_query], convert_to_numpy=True, device=self.device, show_progress_bar=False)
            return query_embedding
        except Exception as e:
             print(f"Error encoding query '{query}': {e}")
             return None # Or raise an error

    def retrieve_relevant_case(self, query, top_k=1, return_scores=False):
        """Find the most relevant clinical case(s) given a query."""
        if not isinstance(top_k, int) or top_k < 1:
             print("Warning: top_k must be a positive integer. Defaulting to 1.")
             top_k = 1

        # Get query embedding
        query_embedding = self.encode_query(query)
        if query_embedding is None:
             return [] if not return_scores else ([], [])

        # Calculate similarity scores
        try:
            similarities = cosine_similarity(query_embedding, self.case_embeddings)[0] # Get the single row of similarities
        except Exception as e:
            print(f"Error calculating cosine similarity: {e}")
            return [] if not return_scores else ([], [])

        # Get indices of top-k most similar cases
        # Ensure we don't request more indices than available cases
        k_actual = min(top_k, len(similarities))
        if k_actual == 0: # Should not happen if dataset loaded, but safe check
             return [] if not return_scores else ([], [])

        # Use partitioning for efficiency if k is much smaller than N, or argsort otherwise
        # Using argsort is generally simpler and fine for moderate N
        top_indices = np.argsort(similarities)[-k_actual:][::-1].astype(int) # Get top k indices, sorted descending

        top_scores = similarities[top_indices].tolist() # Get scores for these indices

        # Return the most relevant case(s)
        try:
            # Retrieve cases safely using integer indices
            retrieved_cases = [self.dataset[int(idx)] for idx in top_indices]
        except IndexError as e:
             print(f"Error retrieving cases using indices {top_indices}: {e}")
             return [] if not return_scores else ([], [])
        except Exception as e:
              print(f"Unexpected error retrieving cases: {e}")
              return [] if not return_scores else ([], [])


        results_with_scores = list(zip(retrieved_cases, top_scores))
        print(f"Retrieved {len(results_with_scores)} cases with similarity scores:")
        for case, score in results_with_scores:
            # Safely access presentation, provide default if missing
            presentation = case.get('clinical_presentation', 'Unknown Presentation')
            print(f"- {presentation}: {score:.4f}")

        if return_scores:
            return retrieved_cases, top_scores
        else:
            # Return list of tuples (case_dict, score)
            return results_with_scores


class DummyRetriever:
    """A simple retriever that bypasses RAG, taking a pre-formatted DataFrame."""

    def __init__(self, df):
        self.dataset = []
        if not isinstance(df, pd.DataFrame) or df.empty:
             print("Warning: DummyRetriever initialized with empty or invalid DataFrame.")
             return

        # Expects df to be pre-processed with columns:
        # 'clinical_presentation', 'turn_id', 'question', 'answer'
        required_cols = ['clinical_presentation', 'turn_id', 'question', 'answer']
        if not all(col in df.columns for col in required_cols):
            print(f"Warning: DummyRetriever DataFrame missing required columns. Need: {required_cols}")
            return

        grouped = df.groupby('clinical_presentation')
        print(f"DummyRetriever processing {len(grouped)} unique presentations.")
        for i, (scenario, group) in enumerate(grouped):
            group_sorted = group.sort_values('turn_id')

            case_dict = {
                "case_id": group_sorted['case_id'].iloc[0] if 'case_id' in group_sorted.columns else f"dummy_{i}",
                "clinical_presentation": scenario,
                "questions": group_sorted["question"].tolist(),
                "answers": group_sorted["answer"].tolist()
            }
            self.dataset.append(case_dict)
        print(f"DummyRetriever initialized with {len(self.dataset)} cases.")

    def retrieve_relevant_case(self, scenario_query, top_k=1):
        """
        Finds the case matching the query string exactly.
        Ignores 'top_k' but mimics the return structure [(case_dict, score)].
        """
        print(f"DummyRetriever searching for exact match: '{scenario_query}'")
        for case_dict in self.dataset:
            if case_dict["clinical_presentation"] == scenario_query:
                print(f"DummyRetriever found match: {case_dict['clinical_presentation']}")
                return [(case_dict, 1.0)]

        print(f"DummyRetriever: No exact match found for '{scenario_query}'")
        return []