import os import re import pandas as pd import numpy as np import torch from docx import Document from sentence_transformers import SentenceTransformer from datasets import Dataset from tqdm import tqdm def read_docx(file_path): """Reads text content from a .docx file.""" try: doc = Document(file_path) return '\n'.join(para.text for para in doc.paragraphs) except Exception as e: print(f"Error reading {file_path}: {e}") return "" def extract_qa_pairs(text): """Extracts alternating Examiner and Examinee Q&A pairs from text.""" pattern = re.compile(r"\*\*Examiner:\*\*(.*?)\n\n\*\*Examinee:\*\*(.*?)(?=\n\n\*\*Examiner:\*\*|$)", re.DOTALL) return [{"question": q.strip(), "answer": a.strip()} for q, a in pattern.findall(text)] def parse_filename(filename): """Parses case ID and topic from BTK filename format.""" # Example: BTK_-_77A___Burn.docx -> case_id = 77A, clinical_presentation = Burn base = os.path.splitext(filename)[0] match = re.match(r"BTK_-_(\d+[A-Z]?)___(.+)", base) if match: case_id = match.group(1) topic = match.group(2).replace("_", " ").strip() else: # Handle potential variations or log unknown formats if needed print(f"Warning: Could not parse filename format: {filename}") case_id, topic = "Unknown", "Unknown" return case_id, topic def process_all_cases(folder_path): """Reads all .docx files in a folder and structures them into a DataFrame.""" rows = [] if not os.path.isdir(folder_path): print(f"Error: Folder not found at {folder_path}") return pd.DataFrame(rows) print(f"Processing case files from: {folder_path}") for filename in os.listdir(folder_path): if filename.lower().endswith('.docx') and not filename.startswith('~'): # Avoid temp files file_path = os.path.join(folder_path, filename) text = read_docx(file_path) if text: qa_pairs = extract_qa_pairs(text) case_id, presentation = parse_filename(filename) if not qa_pairs: print(f"Warning: No Q&A pairs extracted from {filename}") for i, pair in enumerate(qa_pairs): rows.append({ "case_id": case_id, "clinical_presentation": presentation, "turn_id": i + 1, "question": pair["question"], "answer": pair["answer"] }) else: print(f"Warning: Empty content for file {filename}") if not rows: print("Warning: No data rows were generated. Check input files and formats.") return pd.DataFrame(rows) # --- ClinicalCaseProcessor Class --- class ClinicalCaseProcessor: """Handles preprocessing of clinical cases for the RAG system using sentence-transformers.""" def __init__(self, model_name="all-MiniLM-L6-v2"): print(f"Initializing ClinicalCaseProcessor with model: {model_name}") self.model = SentenceTransformer(model_name) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") self.model.to(self.device) def preprocess_data(self, input_data, output_path="./processed_clinical_cases", batch_size=16): """ Convert raw case data (DataFrame or path to CSV) into a vectorized Hugging Face dataset. Args: input_data: DataFrame or path to CSV file with clinical cases. output_path: Where to save the processed Hugging Face dataset. batch_size: Batch size for embedding generation. Returns: datasets.Dataset: The processed dataset with embeddings. """ # Load data if isinstance(input_data, pd.DataFrame): df = input_data print("Using provided DataFrame.") elif isinstance(input_data, str) and os.path.exists(input_data): try: df = pd.read_csv(input_data) print(f"Data loaded from CSV: {input_data}") except Exception as e: print(f"Error loading CSV {input_data}: {e}") return None else: print(f"Error: Invalid input_data type or path does not exist: {input_data}") return None if df.empty: print("Error: Input DataFrame is empty. Cannot process.") return None print(f"Raw data shape: {df.shape}") # Validate necessary columns required_cols = ['case_id', 'clinical_presentation', 'turn_id', 'question', 'answer'] if not all(col in df.columns for col in required_cols): print(f"Error: DataFrame missing required columns. Found: {df.columns}. Required: {required_cols}") return None # Group by case_id to get all Q&A pairs for each case grouped = df.groupby(['case_id', 'clinical_presentation'], dropna=False) # Create a new dataframe with one row per case case_data = [] print("Grouping data by case...") for (case_id, presentation), group in tqdm(grouped, desc="Processing Cases"): # Sort by turn_id to ensure correct order group = group.sort_values('turn_id') # Extract questions and answers in order questions = group['question'].tolist() answers = group['answer'].tolist() # Handle potential NaN/None in presentation if groupby didn't drop them presentation_str = str(presentation) if pd.notna(presentation) else "Unknown Presentation" case_data.append({ 'case_id': str(case_id) if pd.notna(case_id) else "Unknown ID", 'clinical_presentation': presentation_str, 'questions': questions, 'answers': answers }) if not case_data: print("Error: No cases could be processed after grouping. Check input data integrity.") return None processed_df = pd.DataFrame(case_data) print(f"Processed data into {len(processed_df)} unique cases.") # Create a searchable summary of each case (handle empty question lists) processed_df['case_summary'] = processed_df.apply( lambda x: f"Clinical case: {x['clinical_presentation']}. First question: {x['questions'][0] if x['questions'] else 'No questions available'}", axis=1 ) # Generate embeddings using sentence-transformers texts_to_embed = processed_df['case_summary'].tolist() all_embeddings = [] print(f"Generating embeddings for {len(texts_to_embed)} case summaries...") try: for i in tqdm(range(0, len(texts_to_embed), batch_size), desc="Embedding Batches"): batch_texts = texts_to_embed[i:i+batch_size] # Generate embeddings for the batch batch_embeddings = self.model.encode(batch_texts, convert_to_numpy=True, device=self.device, show_progress_bar=False) all_embeddings.append(batch_embeddings) # Combine all batch embeddings if not all_embeddings: print("Error: No embeddings were generated.") return None final_embeddings = np.vstack(all_embeddings) print(f"Generated embeddings with shape: {final_embeddings.shape}") except Exception as e: print(f"Error during embedding generation: {e}") return None # Convert to HF Dataset and add embeddings try: dataset = Dataset.from_pandas(processed_df) # Ensure embeddings column is compatible (list of lists) dataset = dataset.add_column('embeddings', final_embeddings.tolist()) except Exception as e: print(f"Error converting to Hugging Face Dataset or adding embeddings: {e}") return None # Save processed dataset try: os.makedirs(output_path, exist_ok=True) # Ensure directory exists dataset.save_to_disk(output_path) print(f"Processed dataset saved successfully to {output_path}") except Exception as e: print(f"Error saving dataset to disk at {output_path}: {e}") return None # Return None if saving failed return dataset