File size: 6,067 Bytes
04db63a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

import gradio as gr
import pandas as pd
import numpy as np
import os
from io import BytesIO

# Dynamically generate synthetic data to match ANY input schema
def generate_synthetic_like(real_df, num_records=100):
    """Generate synthetic data with same columns and dtypes as real_df"""
    synthetic = pd.DataFrame()
    
    for col in real_df.columns:
        col_data = real_df[col].dropna()
        if len(col_data) == 0:
            synthetic[col] = ["Unknown"] * num_records
            continue
            
        if pd.api.types.is_numeric_dtype(col_data):
            if pd.api.types.is_integer_dtype(col_data):
                low, high = int(col_data.min()), int(col_data.max())
                synthetic[col] = np.random.randint(low, high + 1, size=num_records)
            else:
                low, high = float(col_data.min()), float(col_data.max())
                synthetic[col] = np.random.uniform(low, high, size=num_records)
        else:
            # Categorical or string
            values = col_data.astype(str).tolist()
            synthetic[col] = np.random.choice(values, size=num_records)
    
    return synthetic

# Membership Inference Risk Detection
def check_membership_inference_risk(real_data, synthetic_data, target_col=None):
    try:
        # Use first non-ID, non-trivial column as target if not found
        if target_col is None:
            for col in real_data.columns:
                if col.lower() not in ['id', 'patient_id', 'record_id'] and len(real_data[col].dropna()) > 0:
                    target_col = col
                    break
            if target_col is None:
                target_col = real_data.columns[0]

        real_data = real_data.copy()
        synthetic_data = synthetic_data.copy()
        real_data['membership'] = 1
        synthetic_data['membership'] = 0

        # Combine half real + all synthetic
        combined = pd.concat([
            real_data.sample(frac=0.5, random_state=42),
            synthetic_data
        ], ignore_index=True)

        # Drop target and membership
        X = combined.drop(columns=[target_col, 'membership'], errors='ignore')
        y = combined['membership']

        # One-hot encode categorical features
        X = pd.get_dummies(X, max_categories=10)

        if X.empty or len(X.columns) == 0:
            return False  # Not enough features

        from sklearn.model_selection import train_test_split
        from sklearn.ensemble import RandomForestClassifier
        from sklearn.metrics import accuracy_score

        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
        model = RandomForestClassifier(n_estimators=50, max_depth=5, random_state=42)
        model.fit(X_train, y_train)
        acc = accuracy_score(y_test, model.predict(X_test))
        
        print(f"Membership Inference Attack Accuracy: {acc:.3f}")
        return acc < 0.6  # Safe if below 60%
    
    except Exception as e:
        print(f"Privacy check error: {e}")
        return False  # Default to unsafe on error

# Universal file loader with auto-detection
def load_data(file_path):
    with open(file_path, 'rb') as f:
        header = f.read(16)
    
    # Detect Excel files by magic bytes
    if header.startswith(b'\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1'):  # Older .xls
        return pd.read_excel(file_path, engine='xlrd')
    elif header.startswith(b'PK\x03\x04'):  # .xlsx or .zip-based
        return pd.read_excel(file_path, engine='openpyxl')
    else:
        # Try CSV with fallback encodings
        encodings = ['utf-8', 'latin1', 'cp1252', 'ISO-8859-1']
        for enc in encodings:
            try:
                return pd.read_csv(file_path, encoding=enc)
            except Exception:
                continue
        raise ValueError(f"Could not decode file with any encoding: {encodings}")

# Main processing function
def process_ehr(file):
    try:
        if file is None:
            return pd.DataFrame(), "Please upload a file.", None, None
        
        # Load data (auto-detect format & encoding)
        real_df = load_data(file.name)
        
        if real_df.empty:
            return pd.DataFrame(), "Uploaded file is empty.", None, None
        
        # Generate synthetic data with same schema
        synthetic_df = generate_synthetic_like(real_df, num_records=100)
        
        # Check privacy risk
        is_safe = check_membership_inference_risk(real_df, synthetic_df)
        risk_msg = " Synthetic EHR shows low membership inference risk." if is_safe else "⚠️ High risk: Synthetic data may leak sensitive info."

        # Save outputs
        synthetic_df.to_csv("synthetic_ehr.csv", index=False)
        synthetic_df.to_json("synthetic_ehr.json", orient="records", indent=2)

        return synthetic_df, risk_msg, "synthetic_ehr.json", "synthetic_ehr.csv"
    
    except Exception as e:
        return pd.DataFrame(), f" Error processing file: {str(e)}", None, None

# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft(), title="MedSynth – Universal Synthetic EHR") as demo:
    gr.Markdown("""
    #  MedSynth  Universal Privacy-Preserving Synthetic Data Generator
    Upload **any tabular EHR or patient dataset** (CSV, Excel) — we'll generate realistic synthetic data and check for privacy leaks.
    """)
    
    with gr.Row():
        upload = gr.File(label=" Upload Real EHR (CSV or Excel)", file_types=[".csv", ".xls", ".xlsx"])
        btn = gr.Button(" Generate Synthetic Data", variant="primary")
    
    with gr.Row():
        output_df = gr.Dataframe(label=" Synthetic Data Sample", interactive=False)
    
    with gr.Row():
        risk_text = gr.Textbox(label=" Privacy Risk Status")
    
    with gr.Row():
        json_out = gr.File(label=" Download JSON")
        csv_out = gr.File(label=" Download CSV")
    
    # Event handler
    btn.click(
        fn=process_ehr,
        inputs=[upload],
        outputs=[output_df, risk_text, json_out, csv_out]
    )

# Launch
if __name__ == "__main__":
    demo.launch(share=True)