File size: 12,963 Bytes
4a5d5c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import os
import sys
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image

# Add the hf_model_files directory to the path
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'hf_model_files'))

from model import UNet, marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler

class CompatibleUNet(UNet):
    """A UNet model that's compatible with the saved weights."""
    
    def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256, 512], embed_dim=256,

                 embed_dim_mask=256, input_dim_mask=1*256*256):  # Changed to 1*256*256
        # Override the parent's __init__ to set the correct input channels
        super().__init__(marginal_prob_std, channels, embed_dim, embed_dim_mask, input_dim_mask)
        
        # Replace the first conv layer to accept 1 input channel instead of 4
        self.conv1 = torch.nn.Conv2d(1, channels[0], 3, stride=2, bias=False, padding=1)
        
        # Also need to fix the output layer if it exists
        if hasattr(self, 'tconv0'):
            self.tconv0 = torch.nn.ConvTranspose2d(channels[0], 1, 3, stride=1, padding=1, output_padding=0)

class HFDiffusionService:
    """Service class for the Hugging Face conditional diffusion model."""
    
    def __init__(self):
        # Check if CUDA is available and print status
        cuda_available = torch.cuda.is_available()
        print(f"CUDA available for HF diffusion: {cuda_available}")
        if not cuda_available:
            print("Warning: CUDA is not available for HF diffusion. Using CPU instead. This might be slower.")
            
        self.device = torch.device('cuda:0' if cuda_available else 'cpu')
        self.Lambda = 25.0
        
        # Initialize the model functions
        self.marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=self.Lambda, device=self.device)
        self.diffusion_coeff_fn = lambda t: diffusion_coeff(t, Lambda=self.Lambda, device=self.device)
        
        # Model path for the downloaded Hugging Face model
        self.model_path = os.path.join("hf_model_files", "pytorch_model.bin")
        
        try:
            # Load the state dict first to understand the architecture
            state_dict = torch.load(self.model_path, map_location=self.device)
            
            # Analyze the state dict to determine the correct architecture
            conv1_weight = state_dict.get('conv1.weight', None)
            cond_embed_weight = state_dict.get('cond_embed.1.weight', None)
            
            if conv1_weight is not None:
                actual_input_channels = conv1_weight.shape[1]
                print(f"Detected input channels from state dict: {actual_input_channels}")
                
                if cond_embed_weight is not None:
                    actual_input_dim_mask = cond_embed_weight.shape[1]
                    print(f"Detected input_dim_mask from state dict: {actual_input_dim_mask}")
                    
                    # Create a compatible model
                    if actual_input_channels == 1 and actual_input_dim_mask == 65536:
                        # The saved model expects 1 input channel and 65536 flattened input
                        # This suggests it was trained with 1*256*256 = 65536
                        self.score_model = CompatibleUNet(
                            marginal_prob_std=self.marginal_prob_std_fn,
                            input_dim_mask=65536
                        )
                        self.input_channels = 1
                        self.input_dim_mask = 65536
                    else:
                        # Use the original architecture
                        self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
                        self.input_channels = 4
                        self.input_dim_mask = 262144
                else:
                    # Fallback to original
                    self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
                    self.input_channels = 4
                    self.input_dim_mask = 262144
            else:
                # Fallback to original
                self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
                self.input_channels = 4
                self.input_dim_mask = 262144
            
            # Load the weights
            self.score_model.load_state_dict(state_dict)
            self.score_model.to(self.device)
            self.score_model.eval()
            
            print(f"HF Diffusion model loaded successfully from {self.model_path}")
            print(f"Model configured for {self.input_channels} input channels and {self.input_dim_mask} mask dimensions")
            
        except Exception as e:
            print(f"Error loading HF diffusion model: {e}")
            raise e

    def generate_image(self, mask):
        """

        Generate a medical image based on a conditioning mask.

        

        Args:

            mask: Conditioning mask tensor of shape (1, 4, 256, 256) or PIL Image

            

        Returns:

            Generated image as PIL Image

        """
        try:
            # Process the mask input
            processed_mask = self.process_mask(mask)
            
            # Generate the image
            generated_tensor = self.generate_from_mask(processed_mask)
            
            # Convert tensor to PIL Image
            return self.tensor_to_image(generated_tensor)
            
        except Exception as e:
            print(f"Error generating HF diffusion image: {e}")
            return None

    def process_mask(self, mask):
        """

        Process the input mask to the correct format for the model.

        

        Args:

            mask: Input mask (PIL Image, numpy array, or tensor)

            

        Returns:

            Processed mask tensor of shape (1, 1, 256, 256) for 1-channel model

        """
        try:
            # If mask is a PIL Image, convert to tensor
            if isinstance(mask, Image.Image):
                transform = transforms.Compose([
                    transforms.Grayscale(num_output_channels=1),
                    transforms.Resize((256, 256), antialias=True),
                    transforms.ToTensor()
                ])
                tensor = transform(mask).unsqueeze(0)  # Add batch dimension
            elif isinstance(mask, np.ndarray):
                # Convert numpy array to tensor
                if mask.ndim == 2:
                    mask = mask[np.newaxis, :, :]  # Add channel dimension
                tensor = torch.from_numpy(mask).float()
                if tensor.dim() == 3:
                    tensor = tensor.unsqueeze(0)  # Add batch dimension
            elif isinstance(mask, torch.Tensor):
                tensor = mask
                if tensor.dim() == 3:
                    tensor = tensor.unsqueeze(0)  # Add batch dimension
            else:
                raise ValueError(f"Unsupported mask type: {type(mask)}")
            
            # Ensure the tensor has the correct shape based on model input
            if self.input_channels == 1:
                # Model expects 1 channel
                if tensor.shape[1] != 1:
                    # Take the first channel or average if multiple channels
                    if tensor.shape[1] > 1:
                        tensor = tensor.mean(dim=1, keepdim=True)
                    else:
                        tensor = tensor[:, :1, :, :]
            else:
                # Model expects 4 channels
                if tensor.shape[1] == 1:
                    # If single channel, repeat to 4 channels
                    tensor = tensor.repeat(1, 4, 1, 1)
                elif tensor.shape[1] != 4:
                    raise ValueError(f"Expected 1 or 4 channels, got {tensor.shape[1]}")
            
            # Ensure correct size
            if tensor.shape[2] != 256 or tensor.shape[3] != 256:
                tensor = torch.nn.functional.interpolate(tensor, size=(256, 256), mode='bilinear', align_corners=False)
            
            print(f"Processed mask shape: {tensor.shape}")
            return tensor.to(self.device)
            
        except Exception as e:
            print(f"Error processing mask: {e}")
            raise e

    def generate_from_mask(self, conditioning_mask, num_steps=250, eps=1e-3):
        """

        Generate image from conditioning mask using the diffusion model.

        

        Args:

            conditioning_mask: Conditioning mask tensor

            num_steps: Number of sampling steps

            eps: Smallest time step for numerical stability

            

        Returns:

            Generated image tensor

        """
        try:
            # Determine the output shape based on the model
            if self.input_channels == 1:
                x_shape = (1, 256, 256)
            else:
                x_shape = (4, 256, 256)
            
            with torch.no_grad():
                samples = Euler_Maruyama_sampler(
                    self.score_model,
                    self.marginal_prob_std_fn,
                    self.diffusion_coeff_fn,
                    batch_size=1,
                    x_shape=x_shape,
                    num_steps=num_steps,
                    device=self.device,
                    eps=eps,
                    y=conditioning_mask
                )
            
            # Clamp values to [0, 1] range
            return samples.clamp(0, 1)
            
        except Exception as e:
            print(f"Error in generate_from_mask: {e}")
            raise e

    def tensor_to_image(self, tensor):
        """

        Convert tensor to PIL Image.

        

        Args:

            tensor: Generated tensor

            

        Returns:

            PIL Image

        """
        try:
            # Take the first channel for visualization (or average all channels)
            if tensor.shape[1] > 1:
                # Average the channels
                image_tensor = tensor.squeeze(0).mean(dim=0)
            else:
                image_tensor = tensor.squeeze(0).squeeze(0)
            
            # Convert to numpy and scale to 0-255
            image_array = (image_tensor.cpu().numpy() * 255).astype(np.uint8)
            
            # Create PIL Image
            image = Image.fromarray(image_array, mode='L')
            
            return image
            
        except Exception as e:
            print(f"Error converting tensor to image: {e}")
            raise e

    def generate_batch(self, masks, num_steps=250, eps=1e-3):
        """

        Generate multiple images from a batch of masks.

        

        Args:

            masks: List of masks or batch tensor

            num_steps: Number of sampling steps

            eps: Smallest time step for numerical stability

            

        Returns:

            List of generated PIL Images

        """
        try:
            if isinstance(masks, list):
                # Process each mask individually
                results = []
                for mask in masks:
                    result = self.generate_image(mask)
                    results.append(result)
                return results
            else:
                # Process as batch
                processed_masks = self.process_mask(masks)
                batch_size = processed_masks.shape[0]
                
                # Determine the output shape based on the model
                if self.input_channels == 1:
                    x_shape = (1, 256, 256)
                else:
                    x_shape = (4, 256, 256)
                
                with torch.no_grad():
                    samples = Euler_Maruyama_sampler(
                        self.score_model,
                        self.marginal_prob_std_fn,
                        self.diffusion_coeff_fn,
                        batch_size=batch_size,
                        x_shape=x_shape,
                        num_steps=num_steps,
                        device=self.device,
                        eps=eps,
                        y=processed_masks
                    )
                
                # Convert each sample to image
                results = []
                for i in range(batch_size):
                    sample = samples[i:i+1]
                    image = self.tensor_to_image(sample)
                    results.append(image)
                
                return results
                
        except Exception as e:
            print(f"Error in generate_batch: {e}")
            raise e