tan200224 commited on
Commit
01a4dbe
·
verified ·
1 Parent(s): 871c63b

Update hf_diffusion_service.py

Browse files
Files changed (1) hide show
  1. hf_diffusion_service.py +77 -91
hf_diffusion_service.py CHANGED
@@ -1,35 +1,29 @@
1
- import os
2
  import torch
3
  import numpy as np
4
- import torchvision.transforms as transforms
5
  from PIL import Image
 
 
 
6
 
7
  from model import UNet, marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler
8
 
9
 
10
  class CompatibleUNet(UNet):
11
- """A UNet model that's compatible with the saved weights."""
12
-
13
  def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256, 512], embed_dim=256,
14
  embed_dim_mask=256, input_dim_mask=1*256*256):
15
- # Override the parent's __init__ to set the correct input channels
16
  super().__init__(marginal_prob_std, channels, embed_dim, embed_dim_mask, input_dim_mask)
17
-
18
- # Replace the first conv layer to accept 1 input channel instead of 4
19
  self.conv1 = torch.nn.Conv2d(1, channels[0], 3, stride=2, bias=False, padding=1)
20
-
21
- # Also fix the output layer if it exists
22
  if hasattr(self, 'tconv0'):
23
- self.tconv0 = torch.nn.ConvTranspose2d(
24
- channels[0], 1, 3, stride=1, padding=1, output_padding=0
25
- )
26
 
27
 
28
  class HFDiffusionService:
29
- """Service class for the Hugging Face conditional diffusion model."""
30
 
31
  def __init__(self):
32
- # Check CUDA
33
  cuda_available = torch.cuda.is_available()
34
  print(f"CUDA available for HF diffusion: {cuda_available}")
35
  if not cuda_available:
@@ -37,84 +31,80 @@ class HFDiffusionService:
37
 
38
  self.device = torch.device('cuda:0' if cuda_available else 'cpu')
39
  self.Lambda = 25.0
40
-
41
- # Initialize model functions
42
  self.marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=self.Lambda, device=self.device)
43
  self.diffusion_coeff_fn = lambda t: diffusion_coeff(t, Lambda=self.Lambda, device=self.device)
44
 
45
- # Auto-detect model file path
46
- model_candidates = [
47
- "hf_model_files/pytorch_model.bin",
48
- "pytorch_model.bin"
49
- ]
50
- self.model_path = next((path for path in model_candidates if os.path.exists(path)), None)
51
-
52
- if not self.model_path:
53
- raise FileNotFoundError("pytorch_model.bin not found in root or hf_model_files folder.")
54
 
55
- print(f"Loading diffusion model from: {self.model_path}")
 
56
 
57
- # Load model weights
58
  try:
 
59
  state_dict = torch.load(self.model_path, map_location=self.device)
60
 
61
- # Analyze the state dict to configure model
62
  conv1_weight = state_dict.get('conv1.weight', None)
63
  cond_embed_weight = state_dict.get('cond_embed.1.weight', None)
64
 
65
  if conv1_weight is not None:
66
- actual_input_channels = conv1_weight.shape[1]
67
- print(f"Detected input channels: {actual_input_channels}")
68
-
69
- if cond_embed_weight is not None:
70
- actual_input_dim_mask = cond_embed_weight.shape[1]
71
- print(f"Detected input_dim_mask: {actual_input_dim_mask}")
72
-
73
- # 1-channel model with 256*256 flattened mask
74
- if actual_input_channels == 1 and actual_input_dim_mask == 65536:
75
- self.score_model = CompatibleUNet(
76
- marginal_prob_std=self.marginal_prob_std_fn,
77
- input_dim_mask=65536
78
- )
79
- self.input_channels = 1
80
- self.input_dim_mask = 65536
81
- else:
82
- self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
83
- self.input_channels = 4
84
- self.input_dim_mask = 262144
85
- else:
86
- self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
87
- self.input_channels = 4
88
- self.input_dim_mask = 262144
89
  else:
90
- # Default to original architecture
91
  self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
92
- self.input_channels = 4
93
- self.input_dim_mask = 262144
94
 
95
  self.score_model.load_state_dict(state_dict)
96
  self.score_model.to(self.device)
97
  self.score_model.eval()
98
 
99
- print(f"✅ HF Diffusion model loaded successfully")
100
- print(f" Input channels: {self.input_channels}, Mask dim: {self.input_dim_mask}")
101
 
102
  except Exception as e:
103
  print(f"❌ Error loading HF diffusion model: {e}")
104
  raise e
105
 
106
  def generate_image(self, mask):
107
- """Generate a medical image based on a conditioning mask."""
 
 
108
  try:
109
- processed_mask = self.process_mask(mask)
110
- generated_tensor = self.generate_from_mask(processed_mask)
111
- return self.tensor_to_image(generated_tensor)
112
  except Exception as e:
113
  print(f"❌ Error generating image: {e}")
114
  return None
115
 
116
- def process_mask(self, mask):
117
- """Process the input mask to the correct format for the model."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  try:
119
  if isinstance(mask, Image.Image):
120
  transform = transforms.Compose([
@@ -122,13 +112,15 @@ class HFDiffusionService:
122
  transforms.Resize((256, 256), antialias=True),
123
  transforms.ToTensor()
124
  ])
125
- tensor = transform(mask).unsqueeze(0)
 
126
  elif isinstance(mask, np.ndarray):
127
  if mask.ndim == 2:
128
  mask = mask[np.newaxis, :, :]
129
  tensor = torch.from_numpy(mask).float()
130
  if tensor.dim() == 3:
131
- tensor = tensor.unsqueeze(0)
 
132
  elif isinstance(mask, torch.Tensor):
133
  tensor = mask
134
  if tensor.dim() == 3:
@@ -136,32 +128,23 @@ class HFDiffusionService:
136
  else:
137
  raise ValueError(f"Unsupported mask type: {type(mask)}")
138
 
139
- # Adjust channels
140
- if self.input_channels == 1:
141
- if tensor.shape[1] != 1:
142
- tensor = tensor.mean(dim=1, keepdim=True)
143
- else:
144
- if tensor.shape[1] == 1:
145
- tensor = tensor.repeat(1, 4, 1, 1)
146
- elif tensor.shape[1] != 4:
147
- raise ValueError(f"Expected 1 or 4 channels, got {tensor.shape[1]}")
148
-
149
- # Ensure 256x256 size
150
- if tensor.shape[2] != 256 or tensor.shape[3] != 256:
151
- tensor = torch.nn.functional.interpolate(
152
- tensor, size=(256, 256), mode='bilinear', align_corners=False
153
- )
154
 
155
- print(f"Processed mask shape: {tensor.shape}")
156
  return tensor.to(self.device)
157
  except Exception as e:
158
  print(f"❌ Error processing mask: {e}")
159
  raise e
160
 
161
- def generate_from_mask(self, conditioning_mask, num_steps=250, eps=1e-3):
162
- """Generate image from conditioning mask using diffusion model."""
 
 
163
  try:
164
- x_shape = (1, 256, 256) if self.input_channels == 1 else (4, 256, 256)
165
  with torch.no_grad():
166
  samples = Euler_Maruyama_sampler(
167
  self.score_model,
@@ -176,19 +159,22 @@ class HFDiffusionService:
176
  )
177
  return samples.clamp(0, 1)
178
  except Exception as e:
179
- print(f"❌ Error in generate_from_mask: {e}")
180
  raise e
181
 
182
- def tensor_to_image(self, tensor):
183
- """Convert tensor to PIL Image."""
 
 
184
  try:
185
- if tensor.shape[1] > 1:
186
- image_tensor = tensor.squeeze(0).mean(dim=0)
 
187
  else:
188
- image_tensor = tensor.squeeze(0).squeeze(0)
189
 
190
- image_array = (image_tensor.cpu().numpy() * 255).astype(np.uint8)
191
- return Image.fromarray(image_array, mode='L')
192
  except Exception as e:
193
  print(f"❌ Error converting tensor to image: {e}")
194
  raise e
 
 
1
  import torch
2
  import numpy as np
 
3
  from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ import io
6
+ import base64
7
 
8
  from model import UNet, marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler
9
 
10
 
11
  class CompatibleUNet(UNet):
12
+ """A UNet model that's compatible with saved weights (handles 1-channel input)."""
13
+
14
  def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256, 512], embed_dim=256,
15
  embed_dim_mask=256, input_dim_mask=1*256*256):
 
16
  super().__init__(marginal_prob_std, channels, embed_dim, embed_dim_mask, input_dim_mask)
17
+ # Accept 1-channel input
 
18
  self.conv1 = torch.nn.Conv2d(1, channels[0], 3, stride=2, bias=False, padding=1)
 
 
19
  if hasattr(self, 'tconv0'):
20
+ self.tconv0 = torch.nn.ConvTranspose2d(channels[0], 1, 3, stride=1, padding=1, output_padding=0)
 
 
21
 
22
 
23
  class HFDiffusionService:
24
+ """Handles loading the conditional diffusion model and generating CT images."""
25
 
26
  def __init__(self):
 
27
  cuda_available = torch.cuda.is_available()
28
  print(f"CUDA available for HF diffusion: {cuda_available}")
29
  if not cuda_available:
 
31
 
32
  self.device = torch.device('cuda:0' if cuda_available else 'cpu')
33
  self.Lambda = 25.0
 
 
34
  self.marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=self.Lambda, device=self.device)
35
  self.diffusion_coeff_fn = lambda t: diffusion_coeff(t, Lambda=self.Lambda, device=self.device)
36
 
37
+ # Model path (make sure pytorch_model.bin is present)
38
+ self.model_path = "pytorch_model.bin"
39
+ self.input_channels = 1
40
+ self.input_dim_mask = 65536
 
 
 
 
 
41
 
42
+ # Load model
43
+ self._load_model()
44
 
45
+ def _load_model(self):
46
  try:
47
+ print(f"Loading diffusion model from: {self.model_path}")
48
  state_dict = torch.load(self.model_path, map_location=self.device)
49
 
 
50
  conv1_weight = state_dict.get('conv1.weight', None)
51
  cond_embed_weight = state_dict.get('cond_embed.1.weight', None)
52
 
53
  if conv1_weight is not None:
54
+ self.input_channels = conv1_weight.shape[1]
55
+ print(f"Detected input channels: {self.input_channels}")
56
+ if cond_embed_weight is not None:
57
+ self.input_dim_mask = cond_embed_weight.shape[1]
58
+ print(f"Detected input_dim_mask: {self.input_dim_mask}")
59
+
60
+ # Initialize compatible UNet
61
+ if self.input_channels == 1 and self.input_dim_mask == 65536:
62
+ self.score_model = CompatibleUNet(
63
+ marginal_prob_std=self.marginal_prob_std_fn,
64
+ input_dim_mask=self.input_dim_mask
65
+ )
 
 
 
 
 
 
 
 
 
 
 
66
  else:
 
67
  self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
 
 
68
 
69
  self.score_model.load_state_dict(state_dict)
70
  self.score_model.to(self.device)
71
  self.score_model.eval()
72
 
73
+ print(f"✅ HF Diffusion model loaded successfully\n Input channels: {self.input_channels}, Mask dim: {self.input_dim_mask}")
 
74
 
75
  except Exception as e:
76
  print(f"❌ Error loading HF diffusion model: {e}")
77
  raise e
78
 
79
  def generate_image(self, mask):
80
+ """
81
+ Generate a CT image from a segmentation mask and return it as PIL Image.
82
+ """
83
  try:
84
+ processed_mask = self._process_mask(mask)
85
+ tensor_image = self._generate_from_mask(processed_mask)
86
+ return self._tensor_to_image(tensor_image)
87
  except Exception as e:
88
  print(f"❌ Error generating image: {e}")
89
  return None
90
 
91
+ def generate_image_base64(self, mask):
92
+ """
93
+ Generate a CT image and return it as a base64 string (data URI).
94
+ """
95
+ image = self.generate_image(mask)
96
+ if image is None:
97
+ return None
98
+
99
+ buffer = io.BytesIO()
100
+ image.save(buffer, format="PNG")
101
+ base64_img = base64.b64encode(buffer.getvalue()).decode("utf-8")
102
+ return f"data:image/png;base64,{base64_img}"
103
+
104
+ def _process_mask(self, mask):
105
+ """
106
+ Convert input mask (PIL, np.array, or tensor) into model-ready tensor.
107
+ """
108
  try:
109
  if isinstance(mask, Image.Image):
110
  transform = transforms.Compose([
 
112
  transforms.Resize((256, 256), antialias=True),
113
  transforms.ToTensor()
114
  ])
115
+ tensor = transform(mask).unsqueeze(0) # [1, 1, 256, 256]
116
+
117
  elif isinstance(mask, np.ndarray):
118
  if mask.ndim == 2:
119
  mask = mask[np.newaxis, :, :]
120
  tensor = torch.from_numpy(mask).float()
121
  if tensor.dim() == 3:
122
+ tensor = tensor.unsqueeze(0) # [1, 1, 256, 256]
123
+
124
  elif isinstance(mask, torch.Tensor):
125
  tensor = mask
126
  if tensor.dim() == 3:
 
128
  else:
129
  raise ValueError(f"Unsupported mask type: {type(mask)}")
130
 
131
+ if tensor.shape[2:] != (256, 256):
132
+ tensor = torch.nn.functional.interpolate(tensor, size=(256, 256), mode='bilinear', align_corners=False)
133
+
134
+ if tensor.shape[1] == 1 and self.input_channels > 1:
135
+ tensor = tensor.repeat(1, self.input_channels, 1, 1)
 
 
 
 
 
 
 
 
 
 
136
 
 
137
  return tensor.to(self.device)
138
  except Exception as e:
139
  print(f"❌ Error processing mask: {e}")
140
  raise e
141
 
142
+ def _generate_from_mask(self, conditioning_mask, num_steps=250, eps=1e-3):
143
+ """
144
+ Diffusion sampling given a mask, returns tensor in [0,1].
145
+ """
146
  try:
147
+ x_shape = (self.input_channels, 256, 256)
148
  with torch.no_grad():
149
  samples = Euler_Maruyama_sampler(
150
  self.score_model,
 
159
  )
160
  return samples.clamp(0, 1)
161
  except Exception as e:
162
+ print(f"❌ Error in diffusion sampling: {e}")
163
  raise e
164
 
165
+ def _tensor_to_image(self, tensor):
166
+ """
167
+ Convert tensor -> RGB PIL image.
168
+ """
169
  try:
170
+ tensor = tensor.squeeze(0) # [C, H, W]
171
+ if tensor.shape[0] > 1:
172
+ image_array = (tensor.mean(dim=0).cpu().numpy() * 255).astype(np.uint8)
173
  else:
174
+ image_array = (tensor[0].cpu().numpy() * 255).astype(np.uint8)
175
 
176
+ img_gray = Image.fromarray(image_array, mode='L')
177
+ return img_gray.convert("RGB") # Always RGB for frontend
178
  except Exception as e:
179
  print(f"❌ Error converting tensor to image: {e}")
180
  raise e