Gabriele commited on
Commit
be4b3bf
·
1 Parent(s): d867321

Switch to bus image with matching classes

Browse files
Files changed (1) hide show
  1. README.md +5 -5
README.md CHANGED
@@ -50,7 +50,7 @@ transform = transforms.Compose([
50
  transforms.ToTensor(),
51
  ])
52
 
53
- url = "https://raw.githubusercontent.com/google-deepmind/tips/main/scenic/images/example_image_2.jpg"
54
  image = Image.open(requests.get(url, stream=True).raw)
55
  pixel_values = transform(image).unsqueeze(0)
56
  out = model.encode_image(pixel_values)
@@ -62,7 +62,7 @@ print(out.patch_tokens.shape) # (1, 1024, 1536) — per-patch spatial features
62
  ### Encode text
63
 
64
  ```python
65
- text_emb = model.encode_text(["a photo of a cat", "a photo of a dog"])
66
  print(text_emb.shape) # (2, 1536) — one embedding per query
67
  ```
68
 
@@ -71,17 +71,17 @@ print(text_emb.shape) # (2, 1536) — one embedding per query
71
  ```python
72
  import torch.nn.functional as F
73
 
74
- classes = ["cat", "dog", "car"]
75
  cls = F.normalize(out.cls_token[:, 0, :], dim=-1)
76
  text_emb = F.normalize(model.encode_text(classes), dim=-1)
77
  similarity = cls @ text_emb.T
78
- print(classes[similarity.argmax()]) # cat — predicted class
79
  ```
80
 
81
  ### Zero-shot segmentation
82
 
83
  ```python
84
- classes = ["cat", "carpet", "floor", "furniture"]
85
  patch_feats = F.normalize(out.patch_tokens, dim=-1)
86
  text_emb = F.normalize(model.encode_text(classes), dim=-1)
87
  seg_map = (patch_feats @ text_emb.T).reshape(32, 32, len(classes)).argmax(dim=-1)
 
50
  transforms.ToTensor(),
51
  ])
52
 
53
+ url = "https://huggingface.co/spaces/google/tipsv2-gpu-explorer/resolve/main/examples/zeroseg/pascal_context_00049_image.png"
54
  image = Image.open(requests.get(url, stream=True).raw)
55
  pixel_values = transform(image).unsqueeze(0)
56
  out = model.encode_image(pixel_values)
 
62
  ### Encode text
63
 
64
  ```python
65
+ text_emb = model.encode_text(["a photo of a bus", "a photo of a dog"])
66
  print(text_emb.shape) # (2, 1536) — one embedding per query
67
  ```
68
 
 
71
  ```python
72
  import torch.nn.functional as F
73
 
74
+ classes = ["bus", "car", "dog", "cat"]
75
  cls = F.normalize(out.cls_token[:, 0, :], dim=-1)
76
  text_emb = F.normalize(model.encode_text(classes), dim=-1)
77
  similarity = cls @ text_emb.T
78
+ print(classes[similarity.argmax()]) # bus — predicted class
79
  ```
80
 
81
  ### Zero-shot segmentation
82
 
83
  ```python
84
+ classes = ["bus", "snow", "mountain", "house", "road"]
85
  patch_feats = F.normalize(out.patch_tokens, dim=-1)
86
  text_emb = F.normalize(model.encode_text(classes), dim=-1)
87
  seg_map = (patch_feats @ text_emb.T).reshape(32, 32, len(classes)).argmax(dim=-1)