argus.py: correspond() returns dense {matches, scores, grid} dict; lazy-load class names
Browse files
argus.py
CHANGED
|
@@ -1640,12 +1640,46 @@ class Argus(PreTrainedModel):
|
|
| 1640 |
if torch.isnan(buf).any() or torch.isinf(buf).any():
|
| 1641 |
buf.data.zero_()
|
| 1642 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1643 |
@property
|
| 1644 |
def class_ids(self):
|
|
|
|
|
|
|
| 1645 |
return self.config.class_ids
|
| 1646 |
|
| 1647 |
@property
|
| 1648 |
def class_names(self):
|
|
|
|
|
|
|
| 1649 |
return self.config.class_names
|
| 1650 |
|
| 1651 |
def quantize_int8(self):
|
|
@@ -1782,34 +1816,32 @@ class Argus(PreTrainedModel):
|
|
| 1782 |
self,
|
| 1783 |
src_image: Image.Image,
|
| 1784 |
tgt_image: Image.Image,
|
| 1785 |
-
src_keypoints: list,
|
| 1786 |
resolution: int = 512,
|
| 1787 |
):
|
| 1788 |
-
|
| 1789 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1790 |
transform = make_eupe_transform(resolution)
|
| 1791 |
src_t = transform(src_image).unsqueeze(0).to(self.device)
|
| 1792 |
tgt_t = transform(tgt_image).unsqueeze(0).to(self.device)
|
| 1793 |
|
| 1794 |
-
|
| 1795 |
-
|
| 1796 |
-
|
| 1797 |
-
|
| 1798 |
-
|
| 1799 |
-
|
| 1800 |
-
|
| 1801 |
-
|
| 1802 |
-
|
| 1803 |
-
|
| 1804 |
-
|
| 1805 |
-
|
| 1806 |
-
|
| 1807 |
-
|
| 1808 |
-
sim_map = torch.einsum("d,hwd->hw", src_vec, tgt_feats)
|
| 1809 |
-
flat = sim_map.argmax().item()
|
| 1810 |
-
py, px = flat // resolution, flat % resolution
|
| 1811 |
-
preds.append([px / resolution * tw, py / resolution * th])
|
| 1812 |
-
return preds
|
| 1813 |
|
| 1814 |
@torch.inference_mode()
|
| 1815 |
def detect(
|
|
|
|
| 1640 |
if torch.isnan(buf).any() or torch.isinf(buf).any():
|
| 1641 |
buf.data.zero_()
|
| 1642 |
|
| 1643 |
+
def _load_imagenet_classes(self):
|
| 1644 |
+
if getattr(self, "_imagenet_classes_loaded", False):
|
| 1645 |
+
return
|
| 1646 |
+
self._imagenet_classes_loaded = True
|
| 1647 |
+
import json
|
| 1648 |
+
import os as _os
|
| 1649 |
+
candidates = []
|
| 1650 |
+
here = _os.path.dirname(_os.path.abspath(__file__))
|
| 1651 |
+
candidates.append(_os.path.join(here, "imagenet_classes.json"))
|
| 1652 |
+
name_or_path = getattr(self.config, "_name_or_path", None)
|
| 1653 |
+
if name_or_path and _os.path.isdir(name_or_path):
|
| 1654 |
+
candidates.append(_os.path.join(name_or_path, "imagenet_classes.json"))
|
| 1655 |
+
for path in candidates:
|
| 1656 |
+
if _os.path.isfile(path):
|
| 1657 |
+
with open(path) as f:
|
| 1658 |
+
data = json.load(f)
|
| 1659 |
+
self.config.class_ids = data.get("class_ids", [])
|
| 1660 |
+
self.config.class_names = data.get("class_names", [])
|
| 1661 |
+
return
|
| 1662 |
+
if name_or_path and not _os.path.isdir(name_or_path):
|
| 1663 |
+
try:
|
| 1664 |
+
from huggingface_hub import hf_hub_download
|
| 1665 |
+
path = hf_hub_download(name_or_path, "imagenet_classes.json")
|
| 1666 |
+
with open(path) as f:
|
| 1667 |
+
data = json.load(f)
|
| 1668 |
+
self.config.class_ids = data.get("class_ids", [])
|
| 1669 |
+
self.config.class_names = data.get("class_names", [])
|
| 1670 |
+
except Exception:
|
| 1671 |
+
pass
|
| 1672 |
+
|
| 1673 |
@property
|
| 1674 |
def class_ids(self):
|
| 1675 |
+
if not self.config.class_ids:
|
| 1676 |
+
self._load_imagenet_classes()
|
| 1677 |
return self.config.class_ids
|
| 1678 |
|
| 1679 |
@property
|
| 1680 |
def class_names(self):
|
| 1681 |
+
if not self.config.class_names:
|
| 1682 |
+
self._load_imagenet_classes()
|
| 1683 |
return self.config.class_names
|
| 1684 |
|
| 1685 |
def quantize_int8(self):
|
|
|
|
| 1816 |
self,
|
| 1817 |
src_image: Image.Image,
|
| 1818 |
tgt_image: Image.Image,
|
|
|
|
| 1819 |
resolution: int = 512,
|
| 1820 |
):
|
| 1821 |
+
"""Dense patch correspondence between two images.
|
| 1822 |
+
|
| 1823 |
+
Returns a dict with keys `matches` (numpy array of length grid*grid mapping
|
| 1824 |
+
each source patch to its argmax target patch), `scores` (cosine similarity
|
| 1825 |
+
at the match), and `grid` (the patch-grid side length).
|
| 1826 |
+
"""
|
| 1827 |
transform = make_eupe_transform(resolution)
|
| 1828 |
src_t = transform(src_image).unsqueeze(0).to(self.device)
|
| 1829 |
tgt_t = transform(tgt_image).unsqueeze(0).to(self.device)
|
| 1830 |
|
| 1831 |
+
with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"):
|
| 1832 |
+
oa = self.backbone.forward_features(src_t)
|
| 1833 |
+
ob = self.backbone.forward_features(tgt_t)
|
| 1834 |
+
pa = F.normalize(oa['x_norm_patchtokens'].float().squeeze(0), dim=-1)
|
| 1835 |
+
pb = F.normalize(ob['x_norm_patchtokens'].float().squeeze(0), dim=-1)
|
| 1836 |
+
sim = pa @ pb.t()
|
| 1837 |
+
m = sim.argmax(dim=-1)
|
| 1838 |
+
s = sim.max(dim=-1).values
|
| 1839 |
+
grid = int(np.sqrt(pa.shape[0]))
|
| 1840 |
+
return {
|
| 1841 |
+
"matches": m.cpu().numpy(),
|
| 1842 |
+
"scores": s.cpu().numpy(),
|
| 1843 |
+
"grid": grid,
|
| 1844 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1845 |
|
| 1846 |
@torch.inference_mode()
|
| 1847 |
def detect(
|