Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| import plotly.graph_objects as go | |
| from nilearn import datasets | |
| from nilearn.connectome import ConnectivityMeasure | |
| from nilearn.maskers import MultiNiftiMapsMasker | |
| import numpy as np | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| device = torch.device("cpu") | |
| # Load the model | |
| try: | |
| scripted_model = torch.jit.load("fmri_encoder_commercial.pt", map_location=device) | |
| if isinstance(scripted_model, torch.nn.DataParallel): | |
| scripted_model = scripted_model.module | |
| scripted_model.to(device) | |
| scripted_model.eval() | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| exit(1) | |
| # Fetch atlas (e.g., DiFuMo) | |
| dim = 64 | |
| try: | |
| difumo = datasets.fetch_atlas_difumo(dimension=dim, resolution_mm=2, legacy_format=False) | |
| atlas_filename = difumo.maps | |
| except Exception as e: | |
| print(f"Error fetching atlas: {str(e)}") | |
| exit(1) | |
| # Create masker | |
| masker = MultiNiftiMapsMasker( | |
| maps_img=atlas_filename, | |
| standardize=True, | |
| n_jobs=-1, | |
| verbose=0 | |
| ) | |
| # Connectivity measure | |
| connectome_measure = ConnectivityMeasure(kind='correlation', vectorize=True, discard_diagonal=True) | |
| # Feature extraction function | |
| def extract_features_multiple(func_preproc_files): | |
| all_features = [] | |
| if not func_preproc_files: | |
| return all_features | |
| print("Fitting masker on the first subject...") | |
| masker.fit(func_preproc_files[0]) | |
| for i, sub in enumerate(func_preproc_files): | |
| print(f"Processing subject {i+1}...") | |
| masked_data = masker.transform(sub) | |
| transformed_data = connectome_measure.fit_transform([masked_data])[0] | |
| all_features.append(transformed_data) | |
| print("All subjects processed.") | |
| return all_features | |
| # Function to generate a Plotly probability plot | |
| def plot_probability(probability): | |
| labels = ["No Autism", "Autism"] | |
| probs = [1 - probability, probability] | |
| colors = ["#6a0dad", "#d896ff"] # Dark purple and light purple | |
| fig = go.Figure() | |
| fig.add_trace(go.Bar( | |
| x=labels, | |
| y=probs, | |
| marker=dict(color=colors), | |
| text=[f"{(1-probability)*100:.1f}%", f"{probability*100:.1f}%"], | |
| textposition="auto", | |
| )) | |
| fig.update_layout( | |
| title="Autism Prediction Probability", | |
| paper_bgcolor="black", | |
| plot_bgcolor="black", | |
| font=dict(color="white"), | |
| xaxis=dict(title="Diagnosis", showgrid=False), | |
| yaxis=dict(title="Probability", showgrid=True, gridcolor="gray"), | |
| ) | |
| return fig | |
| # Prediction function | |
| def predict_autism(fmri_files, age, gender): | |
| try: | |
| if not fmri_files: | |
| return "Please upload at least one valid .nii.gz file.", None | |
| features_list = extract_features_multiple(fmri_files) | |
| if not features_list: | |
| return "Error: Failed to extract features from the fMRI files.", None | |
| age_tensor = torch.tensor([float(age)], dtype=torch.float32).to(device) | |
| gender_tensor = torch.tensor([int(gender)], dtype=torch.long).to(device) | |
| predictions = [] | |
| plots = [] | |
| for features in features_list: | |
| features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| prediction = scripted_model(features_tensor, age_tensor, gender_tensor) | |
| probability = torch.sigmoid(prediction).item() | |
| result = f"Prediction: {'Autism' if probability > 0.5 else 'No Autism'} (Confidence: {probability:.2%})" | |
| predictions.append(result) | |
| # Generate Plotly probability plot | |
| plots.append(plot_probability(probability)) | |
| return "\n".join(predictions), plots[0] # Return text and Plotly figure | |
| except Exception as e: | |
| return f"Error: {str(e)}", None | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=predict_autism, | |
| inputs=[ | |
| gr.File(label="Upload preprocessed fMRI files (.nii.gz)", file_count="multiple"), | |
| gr.Number(label="Age", minimum=0, maximum=120), | |
| gr.Radio(["0", "1"], label="Gender (0: Female, 1: Male)"), | |
| ], | |
| outputs=[ | |
| gr.Text(label="Prediction Result"), | |
| gr.Plot(label="Prediction Probability Plot"), | |
| ], | |
| title="Autism Prediction from fMRI Data", | |
| description="Upload one or more preprocessed fMRI files (.nii.gz) and enter the subject's age and gender to predict autism.", | |
| theme="default", | |
| flagging_mode="never" | |
| ) | |
| iface.launch() | |