File size: 2,841 Bytes
ecf95ae bade928 ecf95ae bade928 ecf95ae 0973090 bade928 b2dc1ea 0973090 b2dc1ea 0973090 b2dc1ea 0973090 b2dc1ea 0973090 b2dc1ea 0973090 b2dc1ea 0973090 bade928 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | """
NeoSyn - data_processor.py
Handles synthetic data generation using CTGAN,
distribution plots, and summaries for both original and synthetic data.
Author: Saivivek Katkuri
Date: June 2025
"""
import pandas as pd
import matplotlib.pyplot as plt
import os
from datetime import datetime
from sdv.single_table import CTGANSynthesizer
from sdv.metadata import SingleTableMetadata
def generate_synthetic_data(original_data: pd.DataFrame, num_samples: int = 10, output_dir: str = "generated") -> pd.DataFrame:
"""
Generates synthetic data using CTGAN and saves as a Parquet file.
Args:
original_data (pd.DataFrame): Original data.
num_samples (int): Number of synthetic samples to generate.
output_dir (str): Directory to save synthetic data.
Returns:
pd.DataFrame: Synthetic data.
"""
# Create metadata for the table
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data=original_data)
# Create the CTGAN synthesizer
synthesizer = CTGANSynthesizer(metadata)
# Fit on original data
synthesizer.fit(original_data)
# Sample synthetic data
synthetic_data = synthesizer.sample(num_rows=num_samples)
# Save as Parquet
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
output_path = os.path.join(output_dir, f"synthetic_data_{timestamp}.parquet")
synthetic_data.to_parquet(output_path)
return synthetic_data
def plot_distribution(original_data: pd.DataFrame, synthetic_data: pd.DataFrame, column: str, output_dir: str = "plots") -> str:
"""
Plots the distribution of a specified column in both original and synthetic data.
Args:
original_data (pd.DataFrame): Original data.
synthetic_data (pd.DataFrame): Synthetic data.
column (str): Column to plot.
output_dir (str): Directory to save the plot.
Returns:
str: Path to the saved plot image.
"""
if column not in original_data.columns:
raise ValueError(f"Column '{column}' not found in original data.")
os.makedirs(output_dir, exist_ok=True)
plt.figure(figsize=(10, 6))
plt.hist(original_data[column], bins=30, alpha=0.5, label="Original")
plt.hist(synthetic_data[column], bins=30, alpha=0.5, label="Synthetic")
plt.title(f"Distribution of {column}")
plt.xlabel(column)
plt.ylabel("Frequency")
plt.legend()
plot_path = os.path.join(output_dir, f"{column}_distribution.png")
plt.savefig(plot_path)
plt.close()
return plot_path
def summarize_data(data: pd.DataFrame) -> str:
"""
Generates a textual summary of the data using pandas describe.
Args:
data (pd.DataFrame): Data to summarize.
Returns:
str: Textual summary.
"""
return data.describe().to_string() |