reab5555 commited on
Commit
4395fda
·
verified ·
1 Parent(s): 8887e6d

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +124 -0
  2. clean.py +275 -0
  3. report.py +271 -0
  4. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pyspark.sql import SparkSession
3
+ import os
4
+ import pandas as pd
5
+ from datetime import datetime
6
+ from clean import clean_data, get_numeric_columns
7
+ from report import create_full_report, REPORT_DIR
8
+
9
+
10
+ def clean_and_visualize(file, primary_key_column, progress=gr.Progress()):
11
+ # Create a Spark session
12
+ spark = SparkSession.builder.appName("DataCleaner").getOrCreate()
13
+
14
+ # Read the CSV file
15
+ progress(0.05, desc="Reading CSV file")
16
+ df = spark.read.csv(file.name, header=True, inferSchema=True)
17
+
18
+ # Clean the data
19
+ progress(0.1, desc="Starting data cleaning")
20
+ cleaned_df, nonconforming_cells_before, process_times = clean_data(spark, df, primary_key_column, progress)
21
+ progress(0.8, desc="Data cleaning completed")
22
+
23
+ # Calculate removed columns and rows
24
+ removed_columns = len(df.columns) - len(cleaned_df.columns)
25
+ removed_rows = df.count() - cleaned_df.count()
26
+
27
+ # Generate full visualization report
28
+ progress(0.9, desc="Generating report")
29
+ create_full_report(
30
+ df,
31
+ cleaned_df,
32
+ nonconforming_cells_before,
33
+ process_times,
34
+ removed_columns,
35
+ removed_rows,
36
+ primary_key_column
37
+ )
38
+
39
+ # Convert PySpark DataFrame to Pandas DataFrame and save as CSV
40
+ progress(0.95, desc="Saving cleaned data")
41
+ pandas_df = cleaned_df.toPandas()
42
+
43
+ # Generate cleaned CSV file name with current date and time
44
+ current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
45
+ cleaned_csv_path = os.path.join(f"cleaned_data_{current_time}.csv")
46
+
47
+ pandas_df.to_csv(cleaned_csv_path, index=False)
48
+
49
+ # Collect all generated images
50
+ image_files = [os.path.join(REPORT_DIR, f) for f in os.listdir(REPORT_DIR) if f.endswith('.png')]
51
+
52
+ # Stop the Spark session
53
+ spark.stop()
54
+
55
+ progress(1.0, desc="Process completed")
56
+ return cleaned_csv_path, image_files
57
+
58
+
59
+ def launch_app():
60
+ with gr.Blocks() as app:
61
+ gr.Markdown("# Data Cleaner")
62
+
63
+ with gr.Row():
64
+ file_input = gr.File(label="Upload CSV File", file_count="single", file_types=[".csv"])
65
+
66
+ with gr.Row():
67
+ primary_key_dropdown = gr.Dropdown(label="Select Primary Key Column", choices=[], interactive=True)
68
+
69
+ with gr.Row():
70
+ clean_button = gr.Button("Start Cleaning")
71
+
72
+ with gr.Row():
73
+ progress_bar = gr.Progress()
74
+
75
+ with gr.Row():
76
+ cleaned_file_output = gr.File(label="Cleaned CSV", visible=True)
77
+
78
+ with gr.Row():
79
+ output_gallery = gr.Gallery(
80
+ label="Visualization Results",
81
+ show_label=True,
82
+ elem_id="gallery",
83
+ columns=[3],
84
+ rows=[3],
85
+ object_fit="contain",
86
+ height="auto",
87
+ visible=False
88
+ )
89
+
90
+ def update_primary_key_options(file):
91
+ if file is None:
92
+ return gr.Dropdown(choices=[])
93
+
94
+ spark = SparkSession.builder.appName("DataCleaner").getOrCreate()
95
+ df = spark.read.csv(file.name, header=True, inferSchema=True)
96
+ numeric_columns = get_numeric_columns(df)
97
+ spark.stop()
98
+
99
+ return gr.Dropdown(choices=numeric_columns)
100
+
101
+ def process_and_show_results(file, primary_key_column):
102
+ cleaned_csv_path, image_files = clean_and_visualize(file, primary_key_column, progress=progress_bar)
103
+ return (
104
+ cleaned_csv_path,
105
+ gr.Gallery(visible=True, value=image_files)
106
+ )
107
+
108
+ file_input.change(
109
+ fn=update_primary_key_options,
110
+ inputs=file_input,
111
+ outputs=primary_key_dropdown
112
+ )
113
+
114
+ clean_button.click(
115
+ fn=process_and_show_results,
116
+ inputs=[file_input, primary_key_dropdown],
117
+ outputs=[cleaned_file_output, output_gallery]
118
+ )
119
+
120
+ app.launch()
121
+
122
+
123
+ if __name__ == "__main__":
124
+ launch_app()
clean.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from pyspark.sql import SparkSession
4
+ from pyspark.sql.functions import col, isnan, when, count, lower, regexp_replace, to_date, to_timestamp, udf, \
5
+ levenshtein, array, lit, trim, size, coalesce
6
+ from pyspark.sql.types import DoubleType, IntegerType, StringType, DateType, TimestampType, ArrayType
7
+ from pyspark.sql.utils import AnalysisException
8
+ import time
9
+ from time import perf_counter
10
+
11
+ # Constants
12
+ EMPTY_THRESHOLD = 0.5
13
+ LOW_COUNT_THRESHOLD = 2
14
+ VALID_DATA_THRESHOLD = 0.5
15
+
16
+ def print_dataframe_info(df, step=""):
17
+ num_columns = len(df.columns)
18
+ num_rows = df.count()
19
+ num_cells = num_columns * num_rows
20
+ print(f"{step}Dataframe info:")
21
+ print(f" Number of columns: {num_columns}")
22
+ print(f" Number of rows: {num_rows}")
23
+ print(f" Total number of cells: {num_cells}")
24
+
25
+
26
+ def check_and_normalize_column_headers(df):
27
+ print("Checking and normalizing column headers...")
28
+
29
+ for old_name in df.columns:
30
+ # Create the new name using string manipulation
31
+ new_name = old_name.lower().replace(' ', '_')
32
+
33
+ # Remove any non-alphanumeric characters (excluding underscores)
34
+ new_name = re.sub(r'[^0-9a-zA-Z_]', '', new_name)
35
+
36
+ # Rename the column
37
+ df = df.withColumnRenamed(old_name, new_name)
38
+
39
+ print("Column names have been normalized.")
40
+ return df
41
+
42
+
43
+ def remove_empty_columns(df, threshold=EMPTY_THRESHOLD):
44
+ print(f"Removing columns with less than {threshold * 100}% valid data...")
45
+
46
+ # Calculate the percentage of non-null values for each column
47
+ df_stats = df.select(
48
+ [((count(when(col(c).isNotNull(), c)) / count('*')) >= threshold).alias(c) for c in df.columns])
49
+ valid_columns = [c for c in df_stats.columns if df_stats.select(c).first()[0]]
50
+
51
+ return df.select(valid_columns)
52
+
53
+
54
+ def remove_empty_rows(df, threshold=EMPTY_THRESHOLD):
55
+ print(f"Removing rows with less than {threshold * 100}% valid data...")
56
+
57
+ # Count the number of non-null values for each row
58
+ expr = sum([when(col(c).isNotNull(), lit(1)).otherwise(lit(0)) for c in df.columns])
59
+ df_valid_count = df.withColumn('valid_count', expr)
60
+
61
+ # Filter rows based on the threshold
62
+ total_columns = len(df.columns)
63
+ df_filtered = df_valid_count.filter(col('valid_count') >= threshold * total_columns)
64
+
65
+ print('count of valid rows:', df_filtered.count())
66
+
67
+ return df_filtered.drop('valid_count')
68
+
69
+
70
+ def drop_rows_with_nas(df, threshold=VALID_DATA_THRESHOLD):
71
+ print(f"Dropping rows with NAs for columns with more than {threshold * 100}% valid data...")
72
+
73
+ # Calculate the percentage of non-null values for each column
74
+ df_stats = df.select([((count(when(col(c).isNotNull(), c)) / count('*'))).alias(c) for c in df.columns])
75
+
76
+ # Get columns with more than threshold valid data
77
+ valid_columns = [c for c in df_stats.columns if df_stats.select(c).first()[0] > threshold]
78
+
79
+ # Drop rows with NAs only for the valid columns
80
+ for column in valid_columns:
81
+ df = df.filter(col(column).isNotNull())
82
+
83
+ return df
84
+
85
+ def check_typos(df, column_name, threshold=2, top_n=100):
86
+ # Check if the column is of StringType
87
+ if not isinstance(df.schema[column_name].dataType, StringType):
88
+ print(f"Skipping typo check for column {column_name} as it is not a string type.")
89
+ return None
90
+
91
+ print(f"Checking for typos in column: {column_name}")
92
+
93
+ try:
94
+ # Get value counts for the specific column
95
+ value_counts = df.groupBy(column_name).count().orderBy("count", ascending=False)
96
+
97
+ # Take top N most frequent values
98
+ top_values = [row[column_name] for row in value_counts.limit(top_n).collect()]
99
+
100
+ # Broadcast the top values to all nodes
101
+ broadcast_top_values = df.sparkSession.sparkContext.broadcast(top_values)
102
+
103
+ # Define UDF to find similar strings
104
+ @udf(returnType=ArrayType(StringType()))
105
+ def find_similar_strings(value):
106
+ if value is None:
107
+ return []
108
+ similar = []
109
+ for top_value in broadcast_top_values.value:
110
+ if value != top_value and levenshtein(value, top_value) <= threshold:
111
+ similar.append(top_value)
112
+ return similar
113
+
114
+ # Apply the UDF to the column
115
+ df_with_typos = df.withColumn("possible_typos", find_similar_strings(col(column_name)))
116
+
117
+ # Filter rows with possible typos and select only the relevant columns
118
+ typos_df = df_with_typos.filter(size("possible_typos") > 0).select(column_name, "possible_typos")
119
+
120
+ # Check if there are any potential typos
121
+ typo_count = typos_df.count()
122
+ if typo_count > 0:
123
+ print(f"Potential typos found in column {column_name}: {typo_count}")
124
+ typos_df.show(10, truncate=False)
125
+ return typos_df
126
+ else:
127
+ print(f"No potential typos found in column {column_name}")
128
+ return None
129
+
130
+ except AnalysisException as e:
131
+ print(f"Error analyzing column {column_name}: {str(e)}")
132
+ return None
133
+ except Exception as e:
134
+ print(f"Unexpected error in check_typos for column {column_name}: {str(e)}")
135
+ return None
136
+
137
+
138
+ def transform_string_column(df, column_name):
139
+ print(f"Transforming string column: {column_name}")
140
+ # Lower case transformation (if applicable)
141
+ df = df.withColumn(column_name, lower(col(column_name)))
142
+ # Remove leading and trailing spaces
143
+ df = df.withColumn(column_name, trim(col(column_name)))
144
+ # Replace multiple spaces with a single space
145
+ df = df.withColumn(column_name, regexp_replace(col(column_name), "\\s+", " "))
146
+ # Remove special characters except those used in dates and times
147
+ df = df.withColumn(column_name, regexp_replace(col(column_name), "[^a-zA-Z0-9\\s/:.-]", ""))
148
+ return df
149
+
150
+
151
+ def clean_column(df, column_name):
152
+ print(f"Cleaning column: {column_name}")
153
+ start_time = perf_counter()
154
+ # Get the data type of the current column
155
+ column_type = df.schema[column_name].dataType
156
+
157
+ if isinstance(column_type, StringType):
158
+ # Skip date detection and directly process as string
159
+ # For string columns, check for typos and transform
160
+ typos_df = check_typos(df, column_name)
161
+ if typos_df is not None and typos_df.count() > 0:
162
+ print(f"Detailed typos for column {column_name}:")
163
+ typos_df.show(truncate=False)
164
+ df = transform_string_column(df, column_name)
165
+
166
+ elif isinstance(column_type, (DoubleType, IntegerType)):
167
+ # For numeric columns, we'll do a simple null check
168
+ df = df.withColumn(column_name, when(col(column_name).isNull(), lit(None)).otherwise(col(column_name)))
169
+
170
+ end_time = perf_counter()
171
+ print(f"Time taken to clean {column_name}: {end_time - start_time:.6f} seconds")
172
+ return df
173
+
174
+
175
+
176
+
177
+ # Update the remove_outliers function to work on a single column
178
+ def remove_outliers(df, column):
179
+ print(f"Removing outliers from column: {column}")
180
+
181
+ stats = df.select(column).summary("25%", "75%").collect()
182
+ q1 = float(stats[0][1])
183
+ q3 = float(stats[1][1])
184
+ iqr = q3 - q1
185
+ lower_bound = q1 - 1.5 * iqr
186
+ upper_bound = q3 + 1.5 * iqr
187
+ df = df.filter((col(column) >= lower_bound) & (col(column) <= upper_bound))
188
+
189
+ return df
190
+
191
+
192
+ def calculate_nonconforming_cells(df):
193
+ nonconforming_cells = {}
194
+ for column in df.columns:
195
+ nonconforming_count = df.filter(col(column).isNull() | isnan(column)).count()
196
+ nonconforming_cells[column] = nonconforming_count
197
+ return nonconforming_cells
198
+
199
+
200
+ def get_numeric_columns(df):
201
+ return [field.name for field in df.schema.fields if isinstance(field.dataType, (IntegerType, DoubleType))]
202
+
203
+ def remove_duplicates_from_primary_key(df, primary_key_column):
204
+ print(f"Removing duplicates based on primary key column: {primary_key_column}")
205
+ return df.dropDuplicates([primary_key_column])
206
+
207
+ def clean_data(spark, df, primary_key_column, progress):
208
+ start_time = time.time()
209
+ process_times = {}
210
+
211
+ print("Starting data validation and cleaning...")
212
+ print_dataframe_info(df, "Initial - ")
213
+
214
+ # Calculate nonconforming cells before cleaning
215
+ nonconforming_cells_before = calculate_nonconforming_cells(df)
216
+
217
+ # Step 1: Normalize column headers
218
+ progress(0.1, desc="Normalizing column headers")
219
+ step_start_time = time.time()
220
+ df = check_and_normalize_column_headers(df)
221
+ process_times['Normalize headers'] = time.time() - step_start_time
222
+
223
+ # Step 2: Remove empty columns
224
+ progress(0.2, desc="Removing empty columns")
225
+ step_start_time = time.time()
226
+ df = remove_empty_columns(df)
227
+ print('2) count of valid rows:', df.count())
228
+ process_times['Remove empty columns'] = time.time() - step_start_time
229
+
230
+ # Step 3: Remove empty rows
231
+ progress(0.3, desc="Removing empty rows")
232
+ step_start_time = time.time()
233
+ df = remove_empty_rows(df)
234
+ print('3) count of valid rows:', df.count())
235
+ process_times['Remove empty rows'] = time.time() - step_start_time
236
+
237
+ # Step 4: Drop rows with NAs for columns with more than 50% valid data
238
+ progress(0.4, desc="Dropping rows with NAs")
239
+ step_start_time = time.time()
240
+ df = drop_rows_with_nas(df)
241
+ print('4) count of valid rows:', df.count())
242
+ process_times['Drop rows with NAs'] = time.time() - step_start_time
243
+
244
+ # Step 5: Clean columns (including typo checking and string transformation)
245
+ column_cleaning_times = {}
246
+ total_columns = len(df.columns)
247
+ for index, column in enumerate(df.columns):
248
+ progress(0.5 + (0.2 * (index / total_columns)), desc=f"Cleaning column: {column}")
249
+ column_start_time = time.time()
250
+ df = clean_column(df, column)
251
+ print('5) count of valid rows:', df.count())
252
+ column_cleaning_times[f"Clean column: {column}"] = time.time() - column_start_time
253
+ process_times.update(column_cleaning_times)
254
+
255
+ # Step 6: Remove outliers from numeric columns (excluding primary key)
256
+ progress(0.7, desc="Removing outliers")
257
+ step_start_time = time.time()
258
+ numeric_columns = get_numeric_columns(df)
259
+ numeric_columns = [col for col in numeric_columns if col != primary_key_column]
260
+ for column in numeric_columns:
261
+ df = remove_outliers(df, column)
262
+ print('6) count of valid rows:', df.count())
263
+ process_times['Remove outliers'] = time.time() - step_start_time
264
+
265
+ # Step 7: Remove duplicates from primary key column
266
+ progress(0.8, desc="Removing duplicates from primary key")
267
+ step_start_time = time.time()
268
+ df = remove_duplicates_from_primary_key(df, primary_key_column)
269
+ print('7) count of valid rows:', df.count())
270
+ process_times['Remove duplicates from primary key'] = time.time() - step_start_time
271
+
272
+ print("Cleaning process completed.")
273
+ print_dataframe_info(df, "Final - ")
274
+
275
+ return df, nonconforming_cells_before, process_times
report.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import Counter
3
+ import numpy as np
4
+ import pandas as pd
5
+ import seaborn as sns
6
+ import matplotlib.pyplot as plt
7
+ from datetime import datetime
8
+
9
+ from pyspark.ml.feature import VectorAssembler
10
+ from pyspark.ml.stat import Correlation
11
+ from pyspark.sql.functions import col, count, when, lit, isnan
12
+ from pyspark.sql.types import DoubleType, IntegerType, LongType, FloatType, StringType, DateType, TimestampType
13
+
14
+ REPORT_DIR = f"cleaning_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
15
+ os.makedirs(REPORT_DIR, exist_ok=True)
16
+
17
+
18
+ def save_plot(fig, filename):
19
+ fig.savefig(os.path.join(REPORT_DIR, filename), dpi=400, bbox_inches='tight')
20
+ plt.close(fig)
21
+
22
+
23
+ def plot_heatmap(df, title):
24
+ # Calculate the percentage of null values for each column
25
+ null_percentages = df.select([
26
+ (100 * count(when(col(c).isNull() | isnan(c), c)) / count('*')).alias(c)
27
+ for c in df.columns
28
+ ]).toPandas()
29
+
30
+ plt.figure(figsize=(12, 8))
31
+ sns.heatmap(null_percentages, cbar=True, cmap='Reds', annot=True, fmt='.1f')
32
+ plt.title(title)
33
+ plt.ylabel('Percentage of Missing Values')
34
+ plt.tight_layout()
35
+ save_plot(plt.gcf(), f'{title.lower().replace(" ", "_")}.png')
36
+
37
+
38
+ def plot_column_schemas(df):
39
+ # Get the data types of all columns
40
+ schema = df.schema
41
+ data_types = []
42
+ for field in schema.fields:
43
+ dtype_name = field.dataType.typeName()
44
+ print(f"Column '{field.name}' has data type '{dtype_name}'")
45
+ data_types.append(dtype_name.capitalize())
46
+
47
+ # Count the occurrences of each data type
48
+ type_counts = Counter(data_types)
49
+
50
+ fig, ax = plt.subplots(figsize=(10, 6))
51
+
52
+ # Generate a color palette with as many colors as there are bars
53
+ colors = plt.cm.tab20(np.linspace(0, 1, len(type_counts)))
54
+
55
+ # Plot the bars
56
+ bars = ax.bar(type_counts.keys(), type_counts.values(), color=colors)
57
+
58
+ ax.set_title('Column Data Types')
59
+ ax.set_xlabel('Data Type')
60
+ ax.set_ylabel('Count')
61
+
62
+ # Add value labels on top of each bar
63
+ for bar in bars:
64
+ height = bar.get_height()
65
+ ax.text(bar.get_x() + bar.get_width() / 2., height,
66
+ f'{int(height)}',
67
+ ha='center', va='bottom')
68
+
69
+ plt.xticks(rotation=45)
70
+ plt.tight_layout()
71
+ save_plot(fig, 'column_schemas.png')
72
+
73
+
74
+ def plot_nonconforming_cells(nonconforming_cells):
75
+ # Ensure that nonconforming_cells is a dictionary
76
+ if isinstance(nonconforming_cells, dict):
77
+ # Proceed with plotting if it's a dictionary
78
+ fig, ax = plt.subplots(figsize=(12, 6))
79
+
80
+ # Generate a color palette with as many colors as there are bars
81
+ colors = plt.cm.rainbow(np.linspace(0, 1, len(nonconforming_cells)))
82
+
83
+ # Plot the bars
84
+ bars = ax.bar(list(nonconforming_cells.keys()), list(nonconforming_cells.values()), color=colors)
85
+
86
+ ax.set_title('Nonconforming Cells by Column')
87
+ ax.set_xlabel('Columns')
88
+ ax.set_ylabel('Number of Nonconforming Cells')
89
+ plt.xticks(rotation=90)
90
+
91
+ # Add value labels on top of each bar
92
+ for bar in bars:
93
+ height = bar.get_height()
94
+ ax.text(bar.get_x() + bar.get_width() / 2., height,
95
+ f'{height:,}',
96
+ ha='center', va='bottom')
97
+
98
+ save_plot(fig, 'nonconforming_cells.png')
99
+ else:
100
+ print(f"Expected nonconforming_cells to be a dictionary, but got {type(nonconforming_cells)}.")
101
+
102
+
103
+ def plot_column_distributions(cleaned_df, primary_key_column):
104
+ print("Plotting distribution charts for numeric columns in the cleaned DataFrame...")
105
+
106
+ def get_numeric_columns(df):
107
+ return [field.name for field in df.schema.fields
108
+ if isinstance(field.dataType, (IntegerType, LongType, FloatType, DoubleType))
109
+ and field.name != primary_key_column]
110
+
111
+ numeric_columns = get_numeric_columns(cleaned_df)
112
+ num_columns = len(numeric_columns)
113
+
114
+ if num_columns == 0:
115
+ print("No numeric columns found in the cleaned DataFrame for distribution plots.")
116
+ return
117
+
118
+ # Create subplots for distributions
119
+ ncols = 3
120
+ nrows = (num_columns + ncols - 1) // ncols # Ceiling division
121
+ fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(18, 5 * nrows))
122
+ axes = axes.flatten() if num_columns > 1 else [axes]
123
+
124
+ for i, column in enumerate(numeric_columns):
125
+ # Convert to pandas for plotting
126
+ cleaned_data = cleaned_df.select(column).toPandas()[column].dropna()
127
+
128
+ sns.histplot(cleaned_data, ax=axes[i], kde=True, color='orange', label='After Cleaning', alpha=0.7)
129
+ axes[i].set_title(f'{column} - Distribution After Cleaning')
130
+ axes[i].legend()
131
+
132
+ # Remove any unused subplots
133
+ for j in range(i + 1, len(axes)):
134
+ fig.delaxes(axes[j])
135
+
136
+ plt.tight_layout()
137
+ save_plot(fig, 'distributions_after_cleaning.png')
138
+
139
+
140
+ def plot_boxplot_with_outliers(original_df, primary_key_column):
141
+ print("Plotting boxplots for numeric columns in the original DataFrame...")
142
+
143
+ def get_numeric_columns(df):
144
+ return [field.name for field in df.schema.fields
145
+ if isinstance(field.dataType, (IntegerType, LongType, FloatType, DoubleType))
146
+ and field.name != primary_key_column]
147
+
148
+ numeric_columns = get_numeric_columns(original_df)
149
+ num_columns = len(numeric_columns)
150
+
151
+ if num_columns == 0:
152
+ print("No numeric columns found in the original DataFrame for boxplots.")
153
+ return
154
+
155
+ # Create subplots based on the number of numeric columns
156
+ ncols = 3
157
+ nrows = (num_columns + ncols - 1) // ncols # Ceiling division
158
+ fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(18, 5 * nrows))
159
+ axes = axes.flatten() if num_columns > 1 else [axes]
160
+
161
+ for i, column in enumerate(numeric_columns):
162
+ # Convert data to pandas for plotting
163
+ data = original_df.select(column).toPandas()[column].dropna()
164
+
165
+ sns.boxplot(x=data, ax=axes[i], color='blue', orient='h')
166
+ axes[i].set_title(f'Boxplot of {column} (Before Cleaning)')
167
+
168
+ # Remove any unused subplots
169
+ for j in range(i + 1, len(axes)):
170
+ fig.delaxes(axes[j])
171
+
172
+ plt.tight_layout()
173
+ save_plot(fig, 'boxplots_before_cleaning.png')
174
+
175
+
176
+ def plot_correlation_heatmap(df, primary_key_column):
177
+ # Select only numeric columns
178
+ numeric_columns = [field.name for field in df.schema.fields
179
+ if isinstance(field.dataType, (IntegerType, LongType, FloatType, DoubleType))
180
+ and field.name != primary_key_column]
181
+
182
+ if not numeric_columns:
183
+ print("No numeric columns found for correlation heatmap.")
184
+ return
185
+
186
+ # Create a vector column of numeric columns
187
+ assembler = VectorAssembler(inputCols=numeric_columns, outputCol="features")
188
+ df_vector = assembler.transform(df).select("features")
189
+
190
+ # Compute correlation matrix
191
+ matrix = Correlation.corr(df_vector, "features").collect()[0][0]
192
+ corr_matrix = matrix.toArray().tolist()
193
+
194
+ # Convert to pandas DataFrame for plotting
195
+ corr_df = pd.DataFrame(corr_matrix, columns=numeric_columns, index=numeric_columns)
196
+
197
+ # Plot the heatmap
198
+ plt.figure(figsize=(15, 10))
199
+ sns.heatmap(corr_df, annot=True, fmt=".2f", cmap='coolwarm', cbar_kws={'label': 'Correlation'})
200
+ plt.title('Correlation Heatmap')
201
+ plt.tight_layout()
202
+ save_plot(plt.gcf(), 'correlation_heatmap.png')
203
+
204
+
205
+ def plot_process_times(process_times):
206
+ # Convert seconds to minutes
207
+ process_times_minutes = {k: v / 60 for k, v in process_times.items()}
208
+
209
+ # Separate main processes and column cleaning processes
210
+ main_processes = {k: v for k, v in process_times_minutes.items() if not k.startswith("Clean column:")}
211
+ column_processes = {k: v for k, v in process_times_minutes.items() if k.startswith("Clean column:")}
212
+
213
+ # Create the plot
214
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10))
215
+
216
+ # Plot main processes
217
+ bars1 = ax1.bar(main_processes.keys(), main_processes.values())
218
+ ax1.set_title('Main Process Times')
219
+ ax1.set_ylabel('Time (minutes)')
220
+ ax1.tick_params(axis='x', rotation=45)
221
+
222
+ # Plot column cleaning processes
223
+ bars2 = ax2.bar(column_processes.keys(), column_processes.values())
224
+ ax2.set_title('Column Cleaning Times')
225
+ ax2.set_ylabel('Time (minutes)')
226
+ ax2.tick_params(axis='x', rotation=90)
227
+
228
+ # Add value labels on top of each bar
229
+ for ax, bars in zip([ax1, ax2], [bars1, bars2]):
230
+ for bar in bars:
231
+ height = bar.get_height()
232
+ ax.text(bar.get_x() + bar.get_width() / 2., height,
233
+ f'{height:.2f}', ha='center', va='bottom')
234
+
235
+ # Add total time to the plot
236
+ total_time = sum(process_times_minutes.values())
237
+ fig.suptitle(f'Process Times (Total: {total_time:.2f} minutes)', fontsize=16)
238
+
239
+ plt.tight_layout()
240
+ save_plot(fig, 'process_times.png')
241
+
242
+
243
+ def create_full_report(original_df, cleaned_df, nonconforming_cells_before, process_times, removed_columns,
244
+ removed_rows, primary_key_column):
245
+ os.makedirs(REPORT_DIR, exist_ok=True)
246
+
247
+ sns.set_style("whitegrid")
248
+ plt.rcParams['figure.dpi'] = 400
249
+
250
+ print("Plotting nonconforming cells before cleaning...")
251
+ plot_nonconforming_cells(nonconforming_cells_before)
252
+
253
+ print("Plotting column distributions...")
254
+ plot_column_distributions(cleaned_df, primary_key_column)
255
+
256
+ print("Plotting boxplots for original data...")
257
+ plot_boxplot_with_outliers(original_df, primary_key_column)
258
+
259
+ print("Plotting process times...")
260
+ plot_process_times(process_times)
261
+
262
+ print("Plotting heatmaps...")
263
+ plot_heatmap(original_df, "Missing Values Before Cleaning")
264
+
265
+ print("Plotting correlation heatmap...")
266
+ plot_correlation_heatmap(cleaned_df, primary_key_column)
267
+
268
+ print("Plotting column schemas...")
269
+ plot_column_schemas(cleaned_df)
270
+
271
+ print(f"All visualization reports saved in directory: {REPORT_DIR}")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy
2
+ pandas
3
+ seaborn
4
+ matplotlib
5
+ pyspark
6
+ gradio