boardgpt-llm / src /data_processing.py
melmoheb's picture
Upload 6 files
129641e verified
import os
import re
import pandas as pd
import numpy as np
import torch
from docx import Document
from sentence_transformers import SentenceTransformer
from datasets import Dataset
from tqdm import tqdm
def read_docx(file_path):
"""Reads text content from a .docx file."""
try:
doc = Document(file_path)
return '\n'.join(para.text for para in doc.paragraphs)
except Exception as e:
print(f"Error reading {file_path}: {e}")
return ""
def extract_qa_pairs(text):
"""Extracts alternating Examiner and Examinee Q&A pairs from text."""
pattern = re.compile(r"\*\*Examiner:\*\*(.*?)\n\n\*\*Examinee:\*\*(.*?)(?=\n\n\*\*Examiner:\*\*|$)", re.DOTALL)
return [{"question": q.strip(), "answer": a.strip()} for q, a in pattern.findall(text)]
def parse_filename(filename):
"""Parses case ID and topic from BTK filename format."""
# Example: BTK_-_77A___Burn.docx -> case_id = 77A, clinical_presentation = Burn
base = os.path.splitext(filename)[0]
match = re.match(r"BTK_-_(\d+[A-Z]?)___(.+)", base)
if match:
case_id = match.group(1)
topic = match.group(2).replace("_", " ").strip()
else:
# Handle potential variations or log unknown formats if needed
print(f"Warning: Could not parse filename format: {filename}")
case_id, topic = "Unknown", "Unknown"
return case_id, topic
def process_all_cases(folder_path):
"""Reads all .docx files in a folder and structures them into a DataFrame."""
rows = []
if not os.path.isdir(folder_path):
print(f"Error: Folder not found at {folder_path}")
return pd.DataFrame(rows)
print(f"Processing case files from: {folder_path}")
for filename in os.listdir(folder_path):
if filename.lower().endswith('.docx') and not filename.startswith('~'): # Avoid temp files
file_path = os.path.join(folder_path, filename)
text = read_docx(file_path)
if text:
qa_pairs = extract_qa_pairs(text)
case_id, presentation = parse_filename(filename)
if not qa_pairs:
print(f"Warning: No Q&A pairs extracted from {filename}")
for i, pair in enumerate(qa_pairs):
rows.append({
"case_id": case_id,
"clinical_presentation": presentation,
"turn_id": i + 1,
"question": pair["question"],
"answer": pair["answer"]
})
else:
print(f"Warning: Empty content for file {filename}")
if not rows:
print("Warning: No data rows were generated. Check input files and formats.")
return pd.DataFrame(rows)
# --- ClinicalCaseProcessor Class ---
class ClinicalCaseProcessor:
"""Handles preprocessing of clinical cases for the RAG system using sentence-transformers."""
def __init__(self, model_name="all-MiniLM-L6-v2"):
print(f"Initializing ClinicalCaseProcessor with model: {model_name}")
self.model = SentenceTransformer(model_name)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
self.model.to(self.device)
def preprocess_data(self, input_data, output_path="./processed_clinical_cases", batch_size=16):
"""
Convert raw case data (DataFrame or path to CSV) into a vectorized Hugging Face dataset.
Args:
input_data: DataFrame or path to CSV file with clinical cases.
output_path: Where to save the processed Hugging Face dataset.
batch_size: Batch size for embedding generation.
Returns:
datasets.Dataset: The processed dataset with embeddings.
"""
# Load data
if isinstance(input_data, pd.DataFrame):
df = input_data
print("Using provided DataFrame.")
elif isinstance(input_data, str) and os.path.exists(input_data):
try:
df = pd.read_csv(input_data)
print(f"Data loaded from CSV: {input_data}")
except Exception as e:
print(f"Error loading CSV {input_data}: {e}")
return None
else:
print(f"Error: Invalid input_data type or path does not exist: {input_data}")
return None
if df.empty:
print("Error: Input DataFrame is empty. Cannot process.")
return None
print(f"Raw data shape: {df.shape}")
# Validate necessary columns
required_cols = ['case_id', 'clinical_presentation', 'turn_id', 'question', 'answer']
if not all(col in df.columns for col in required_cols):
print(f"Error: DataFrame missing required columns. Found: {df.columns}. Required: {required_cols}")
return None
# Group by case_id to get all Q&A pairs for each case
grouped = df.groupby(['case_id', 'clinical_presentation'], dropna=False)
# Create a new dataframe with one row per case
case_data = []
print("Grouping data by case...")
for (case_id, presentation), group in tqdm(grouped, desc="Processing Cases"):
# Sort by turn_id to ensure correct order
group = group.sort_values('turn_id')
# Extract questions and answers in order
questions = group['question'].tolist()
answers = group['answer'].tolist()
# Handle potential NaN/None in presentation if groupby didn't drop them
presentation_str = str(presentation) if pd.notna(presentation) else "Unknown Presentation"
case_data.append({
'case_id': str(case_id) if pd.notna(case_id) else "Unknown ID",
'clinical_presentation': presentation_str,
'questions': questions,
'answers': answers
})
if not case_data:
print("Error: No cases could be processed after grouping. Check input data integrity.")
return None
processed_df = pd.DataFrame(case_data)
print(f"Processed data into {len(processed_df)} unique cases.")
# Create a searchable summary of each case (handle empty question lists)
processed_df['case_summary'] = processed_df.apply(
lambda x: f"Clinical case: {x['clinical_presentation']}. First question: {x['questions'][0] if x['questions'] else 'No questions available'}",
axis=1
)
# Generate embeddings using sentence-transformers
texts_to_embed = processed_df['case_summary'].tolist()
all_embeddings = []
print(f"Generating embeddings for {len(texts_to_embed)} case summaries...")
try:
for i in tqdm(range(0, len(texts_to_embed), batch_size), desc="Embedding Batches"):
batch_texts = texts_to_embed[i:i+batch_size]
# Generate embeddings for the batch
batch_embeddings = self.model.encode(batch_texts, convert_to_numpy=True, device=self.device, show_progress_bar=False)
all_embeddings.append(batch_embeddings)
# Combine all batch embeddings
if not all_embeddings:
print("Error: No embeddings were generated.")
return None
final_embeddings = np.vstack(all_embeddings)
print(f"Generated embeddings with shape: {final_embeddings.shape}")
except Exception as e:
print(f"Error during embedding generation: {e}")
return None
# Convert to HF Dataset and add embeddings
try:
dataset = Dataset.from_pandas(processed_df)
# Ensure embeddings column is compatible (list of lists)
dataset = dataset.add_column('embeddings', final_embeddings.tolist())
except Exception as e:
print(f"Error converting to Hugging Face Dataset or adding embeddings: {e}")
return None
# Save processed dataset
try:
os.makedirs(output_path, exist_ok=True) # Ensure directory exists
dataset.save_to_disk(output_path)
print(f"Processed dataset saved successfully to {output_path}")
except Exception as e:
print(f"Error saving dataset to disk at {output_path}: {e}")
return None # Return None if saving failed
return dataset