Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import os | |
| from io import BytesIO | |
| # Dynamically generate synthetic data to match ANY input schema | |
| def generate_synthetic_like(real_df, num_records=100): | |
| """Generate synthetic data with same columns and dtypes as real_df""" | |
| synthetic = pd.DataFrame() | |
| for col in real_df.columns: | |
| col_data = real_df[col].dropna() | |
| if len(col_data) == 0: | |
| synthetic[col] = ["Unknown"] * num_records | |
| continue | |
| if pd.api.types.is_numeric_dtype(col_data): | |
| if pd.api.types.is_integer_dtype(col_data): | |
| low, high = int(col_data.min()), int(col_data.max()) | |
| synthetic[col] = np.random.randint(low, high + 1, size=num_records) | |
| else: | |
| low, high = float(col_data.min()), float(col_data.max()) | |
| synthetic[col] = np.random.uniform(low, high, size=num_records) | |
| else: | |
| # Categorical or string | |
| values = col_data.astype(str).tolist() | |
| synthetic[col] = np.random.choice(values, size=num_records) | |
| return synthetic | |
| # Membership Inference Risk Detection | |
| def check_membership_inference_risk(real_data, synthetic_data, target_col=None): | |
| try: | |
| # Use first non-ID, non-trivial column as target if not found | |
| if target_col is None: | |
| for col in real_data.columns: | |
| if col.lower() not in ['id', 'patient_id', 'record_id'] and len(real_data[col].dropna()) > 0: | |
| target_col = col | |
| break | |
| if target_col is None: | |
| target_col = real_data.columns[0] | |
| real_data = real_data.copy() | |
| synthetic_data = synthetic_data.copy() | |
| real_data['membership'] = 1 | |
| synthetic_data['membership'] = 0 | |
| # Combine half real + all synthetic | |
| combined = pd.concat([ | |
| real_data.sample(frac=0.5, random_state=42), | |
| synthetic_data | |
| ], ignore_index=True) | |
| # Drop target and membership | |
| X = combined.drop(columns=[target_col, 'membership'], errors='ignore') | |
| y = combined['membership'] | |
| # One-hot encode categorical features | |
| X = pd.get_dummies(X, max_categories=10) | |
| if X.empty or len(X.columns) == 0: | |
| return False # Not enough features | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.metrics import accuracy_score | |
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) | |
| model = RandomForestClassifier(n_estimators=50, max_depth=5, random_state=42) | |
| model.fit(X_train, y_train) | |
| acc = accuracy_score(y_test, model.predict(X_test)) | |
| print(f"Membership Inference Attack Accuracy: {acc:.3f}") | |
| return acc < 0.6 # Safe if below 60% | |
| except Exception as e: | |
| print(f"Privacy check error: {e}") | |
| return False # Default to unsafe on error | |
| # Universal file loader with auto-detection | |
| def load_data(file_path): | |
| with open(file_path, 'rb') as f: | |
| header = f.read(16) | |
| # Detect Excel files by magic bytes | |
| if header.startswith(b'\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1'): # Older .xls | |
| return pd.read_excel(file_path, engine='xlrd') | |
| elif header.startswith(b'PK\x03\x04'): # .xlsx or .zip-based | |
| return pd.read_excel(file_path, engine='openpyxl') | |
| else: | |
| # Try CSV with fallback encodings | |
| encodings = ['utf-8', 'latin1', 'cp1252', 'ISO-8859-1'] | |
| for enc in encodings: | |
| try: | |
| return pd.read_csv(file_path, encoding=enc) | |
| except Exception: | |
| continue | |
| raise ValueError(f"Could not decode file with any encoding: {encodings}") | |
| # Main processing function | |
| def process_ehr(file): | |
| try: | |
| if file is None: | |
| return pd.DataFrame(), "Please upload a file.", None, None | |
| # Load data (auto-detect format & encoding) | |
| real_df = load_data(file.name) | |
| if real_df.empty: | |
| return pd.DataFrame(), "Uploaded file is empty.", None, None | |
| # Generate synthetic data with same schema | |
| synthetic_df = generate_synthetic_like(real_df, num_records=100) | |
| # Check privacy risk | |
| is_safe = check_membership_inference_risk(real_df, synthetic_df) | |
| risk_msg = " Synthetic EHR shows low membership inference risk." if is_safe else "⚠️ High risk: Synthetic data may leak sensitive info." | |
| # Save outputs | |
| synthetic_df.to_csv("synthetic_ehr.csv", index=False) | |
| synthetic_df.to_json("synthetic_ehr.json", orient="records", indent=2) | |
| return synthetic_df, risk_msg, "synthetic_ehr.json", "synthetic_ehr.csv" | |
| except Exception as e: | |
| return pd.DataFrame(), f" Error processing file: {str(e)}", None, None | |
| # Gradio Interface | |
| with gr.Blocks(theme=gr.themes.Soft(), title="MedSynth – Universal Synthetic EHR") as demo: | |
| gr.Markdown(""" | |
| # MedSynth Universal Privacy-Preserving Synthetic Data Generator | |
| Upload **any tabular EHR or patient dataset** (CSV, Excel) — we'll generate realistic synthetic data and check for privacy leaks. | |
| """) | |
| with gr.Row(): | |
| upload = gr.File(label=" Upload Real EHR (CSV or Excel)", file_types=[".csv", ".xls", ".xlsx"]) | |
| btn = gr.Button(" Generate Synthetic Data", variant="primary") | |
| with gr.Row(): | |
| output_df = gr.Dataframe(label=" Synthetic Data Sample", interactive=False) | |
| with gr.Row(): | |
| risk_text = gr.Textbox(label=" Privacy Risk Status") | |
| with gr.Row(): | |
| json_out = gr.File(label=" Download JSON") | |
| csv_out = gr.File(label=" Download CSV") | |
| # Event handler | |
| btn.click( | |
| fn=process_ehr, | |
| inputs=[upload], | |
| outputs=[output_df, risk_text, json_out, csv_out] | |
| ) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |