Spaces:
Sleeping
Sleeping
Update models/DiT.py
Browse filesAdded PyTorchModelHubMixin to allow from_pretrained method
- models/DiT.py +3 -1
models/DiT.py
CHANGED
|
@@ -3,6 +3,7 @@ import torch.nn as nn
|
|
| 3 |
import torch.nn.functional as F
|
| 4 |
import math
|
| 5 |
from timm.models.vision_transformer import PatchEmbed
|
|
|
|
| 6 |
|
| 7 |
class TimestepEmbedder(nn.Module):
|
| 8 |
"""Module to create timestep's embedding."""
|
|
@@ -65,7 +66,8 @@ class DiTBlock(nn.Module):
|
|
| 65 |
x = x * (1+gamma_2.unsqueeze(1)) + beta_2.unsqueeze(1)
|
| 66 |
return x
|
| 67 |
|
| 68 |
-
class DiT(nn.Module
|
|
|
|
| 69 |
def __init__(self,
|
| 70 |
num_blocks=10,
|
| 71 |
hidden_size=640,
|
|
|
|
| 3 |
import torch.nn.functional as F
|
| 4 |
import math
|
| 5 |
from timm.models.vision_transformer import PatchEmbed
|
| 6 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 7 |
|
| 8 |
class TimestepEmbedder(nn.Module):
|
| 9 |
"""Module to create timestep's embedding."""
|
|
|
|
| 66 |
x = x * (1+gamma_2.unsqueeze(1)) + beta_2.unsqueeze(1)
|
| 67 |
return x
|
| 68 |
|
| 69 |
+
class DiT(nn.Module,
|
| 70 |
+
PyTorchModelHubMixin):
|
| 71 |
def __init__(self,
|
| 72 |
num_blocks=10,
|
| 73 |
hidden_size=640,
|