Spaces:
Sleeping
Sleeping
File size: 6,067 Bytes
04db63a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import gradio as gr
import pandas as pd
import numpy as np
import os
from io import BytesIO
# Dynamically generate synthetic data to match ANY input schema
def generate_synthetic_like(real_df, num_records=100):
"""Generate synthetic data with same columns and dtypes as real_df"""
synthetic = pd.DataFrame()
for col in real_df.columns:
col_data = real_df[col].dropna()
if len(col_data) == 0:
synthetic[col] = ["Unknown"] * num_records
continue
if pd.api.types.is_numeric_dtype(col_data):
if pd.api.types.is_integer_dtype(col_data):
low, high = int(col_data.min()), int(col_data.max())
synthetic[col] = np.random.randint(low, high + 1, size=num_records)
else:
low, high = float(col_data.min()), float(col_data.max())
synthetic[col] = np.random.uniform(low, high, size=num_records)
else:
# Categorical or string
values = col_data.astype(str).tolist()
synthetic[col] = np.random.choice(values, size=num_records)
return synthetic
# Membership Inference Risk Detection
def check_membership_inference_risk(real_data, synthetic_data, target_col=None):
try:
# Use first non-ID, non-trivial column as target if not found
if target_col is None:
for col in real_data.columns:
if col.lower() not in ['id', 'patient_id', 'record_id'] and len(real_data[col].dropna()) > 0:
target_col = col
break
if target_col is None:
target_col = real_data.columns[0]
real_data = real_data.copy()
synthetic_data = synthetic_data.copy()
real_data['membership'] = 1
synthetic_data['membership'] = 0
# Combine half real + all synthetic
combined = pd.concat([
real_data.sample(frac=0.5, random_state=42),
synthetic_data
], ignore_index=True)
# Drop target and membership
X = combined.drop(columns=[target_col, 'membership'], errors='ignore')
y = combined['membership']
# One-hot encode categorical features
X = pd.get_dummies(X, max_categories=10)
if X.empty or len(X.columns) == 0:
return False # Not enough features
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
model = RandomForestClassifier(n_estimators=50, max_depth=5, random_state=42)
model.fit(X_train, y_train)
acc = accuracy_score(y_test, model.predict(X_test))
print(f"Membership Inference Attack Accuracy: {acc:.3f}")
return acc < 0.6 # Safe if below 60%
except Exception as e:
print(f"Privacy check error: {e}")
return False # Default to unsafe on error
# Universal file loader with auto-detection
def load_data(file_path):
with open(file_path, 'rb') as f:
header = f.read(16)
# Detect Excel files by magic bytes
if header.startswith(b'\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1'): # Older .xls
return pd.read_excel(file_path, engine='xlrd')
elif header.startswith(b'PK\x03\x04'): # .xlsx or .zip-based
return pd.read_excel(file_path, engine='openpyxl')
else:
# Try CSV with fallback encodings
encodings = ['utf-8', 'latin1', 'cp1252', 'ISO-8859-1']
for enc in encodings:
try:
return pd.read_csv(file_path, encoding=enc)
except Exception:
continue
raise ValueError(f"Could not decode file with any encoding: {encodings}")
# Main processing function
def process_ehr(file):
try:
if file is None:
return pd.DataFrame(), "Please upload a file.", None, None
# Load data (auto-detect format & encoding)
real_df = load_data(file.name)
if real_df.empty:
return pd.DataFrame(), "Uploaded file is empty.", None, None
# Generate synthetic data with same schema
synthetic_df = generate_synthetic_like(real_df, num_records=100)
# Check privacy risk
is_safe = check_membership_inference_risk(real_df, synthetic_df)
risk_msg = " Synthetic EHR shows low membership inference risk." if is_safe else "⚠️ High risk: Synthetic data may leak sensitive info."
# Save outputs
synthetic_df.to_csv("synthetic_ehr.csv", index=False)
synthetic_df.to_json("synthetic_ehr.json", orient="records", indent=2)
return synthetic_df, risk_msg, "synthetic_ehr.json", "synthetic_ehr.csv"
except Exception as e:
return pd.DataFrame(), f" Error processing file: {str(e)}", None, None
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft(), title="MedSynth – Universal Synthetic EHR") as demo:
gr.Markdown("""
# MedSynth Universal Privacy-Preserving Synthetic Data Generator
Upload **any tabular EHR or patient dataset** (CSV, Excel) — we'll generate realistic synthetic data and check for privacy leaks.
""")
with gr.Row():
upload = gr.File(label=" Upload Real EHR (CSV or Excel)", file_types=[".csv", ".xls", ".xlsx"])
btn = gr.Button(" Generate Synthetic Data", variant="primary")
with gr.Row():
output_df = gr.Dataframe(label=" Synthetic Data Sample", interactive=False)
with gr.Row():
risk_text = gr.Textbox(label=" Privacy Risk Status")
with gr.Row():
json_out = gr.File(label=" Download JSON")
csv_out = gr.File(label=" Download CSV")
# Event handler
btn.click(
fn=process_ehr,
inputs=[upload],
outputs=[output_df, risk_text, json_out, csv_out]
)
# Launch
if __name__ == "__main__":
demo.launch(share=True)
|