File size: 8,502 Bytes
2247e66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import os
import re
import pandas as pd
import numpy as np
import torch
from docx import Document
from sentence_transformers import SentenceTransformer
from datasets import Dataset
from tqdm import tqdm
def read_docx(file_path):
"""Reads text content from a .docx file."""
try:
doc = Document(file_path)
return '\n'.join(para.text for para in doc.paragraphs)
except Exception as e:
print(f"Error reading {file_path}: {e}")
return ""
def extract_qa_pairs(text):
"""Extracts alternating Examiner and Examinee Q&A pairs from text."""
pattern = re.compile(r"\*\*Examiner:\*\*(.*?)\n\n\*\*Examinee:\*\*(.*?)(?=\n\n\*\*Examiner:\*\*|$)", re.DOTALL)
return [{"question": q.strip(), "answer": a.strip()} for q, a in pattern.findall(text)]
def parse_filename(filename):
"""Parses case ID and topic from BTK filename format."""
# Example: BTK_-_77A___Burn.docx -> case_id = 77A, clinical_presentation = Burn
base = os.path.splitext(filename)[0]
match = re.match(r"BTK_-_(\d+[A-Z]?)___(.+)", base)
if match:
case_id = match.group(1)
topic = match.group(2).replace("_", " ").strip()
else:
# Handle potential variations or log unknown formats if needed
print(f"Warning: Could not parse filename format: {filename}")
case_id, topic = "Unknown", "Unknown"
return case_id, topic
def process_all_cases(folder_path):
"""Reads all .docx files in a folder and structures them into a DataFrame."""
rows = []
if not os.path.isdir(folder_path):
print(f"Error: Folder not found at {folder_path}")
return pd.DataFrame(rows)
print(f"Processing case files from: {folder_path}")
for filename in os.listdir(folder_path):
if filename.lower().endswith('.docx') and not filename.startswith('~'): # Avoid temp files
file_path = os.path.join(folder_path, filename)
text = read_docx(file_path)
if text:
qa_pairs = extract_qa_pairs(text)
case_id, presentation = parse_filename(filename)
if not qa_pairs:
print(f"Warning: No Q&A pairs extracted from {filename}")
for i, pair in enumerate(qa_pairs):
rows.append({
"case_id": case_id,
"clinical_presentation": presentation,
"turn_id": i + 1,
"question": pair["question"],
"answer": pair["answer"]
})
else:
print(f"Warning: Empty content for file {filename}")
if not rows:
print("Warning: No data rows were generated. Check input files and formats.")
return pd.DataFrame(rows)
# --- ClinicalCaseProcessor Class ---
class ClinicalCaseProcessor:
"""Handles preprocessing of clinical cases for the RAG system using sentence-transformers."""
def __init__(self, model_name="all-MiniLM-L6-v2"):
print(f"Initializing ClinicalCaseProcessor with model: {model_name}")
self.model = SentenceTransformer(model_name)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
self.model.to(self.device)
def preprocess_data(self, input_data, output_path="./processed_clinical_cases", batch_size=16):
"""
Convert raw case data (DataFrame or path to CSV) into a vectorized Hugging Face dataset.
Args:
input_data: DataFrame or path to CSV file with clinical cases.
output_path: Where to save the processed Hugging Face dataset.
batch_size: Batch size for embedding generation.
Returns:
datasets.Dataset: The processed dataset with embeddings.
"""
# Load data
if isinstance(input_data, pd.DataFrame):
df = input_data
print("Using provided DataFrame.")
elif isinstance(input_data, str) and os.path.exists(input_data):
try:
df = pd.read_csv(input_data)
print(f"Data loaded from CSV: {input_data}")
except Exception as e:
print(f"Error loading CSV {input_data}: {e}")
return None
else:
print(f"Error: Invalid input_data type or path does not exist: {input_data}")
return None
if df.empty:
print("Error: Input DataFrame is empty. Cannot process.")
return None
print(f"Raw data shape: {df.shape}")
# Validate necessary columns
required_cols = ['case_id', 'clinical_presentation', 'turn_id', 'question', 'answer']
if not all(col in df.columns for col in required_cols):
print(f"Error: DataFrame missing required columns. Found: {df.columns}. Required: {required_cols}")
return None
# Group by case_id to get all Q&A pairs for each case
grouped = df.groupby(['case_id', 'clinical_presentation'], dropna=False)
# Create a new dataframe with one row per case
case_data = []
print("Grouping data by case...")
for (case_id, presentation), group in tqdm(grouped, desc="Processing Cases"):
# Sort by turn_id to ensure correct order
group = group.sort_values('turn_id')
# Extract questions and answers in order
questions = group['question'].tolist()
answers = group['answer'].tolist()
# Handle potential NaN/None in presentation if groupby didn't drop them
presentation_str = str(presentation) if pd.notna(presentation) else "Unknown Presentation"
case_data.append({
'case_id': str(case_id) if pd.notna(case_id) else "Unknown ID",
'clinical_presentation': presentation_str,
'questions': questions,
'answers': answers
})
if not case_data:
print("Error: No cases could be processed after grouping. Check input data integrity.")
return None
processed_df = pd.DataFrame(case_data)
print(f"Processed data into {len(processed_df)} unique cases.")
# Create a searchable summary of each case (handle empty question lists)
processed_df['case_summary'] = processed_df.apply(
lambda x: f"Clinical case: {x['clinical_presentation']}. First question: {x['questions'][0] if x['questions'] else 'No questions available'}",
axis=1
)
# Generate embeddings using sentence-transformers
texts_to_embed = processed_df['case_summary'].tolist()
all_embeddings = []
print(f"Generating embeddings for {len(texts_to_embed)} case summaries...")
try:
for i in tqdm(range(0, len(texts_to_embed), batch_size), desc="Embedding Batches"):
batch_texts = texts_to_embed[i:i+batch_size]
# Generate embeddings for the batch
batch_embeddings = self.model.encode(batch_texts, convert_to_numpy=True, device=self.device, show_progress_bar=False)
all_embeddings.append(batch_embeddings)
# Combine all batch embeddings
if not all_embeddings:
print("Error: No embeddings were generated.")
return None
final_embeddings = np.vstack(all_embeddings)
print(f"Generated embeddings with shape: {final_embeddings.shape}")
except Exception as e:
print(f"Error during embedding generation: {e}")
return None
# Convert to HF Dataset and add embeddings
try:
dataset = Dataset.from_pandas(processed_df)
# Ensure embeddings column is compatible (list of lists)
dataset = dataset.add_column('embeddings', final_embeddings.tolist())
except Exception as e:
print(f"Error converting to Hugging Face Dataset or adding embeddings: {e}")
return None
# Save processed dataset
try:
os.makedirs(output_path, exist_ok=True) # Ensure directory exists
dataset.save_to_disk(output_path)
print(f"Processed dataset saved successfully to {output_path}")
except Exception as e:
print(f"Error saving dataset to disk at {output_path}: {e}")
return None # Return None if saving failed
return dataset |