import torch import gradio as gr import plotly.graph_objects as go from nilearn import datasets from nilearn.connectome import ConnectivityMeasure from nilearn.maskers import MultiNiftiMapsMasker import numpy as np # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cpu") # Load the model try: scripted_model = torch.jit.load("fmri_encoder_commercial.pt", map_location=device) if isinstance(scripted_model, torch.nn.DataParallel): scripted_model = scripted_model.module scripted_model.to(device) scripted_model.eval() except Exception as e: print(f"Error loading model: {str(e)}") exit(1) # Fetch atlas (e.g., DiFuMo) dim = 64 try: difumo = datasets.fetch_atlas_difumo(dimension=dim, resolution_mm=2, legacy_format=False) atlas_filename = difumo.maps except Exception as e: print(f"Error fetching atlas: {str(e)}") exit(1) # Create masker masker = MultiNiftiMapsMasker( maps_img=atlas_filename, standardize=True, n_jobs=-1, verbose=0 ) # Connectivity measure connectome_measure = ConnectivityMeasure(kind='correlation', vectorize=True, discard_diagonal=True) # Feature extraction function def extract_features_multiple(func_preproc_files): all_features = [] if not func_preproc_files: return all_features print("Fitting masker on the first subject...") masker.fit(func_preproc_files[0]) for i, sub in enumerate(func_preproc_files): print(f"Processing subject {i+1}...") masked_data = masker.transform(sub) transformed_data = connectome_measure.fit_transform([masked_data])[0] all_features.append(transformed_data) print("All subjects processed.") return all_features # Function to generate a Plotly probability plot def plot_probability(probability): labels = ["No Autism", "Autism"] probs = [1 - probability, probability] colors = ["#6a0dad", "#d896ff"] # Dark purple and light purple fig = go.Figure() fig.add_trace(go.Bar( x=labels, y=probs, marker=dict(color=colors), text=[f"{(1-probability)*100:.1f}%", f"{probability*100:.1f}%"], textposition="auto", )) fig.update_layout( title="Autism Prediction Probability", paper_bgcolor="black", plot_bgcolor="black", font=dict(color="white"), xaxis=dict(title="Diagnosis", showgrid=False), yaxis=dict(title="Probability", showgrid=True, gridcolor="gray"), ) return fig # Prediction function def predict_autism(fmri_files, age, gender): try: if not fmri_files: return "Please upload at least one valid .nii.gz file.", None features_list = extract_features_multiple(fmri_files) if not features_list: return "Error: Failed to extract features from the fMRI files.", None age_tensor = torch.tensor([float(age)], dtype=torch.float32).to(device) gender_tensor = torch.tensor([int(gender)], dtype=torch.long).to(device) predictions = [] plots = [] for features in features_list: features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device) with torch.no_grad(): prediction = scripted_model(features_tensor, age_tensor, gender_tensor) probability = torch.sigmoid(prediction).item() result = f"Prediction: {'Autism' if probability > 0.5 else 'No Autism'} (Confidence: {probability:.2%})" predictions.append(result) # Generate Plotly probability plot plots.append(plot_probability(probability)) return "\n".join(predictions), plots[0] # Return text and Plotly figure except Exception as e: return f"Error: {str(e)}", None # Gradio interface iface = gr.Interface( fn=predict_autism, inputs=[ gr.File(label="Upload preprocessed fMRI files (.nii.gz)", file_count="multiple"), gr.Number(label="Age", minimum=0, maximum=120), gr.Radio(["0", "1"], label="Gender (0: Female, 1: Male)"), ], outputs=[ gr.Text(label="Prediction Result"), gr.Plot(label="Prediction Probability Plot"), ], title="Autism Prediction from fMRI Data", description="Upload one or more preprocessed fMRI files (.nii.gz) and enter the subject's age and gender to predict autism.", theme="default", flagging_mode="never" ) iface.launch()