Improve "Usage" section with correct tokenizer example (#2)
Browse files- Improve "Usage" section with correct tokenizer example (cfa1e16f340c543dd7871a103af7e2596b4cd0fe)
Co-authored-by: Niels Rogge <nielsr@users.noreply.huggingface.co>
README.md
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
---
|
| 2 |
library_name: pytorch
|
| 3 |
license: mit
|
| 4 |
-
pipeline_tag:
|
| 5 |
tags:
|
| 6 |
- computer-vision
|
| 7 |
- image-generation
|
|
@@ -65,11 +65,71 @@ We evaluate our approach across six generative models on ImageNet 256×256 and o
|
|
| 65 |
|
| 66 |
### Installation
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
```
|
| 74 |
|
| 75 |
## Training Details
|
|
|
|
| 1 |
---
|
| 2 |
library_name: pytorch
|
| 3 |
license: mit
|
| 4 |
+
pipeline_tag: image-feature-extraction
|
| 5 |
tags:
|
| 6 |
- computer-vision
|
| 7 |
- image-generation
|
|
|
|
| 65 |
|
| 66 |
### Installation
|
| 67 |
|
| 68 |
+
To use DeTok for extracting latent embeddings from images, you need to:
|
| 69 |
+
|
| 70 |
+
1. **Clone the official DeTok repository**:
|
| 71 |
+
```bash
|
| 72 |
+
git clone https://github.com/Jiawei-Yang/DeTok.git
|
| 73 |
+
cd DeTok
|
| 74 |
+
pip install -r requirements.txt
|
| 75 |
+
```
|
| 76 |
+
2. **Download the pre-trained tokenizer weights**:
|
| 77 |
+
You can download the `DeTok-BB-decoder_ft` checkpoint (recommended) from [here](https://huggingface.co/jjiaweiyang/l-DeTok/resolve/main/detok-BB-gamm3.0-m0.7-decoder_tuned.pth) and place it in your working directory (e.g., `detok-BB-gamm3.0-m0.7-decoder_tuned.pth`).
|
| 78 |
+
|
| 79 |
+
### Extract latent embeddings
|
| 80 |
+
|
| 81 |
+
Here's a sample Python code snippet for feature extraction using the `DeTok_BB` tokenizer:
|
| 82 |
+
|
| 83 |
+
```python
|
| 84 |
+
import torch
|
| 85 |
+
from PIL import Image
|
| 86 |
+
from torchvision.transforms import transforms
|
| 87 |
+
from models.detok import DeTok_BB # Import from the cloned DeTok repository
|
| 88 |
+
|
| 89 |
+
# --- Configuration (matching DeTok-BB-decoder_ft architecture from paper) ---
|
| 90 |
+
model_params = {
|
| 91 |
+
"img_size": 256,
|
| 92 |
+
"patch_size": 16,
|
| 93 |
+
"in_chans": 3,
|
| 94 |
+
"embed_dim": 768,
|
| 95 |
+
"depths": [2, 2, 8, 2],
|
| 96 |
+
"num_heads": [3, 6, 12, 24],
|
| 97 |
+
}
|
| 98 |
+
tokenizer_weights_path = "detok-BB-gamm3.0-m0.7-decoder_tuned.pth" # Path to your downloaded weights
|
| 99 |
+
|
| 100 |
+
# 1. Initialize and load the tokenizer
|
| 101 |
+
tokenizer = DeTok_BB(**model_params).eval()
|
| 102 |
+
if torch.cuda.is_available():
|
| 103 |
+
tokenizer = tokenizer.cuda()
|
| 104 |
+
|
| 105 |
+
# Load checkpoint state_dict
|
| 106 |
+
checkpoint = torch.load(tokenizer_weights_path, map_location='cpu')
|
| 107 |
+
tokenizer.load_state_dict(checkpoint['model'])
|
| 108 |
+
|
| 109 |
+
# 2. Prepare your image
|
| 110 |
+
transform = transforms.Compose([
|
| 111 |
+
transforms.Resize(model_params["img_size"]),
|
| 112 |
+
transforms.CenterCrop(model_params["img_size"]),
|
| 113 |
+
transforms.ToTensor(),
|
| 114 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 115 |
+
])
|
| 116 |
+
|
| 117 |
+
# Replace 'path/to/your/image.jpg' with your actual image file
|
| 118 |
+
image = Image.new('RGB', (model_params["img_size"], model_params["img_size"]), color = 'red') # Example dummy image
|
| 119 |
+
# image = Image.open("path/to/your/image.jpg").convert("RGB")
|
| 120 |
+
|
| 121 |
+
pixel_values = transform(image).unsqueeze(0) # Add batch dimension
|
| 122 |
+
|
| 123 |
+
if torch.cuda.is_available():
|
| 124 |
+
pixel_values = pixel_values.cuda()
|
| 125 |
+
|
| 126 |
+
# 3. Extract latent embeddings
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
latent_embeddings = tokenizer.encode(pixel_values)
|
| 129 |
+
|
| 130 |
+
print(f"Shape of latent embeddings: {latent_embeddings.shape}")
|
| 131 |
+
# Expected output for a 256x256 input image with 16x16 patches is (1, 256, 768),
|
| 132 |
+
# representing 256 image patches with 768-dimensional embeddings.
|
| 133 |
```
|
| 134 |
|
| 135 |
## Training Details
|