Spaces:
Runtime error
Runtime error
Update dataset
Browse files
app.py
CHANGED
|
@@ -8,7 +8,7 @@ from tokenizers import Tokenizer
|
|
| 8 |
from torch.utils.data import Dataset
|
| 9 |
import albumentations as A
|
| 10 |
from tqdm import tqdm
|
| 11 |
-
|
| 12 |
from fourm.vq.vqvae import VQVAE
|
| 13 |
from fourm.models.fm import FM
|
| 14 |
from fourm.models.generate import (
|
|
@@ -28,7 +28,7 @@ IMG_SIZE = 224
|
|
| 28 |
TOKENIZER_PATH = "./fourm/utils/tokenizer/trained/text_tokenizer_4m_wordpiece_30k.json"
|
| 29 |
FM_MODEL_PATH = "EPFL-VILAB/4M-21_L"
|
| 30 |
VQVAE_PATH = "EPFL-VILAB/4M_tokenizers_DINOv2-B14-global_8k_16_224"
|
| 31 |
-
IMAGE_DATASET_PATH = "
|
| 32 |
|
| 33 |
# Load models
|
| 34 |
text_tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
|
|
@@ -61,25 +61,24 @@ schedule = build_chained_generation_schedules(
|
|
| 61 |
sampler = GenerationSampler(fm_model)
|
| 62 |
|
| 63 |
|
| 64 |
-
class
|
| 65 |
-
def __init__(self,
|
| 66 |
-
self.
|
| 67 |
-
self.
|
| 68 |
-
|
| 69 |
-
|
| 70 |
|
| 71 |
def __len__(self):
|
| 72 |
-
return len(self.
|
| 73 |
|
| 74 |
def __getitem__(self, idx):
|
| 75 |
-
img =
|
| 76 |
img = np.array(img)
|
| 77 |
img = self.tfms(image=img)["image"]
|
| 78 |
return Image.fromarray(img)
|
| 79 |
|
| 80 |
-
|
| 81 |
-
dataset =
|
| 82 |
-
|
| 83 |
|
| 84 |
@torch.no_grad()
|
| 85 |
def get_image_embeddings(dataset):
|
|
|
|
| 8 |
from torch.utils.data import Dataset
|
| 9 |
import albumentations as A
|
| 10 |
from tqdm import tqdm
|
| 11 |
+
from datasets import load_dataset
|
| 12 |
from fourm.vq.vqvae import VQVAE
|
| 13 |
from fourm.models.fm import FM
|
| 14 |
from fourm.models.generate import (
|
|
|
|
| 28 |
TOKENIZER_PATH = "./fourm/utils/tokenizer/trained/text_tokenizer_4m_wordpiece_30k.json"
|
| 29 |
FM_MODEL_PATH = "EPFL-VILAB/4M-21_L"
|
| 30 |
VQVAE_PATH = "EPFL-VILAB/4M_tokenizers_DINOv2-B14-global_8k_16_224"
|
| 31 |
+
IMAGE_DATASET_PATH = "./data"
|
| 32 |
|
| 33 |
# Load models
|
| 34 |
text_tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
|
|
|
|
| 61 |
sampler = GenerationSampler(fm_model)
|
| 62 |
|
| 63 |
|
| 64 |
+
class HuggingFaceImageDataset(Dataset):
|
| 65 |
+
def __init__(self, dataset_name, split="train", img_sz=224):
|
| 66 |
+
self.dataset = load_dataset(dataset_name, split=split)
|
| 67 |
+
self.tfms = A.Compose([
|
| 68 |
+
A.SmallestMaxSize(img_sz)
|
| 69 |
+
])
|
| 70 |
|
| 71 |
def __len__(self):
|
| 72 |
+
return len(self.dataset)
|
| 73 |
|
| 74 |
def __getitem__(self, idx):
|
| 75 |
+
img = self.dataset[idx]['image']
|
| 76 |
img = np.array(img)
|
| 77 |
img = self.tfms(image=img)["image"]
|
| 78 |
return Image.fromarray(img)
|
| 79 |
|
| 80 |
+
# Usage
|
| 81 |
+
dataset = HuggingFaceImageDataset("aroraaman/4m-21-demo")
|
|
|
|
| 82 |
|
| 83 |
@torch.no_grad()
|
| 84 |
def get_image_embeddings(dataset):
|