Gabriele commited on
Commit
365a9aa
·
1 Parent(s): f46d027

Use cat photo, add print statements to code examples

Browse files
Files changed (1) hide show
  1. README.md +9 -6
README.md CHANGED
@@ -50,20 +50,20 @@ transform = transforms.Compose([
50
  transforms.ToTensor(),
51
  ])
52
 
53
- url = "https://raw.githubusercontent.com/google-deepmind/tips/main/scenic/images/example_image.jpg"
54
  image = Image.open(requests.get(url, stream=True).raw)
55
  pixel_values = transform(image).unsqueeze(0)
56
  out = model.encode_image(pixel_values)
57
 
58
- out.cls_token # (B, 1, 768)
59
- out.patch_tokens # (B, N, 768)
60
  ```
61
 
62
  ### Encode text
63
 
64
  ```python
65
  text_emb = model.encode_text(["a photo of a cat", "a photo of a dog"])
66
- # (2, 768)
67
  ```
68
 
69
  ### Zero-shot classification
@@ -71,10 +71,11 @@ text_emb = model.encode_text(["a photo of a cat", "a photo of a dog"])
71
  ```python
72
  import torch.nn.functional as F
73
 
 
74
  cls = F.normalize(out.cls_token[:, 0, :], dim=-1)
75
- text_emb = F.normalize(model.encode_text(["cat", "dog", "car"]), dim=-1)
76
  similarity = cls @ text_emb.T
77
- prediction = similarity.argmax(dim=-1)
78
  ```
79
 
80
  ### Zero-shot segmentation
@@ -84,6 +85,7 @@ classes = ["cat", "dog", "grass", "sky"]
84
  patch_feats = F.normalize(out.patch_tokens, dim=-1)
85
  text_emb = F.normalize(model.encode_text(classes), dim=-1)
86
  seg_map = (patch_feats @ text_emb.T).reshape(32, 32, len(classes)).argmax(dim=-1)
 
87
  ```
88
 
89
  ### Visualize spatial features
@@ -95,6 +97,7 @@ from sklearn.decomposition import PCA
95
  spatial = out.patch_tokens.reshape(1, 32, 32, 768)
96
  feat = spatial[0].detach().numpy().reshape(-1, 768)
97
  rgb = PCA(n_components=3).fit_transform(feat).reshape(32, 32, 3)
 
98
  ```
99
 
100
  ### GPU inference
 
50
  transforms.ToTensor(),
51
  ])
52
 
53
+ url = "https://raw.githubusercontent.com/google-deepmind/tips/main/scenic/images/example_image_2.jpg"
54
  image = Image.open(requests.get(url, stream=True).raw)
55
  pixel_values = transform(image).unsqueeze(0)
56
  out = model.encode_image(pixel_values)
57
 
58
+ print(out.cls_token.shape) # (1, 1, 768)
59
+ print(out.patch_tokens.shape) # (1, 1024, 768)
60
  ```
61
 
62
  ### Encode text
63
 
64
  ```python
65
  text_emb = model.encode_text(["a photo of a cat", "a photo of a dog"])
66
+ print(text_emb.shape) # (2, 768)
67
  ```
68
 
69
  ### Zero-shot classification
 
71
  ```python
72
  import torch.nn.functional as F
73
 
74
+ classes = ["cat", "dog", "car"]
75
  cls = F.normalize(out.cls_token[:, 0, :], dim=-1)
76
+ text_emb = F.normalize(model.encode_text(classes), dim=-1)
77
  similarity = cls @ text_emb.T
78
+ print(classes[similarity.argmax()]) # cat
79
  ```
80
 
81
  ### Zero-shot segmentation
 
85
  patch_feats = F.normalize(out.patch_tokens, dim=-1)
86
  text_emb = F.normalize(model.encode_text(classes), dim=-1)
87
  seg_map = (patch_feats @ text_emb.T).reshape(32, 32, len(classes)).argmax(dim=-1)
88
+ print(seg_map.shape) # (32, 32)
89
  ```
90
 
91
  ### Visualize spatial features
 
97
  spatial = out.patch_tokens.reshape(1, 32, 32, 768)
98
  feat = spatial[0].detach().numpy().reshape(-1, 768)
99
  rgb = PCA(n_components=3).fit_transform(feat).reshape(32, 32, 3)
100
+ print(rgb.shape) # (32, 32, 3)
101
  ```
102
 
103
  ### GPU inference