tan200224 commited on
Commit
4a5d5c6
·
verified ·
1 Parent(s): 2b4f902

Upload hf_diffusion_service.py

Browse files
Files changed (1) hide show
  1. hf_diffusion_service.py +311 -0
hf_diffusion_service.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import numpy as np
5
+ import torchvision.transforms as transforms
6
+ from PIL import Image
7
+
8
+ # Add the hf_model_files directory to the path
9
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'hf_model_files'))
10
+
11
+ from model import UNet, marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler
12
+
13
+ class CompatibleUNet(UNet):
14
+ """A UNet model that's compatible with the saved weights."""
15
+
16
+ def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256, 512], embed_dim=256,
17
+ embed_dim_mask=256, input_dim_mask=1*256*256): # Changed to 1*256*256
18
+ # Override the parent's __init__ to set the correct input channels
19
+ super().__init__(marginal_prob_std, channels, embed_dim, embed_dim_mask, input_dim_mask)
20
+
21
+ # Replace the first conv layer to accept 1 input channel instead of 4
22
+ self.conv1 = torch.nn.Conv2d(1, channels[0], 3, stride=2, bias=False, padding=1)
23
+
24
+ # Also need to fix the output layer if it exists
25
+ if hasattr(self, 'tconv0'):
26
+ self.tconv0 = torch.nn.ConvTranspose2d(channels[0], 1, 3, stride=1, padding=1, output_padding=0)
27
+
28
+ class HFDiffusionService:
29
+ """Service class for the Hugging Face conditional diffusion model."""
30
+
31
+ def __init__(self):
32
+ # Check if CUDA is available and print status
33
+ cuda_available = torch.cuda.is_available()
34
+ print(f"CUDA available for HF diffusion: {cuda_available}")
35
+ if not cuda_available:
36
+ print("Warning: CUDA is not available for HF diffusion. Using CPU instead. This might be slower.")
37
+
38
+ self.device = torch.device('cuda:0' if cuda_available else 'cpu')
39
+ self.Lambda = 25.0
40
+
41
+ # Initialize the model functions
42
+ self.marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=self.Lambda, device=self.device)
43
+ self.diffusion_coeff_fn = lambda t: diffusion_coeff(t, Lambda=self.Lambda, device=self.device)
44
+
45
+ # Model path for the downloaded Hugging Face model
46
+ self.model_path = os.path.join("hf_model_files", "pytorch_model.bin")
47
+
48
+ try:
49
+ # Load the state dict first to understand the architecture
50
+ state_dict = torch.load(self.model_path, map_location=self.device)
51
+
52
+ # Analyze the state dict to determine the correct architecture
53
+ conv1_weight = state_dict.get('conv1.weight', None)
54
+ cond_embed_weight = state_dict.get('cond_embed.1.weight', None)
55
+
56
+ if conv1_weight is not None:
57
+ actual_input_channels = conv1_weight.shape[1]
58
+ print(f"Detected input channels from state dict: {actual_input_channels}")
59
+
60
+ if cond_embed_weight is not None:
61
+ actual_input_dim_mask = cond_embed_weight.shape[1]
62
+ print(f"Detected input_dim_mask from state dict: {actual_input_dim_mask}")
63
+
64
+ # Create a compatible model
65
+ if actual_input_channels == 1 and actual_input_dim_mask == 65536:
66
+ # The saved model expects 1 input channel and 65536 flattened input
67
+ # This suggests it was trained with 1*256*256 = 65536
68
+ self.score_model = CompatibleUNet(
69
+ marginal_prob_std=self.marginal_prob_std_fn,
70
+ input_dim_mask=65536
71
+ )
72
+ self.input_channels = 1
73
+ self.input_dim_mask = 65536
74
+ else:
75
+ # Use the original architecture
76
+ self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
77
+ self.input_channels = 4
78
+ self.input_dim_mask = 262144
79
+ else:
80
+ # Fallback to original
81
+ self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
82
+ self.input_channels = 4
83
+ self.input_dim_mask = 262144
84
+ else:
85
+ # Fallback to original
86
+ self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
87
+ self.input_channels = 4
88
+ self.input_dim_mask = 262144
89
+
90
+ # Load the weights
91
+ self.score_model.load_state_dict(state_dict)
92
+ self.score_model.to(self.device)
93
+ self.score_model.eval()
94
+
95
+ print(f"HF Diffusion model loaded successfully from {self.model_path}")
96
+ print(f"Model configured for {self.input_channels} input channels and {self.input_dim_mask} mask dimensions")
97
+
98
+ except Exception as e:
99
+ print(f"Error loading HF diffusion model: {e}")
100
+ raise e
101
+
102
+ def generate_image(self, mask):
103
+ """
104
+ Generate a medical image based on a conditioning mask.
105
+
106
+ Args:
107
+ mask: Conditioning mask tensor of shape (1, 4, 256, 256) or PIL Image
108
+
109
+ Returns:
110
+ Generated image as PIL Image
111
+ """
112
+ try:
113
+ # Process the mask input
114
+ processed_mask = self.process_mask(mask)
115
+
116
+ # Generate the image
117
+ generated_tensor = self.generate_from_mask(processed_mask)
118
+
119
+ # Convert tensor to PIL Image
120
+ return self.tensor_to_image(generated_tensor)
121
+
122
+ except Exception as e:
123
+ print(f"Error generating HF diffusion image: {e}")
124
+ return None
125
+
126
+ def process_mask(self, mask):
127
+ """
128
+ Process the input mask to the correct format for the model.
129
+
130
+ Args:
131
+ mask: Input mask (PIL Image, numpy array, or tensor)
132
+
133
+ Returns:
134
+ Processed mask tensor of shape (1, 1, 256, 256) for 1-channel model
135
+ """
136
+ try:
137
+ # If mask is a PIL Image, convert to tensor
138
+ if isinstance(mask, Image.Image):
139
+ transform = transforms.Compose([
140
+ transforms.Grayscale(num_output_channels=1),
141
+ transforms.Resize((256, 256), antialias=True),
142
+ transforms.ToTensor()
143
+ ])
144
+ tensor = transform(mask).unsqueeze(0) # Add batch dimension
145
+ elif isinstance(mask, np.ndarray):
146
+ # Convert numpy array to tensor
147
+ if mask.ndim == 2:
148
+ mask = mask[np.newaxis, :, :] # Add channel dimension
149
+ tensor = torch.from_numpy(mask).float()
150
+ if tensor.dim() == 3:
151
+ tensor = tensor.unsqueeze(0) # Add batch dimension
152
+ elif isinstance(mask, torch.Tensor):
153
+ tensor = mask
154
+ if tensor.dim() == 3:
155
+ tensor = tensor.unsqueeze(0) # Add batch dimension
156
+ else:
157
+ raise ValueError(f"Unsupported mask type: {type(mask)}")
158
+
159
+ # Ensure the tensor has the correct shape based on model input
160
+ if self.input_channels == 1:
161
+ # Model expects 1 channel
162
+ if tensor.shape[1] != 1:
163
+ # Take the first channel or average if multiple channels
164
+ if tensor.shape[1] > 1:
165
+ tensor = tensor.mean(dim=1, keepdim=True)
166
+ else:
167
+ tensor = tensor[:, :1, :, :]
168
+ else:
169
+ # Model expects 4 channels
170
+ if tensor.shape[1] == 1:
171
+ # If single channel, repeat to 4 channels
172
+ tensor = tensor.repeat(1, 4, 1, 1)
173
+ elif tensor.shape[1] != 4:
174
+ raise ValueError(f"Expected 1 or 4 channels, got {tensor.shape[1]}")
175
+
176
+ # Ensure correct size
177
+ if tensor.shape[2] != 256 or tensor.shape[3] != 256:
178
+ tensor = torch.nn.functional.interpolate(tensor, size=(256, 256), mode='bilinear', align_corners=False)
179
+
180
+ print(f"Processed mask shape: {tensor.shape}")
181
+ return tensor.to(self.device)
182
+
183
+ except Exception as e:
184
+ print(f"Error processing mask: {e}")
185
+ raise e
186
+
187
+ def generate_from_mask(self, conditioning_mask, num_steps=250, eps=1e-3):
188
+ """
189
+ Generate image from conditioning mask using the diffusion model.
190
+
191
+ Args:
192
+ conditioning_mask: Conditioning mask tensor
193
+ num_steps: Number of sampling steps
194
+ eps: Smallest time step for numerical stability
195
+
196
+ Returns:
197
+ Generated image tensor
198
+ """
199
+ try:
200
+ # Determine the output shape based on the model
201
+ if self.input_channels == 1:
202
+ x_shape = (1, 256, 256)
203
+ else:
204
+ x_shape = (4, 256, 256)
205
+
206
+ with torch.no_grad():
207
+ samples = Euler_Maruyama_sampler(
208
+ self.score_model,
209
+ self.marginal_prob_std_fn,
210
+ self.diffusion_coeff_fn,
211
+ batch_size=1,
212
+ x_shape=x_shape,
213
+ num_steps=num_steps,
214
+ device=self.device,
215
+ eps=eps,
216
+ y=conditioning_mask
217
+ )
218
+
219
+ # Clamp values to [0, 1] range
220
+ return samples.clamp(0, 1)
221
+
222
+ except Exception as e:
223
+ print(f"Error in generate_from_mask: {e}")
224
+ raise e
225
+
226
+ def tensor_to_image(self, tensor):
227
+ """
228
+ Convert tensor to PIL Image.
229
+
230
+ Args:
231
+ tensor: Generated tensor
232
+
233
+ Returns:
234
+ PIL Image
235
+ """
236
+ try:
237
+ # Take the first channel for visualization (or average all channels)
238
+ if tensor.shape[1] > 1:
239
+ # Average the channels
240
+ image_tensor = tensor.squeeze(0).mean(dim=0)
241
+ else:
242
+ image_tensor = tensor.squeeze(0).squeeze(0)
243
+
244
+ # Convert to numpy and scale to 0-255
245
+ image_array = (image_tensor.cpu().numpy() * 255).astype(np.uint8)
246
+
247
+ # Create PIL Image
248
+ image = Image.fromarray(image_array, mode='L')
249
+
250
+ return image
251
+
252
+ except Exception as e:
253
+ print(f"Error converting tensor to image: {e}")
254
+ raise e
255
+
256
+ def generate_batch(self, masks, num_steps=250, eps=1e-3):
257
+ """
258
+ Generate multiple images from a batch of masks.
259
+
260
+ Args:
261
+ masks: List of masks or batch tensor
262
+ num_steps: Number of sampling steps
263
+ eps: Smallest time step for numerical stability
264
+
265
+ Returns:
266
+ List of generated PIL Images
267
+ """
268
+ try:
269
+ if isinstance(masks, list):
270
+ # Process each mask individually
271
+ results = []
272
+ for mask in masks:
273
+ result = self.generate_image(mask)
274
+ results.append(result)
275
+ return results
276
+ else:
277
+ # Process as batch
278
+ processed_masks = self.process_mask(masks)
279
+ batch_size = processed_masks.shape[0]
280
+
281
+ # Determine the output shape based on the model
282
+ if self.input_channels == 1:
283
+ x_shape = (1, 256, 256)
284
+ else:
285
+ x_shape = (4, 256, 256)
286
+
287
+ with torch.no_grad():
288
+ samples = Euler_Maruyama_sampler(
289
+ self.score_model,
290
+ self.marginal_prob_std_fn,
291
+ self.diffusion_coeff_fn,
292
+ batch_size=batch_size,
293
+ x_shape=x_shape,
294
+ num_steps=num_steps,
295
+ device=self.device,
296
+ eps=eps,
297
+ y=processed_masks
298
+ )
299
+
300
+ # Convert each sample to image
301
+ results = []
302
+ for i in range(batch_size):
303
+ sample = samples[i:i+1]
304
+ image = self.tensor_to_image(sample)
305
+ results.append(image)
306
+
307
+ return results
308
+
309
+ except Exception as e:
310
+ print(f"Error in generate_batch: {e}")
311
+ raise e