Spaces:
Runtime error
Runtime error
Alex Ergasti commited on
Commit ·
0e0633b
1
Parent(s): 9837429
Update model
Browse files
models.py
CHANGED
|
@@ -1,14 +1,3 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
# --------------------------------------------------------
|
| 7 |
-
# References:
|
| 8 |
-
# GLIDE: https://github.com/openai/glide-text2im
|
| 9 |
-
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
| 10 |
-
# --------------------------------------------------------
|
| 11 |
-
|
| 12 |
import torch
|
| 13 |
import torch.nn as nn
|
| 14 |
import numpy as np
|
|
@@ -16,6 +5,8 @@ import math
|
|
| 16 |
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
| 17 |
import einops
|
| 18 |
|
|
|
|
|
|
|
| 19 |
import torch.utils.checkpoint as checkpoint
|
| 20 |
|
| 21 |
from transformers import PreTrainedModel
|
|
@@ -371,7 +362,7 @@ class FinalLayer(nn.Module):
|
|
| 371 |
return x
|
| 372 |
|
| 373 |
|
| 374 |
-
class FLAV(nn.Module):
|
| 375 |
"""
|
| 376 |
Diffusion model with a Transformer backbone.
|
| 377 |
"""
|
|
@@ -748,4 +739,3 @@ FLAV_models = {
|
|
| 748 |
'FLAV-B/1' : FLAV_B_1, 'FLAV-B/2': FLAV_B_2, 'FLAV-B/4': FLAV_B_4, 'FLAV-B/8': FLAV_B_8,
|
| 749 |
'FLAV-S/2' : FLAV_S_2, 'FLAV-S/4': FLAV_S_4, 'FLAV-S/8': FLAV_S_8,
|
| 750 |
}
|
| 751 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import numpy as np
|
|
|
|
| 5 |
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
| 6 |
import einops
|
| 7 |
|
| 8 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 9 |
+
|
| 10 |
import torch.utils.checkpoint as checkpoint
|
| 11 |
|
| 12 |
from transformers import PreTrainedModel
|
|
|
|
| 362 |
return x
|
| 363 |
|
| 364 |
|
| 365 |
+
class FLAV(nn.Module, PyTorchModelHubMixin):
|
| 366 |
"""
|
| 367 |
Diffusion model with a Transformer backbone.
|
| 368 |
"""
|
|
|
|
| 739 |
'FLAV-B/1' : FLAV_B_1, 'FLAV-B/2': FLAV_B_2, 'FLAV-B/4': FLAV_B_4, 'FLAV-B/8': FLAV_B_8,
|
| 740 |
'FLAV-S/2' : FLAV_S_2, 'FLAV-S/4': FLAV_S_4, 'FLAV-S/8': FLAV_S_8,
|
| 741 |
}
|
|
|