NeoSyn / data_processor.py
vivekprojects-GIT
Final data_processor.py with all functions and SDV API fixes
bade928
"""
NeoSyn - data_processor.py
Handles synthetic data generation using CTGAN,
distribution plots, and summaries for both original and synthetic data.
Author: Saivivek Katkuri
Date: June 2025
"""
import pandas as pd
import matplotlib.pyplot as plt
import os
from datetime import datetime
from sdv.single_table import CTGANSynthesizer
from sdv.metadata import SingleTableMetadata
def generate_synthetic_data(original_data: pd.DataFrame, num_samples: int = 10, output_dir: str = "generated") -> pd.DataFrame:
"""
Generates synthetic data using CTGAN and saves as a Parquet file.
Args:
original_data (pd.DataFrame): Original data.
num_samples (int): Number of synthetic samples to generate.
output_dir (str): Directory to save synthetic data.
Returns:
pd.DataFrame: Synthetic data.
"""
# Create metadata for the table
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data=original_data)
# Create the CTGAN synthesizer
synthesizer = CTGANSynthesizer(metadata)
# Fit on original data
synthesizer.fit(original_data)
# Sample synthetic data
synthetic_data = synthesizer.sample(num_rows=num_samples)
# Save as Parquet
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
output_path = os.path.join(output_dir, f"synthetic_data_{timestamp}.parquet")
synthetic_data.to_parquet(output_path)
return synthetic_data
def plot_distribution(original_data: pd.DataFrame, synthetic_data: pd.DataFrame, column: str, output_dir: str = "plots") -> str:
"""
Plots the distribution of a specified column in both original and synthetic data.
Args:
original_data (pd.DataFrame): Original data.
synthetic_data (pd.DataFrame): Synthetic data.
column (str): Column to plot.
output_dir (str): Directory to save the plot.
Returns:
str: Path to the saved plot image.
"""
if column not in original_data.columns:
raise ValueError(f"Column '{column}' not found in original data.")
os.makedirs(output_dir, exist_ok=True)
plt.figure(figsize=(10, 6))
plt.hist(original_data[column], bins=30, alpha=0.5, label="Original")
plt.hist(synthetic_data[column], bins=30, alpha=0.5, label="Synthetic")
plt.title(f"Distribution of {column}")
plt.xlabel(column)
plt.ylabel("Frequency")
plt.legend()
plot_path = os.path.join(output_dir, f"{column}_distribution.png")
plt.savefig(plot_path)
plt.close()
return plot_path
def summarize_data(data: pd.DataFrame) -> str:
"""
Generates a textual summary of the data using pandas describe.
Args:
data (pd.DataFrame): Data to summarize.
Returns:
str: Textual summary.
"""
return data.describe().to_string()