File size: 8,502 Bytes
129641e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import os
import re
import pandas as pd
import numpy as np
import torch
from docx import Document
from sentence_transformers import SentenceTransformer
from datasets import Dataset
from tqdm import tqdm

def read_docx(file_path):
    """Reads text content from a .docx file."""
    try:
        doc = Document(file_path)
        return '\n'.join(para.text for para in doc.paragraphs)
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return ""

def extract_qa_pairs(text):
    """Extracts alternating Examiner and Examinee Q&A pairs from text."""
    pattern = re.compile(r"\*\*Examiner:\*\*(.*?)\n\n\*\*Examinee:\*\*(.*?)(?=\n\n\*\*Examiner:\*\*|$)", re.DOTALL)
    return [{"question": q.strip(), "answer": a.strip()} for q, a in pattern.findall(text)]

def parse_filename(filename):
    """Parses case ID and topic from BTK filename format."""
    # Example: BTK_-_77A___Burn.docx -> case_id = 77A, clinical_presentation = Burn
    base = os.path.splitext(filename)[0]
    match = re.match(r"BTK_-_(\d+[A-Z]?)___(.+)", base)
    if match:
        case_id = match.group(1)
        topic = match.group(2).replace("_", " ").strip()
    else:
        # Handle potential variations or log unknown formats if needed
        print(f"Warning: Could not parse filename format: {filename}")
        case_id, topic = "Unknown", "Unknown"
    return case_id, topic

def process_all_cases(folder_path):
    """Reads all .docx files in a folder and structures them into a DataFrame."""
    rows = []
    if not os.path.isdir(folder_path):
        print(f"Error: Folder not found at {folder_path}")
        return pd.DataFrame(rows)

    print(f"Processing case files from: {folder_path}")
    for filename in os.listdir(folder_path):
        if filename.lower().endswith('.docx') and not filename.startswith('~'): # Avoid temp files
            file_path = os.path.join(folder_path, filename)
            text = read_docx(file_path)
            if text:
                qa_pairs = extract_qa_pairs(text)
                case_id, presentation = parse_filename(filename)
                if not qa_pairs:
                     print(f"Warning: No Q&A pairs extracted from {filename}")
                for i, pair in enumerate(qa_pairs):
                    rows.append({
                        "case_id": case_id,
                        "clinical_presentation": presentation,
                        "turn_id": i + 1,
                        "question": pair["question"],
                        "answer": pair["answer"]
                    })
            else:
                print(f"Warning: Empty content for file {filename}")

    if not rows:
         print("Warning: No data rows were generated. Check input files and formats.")

    return pd.DataFrame(rows)


# --- ClinicalCaseProcessor Class ---
class ClinicalCaseProcessor:
    """Handles preprocessing of clinical cases for the RAG system using sentence-transformers."""
    def __init__(self, model_name="all-MiniLM-L6-v2"):
        print(f"Initializing ClinicalCaseProcessor with model: {model_name}")
        self.model = SentenceTransformer(model_name)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        self.model.to(self.device)

    def preprocess_data(self, input_data, output_path="./processed_clinical_cases", batch_size=16):
        """
        Convert raw case data (DataFrame or path to CSV) into a vectorized Hugging Face dataset.

        Args:
            input_data: DataFrame or path to CSV file with clinical cases.
            output_path: Where to save the processed Hugging Face dataset.
            batch_size: Batch size for embedding generation.

        Returns:
            datasets.Dataset: The processed dataset with embeddings.
        """
        # Load data
        if isinstance(input_data, pd.DataFrame):
            df = input_data
            print("Using provided DataFrame.")
        elif isinstance(input_data, str) and os.path.exists(input_data):
            try:
                df = pd.read_csv(input_data)
                print(f"Data loaded from CSV: {input_data}")
            except Exception as e:
                print(f"Error loading CSV {input_data}: {e}")
                return None
        else:
             print(f"Error: Invalid input_data type or path does not exist: {input_data}")
             return None

        if df.empty:
            print("Error: Input DataFrame is empty. Cannot process.")
            return None

        print(f"Raw data shape: {df.shape}")

        # Validate necessary columns
        required_cols = ['case_id', 'clinical_presentation', 'turn_id', 'question', 'answer']
        if not all(col in df.columns for col in required_cols):
            print(f"Error: DataFrame missing required columns. Found: {df.columns}. Required: {required_cols}")
            return None

        # Group by case_id to get all Q&A pairs for each case
        grouped = df.groupby(['case_id', 'clinical_presentation'], dropna=False)

        # Create a new dataframe with one row per case
        case_data = []
        print("Grouping data by case...")
        for (case_id, presentation), group in tqdm(grouped, desc="Processing Cases"):
            # Sort by turn_id to ensure correct order
            group = group.sort_values('turn_id')

            # Extract questions and answers in order
            questions = group['question'].tolist()
            answers = group['answer'].tolist()

            # Handle potential NaN/None in presentation if groupby didn't drop them
            presentation_str = str(presentation) if pd.notna(presentation) else "Unknown Presentation"

            case_data.append({
                'case_id': str(case_id) if pd.notna(case_id) else "Unknown ID",
                'clinical_presentation': presentation_str,
                'questions': questions,
                'answers': answers
            })

        if not case_data:
            print("Error: No cases could be processed after grouping. Check input data integrity.")
            return None

        processed_df = pd.DataFrame(case_data)
        print(f"Processed data into {len(processed_df)} unique cases.")

        # Create a searchable summary of each case (handle empty question lists)
        processed_df['case_summary'] = processed_df.apply(
            lambda x: f"Clinical case: {x['clinical_presentation']}. First question: {x['questions'][0] if x['questions'] else 'No questions available'}",
            axis=1
        )

        # Generate embeddings using sentence-transformers
        texts_to_embed = processed_df['case_summary'].tolist()
        all_embeddings = []

        print(f"Generating embeddings for {len(texts_to_embed)} case summaries...")
        try:
            for i in tqdm(range(0, len(texts_to_embed), batch_size), desc="Embedding Batches"):
                batch_texts = texts_to_embed[i:i+batch_size]
                # Generate embeddings for the batch
                batch_embeddings = self.model.encode(batch_texts, convert_to_numpy=True, device=self.device, show_progress_bar=False)
                all_embeddings.append(batch_embeddings)

            # Combine all batch embeddings
            if not all_embeddings:
                 print("Error: No embeddings were generated.")
                 return None
            final_embeddings = np.vstack(all_embeddings)
            print(f"Generated embeddings with shape: {final_embeddings.shape}")

        except Exception as e:
            print(f"Error during embedding generation: {e}")
            return None


        # Convert to HF Dataset and add embeddings
        try:
            dataset = Dataset.from_pandas(processed_df)
            # Ensure embeddings column is compatible (list of lists)
            dataset = dataset.add_column('embeddings', final_embeddings.tolist())
        except Exception as e:
             print(f"Error converting to Hugging Face Dataset or adding embeddings: {e}")
             return None

        # Save processed dataset
        try:
            os.makedirs(output_path, exist_ok=True) # Ensure directory exists
            dataset.save_to_disk(output_path)
            print(f"Processed dataset saved successfully to {output_path}")
        except Exception as e:
             print(f"Error saving dataset to disk at {output_path}: {e}")
             return None # Return None if saving failed

        return dataset