Gabriele commited on
Commit
2bad776
·
1 Parent(s): 3d0bb08

Fix zero-shot segmentation section, use public example image

Browse files
Files changed (1) hide show
  1. README.md +16 -4
README.md CHANGED
@@ -24,7 +24,7 @@ TIPSv2 (Text-Image Pre-training with Spatial awareness) is a family of contrasti
24
  ## Usage
25
 
26
  ```bash
27
- pip install transformers torch torchvision sentencepiece
28
  ```
29
 
30
  ### Load the model
@@ -43,14 +43,17 @@ Images should be tensors in `[0, 1]` range (just `ToTensor()`, no ImageNet norma
43
  ```python
44
  from torchvision import transforms
45
  from PIL import Image
 
46
 
47
  transform = transforms.Compose([
48
  transforms.Resize((448, 448)),
49
  transforms.ToTensor(),
50
  ])
51
 
52
- image = transform(Image.open("photo.jpg")).unsqueeze(0)
53
- out = model.encode_image(image)
 
 
54
 
55
  out.cls_token # (B, 1, 1152)
56
  out.patch_tokens # (B, N, 1152)
@@ -76,6 +79,15 @@ prediction = similarity.argmax(dim=-1)
76
 
77
  ### Zero-shot segmentation
78
 
 
 
 
 
 
 
 
 
 
79
  ```python
80
  import numpy as np
81
  from sklearn.decomposition import PCA
@@ -89,7 +101,7 @@ rgb = PCA(n_components=3).fit_transform(feat).reshape(32, 32, 3)
89
 
90
  ```python
91
  model = model.cuda()
92
- out = model.encode_image(image.cuda())
93
  text_emb = model.encode_text(["a city"])
94
  ```
95
 
 
24
  ## Usage
25
 
26
  ```bash
27
+ pip install transformers torch torchvision sentencepiece scikit-learn
28
  ```
29
 
30
  ### Load the model
 
43
  ```python
44
  from torchvision import transforms
45
  from PIL import Image
46
+ import requests
47
 
48
  transform = transforms.Compose([
49
  transforms.Resize((448, 448)),
50
  transforms.ToTensor(),
51
  ])
52
 
53
+ url = "https://raw.githubusercontent.com/google-deepmind/tips/main/scenic/images/example_image.jpg"
54
+ image = Image.open(requests.get(url, stream=True).raw)
55
+ pixel_values = transform(image).unsqueeze(0)
56
+ out = model.encode_image(pixel_values)
57
 
58
  out.cls_token # (B, 1, 1152)
59
  out.patch_tokens # (B, N, 1152)
 
79
 
80
  ### Zero-shot segmentation
81
 
82
+ ```python
83
+ classes = ["cat", "dog", "grass", "sky"]
84
+ patch_feats = F.normalize(out.patch_tokens, dim=-1)
85
+ text_emb = F.normalize(model.encode_text(classes), dim=-1)
86
+ seg_map = (patch_feats @ text_emb.T).reshape(32, 32, len(classes)).argmax(dim=-1)
87
+ ```
88
+
89
+ ### Visualize spatial features
90
+
91
  ```python
92
  import numpy as np
93
  from sklearn.decomposition import PCA
 
101
 
102
  ```python
103
  model = model.cuda()
104
+ out = model.encode_image(pixel_values.cuda())
105
  text_emb = model.encode_text(["a city"])
106
  ```
107