Gabriele commited on
Commit
a34ff2b
·
1 Parent(s): 44dd5af

Improve code examples: better classes, descriptive comments

Browse files
Files changed (1) hide show
  1. README.md +7 -7
README.md CHANGED
@@ -55,15 +55,15 @@ image = Image.open(requests.get(url, stream=True).raw)
55
  pixel_values = transform(image).unsqueeze(0)
56
  out = model.encode_image(pixel_values)
57
 
58
- print(out.cls_token.shape) # (1, 1, 1024)
59
- print(out.patch_tokens.shape) # (1, 1024, 1024)
60
  ```
61
 
62
  ### Encode text
63
 
64
  ```python
65
  text_emb = model.encode_text(["a photo of a cat", "a photo of a dog"])
66
- print(text_emb.shape) # (2, 1024)
67
  ```
68
 
69
  ### Zero-shot classification
@@ -75,17 +75,17 @@ classes = ["cat", "dog", "car"]
75
  cls = F.normalize(out.cls_token[:, 0, :], dim=-1)
76
  text_emb = F.normalize(model.encode_text(classes), dim=-1)
77
  similarity = cls @ text_emb.T
78
- print(classes[similarity.argmax()]) # cat
79
  ```
80
 
81
  ### Zero-shot segmentation
82
 
83
  ```python
84
- classes = ["cat", "dog", "grass", "sky"]
85
  patch_feats = F.normalize(out.patch_tokens, dim=-1)
86
  text_emb = F.normalize(model.encode_text(classes), dim=-1)
87
  seg_map = (patch_feats @ text_emb.T).reshape(32, 32, len(classes)).argmax(dim=-1)
88
- print(seg_map.shape) # (32, 32)
89
  ```
90
 
91
  ### Visualize spatial features
@@ -97,7 +97,7 @@ from sklearn.decomposition import PCA
97
  spatial = out.patch_tokens.reshape(1, 32, 32, 1024)
98
  feat = spatial[0].detach().numpy().reshape(-1, 1024)
99
  rgb = PCA(n_components=3).fit_transform(feat).reshape(32, 32, 3)
100
- print(rgb.shape) # (32, 32, 3)
101
  ```
102
 
103
  ### GPU inference
 
55
  pixel_values = transform(image).unsqueeze(0)
56
  out = model.encode_image(pixel_values)
57
 
58
+ print(out.cls_token.shape) # (1, 1, 1024) — global image embedding
59
+ print(out.patch_tokens.shape) # (1, 1024, 1024) — per-patch spatial features
60
  ```
61
 
62
  ### Encode text
63
 
64
  ```python
65
  text_emb = model.encode_text(["a photo of a cat", "a photo of a dog"])
66
+ print(text_emb.shape) # (2, 1024) — one embedding per query
67
  ```
68
 
69
  ### Zero-shot classification
 
75
  cls = F.normalize(out.cls_token[:, 0, :], dim=-1)
76
  text_emb = F.normalize(model.encode_text(classes), dim=-1)
77
  similarity = cls @ text_emb.T
78
+ print(classes[similarity.argmax()]) # cat — predicted class
79
  ```
80
 
81
  ### Zero-shot segmentation
82
 
83
  ```python
84
+ classes = ["cat", "carpet", "floor", "furniture"]
85
  patch_feats = F.normalize(out.patch_tokens, dim=-1)
86
  text_emb = F.normalize(model.encode_text(classes), dim=-1)
87
  seg_map = (patch_feats @ text_emb.T).reshape(32, 32, len(classes)).argmax(dim=-1)
88
+ print(seg_map.shape) # (32, 32) — per-patch class prediction
89
  ```
90
 
91
  ### Visualize spatial features
 
97
  spatial = out.patch_tokens.reshape(1, 32, 32, 1024)
98
  feat = spatial[0].detach().numpy().reshape(-1, 1024)
99
  rgb = PCA(n_components=3).fit_transform(feat).reshape(32, 32, 3)
100
+ print(rgb.shape) # (32, 32, 3) — PCA of patch features as RGB
101
  ```
102
 
103
  ### GPU inference