phanerozoic commited on
Commit
d78d51e
·
verified ·
1 Parent(s): 87b8116

argus.py: correspond() returns dense {matches, scores, grid} dict; lazy-load class names

Browse files
Files changed (1) hide show
  1. argus.py +54 -22
argus.py CHANGED
@@ -1640,12 +1640,46 @@ class Argus(PreTrainedModel):
1640
  if torch.isnan(buf).any() or torch.isinf(buf).any():
1641
  buf.data.zero_()
1642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1643
  @property
1644
  def class_ids(self):
 
 
1645
  return self.config.class_ids
1646
 
1647
  @property
1648
  def class_names(self):
 
 
1649
  return self.config.class_names
1650
 
1651
  def quantize_int8(self):
@@ -1782,34 +1816,32 @@ class Argus(PreTrainedModel):
1782
  self,
1783
  src_image: Image.Image,
1784
  tgt_image: Image.Image,
1785
- src_keypoints: list,
1786
  resolution: int = 512,
1787
  ):
1788
- sw, sh = src_image.size
1789
- tw, th = tgt_image.size
 
 
 
 
1790
  transform = make_eupe_transform(resolution)
1791
  src_t = transform(src_image).unsqueeze(0).to(self.device)
1792
  tgt_t = transform(tgt_image).unsqueeze(0).to(self.device)
1793
 
1794
- _, src_feats = self._extract(src_t)
1795
- _, tgt_feats = self._extract(tgt_t)
1796
-
1797
- src_feats = F.interpolate(src_feats, size=(resolution, resolution), mode="bilinear", align_corners=False)
1798
- tgt_feats = F.interpolate(tgt_feats, size=(resolution, resolution), mode="bilinear", align_corners=False)
1799
-
1800
- src_feats = F.normalize(src_feats[0].permute(1, 2, 0), dim=-1)
1801
- tgt_feats = F.normalize(tgt_feats[0].permute(1, 2, 0), dim=-1)
1802
-
1803
- preds = []
1804
- for kp in src_keypoints:
1805
- sx = min(max(int(kp[0] / sw * resolution), 0), resolution - 1)
1806
- sy = min(max(int(kp[1] / sh * resolution), 0), resolution - 1)
1807
- src_vec = src_feats[sy, sx]
1808
- sim_map = torch.einsum("d,hwd->hw", src_vec, tgt_feats)
1809
- flat = sim_map.argmax().item()
1810
- py, px = flat // resolution, flat % resolution
1811
- preds.append([px / resolution * tw, py / resolution * th])
1812
- return preds
1813
 
1814
  @torch.inference_mode()
1815
  def detect(
 
1640
  if torch.isnan(buf).any() or torch.isinf(buf).any():
1641
  buf.data.zero_()
1642
 
1643
+ def _load_imagenet_classes(self):
1644
+ if getattr(self, "_imagenet_classes_loaded", False):
1645
+ return
1646
+ self._imagenet_classes_loaded = True
1647
+ import json
1648
+ import os as _os
1649
+ candidates = []
1650
+ here = _os.path.dirname(_os.path.abspath(__file__))
1651
+ candidates.append(_os.path.join(here, "imagenet_classes.json"))
1652
+ name_or_path = getattr(self.config, "_name_or_path", None)
1653
+ if name_or_path and _os.path.isdir(name_or_path):
1654
+ candidates.append(_os.path.join(name_or_path, "imagenet_classes.json"))
1655
+ for path in candidates:
1656
+ if _os.path.isfile(path):
1657
+ with open(path) as f:
1658
+ data = json.load(f)
1659
+ self.config.class_ids = data.get("class_ids", [])
1660
+ self.config.class_names = data.get("class_names", [])
1661
+ return
1662
+ if name_or_path and not _os.path.isdir(name_or_path):
1663
+ try:
1664
+ from huggingface_hub import hf_hub_download
1665
+ path = hf_hub_download(name_or_path, "imagenet_classes.json")
1666
+ with open(path) as f:
1667
+ data = json.load(f)
1668
+ self.config.class_ids = data.get("class_ids", [])
1669
+ self.config.class_names = data.get("class_names", [])
1670
+ except Exception:
1671
+ pass
1672
+
1673
  @property
1674
  def class_ids(self):
1675
+ if not self.config.class_ids:
1676
+ self._load_imagenet_classes()
1677
  return self.config.class_ids
1678
 
1679
  @property
1680
  def class_names(self):
1681
+ if not self.config.class_names:
1682
+ self._load_imagenet_classes()
1683
  return self.config.class_names
1684
 
1685
  def quantize_int8(self):
 
1816
  self,
1817
  src_image: Image.Image,
1818
  tgt_image: Image.Image,
 
1819
  resolution: int = 512,
1820
  ):
1821
+ """Dense patch correspondence between two images.
1822
+
1823
+ Returns a dict with keys `matches` (numpy array of length grid*grid mapping
1824
+ each source patch to its argmax target patch), `scores` (cosine similarity
1825
+ at the match), and `grid` (the patch-grid side length).
1826
+ """
1827
  transform = make_eupe_transform(resolution)
1828
  src_t = transform(src_image).unsqueeze(0).to(self.device)
1829
  tgt_t = transform(tgt_image).unsqueeze(0).to(self.device)
1830
 
1831
+ with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"):
1832
+ oa = self.backbone.forward_features(src_t)
1833
+ ob = self.backbone.forward_features(tgt_t)
1834
+ pa = F.normalize(oa['x_norm_patchtokens'].float().squeeze(0), dim=-1)
1835
+ pb = F.normalize(ob['x_norm_patchtokens'].float().squeeze(0), dim=-1)
1836
+ sim = pa @ pb.t()
1837
+ m = sim.argmax(dim=-1)
1838
+ s = sim.max(dim=-1).values
1839
+ grid = int(np.sqrt(pa.shape[0]))
1840
+ return {
1841
+ "matches": m.cpu().numpy(),
1842
+ "scores": s.cpu().numpy(),
1843
+ "grid": grid,
1844
+ }
 
 
 
 
 
1845
 
1846
  @torch.inference_mode()
1847
  def detect(