STREAMLITE / irm_cancer_module.py
Stroke-ia's picture
Update irm_cancer_module.py
4b3a3b2 verified
import os
import torch
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
class IRMCancerModule:
def __init__(self, model_path="unet3d_model.pth", device=None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model = torch.load(model_path, map_location=self.device)
self.model.eval()
def load_and_preprocess(self, nii_paths):
"""Charge et normalise 4 séquences IRM (FLAIR, T1, T1CE, T2)."""
images = []
for path in nii_paths:
img = nib.load(path).get_fdata()
img = self.normalize(img)
images.append(img)
# Empile en [C, D, H, W]
img = np.stack(images, axis=0)
# Ajoute batch → [1, 4, D, H, W]
img = np.expand_dims(img, axis=0)
return torch.tensor(img, dtype=torch.float32).to(self.device)
def normalize(self, data: np.ndarray):
data_min, data_max = data.min(), data.max()
return (data - data_min) / (data_max - data_min + 1e-8)
def predict(self, nii_paths):
"""Retourne la segmentation prédite et le masque numpy."""
x = self.load_and_preprocess(nii_paths)
with torch.no_grad():
y_pred = self.model(x) # [1, 3, D, H, W]
y_pred = torch.sigmoid(y_pred)
y_pred = (y_pred > 0.5).float()
return y_pred.squeeze(0).cpu().numpy()
def save_mask(self, mask, out_path="pred_mask.png"):
"""Sauvegarde un aperçu du masque sur une coupe médiane."""
mid_slice = mask[0, mask.shape[1] // 2, :, :]
plt.imshow(mid_slice, cmap="gray")
plt.title("Segmentation prédite (slice médiane)")
plt.axis("off")
plt.savefig(out_path, bbox_inches="tight")
return out_path
def run(self, tmp_paths):
"""Pipeline complet appelé par Streamlit."""
mask = self.predict(tmp_paths)
mask_path = self.save_mask(mask)
report_text = (
"Rapport automatique : détection effectuée.\n"
f"- Entrées : {len(tmp_paths)} séquences IRM\n"
f"- Sortie : masque avec {mask.shape[0]} classes\n"
)
return mask, report_text, (tmp_paths, "report.txt", mask_path)
# Test local (désactiver si exécuté via Streamlit)
if __name__ == "__main__":
module = IRMCancerModule("unet3d_model.pth")
test_files = [
"BraTS20_Training_001_flair.nii",
"BraTS20_Training_001_t1.nii",
"BraTS20_Training_001_t1ce.nii",
"BraTS20_Training_001_t2.nii",
]
mask, report, paths = module.run(test_files)
print(report)
print("Mask saved at:", paths[2])