tan200224 commited on
Commit
48c35f9
·
verified ·
1 Parent(s): e667bbe

Update hf_diffusion_service.py

Browse files
Files changed (1) hide show
  1. hf_diffusion_service.py +194 -311
hf_diffusion_service.py CHANGED
@@ -1,311 +1,194 @@
1
- import os
2
- import sys
3
- import torch
4
- import numpy as np
5
- import torchvision.transforms as transforms
6
- from PIL import Image
7
-
8
- # Add the hf_model_files directory to the path
9
- sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'hf_model_files'))
10
-
11
- from model import UNet, marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler
12
-
13
- class CompatibleUNet(UNet):
14
- """A UNet model that's compatible with the saved weights."""
15
-
16
- def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256, 512], embed_dim=256,
17
- embed_dim_mask=256, input_dim_mask=1*256*256): # Changed to 1*256*256
18
- # Override the parent's __init__ to set the correct input channels
19
- super().__init__(marginal_prob_std, channels, embed_dim, embed_dim_mask, input_dim_mask)
20
-
21
- # Replace the first conv layer to accept 1 input channel instead of 4
22
- self.conv1 = torch.nn.Conv2d(1, channels[0], 3, stride=2, bias=False, padding=1)
23
-
24
- # Also need to fix the output layer if it exists
25
- if hasattr(self, 'tconv0'):
26
- self.tconv0 = torch.nn.ConvTranspose2d(channels[0], 1, 3, stride=1, padding=1, output_padding=0)
27
-
28
- class HFDiffusionService:
29
- """Service class for the Hugging Face conditional diffusion model."""
30
-
31
- def __init__(self):
32
- # Check if CUDA is available and print status
33
- cuda_available = torch.cuda.is_available()
34
- print(f"CUDA available for HF diffusion: {cuda_available}")
35
- if not cuda_available:
36
- print("Warning: CUDA is not available for HF diffusion. Using CPU instead. This might be slower.")
37
-
38
- self.device = torch.device('cuda:0' if cuda_available else 'cpu')
39
- self.Lambda = 25.0
40
-
41
- # Initialize the model functions
42
- self.marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=self.Lambda, device=self.device)
43
- self.diffusion_coeff_fn = lambda t: diffusion_coeff(t, Lambda=self.Lambda, device=self.device)
44
-
45
- # Model path for the downloaded Hugging Face model
46
- self.model_path = os.path.join("hf_model_files", "pytorch_model.bin")
47
-
48
- try:
49
- # Load the state dict first to understand the architecture
50
- state_dict = torch.load(self.model_path, map_location=self.device)
51
-
52
- # Analyze the state dict to determine the correct architecture
53
- conv1_weight = state_dict.get('conv1.weight', None)
54
- cond_embed_weight = state_dict.get('cond_embed.1.weight', None)
55
-
56
- if conv1_weight is not None:
57
- actual_input_channels = conv1_weight.shape[1]
58
- print(f"Detected input channels from state dict: {actual_input_channels}")
59
-
60
- if cond_embed_weight is not None:
61
- actual_input_dim_mask = cond_embed_weight.shape[1]
62
- print(f"Detected input_dim_mask from state dict: {actual_input_dim_mask}")
63
-
64
- # Create a compatible model
65
- if actual_input_channels == 1 and actual_input_dim_mask == 65536:
66
- # The saved model expects 1 input channel and 65536 flattened input
67
- # This suggests it was trained with 1*256*256 = 65536
68
- self.score_model = CompatibleUNet(
69
- marginal_prob_std=self.marginal_prob_std_fn,
70
- input_dim_mask=65536
71
- )
72
- self.input_channels = 1
73
- self.input_dim_mask = 65536
74
- else:
75
- # Use the original architecture
76
- self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
77
- self.input_channels = 4
78
- self.input_dim_mask = 262144
79
- else:
80
- # Fallback to original
81
- self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
82
- self.input_channels = 4
83
- self.input_dim_mask = 262144
84
- else:
85
- # Fallback to original
86
- self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
87
- self.input_channels = 4
88
- self.input_dim_mask = 262144
89
-
90
- # Load the weights
91
- self.score_model.load_state_dict(state_dict)
92
- self.score_model.to(self.device)
93
- self.score_model.eval()
94
-
95
- print(f"HF Diffusion model loaded successfully from {self.model_path}")
96
- print(f"Model configured for {self.input_channels} input channels and {self.input_dim_mask} mask dimensions")
97
-
98
- except Exception as e:
99
- print(f"Error loading HF diffusion model: {e}")
100
- raise e
101
-
102
- def generate_image(self, mask):
103
- """
104
- Generate a medical image based on a conditioning mask.
105
-
106
- Args:
107
- mask: Conditioning mask tensor of shape (1, 4, 256, 256) or PIL Image
108
-
109
- Returns:
110
- Generated image as PIL Image
111
- """
112
- try:
113
- # Process the mask input
114
- processed_mask = self.process_mask(mask)
115
-
116
- # Generate the image
117
- generated_tensor = self.generate_from_mask(processed_mask)
118
-
119
- # Convert tensor to PIL Image
120
- return self.tensor_to_image(generated_tensor)
121
-
122
- except Exception as e:
123
- print(f"Error generating HF diffusion image: {e}")
124
- return None
125
-
126
- def process_mask(self, mask):
127
- """
128
- Process the input mask to the correct format for the model.
129
-
130
- Args:
131
- mask: Input mask (PIL Image, numpy array, or tensor)
132
-
133
- Returns:
134
- Processed mask tensor of shape (1, 1, 256, 256) for 1-channel model
135
- """
136
- try:
137
- # If mask is a PIL Image, convert to tensor
138
- if isinstance(mask, Image.Image):
139
- transform = transforms.Compose([
140
- transforms.Grayscale(num_output_channels=1),
141
- transforms.Resize((256, 256), antialias=True),
142
- transforms.ToTensor()
143
- ])
144
- tensor = transform(mask).unsqueeze(0) # Add batch dimension
145
- elif isinstance(mask, np.ndarray):
146
- # Convert numpy array to tensor
147
- if mask.ndim == 2:
148
- mask = mask[np.newaxis, :, :] # Add channel dimension
149
- tensor = torch.from_numpy(mask).float()
150
- if tensor.dim() == 3:
151
- tensor = tensor.unsqueeze(0) # Add batch dimension
152
- elif isinstance(mask, torch.Tensor):
153
- tensor = mask
154
- if tensor.dim() == 3:
155
- tensor = tensor.unsqueeze(0) # Add batch dimension
156
- else:
157
- raise ValueError(f"Unsupported mask type: {type(mask)}")
158
-
159
- # Ensure the tensor has the correct shape based on model input
160
- if self.input_channels == 1:
161
- # Model expects 1 channel
162
- if tensor.shape[1] != 1:
163
- # Take the first channel or average if multiple channels
164
- if tensor.shape[1] > 1:
165
- tensor = tensor.mean(dim=1, keepdim=True)
166
- else:
167
- tensor = tensor[:, :1, :, :]
168
- else:
169
- # Model expects 4 channels
170
- if tensor.shape[1] == 1:
171
- # If single channel, repeat to 4 channels
172
- tensor = tensor.repeat(1, 4, 1, 1)
173
- elif tensor.shape[1] != 4:
174
- raise ValueError(f"Expected 1 or 4 channels, got {tensor.shape[1]}")
175
-
176
- # Ensure correct size
177
- if tensor.shape[2] != 256 or tensor.shape[3] != 256:
178
- tensor = torch.nn.functional.interpolate(tensor, size=(256, 256), mode='bilinear', align_corners=False)
179
-
180
- print(f"Processed mask shape: {tensor.shape}")
181
- return tensor.to(self.device)
182
-
183
- except Exception as e:
184
- print(f"Error processing mask: {e}")
185
- raise e
186
-
187
- def generate_from_mask(self, conditioning_mask, num_steps=250, eps=1e-3):
188
- """
189
- Generate image from conditioning mask using the diffusion model.
190
-
191
- Args:
192
- conditioning_mask: Conditioning mask tensor
193
- num_steps: Number of sampling steps
194
- eps: Smallest time step for numerical stability
195
-
196
- Returns:
197
- Generated image tensor
198
- """
199
- try:
200
- # Determine the output shape based on the model
201
- if self.input_channels == 1:
202
- x_shape = (1, 256, 256)
203
- else:
204
- x_shape = (4, 256, 256)
205
-
206
- with torch.no_grad():
207
- samples = Euler_Maruyama_sampler(
208
- self.score_model,
209
- self.marginal_prob_std_fn,
210
- self.diffusion_coeff_fn,
211
- batch_size=1,
212
- x_shape=x_shape,
213
- num_steps=num_steps,
214
- device=self.device,
215
- eps=eps,
216
- y=conditioning_mask
217
- )
218
-
219
- # Clamp values to [0, 1] range
220
- return samples.clamp(0, 1)
221
-
222
- except Exception as e:
223
- print(f"Error in generate_from_mask: {e}")
224
- raise e
225
-
226
- def tensor_to_image(self, tensor):
227
- """
228
- Convert tensor to PIL Image.
229
-
230
- Args:
231
- tensor: Generated tensor
232
-
233
- Returns:
234
- PIL Image
235
- """
236
- try:
237
- # Take the first channel for visualization (or average all channels)
238
- if tensor.shape[1] > 1:
239
- # Average the channels
240
- image_tensor = tensor.squeeze(0).mean(dim=0)
241
- else:
242
- image_tensor = tensor.squeeze(0).squeeze(0)
243
-
244
- # Convert to numpy and scale to 0-255
245
- image_array = (image_tensor.cpu().numpy() * 255).astype(np.uint8)
246
-
247
- # Create PIL Image
248
- image = Image.fromarray(image_array, mode='L')
249
-
250
- return image
251
-
252
- except Exception as e:
253
- print(f"Error converting tensor to image: {e}")
254
- raise e
255
-
256
- def generate_batch(self, masks, num_steps=250, eps=1e-3):
257
- """
258
- Generate multiple images from a batch of masks.
259
-
260
- Args:
261
- masks: List of masks or batch tensor
262
- num_steps: Number of sampling steps
263
- eps: Smallest time step for numerical stability
264
-
265
- Returns:
266
- List of generated PIL Images
267
- """
268
- try:
269
- if isinstance(masks, list):
270
- # Process each mask individually
271
- results = []
272
- for mask in masks:
273
- result = self.generate_image(mask)
274
- results.append(result)
275
- return results
276
- else:
277
- # Process as batch
278
- processed_masks = self.process_mask(masks)
279
- batch_size = processed_masks.shape[0]
280
-
281
- # Determine the output shape based on the model
282
- if self.input_channels == 1:
283
- x_shape = (1, 256, 256)
284
- else:
285
- x_shape = (4, 256, 256)
286
-
287
- with torch.no_grad():
288
- samples = Euler_Maruyama_sampler(
289
- self.score_model,
290
- self.marginal_prob_std_fn,
291
- self.diffusion_coeff_fn,
292
- batch_size=batch_size,
293
- x_shape=x_shape,
294
- num_steps=num_steps,
295
- device=self.device,
296
- eps=eps,
297
- y=processed_masks
298
- )
299
-
300
- # Convert each sample to image
301
- results = []
302
- for i in range(batch_size):
303
- sample = samples[i:i+1]
304
- image = self.tensor_to_image(sample)
305
- results.append(image)
306
-
307
- return results
308
-
309
- except Exception as e:
310
- print(f"Error in generate_batch: {e}")
311
- raise e
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image
6
+
7
+ from model import UNet, marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler
8
+
9
+
10
+ class CompatibleUNet(UNet):
11
+ """A UNet model that's compatible with the saved weights."""
12
+
13
+ def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256, 512], embed_dim=256,
14
+ embed_dim_mask=256, input_dim_mask=1*256*256):
15
+ # Override the parent's __init__ to set the correct input channels
16
+ super().__init__(marginal_prob_std, channels, embed_dim, embed_dim_mask, input_dim_mask)
17
+
18
+ # Replace the first conv layer to accept 1 input channel instead of 4
19
+ self.conv1 = torch.nn.Conv2d(1, channels[0], 3, stride=2, bias=False, padding=1)
20
+
21
+ # Also fix the output layer if it exists
22
+ if hasattr(self, 'tconv0'):
23
+ self.tconv0 = torch.nn.ConvTranspose2d(
24
+ channels[0], 1, 3, stride=1, padding=1, output_padding=0
25
+ )
26
+
27
+
28
+ class HFDiffusionService:
29
+ """Service class for the Hugging Face conditional diffusion model."""
30
+
31
+ def __init__(self):
32
+ # Check CUDA
33
+ cuda_available = torch.cuda.is_available()
34
+ print(f"CUDA available for HF diffusion: {cuda_available}")
35
+ if not cuda_available:
36
+ print("Warning: CUDA is not available. Using CPU (this will be slow).")
37
+
38
+ self.device = torch.device('cuda:0' if cuda_available else 'cpu')
39
+ self.Lambda = 25.0
40
+
41
+ # Initialize model functions
42
+ self.marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=self.Lambda, device=self.device)
43
+ self.diffusion_coeff_fn = lambda t: diffusion_coeff(t, Lambda=self.Lambda, device=self.device)
44
+
45
+ # Auto-detect model file path
46
+ model_candidates = [
47
+ "hf_model_files/pytorch_model.bin",
48
+ "pytorch_model.bin"
49
+ ]
50
+ self.model_path = next((path for path in model_candidates if os.path.exists(path)), None)
51
+
52
+ if not self.model_path:
53
+ raise FileNotFoundError("pytorch_model.bin not found in root or hf_model_files folder.")
54
+
55
+ print(f"Loading diffusion model from: {self.model_path}")
56
+
57
+ # Load model weights
58
+ try:
59
+ state_dict = torch.load(self.model_path, map_location=self.device)
60
+
61
+ # Analyze the state dict to configure model
62
+ conv1_weight = state_dict.get('conv1.weight', None)
63
+ cond_embed_weight = state_dict.get('cond_embed.1.weight', None)
64
+
65
+ if conv1_weight is not None:
66
+ actual_input_channels = conv1_weight.shape[1]
67
+ print(f"Detected input channels: {actual_input_channels}")
68
+
69
+ if cond_embed_weight is not None:
70
+ actual_input_dim_mask = cond_embed_weight.shape[1]
71
+ print(f"Detected input_dim_mask: {actual_input_dim_mask}")
72
+
73
+ # 1-channel model with 256*256 flattened mask
74
+ if actual_input_channels == 1 and actual_input_dim_mask == 65536:
75
+ self.score_model = CompatibleUNet(
76
+ marginal_prob_std=self.marginal_prob_std_fn,
77
+ input_dim_mask=65536
78
+ )
79
+ self.input_channels = 1
80
+ self.input_dim_mask = 65536
81
+ else:
82
+ self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
83
+ self.input_channels = 4
84
+ self.input_dim_mask = 262144
85
+ else:
86
+ self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
87
+ self.input_channels = 4
88
+ self.input_dim_mask = 262144
89
+ else:
90
+ # Default to original architecture
91
+ self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
92
+ self.input_channels = 4
93
+ self.input_dim_mask = 262144
94
+
95
+ self.score_model.load_state_dict(state_dict)
96
+ self.score_model.to(self.device)
97
+ self.score_model.eval()
98
+
99
+ print(f" HF Diffusion model loaded successfully")
100
+ print(f" Input channels: {self.input_channels}, Mask dim: {self.input_dim_mask}")
101
+
102
+ except Exception as e:
103
+ print(f"❌ Error loading HF diffusion model: {e}")
104
+ raise e
105
+
106
+ def generate_image(self, mask):
107
+ """Generate a medical image based on a conditioning mask."""
108
+ try:
109
+ processed_mask = self.process_mask(mask)
110
+ generated_tensor = self.generate_from_mask(processed_mask)
111
+ return self.tensor_to_image(generated_tensor)
112
+ except Exception as e:
113
+ print(f"❌ Error generating image: {e}")
114
+ return None
115
+
116
+ def process_mask(self, mask):
117
+ """Process the input mask to the correct format for the model."""
118
+ try:
119
+ if isinstance(mask, Image.Image):
120
+ transform = transforms.Compose([
121
+ transforms.Grayscale(num_output_channels=1),
122
+ transforms.Resize((256, 256), antialias=True),
123
+ transforms.ToTensor()
124
+ ])
125
+ tensor = transform(mask).unsqueeze(0)
126
+ elif isinstance(mask, np.ndarray):
127
+ if mask.ndim == 2:
128
+ mask = mask[np.newaxis, :, :]
129
+ tensor = torch.from_numpy(mask).float()
130
+ if tensor.dim() == 3:
131
+ tensor = tensor.unsqueeze(0)
132
+ elif isinstance(mask, torch.Tensor):
133
+ tensor = mask
134
+ if tensor.dim() == 3:
135
+ tensor = tensor.unsqueeze(0)
136
+ else:
137
+ raise ValueError(f"Unsupported mask type: {type(mask)}")
138
+
139
+ # Adjust channels
140
+ if self.input_channels == 1:
141
+ if tensor.shape[1] != 1:
142
+ tensor = tensor.mean(dim=1, keepdim=True)
143
+ else:
144
+ if tensor.shape[1] == 1:
145
+ tensor = tensor.repeat(1, 4, 1, 1)
146
+ elif tensor.shape[1] != 4:
147
+ raise ValueError(f"Expected 1 or 4 channels, got {tensor.shape[1]}")
148
+
149
+ # Ensure 256x256 size
150
+ if tensor.shape[2] != 256 or tensor.shape[3] != 256:
151
+ tensor = torch.nn.functional.interpolate(
152
+ tensor, size=(256, 256), mode='bilinear', align_corners=False
153
+ )
154
+
155
+ print(f"Processed mask shape: {tensor.shape}")
156
+ return tensor.to(self.device)
157
+ except Exception as e:
158
+ print(f"❌ Error processing mask: {e}")
159
+ raise e
160
+
161
+ def generate_from_mask(self, conditioning_mask, num_steps=250, eps=1e-3):
162
+ """Generate image from conditioning mask using diffusion model."""
163
+ try:
164
+ x_shape = (1, 256, 256) if self.input_channels == 1 else (4, 256, 256)
165
+ with torch.no_grad():
166
+ samples = Euler_Maruyama_sampler(
167
+ self.score_model,
168
+ self.marginal_prob_std_fn,
169
+ self.diffusion_coeff_fn,
170
+ batch_size=1,
171
+ x_shape=x_shape,
172
+ num_steps=num_steps,
173
+ device=self.device,
174
+ eps=eps,
175
+ y=conditioning_mask
176
+ )
177
+ return samples.clamp(0, 1)
178
+ except Exception as e:
179
+ print(f"❌ Error in generate_from_mask: {e}")
180
+ raise e
181
+
182
+ def tensor_to_image(self, tensor):
183
+ """Convert tensor to PIL Image."""
184
+ try:
185
+ if tensor.shape[1] > 1:
186
+ image_tensor = tensor.squeeze(0).mean(dim=0)
187
+ else:
188
+ image_tensor = tensor.squeeze(0).squeeze(0)
189
+
190
+ image_array = (image_tensor.cpu().numpy() * 255).astype(np.uint8)
191
+ return Image.fromarray(image_array, mode='L')
192
+ except Exception as e:
193
+ print(f"❌ Error converting tensor to image: {e}")
194
+ raise e