Spaces:
Running
Running
File size: 7,123 Bytes
48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 01a4dbe 48c35f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import io
import base64
from model import UNet, marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler
class CompatibleUNet(UNet):
"""A UNet model that's compatible with saved weights (handles 1-channel input)."""
def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256, 512], embed_dim=256,
embed_dim_mask=256, input_dim_mask=1*256*256):
super().__init__(marginal_prob_std, channels, embed_dim, embed_dim_mask, input_dim_mask)
# Accept 1-channel input
self.conv1 = torch.nn.Conv2d(1, channels[0], 3, stride=2, bias=False, padding=1)
if hasattr(self, 'tconv0'):
self.tconv0 = torch.nn.ConvTranspose2d(channels[0], 1, 3, stride=1, padding=1, output_padding=0)
class HFDiffusionService:
"""Handles loading the conditional diffusion model and generating CT images."""
def __init__(self):
cuda_available = torch.cuda.is_available()
print(f"CUDA available for HF diffusion: {cuda_available}")
if not cuda_available:
print("⚠ Warning: CUDA is not available. Using CPU (this will be slow).")
self.device = torch.device('cuda:0' if cuda_available else 'cpu')
self.Lambda = 25.0
self.marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=self.Lambda, device=self.device)
self.diffusion_coeff_fn = lambda t: diffusion_coeff(t, Lambda=self.Lambda, device=self.device)
# Model path (make sure pytorch_model.bin is present)
self.model_path = "pytorch_model.bin"
self.input_channels = 1
self.input_dim_mask = 65536
# Load model
self._load_model()
def _load_model(self):
try:
print(f"Loading diffusion model from: {self.model_path}")
state_dict = torch.load(self.model_path, map_location=self.device)
conv1_weight = state_dict.get('conv1.weight', None)
cond_embed_weight = state_dict.get('cond_embed.1.weight', None)
if conv1_weight is not None:
self.input_channels = conv1_weight.shape[1]
print(f"Detected input channels: {self.input_channels}")
if cond_embed_weight is not None:
self.input_dim_mask = cond_embed_weight.shape[1]
print(f"Detected input_dim_mask: {self.input_dim_mask}")
# Initialize compatible UNet
if self.input_channels == 1 and self.input_dim_mask == 65536:
self.score_model = CompatibleUNet(
marginal_prob_std=self.marginal_prob_std_fn,
input_dim_mask=self.input_dim_mask
)
else:
self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
self.score_model.load_state_dict(state_dict)
self.score_model.to(self.device)
self.score_model.eval()
print(f"✅ HF Diffusion model loaded successfully\n Input channels: {self.input_channels}, Mask dim: {self.input_dim_mask}")
except Exception as e:
print(f"❌ Error loading HF diffusion model: {e}")
raise e
def generate_image(self, mask):
"""
Generate a CT image from a segmentation mask and return it as PIL Image.
"""
try:
processed_mask = self._process_mask(mask)
tensor_image = self._generate_from_mask(processed_mask)
return self._tensor_to_image(tensor_image)
except Exception as e:
print(f"❌ Error generating image: {e}")
return None
def generate_image_base64(self, mask):
"""
Generate a CT image and return it as a base64 string (data URI).
"""
image = self.generate_image(mask)
if image is None:
return None
buffer = io.BytesIO()
image.save(buffer, format="PNG")
base64_img = base64.b64encode(buffer.getvalue()).decode("utf-8")
return f"data:image/png;base64,{base64_img}"
def _process_mask(self, mask):
"""
Convert input mask (PIL, np.array, or tensor) into model-ready tensor.
"""
try:
if isinstance(mask, Image.Image):
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize((256, 256), antialias=True),
transforms.ToTensor()
])
tensor = transform(mask).unsqueeze(0) # [1, 1, 256, 256]
elif isinstance(mask, np.ndarray):
if mask.ndim == 2:
mask = mask[np.newaxis, :, :]
tensor = torch.from_numpy(mask).float()
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0) # [1, 1, 256, 256]
elif isinstance(mask, torch.Tensor):
tensor = mask
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0)
else:
raise ValueError(f"Unsupported mask type: {type(mask)}")
if tensor.shape[2:] != (256, 256):
tensor = torch.nn.functional.interpolate(tensor, size=(256, 256), mode='bilinear', align_corners=False)
if tensor.shape[1] == 1 and self.input_channels > 1:
tensor = tensor.repeat(1, self.input_channels, 1, 1)
return tensor.to(self.device)
except Exception as e:
print(f"❌ Error processing mask: {e}")
raise e
def _generate_from_mask(self, conditioning_mask, num_steps=250, eps=1e-3):
"""
Diffusion sampling given a mask, returns tensor in [0,1].
"""
try:
x_shape = (self.input_channels, 256, 256)
with torch.no_grad():
samples = Euler_Maruyama_sampler(
self.score_model,
self.marginal_prob_std_fn,
self.diffusion_coeff_fn,
batch_size=1,
x_shape=x_shape,
num_steps=num_steps,
device=self.device,
eps=eps,
y=conditioning_mask
)
return samples.clamp(0, 1)
except Exception as e:
print(f"❌ Error in diffusion sampling: {e}")
raise e
def _tensor_to_image(self, tensor):
"""
Convert tensor -> RGB PIL image.
"""
try:
tensor = tensor.squeeze(0) # [C, H, W]
if tensor.shape[0] > 1:
image_array = (tensor.mean(dim=0).cpu().numpy() * 255).astype(np.uint8)
else:
image_array = (tensor[0].cpu().numpy() * 255).astype(np.uint8)
img_gray = Image.fromarray(image_array, mode='L')
return img_gray.convert("RGB") # Always RGB for frontend
except Exception as e:
print(f"❌ Error converting tensor to image: {e}")
raise e
|