|
|
import os |
|
|
import re |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import torch |
|
|
from docx import Document |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from datasets import Dataset |
|
|
from tqdm import tqdm |
|
|
|
|
|
def read_docx(file_path): |
|
|
"""Reads text content from a .docx file.""" |
|
|
try: |
|
|
doc = Document(file_path) |
|
|
return '\n'.join(para.text for para in doc.paragraphs) |
|
|
except Exception as e: |
|
|
print(f"Error reading {file_path}: {e}") |
|
|
return "" |
|
|
|
|
|
def extract_qa_pairs(text): |
|
|
"""Extracts alternating Examiner and Examinee Q&A pairs from text.""" |
|
|
pattern = re.compile(r"\*\*Examiner:\*\*(.*?)\n\n\*\*Examinee:\*\*(.*?)(?=\n\n\*\*Examiner:\*\*|$)", re.DOTALL) |
|
|
return [{"question": q.strip(), "answer": a.strip()} for q, a in pattern.findall(text)] |
|
|
|
|
|
def parse_filename(filename): |
|
|
"""Parses case ID and topic from BTK filename format.""" |
|
|
|
|
|
base = os.path.splitext(filename)[0] |
|
|
match = re.match(r"BTK_-_(\d+[A-Z]?)___(.+)", base) |
|
|
if match: |
|
|
case_id = match.group(1) |
|
|
topic = match.group(2).replace("_", " ").strip() |
|
|
else: |
|
|
|
|
|
print(f"Warning: Could not parse filename format: {filename}") |
|
|
case_id, topic = "Unknown", "Unknown" |
|
|
return case_id, topic |
|
|
|
|
|
def process_all_cases(folder_path): |
|
|
"""Reads all .docx files in a folder and structures them into a DataFrame.""" |
|
|
rows = [] |
|
|
if not os.path.isdir(folder_path): |
|
|
print(f"Error: Folder not found at {folder_path}") |
|
|
return pd.DataFrame(rows) |
|
|
|
|
|
print(f"Processing case files from: {folder_path}") |
|
|
for filename in os.listdir(folder_path): |
|
|
if filename.lower().endswith('.docx') and not filename.startswith('~'): |
|
|
file_path = os.path.join(folder_path, filename) |
|
|
text = read_docx(file_path) |
|
|
if text: |
|
|
qa_pairs = extract_qa_pairs(text) |
|
|
case_id, presentation = parse_filename(filename) |
|
|
if not qa_pairs: |
|
|
print(f"Warning: No Q&A pairs extracted from {filename}") |
|
|
for i, pair in enumerate(qa_pairs): |
|
|
rows.append({ |
|
|
"case_id": case_id, |
|
|
"clinical_presentation": presentation, |
|
|
"turn_id": i + 1, |
|
|
"question": pair["question"], |
|
|
"answer": pair["answer"] |
|
|
}) |
|
|
else: |
|
|
print(f"Warning: Empty content for file {filename}") |
|
|
|
|
|
if not rows: |
|
|
print("Warning: No data rows were generated. Check input files and formats.") |
|
|
|
|
|
return pd.DataFrame(rows) |
|
|
|
|
|
|
|
|
|
|
|
class ClinicalCaseProcessor: |
|
|
"""Handles preprocessing of clinical cases for the RAG system using sentence-transformers.""" |
|
|
def __init__(self, model_name="all-MiniLM-L6-v2"): |
|
|
print(f"Initializing ClinicalCaseProcessor with model: {model_name}") |
|
|
self.model = SentenceTransformer(model_name) |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {self.device}") |
|
|
self.model.to(self.device) |
|
|
|
|
|
def preprocess_data(self, input_data, output_path="./processed_clinical_cases", batch_size=16): |
|
|
""" |
|
|
Convert raw case data (DataFrame or path to CSV) into a vectorized Hugging Face dataset. |
|
|
|
|
|
Args: |
|
|
input_data: DataFrame or path to CSV file with clinical cases. |
|
|
output_path: Where to save the processed Hugging Face dataset. |
|
|
batch_size: Batch size for embedding generation. |
|
|
|
|
|
Returns: |
|
|
datasets.Dataset: The processed dataset with embeddings. |
|
|
""" |
|
|
|
|
|
if isinstance(input_data, pd.DataFrame): |
|
|
df = input_data |
|
|
print("Using provided DataFrame.") |
|
|
elif isinstance(input_data, str) and os.path.exists(input_data): |
|
|
try: |
|
|
df = pd.read_csv(input_data) |
|
|
print(f"Data loaded from CSV: {input_data}") |
|
|
except Exception as e: |
|
|
print(f"Error loading CSV {input_data}: {e}") |
|
|
return None |
|
|
else: |
|
|
print(f"Error: Invalid input_data type or path does not exist: {input_data}") |
|
|
return None |
|
|
|
|
|
if df.empty: |
|
|
print("Error: Input DataFrame is empty. Cannot process.") |
|
|
return None |
|
|
|
|
|
print(f"Raw data shape: {df.shape}") |
|
|
|
|
|
|
|
|
required_cols = ['case_id', 'clinical_presentation', 'turn_id', 'question', 'answer'] |
|
|
if not all(col in df.columns for col in required_cols): |
|
|
print(f"Error: DataFrame missing required columns. Found: {df.columns}. Required: {required_cols}") |
|
|
return None |
|
|
|
|
|
|
|
|
grouped = df.groupby(['case_id', 'clinical_presentation'], dropna=False) |
|
|
|
|
|
|
|
|
case_data = [] |
|
|
print("Grouping data by case...") |
|
|
for (case_id, presentation), group in tqdm(grouped, desc="Processing Cases"): |
|
|
|
|
|
group = group.sort_values('turn_id') |
|
|
|
|
|
|
|
|
questions = group['question'].tolist() |
|
|
answers = group['answer'].tolist() |
|
|
|
|
|
|
|
|
presentation_str = str(presentation) if pd.notna(presentation) else "Unknown Presentation" |
|
|
|
|
|
case_data.append({ |
|
|
'case_id': str(case_id) if pd.notna(case_id) else "Unknown ID", |
|
|
'clinical_presentation': presentation_str, |
|
|
'questions': questions, |
|
|
'answers': answers |
|
|
}) |
|
|
|
|
|
if not case_data: |
|
|
print("Error: No cases could be processed after grouping. Check input data integrity.") |
|
|
return None |
|
|
|
|
|
processed_df = pd.DataFrame(case_data) |
|
|
print(f"Processed data into {len(processed_df)} unique cases.") |
|
|
|
|
|
|
|
|
processed_df['case_summary'] = processed_df.apply( |
|
|
lambda x: f"Clinical case: {x['clinical_presentation']}. First question: {x['questions'][0] if x['questions'] else 'No questions available'}", |
|
|
axis=1 |
|
|
) |
|
|
|
|
|
|
|
|
texts_to_embed = processed_df['case_summary'].tolist() |
|
|
all_embeddings = [] |
|
|
|
|
|
print(f"Generating embeddings for {len(texts_to_embed)} case summaries...") |
|
|
try: |
|
|
for i in tqdm(range(0, len(texts_to_embed), batch_size), desc="Embedding Batches"): |
|
|
batch_texts = texts_to_embed[i:i+batch_size] |
|
|
|
|
|
batch_embeddings = self.model.encode(batch_texts, convert_to_numpy=True, device=self.device, show_progress_bar=False) |
|
|
all_embeddings.append(batch_embeddings) |
|
|
|
|
|
|
|
|
if not all_embeddings: |
|
|
print("Error: No embeddings were generated.") |
|
|
return None |
|
|
final_embeddings = np.vstack(all_embeddings) |
|
|
print(f"Generated embeddings with shape: {final_embeddings.shape}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during embedding generation: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
dataset = Dataset.from_pandas(processed_df) |
|
|
|
|
|
dataset = dataset.add_column('embeddings', final_embeddings.tolist()) |
|
|
except Exception as e: |
|
|
print(f"Error converting to Hugging Face Dataset or adding embeddings: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
try: |
|
|
os.makedirs(output_path, exist_ok=True) |
|
|
dataset.save_to_disk(output_path) |
|
|
print(f"Processed dataset saved successfully to {output_path}") |
|
|
except Exception as e: |
|
|
print(f"Error saving dataset to disk at {output_path}: {e}") |
|
|
return None |
|
|
|
|
|
return dataset |