Med_Synth_AI / app.py
solfedge's picture
Upload app.py
04db63a verified
import gradio as gr
import pandas as pd
import numpy as np
import os
from io import BytesIO
# Dynamically generate synthetic data to match ANY input schema
def generate_synthetic_like(real_df, num_records=100):
"""Generate synthetic data with same columns and dtypes as real_df"""
synthetic = pd.DataFrame()
for col in real_df.columns:
col_data = real_df[col].dropna()
if len(col_data) == 0:
synthetic[col] = ["Unknown"] * num_records
continue
if pd.api.types.is_numeric_dtype(col_data):
if pd.api.types.is_integer_dtype(col_data):
low, high = int(col_data.min()), int(col_data.max())
synthetic[col] = np.random.randint(low, high + 1, size=num_records)
else:
low, high = float(col_data.min()), float(col_data.max())
synthetic[col] = np.random.uniform(low, high, size=num_records)
else:
# Categorical or string
values = col_data.astype(str).tolist()
synthetic[col] = np.random.choice(values, size=num_records)
return synthetic
# Membership Inference Risk Detection
def check_membership_inference_risk(real_data, synthetic_data, target_col=None):
try:
# Use first non-ID, non-trivial column as target if not found
if target_col is None:
for col in real_data.columns:
if col.lower() not in ['id', 'patient_id', 'record_id'] and len(real_data[col].dropna()) > 0:
target_col = col
break
if target_col is None:
target_col = real_data.columns[0]
real_data = real_data.copy()
synthetic_data = synthetic_data.copy()
real_data['membership'] = 1
synthetic_data['membership'] = 0
# Combine half real + all synthetic
combined = pd.concat([
real_data.sample(frac=0.5, random_state=42),
synthetic_data
], ignore_index=True)
# Drop target and membership
X = combined.drop(columns=[target_col, 'membership'], errors='ignore')
y = combined['membership']
# One-hot encode categorical features
X = pd.get_dummies(X, max_categories=10)
if X.empty or len(X.columns) == 0:
return False # Not enough features
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
model = RandomForestClassifier(n_estimators=50, max_depth=5, random_state=42)
model.fit(X_train, y_train)
acc = accuracy_score(y_test, model.predict(X_test))
print(f"Membership Inference Attack Accuracy: {acc:.3f}")
return acc < 0.6 # Safe if below 60%
except Exception as e:
print(f"Privacy check error: {e}")
return False # Default to unsafe on error
# Universal file loader with auto-detection
def load_data(file_path):
with open(file_path, 'rb') as f:
header = f.read(16)
# Detect Excel files by magic bytes
if header.startswith(b'\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1'): # Older .xls
return pd.read_excel(file_path, engine='xlrd')
elif header.startswith(b'PK\x03\x04'): # .xlsx or .zip-based
return pd.read_excel(file_path, engine='openpyxl')
else:
# Try CSV with fallback encodings
encodings = ['utf-8', 'latin1', 'cp1252', 'ISO-8859-1']
for enc in encodings:
try:
return pd.read_csv(file_path, encoding=enc)
except Exception:
continue
raise ValueError(f"Could not decode file with any encoding: {encodings}")
# Main processing function
def process_ehr(file):
try:
if file is None:
return pd.DataFrame(), "Please upload a file.", None, None
# Load data (auto-detect format & encoding)
real_df = load_data(file.name)
if real_df.empty:
return pd.DataFrame(), "Uploaded file is empty.", None, None
# Generate synthetic data with same schema
synthetic_df = generate_synthetic_like(real_df, num_records=100)
# Check privacy risk
is_safe = check_membership_inference_risk(real_df, synthetic_df)
risk_msg = " Synthetic EHR shows low membership inference risk." if is_safe else "⚠️ High risk: Synthetic data may leak sensitive info."
# Save outputs
synthetic_df.to_csv("synthetic_ehr.csv", index=False)
synthetic_df.to_json("synthetic_ehr.json", orient="records", indent=2)
return synthetic_df, risk_msg, "synthetic_ehr.json", "synthetic_ehr.csv"
except Exception as e:
return pd.DataFrame(), f" Error processing file: {str(e)}", None, None
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft(), title="MedSynth – Universal Synthetic EHR") as demo:
gr.Markdown("""
# MedSynth Universal Privacy-Preserving Synthetic Data Generator
Upload **any tabular EHR or patient dataset** (CSV, Excel) — we'll generate realistic synthetic data and check for privacy leaks.
""")
with gr.Row():
upload = gr.File(label=" Upload Real EHR (CSV or Excel)", file_types=[".csv", ".xls", ".xlsx"])
btn = gr.Button(" Generate Synthetic Data", variant="primary")
with gr.Row():
output_df = gr.Dataframe(label=" Synthetic Data Sample", interactive=False)
with gr.Row():
risk_text = gr.Textbox(label=" Privacy Risk Status")
with gr.Row():
json_out = gr.File(label=" Download JSON")
csv_out = gr.File(label=" Download CSV")
# Event handler
btn.click(
fn=process_ehr,
inputs=[upload],
outputs=[output_df, risk_text, json_out, csv_out]
)
# Launch
if __name__ == "__main__":
demo.launch(share=True)