Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import numpy as np | |
| import nibabel as nib | |
| import matplotlib.pyplot as plt | |
| class IRMCancerModule: | |
| def __init__(self, model_path="unet3d_model.pth", device=None): | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = torch.load(model_path, map_location=self.device) | |
| self.model.eval() | |
| def load_and_preprocess(self, nii_paths): | |
| """Charge et normalise 4 séquences IRM (FLAIR, T1, T1CE, T2).""" | |
| images = [] | |
| for path in nii_paths: | |
| img = nib.load(path).get_fdata() | |
| img = self.normalize(img) | |
| images.append(img) | |
| # Empile en [C, D, H, W] | |
| img = np.stack(images, axis=0) | |
| # Ajoute batch → [1, 4, D, H, W] | |
| img = np.expand_dims(img, axis=0) | |
| return torch.tensor(img, dtype=torch.float32).to(self.device) | |
| def normalize(self, data: np.ndarray): | |
| data_min, data_max = data.min(), data.max() | |
| return (data - data_min) / (data_max - data_min + 1e-8) | |
| def predict(self, nii_paths): | |
| """Retourne la segmentation prédite et le masque numpy.""" | |
| x = self.load_and_preprocess(nii_paths) | |
| with torch.no_grad(): | |
| y_pred = self.model(x) # [1, 3, D, H, W] | |
| y_pred = torch.sigmoid(y_pred) | |
| y_pred = (y_pred > 0.5).float() | |
| return y_pred.squeeze(0).cpu().numpy() | |
| def save_mask(self, mask, out_path="pred_mask.png"): | |
| """Sauvegarde un aperçu du masque sur une coupe médiane.""" | |
| mid_slice = mask[0, mask.shape[1] // 2, :, :] | |
| plt.imshow(mid_slice, cmap="gray") | |
| plt.title("Segmentation prédite (slice médiane)") | |
| plt.axis("off") | |
| plt.savefig(out_path, bbox_inches="tight") | |
| return out_path | |
| def run(self, tmp_paths): | |
| """Pipeline complet appelé par Streamlit.""" | |
| mask = self.predict(tmp_paths) | |
| mask_path = self.save_mask(mask) | |
| report_text = ( | |
| "Rapport automatique : détection effectuée.\n" | |
| f"- Entrées : {len(tmp_paths)} séquences IRM\n" | |
| f"- Sortie : masque avec {mask.shape[0]} classes\n" | |
| ) | |
| return mask, report_text, (tmp_paths, "report.txt", mask_path) | |
| # Test local (désactiver si exécuté via Streamlit) | |
| if __name__ == "__main__": | |
| module = IRMCancerModule("unet3d_model.pth") | |
| test_files = [ | |
| "BraTS20_Training_001_flair.nii", | |
| "BraTS20_Training_001_t1.nii", | |
| "BraTS20_Training_001_t1ce.nii", | |
| "BraTS20_Training_001_t2.nii", | |
| ] | |
| mask, report, paths = module.run(test_files) | |
| print(report) | |
| print("Mask saved at:", paths[2]) | |