Update README.md
Browse files
README.md
CHANGED
|
@@ -20,7 +20,7 @@ library_name: diffusers
|
|
| 20 |
## Usage
|
| 21 |
```python
|
| 22 |
|
| 23 |
-
#
|
| 24 |
def load_model(model_path, device):
|
| 25 |
# Initialize the same model architecture as during training
|
| 26 |
model = ClassConditionedUnet().to(device)
|
|
@@ -33,7 +33,7 @@ def load_model(model_path, device):
|
|
| 33 |
|
| 34 |
return model
|
| 35 |
|
| 36 |
-
|
| 37 |
def predict(model, class_label, noise_scheduler, num_samples=8, device='cuda'):
|
| 38 |
model.eval() # Ensure the model is in evaluation mode
|
| 39 |
|
|
@@ -60,7 +60,6 @@ def predict(model, class_label, noise_scheduler, num_samples=8, device='cuda'):
|
|
| 60 |
|
| 61 |
return generated_images
|
| 62 |
|
| 63 |
-
# Display predicted images
|
| 64 |
def display_images(images, num_rows=2):
|
| 65 |
# Create a grid of images
|
| 66 |
grid = torchvision.utils.make_grid(images, nrow=num_rows)
|
|
@@ -75,17 +74,9 @@ def display_images(images, num_rows=2):
|
|
| 75 |
# Example of loading a model and generating predictions
|
| 76 |
model_path = "model_epoch_0.pth" # Path to your saved model
|
| 77 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 78 |
-
|
| 79 |
-
# Load the model
|
| 80 |
model = load_model(model_path, device)
|
| 81 |
-
|
| 82 |
-
# Create a noise scheduler
|
| 83 |
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
|
| 84 |
-
|
| 85 |
-
# Predict and generate samples for a specific class label
|
| 86 |
class_label = 1 # Example class label, change to your desired class
|
| 87 |
generated_images = predict(model, class_label, noise_scheduler, num_samples=2, device=device)
|
| 88 |
-
|
| 89 |
-
# Display the generated images
|
| 90 |
display_images(generated_images)
|
| 91 |
```
|
|
|
|
| 20 |
## Usage
|
| 21 |
```python
|
| 22 |
|
| 23 |
+
# Predict function to generate images
|
| 24 |
def load_model(model_path, device):
|
| 25 |
# Initialize the same model architecture as during training
|
| 26 |
model = ClassConditionedUnet().to(device)
|
|
|
|
| 33 |
|
| 34 |
return model
|
| 35 |
|
| 36 |
+
|
| 37 |
def predict(model, class_label, noise_scheduler, num_samples=8, device='cuda'):
|
| 38 |
model.eval() # Ensure the model is in evaluation mode
|
| 39 |
|
|
|
|
| 60 |
|
| 61 |
return generated_images
|
| 62 |
|
|
|
|
| 63 |
def display_images(images, num_rows=2):
|
| 64 |
# Create a grid of images
|
| 65 |
grid = torchvision.utils.make_grid(images, nrow=num_rows)
|
|
|
|
| 74 |
# Example of loading a model and generating predictions
|
| 75 |
model_path = "model_epoch_0.pth" # Path to your saved model
|
| 76 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
|
|
| 77 |
model = load_model(model_path, device)
|
|
|
|
|
|
|
| 78 |
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
|
|
|
|
|
|
|
| 79 |
class_label = 1 # Example class label, change to your desired class
|
| 80 |
generated_images = predict(model, class_label, noise_scheduler, num_samples=2, device=device)
|
|
|
|
|
|
|
| 81 |
display_images(generated_images)
|
| 82 |
```
|