psychology-tutor-engine / normalize_psych_data.py
adfras's picture
Initial commit: Psychology tutor engine and data pipelines
1da14e1
# normalize_psych_data.py
# FINAL CORRECTED VERSION
import os
import requests
import re
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm
import warnings
from langdetect import detect, LangDetectException # Import the new library
warnings.simplefilter(action='ignore', category=FutureWarning)
# --- Configuration ---
RAW_DATA_DIR = "data/raw_psych_data"
NORMALIZED_DATA_DIR = "data/processed"
os.makedirs(RAW_DATA_DIR, exist_ok=True)
os.makedirs(NORMALIZED_DATA_DIR, exist_ok=True)
SCHEMA = ['question', 'answer', 'source', 'licence']
# ... (All the download_file, save_normalized_df, and process_* functions remain exactly the same) ...
# --- Helper Functions ---
def download_file(url, local_filename):
local_path = os.path.join(RAW_DATA_DIR, local_filename)
if os.path.exists(local_path):
return local_path
print(f"Downloading {url} to {local_path}...")
try:
with requests.get(url, stream=True, timeout=120) as r:
r.raise_for_status()
total_size = int(r.headers.get('content-length', 0))
with open(local_path, 'wb') as f, tqdm(
total=total_size, unit='iB', unit_scale=True, desc=local_filename
) as pbar:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
pbar.update(len(chunk))
return local_path
except requests.exceptions.RequestException as e:
print(f"ERROR: Download failed for {url}. Error: {e}")
return None
def save_normalized_df(df, filename):
# This function now ONLY saves. It does not filter.
assert set(df.columns) >= {"question", "answer"}, f"DataFrame for {filename} is missing 'question' or 'answer' columns. Found: {list(df.columns)}"
df = df.dropna(subset=['question', 'answer'])
df = df[df['question'].astype(str).str.strip() != '']
df = df[df['answer'].astype(str).str.strip() != '']
df = df[SCHEMA].copy()
output_path = os.path.join(NORMALIZED_DATA_DIR, filename)
df.to_parquet(output_path, index=False)
# --- Dataset Processing Functions ---
def process_boltmonkey():
print("\n--- Processing: BoltMonkey ---")
url = "https://huggingface.co/datasets/BoltMonkey/psychology-question-answer/resolve/main/data/train/train.json?download=true"
filepath = download_file(url, "boltmonkey.json")
if not filepath: return pd.DataFrame()
df = pd.read_json(filepath)
df['source'] = 'BoltMonkey/psychology-question-answer'
df['licence'] = 'CC-BY-NC'
return df
def process_gragroo():
print("\n--- Processing: Gragroo ---")
url = "https://huggingface.co/datasets/Gragroo/psychology-question-answer_psygpt_with_validation/resolve/main/data/train-00000-of-00001.parquet?download=true"
filepath = download_file(url, "gragroo_train.parquet")
if not filepath: return pd.DataFrame()
pairs = []
for conv in pd.read_parquet(filepath)["conversations"]:
q = None
for turn in conv:
if turn["from"] == "human": q = turn["value"].strip()
elif turn["from"] == "assistant" and q:
pairs.append({"question": q, "answer": turn["value"].strip()})
q = None
if not pairs: return pd.DataFrame()
df = pd.DataFrame(pairs)
df["source"] = "Gragroo/psychology-question-answer_psygpt_with_validation"
df["licence"] = "CC-BY-NC"
return df
def process_psycholexqa():
print("\n--- Processing: PsychoLexQA ---")
try:
ds = load_dataset("aminabbasi/PsychoLexQA", split="train")
df = ds.to_pandas().rename(columns={"instruction": "question", "output": "answer"})
df["source"] = "PsychoLexQA"
df["licence"] = "CC-BY-NC"
return df
except Exception as e:
print(f"ERROR: Could not load PsychoLexQA. Accept the licence on Hugging Face first. Error: {e}")
return pd.DataFrame()
def process_mmlu():
print("\n--- Processing: MMLU Psychology ---")
all_dfs = []
for split in ["high_school_psychology", "professional_psychology"]:
try:
ds = load_dataset("cais/mmlu", name=split, split="test")
df = ds.to_pandas()
def format_answer(row):
choices_text = "\n".join([f"{chr(65+i)}. {choice}" for i, choice in enumerate(row['choices'])])
correct_choice = row['choices'][row['answer']]
return f"{row['question']}\n\n{choices_text}", f"The correct answer is {chr(65 + row['answer'])}: {correct_choice}"
df['question'], df['answer'] = zip(*df.apply(format_answer, axis=1))
df['source'], df['licence'] = f'MMLU/{split}', 'MIT'
all_dfs.append(df)
except Exception as e:
print(f"ERROR: Could not process MMLU split {split}. Error: {e}")
return pd.concat(all_dfs, ignore_index=True) if all_dfs else pd.DataFrame()
# --- Main Execution ---
if __name__ == "__main__":
all_dataframes = []
processing_functions = [process_boltmonkey, process_gragroo, process_psycholexqa, process_mmlu]
for func in processing_functions:
try:
df = func()
if not df.empty:
all_dataframes.append(df[SCHEMA])
except Exception as e:
print(f"A critical error occurred during execution of {func.__name__}: {e}")
if all_dataframes:
print("\n--- Combining all datasets ---")
final_df = pd.concat(all_dataframes, ignore_index=True)
# --- APPLYING FINAL QUALITY FILTERS ---
print(f"\nApplying final quality filters to {len(final_df)} combined rows...")
original_rows = len(final_df)
# Filter 1: Question Length
final_df = final_df[final_df['question'].str.len() >= 10].copy()
# --- NEW: Filter 2: Language Detection ---
def is_english(text):
try:
# Returns True if language is English, False otherwise
return detect(text) == 'en'
except LangDetectException:
# If language cannot be detected, assume it's not valid English
return False
print("Detecting language for each question... (This might take a moment)")
# Apply the function to the 'question' column
mask = final_df['question'].apply(is_english)
final_df = final_df[mask]
# --- END OF NEW FILTER ---
rows_removed = original_rows - len(final_df)
if rows_removed > 0:
print(f"-> SUCCESS: Filtered out {rows_removed} rows due to length or language.")
# Save the final, clean dataset
final_output_path = os.path.join(NORMALIZED_DATA_DIR, "ALL_PSYCHOLOGY_DATA_normalized.parquet")
final_df.to_parquet(final_output_path, index=False)
print(f"\nSaved final combined data to {final_output_path} ({len(final_df)} rows)")
print("\n--- Final Summary ---")
print("Breakdown by source:")
print(final_df['source'].value_counts())
else:
print("\nNo data was processed successfully. Check logs for errors.")