prthm11 commited on
Commit
5727fc7
·
verified ·
1 Parent(s): 6a9f87c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -138
app.py CHANGED
@@ -703,6 +703,146 @@ def choose_top_candidates(embedding_results, phash_results, imgmatch_results, to
703
 
704
  return result
705
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
  def is_subpath(path: str, base: str) -> bool:
707
  """Return True if path is inside base (works across OSes)."""
708
  try:
@@ -1731,144 +1871,6 @@ def similarity_matching(sprites_data: dict, project_folder: str, top_k: int = 1,
1731
  img.save(buffer, format="PNG")
1732
  buffer.seek(0)
1733
  sprite_images_bytes.append(buffer)
1734
-
1735
- def hybrid_similarity_matching(sprite_images_bytes, sprite_ids, min_similarity=None, top_k=5, method_weights=(0.5,0.3,0.2)):
1736
- from PIL import Image
1737
- # Local safe defaults
1738
- embeddings_path = os.path.join(BLOCKS_DIR, "hybrid_embeddings.json")
1739
- hash_path = os.path.join(BLOCKS_DIR, "phash_data.json")
1740
- signature_path = os.path.join(BLOCKS_DIR, "signature_data.json")
1741
-
1742
- # Load embeddings
1743
- embedding_json = {}
1744
- if os.path.exists(embeddings_path):
1745
- with open(embeddings_path, "r", encoding="utf-8") as f:
1746
- embedding_json = json.load(f)
1747
-
1748
- # Load phash data (if exists) -> ensure hash_dict variable exists
1749
- hash_dict = {}
1750
- if os.path.exists(hash_path):
1751
- try:
1752
- with open(hash_path, "r", encoding="utf-8") as f:
1753
- hash_data = json.load(f)
1754
- for path, hash_str in hash_data.items():
1755
- try:
1756
- hash_dict[path] = hash_str
1757
- except Exception:
1758
- pass
1759
- except Exception:
1760
- pass
1761
-
1762
- # Load signature data (if exists) -> ensure signature_dict exists
1763
- signature_dict = {}
1764
- sig_data = {}
1765
- if os.path.exists(signature_path):
1766
- try:
1767
- with open(signature_path, "r", encoding="utf-8") as f:
1768
- sig_data = json.load(f)
1769
- for path, sig_list in sig_data.items():
1770
- try:
1771
- signature_dict[path] = np.array(sig_list)
1772
- except Exception:
1773
- pass
1774
- except Exception:
1775
- pass
1776
-
1777
- # Parse embeddings into lists
1778
- paths_list = []
1779
- embeddings_list = []
1780
- if isinstance(embedding_json, dict):
1781
- for p, emb in embedding_json.items():
1782
- if isinstance(emb, dict):
1783
- maybe_emb = emb.get("embedding") or emb.get("embeddings") or emb.get("emb")
1784
- if maybe_emb is None:
1785
- continue
1786
- arr = np.asarray(maybe_emb, dtype=np.float32)
1787
- elif isinstance(emb, list):
1788
- arr = np.asarray(emb, dtype=np.float32)
1789
- else:
1790
- continue
1791
- paths_list.append(os.path.normpath(str(p)))
1792
- embeddings_list.append(arr)
1793
- elif isinstance(embedding_json, list):
1794
- for item in embedding_json:
1795
- if not isinstance(item, dict):
1796
- continue
1797
- p = item.get("path") or item.get("image_path") or item.get("file") or item.get("filename") or item.get("img_path")
1798
- emb = item.get("embeddings") or item.get("embedding") or item.get("features") or item.get("vector") or item.get("emb")
1799
- if p is None or emb is None:
1800
- continue
1801
- paths_list.append(os.path.normpath(str(p)))
1802
- embeddings_list.append(np.asarray(emb, dtype=np.float32))
1803
-
1804
- if len(paths_list) == 0:
1805
- print("⚠ No reference images/embeddings found (this test harness may be running without data)")
1806
- # Return empty results gracefully
1807
- return [[] for _ in sprite_images_bytes], [[] for _ in sprite_images_bytes], []
1808
-
1809
- ref_matrix = np.vstack(embeddings_list).astype(np.float32)
1810
-
1811
- # Batch: Get all sprite embeddings, phash, sigs first
1812
- sprite_emb_list = []
1813
- sprite_phash_list = []
1814
- sprite_sig_list = []
1815
- per_sprite_final_indices = []
1816
- per_sprite_final_scores = []
1817
- per_sprite_rerank_debug = []
1818
- for i, sprite_bytes in enumerate(sprite_images_bytes):
1819
- sprite_pil = Image.open(sprite_bytes)
1820
- enhanced_sprite = process_image_cv2_from_pil(sprite_pil, scale=2) or sprite_pil
1821
- # sprite_emb = get_dinov2_embedding_from_pil(preprocess_for_model(enhanced_sprite)) or np.zeros(ref_matrix.shape[1])
1822
- # sprite_emb_list.append(sprite_emb)
1823
- sprite_emb = get_dinov2_embedding_from_pil(preprocess_for_model(enhanced_sprite))
1824
- sprite_emb = sprite_emb if sprite_emb is not None else np.zeros(ref_matrix.shape[1])
1825
- sprite_emb_list.append(sprite_emb)
1826
- # Perceptual hash
1827
- sprite_hash_arr = preprocess_for_hash(enhanced_sprite)
1828
- sprite_phash = None
1829
- if sprite_hash_arr is not None:
1830
- try: sprite_phash = phash.encode_image(image_array=sprite_hash_arr)
1831
- except: pass
1832
- sprite_phash_list.append(sprite_phash)
1833
- # Signature
1834
- sprite_sig = None
1835
- embedding_results, phash_results, imgmatch_results, combined_results = run_query_search_flow(
1836
- query_b64=sprite_b64_clean[i],
1837
- processed_dir=BLOCKS_DIR,
1838
- embeddings_dict=embedding_json,
1839
- hash_dict=hash_data,
1840
- signature_obj_map=sig_data,
1841
- gis=gis,
1842
- phash=phash,
1843
- MAX_PHASH_BITS=64,
1844
- k=5
1845
- )
1846
- # Call the advanced re-ranker
1847
- rerank_result = choose_top_candidates(embedding_results, phash_results, imgmatch_results,
1848
- top_k=top_k, method_weights=method_weights, verbose=True)
1849
- per_sprite_rerank_debug.append(rerank_result)
1850
-
1851
- # Selection logic: prefer consensus, else weighted top-1
1852
- final = None
1853
- if len(rerank_result["consensus_topk"]) > 0:
1854
- consensus = rerank_result["consensus_topk"]
1855
- best = max(consensus, key=lambda p: rerank_result["weighted_scores_full"].get(p, 0.0))
1856
- final = best
1857
- else:
1858
- final = rerank_result["weighted_topk"][0][0] if rerank_result["weighted_topk"] else None
1859
-
1860
- # Store index and score for downstream use
1861
- if final is not None and final in paths_list:
1862
- idx = paths_list.index(final)
1863
- score = rerank_result["weighted_scores_full"].get(final, 0.0)
1864
- per_sprite_final_indices.append([idx])
1865
- per_sprite_final_scores.append([score])
1866
- print(f"Sprite '{sprite_ids}' FINAL selected: {final} (index {idx}) score={score:.4f}")
1867
- else:
1868
- per_sprite_final_indices.append([])
1869
- per_sprite_final_scores.append([])
1870
-
1871
- return per_sprite_final_indices, per_sprite_final_scores, paths_list#, per_sprite_rerank_debug
1872
 
1873
  # Use hybrid matching system
1874
  per_sprite_matched_indices, per_sprite_scores, paths_list = hybrid_similarity_matching(
 
703
 
704
  return result
705
 
706
+
707
+ def hybrid_similarity_matching(sprite_images_bytes, sprite_ids, min_similarity=None, top_k=5, method_weights=(0.5,0.3,0.2)):
708
+ from PIL import Image
709
+ # Local safe defaults
710
+ embeddings_path = os.path.join(BLOCKS_DIR, "hybrid_embeddings.json")
711
+ hash_path = os.path.join(BLOCKS_DIR, "phash_data.json")
712
+ signature_path = os.path.join(BLOCKS_DIR, "signature_data.json")
713
+
714
+ # Load embeddings
715
+ embedding_json = {}
716
+ if os.path.exists(embeddings_path):
717
+ with open(embeddings_path, "r", encoding="utf-8") as f:
718
+ embedding_json = json.load(f)
719
+
720
+ # Load phash data (if exists) -> ensure hash_dict variable exists
721
+ hash_dict = {}
722
+ if os.path.exists(hash_path):
723
+ try:
724
+ with open(hash_path, "r", encoding="utf-8") as f:
725
+ hash_data = json.load(f)
726
+ for path, hash_str in hash_data.items():
727
+ try:
728
+ hash_dict[path] = hash_str
729
+ except Exception:
730
+ pass
731
+ except Exception:
732
+ pass
733
+
734
+ # Load signature data (if exists) -> ensure signature_dict exists
735
+ signature_dict = {}
736
+ sig_data = {}
737
+ if os.path.exists(signature_path):
738
+ try:
739
+ with open(signature_path, "r", encoding="utf-8") as f:
740
+ sig_data = json.load(f)
741
+ for path, sig_list in sig_data.items():
742
+ try:
743
+ signature_dict[path] = np.array(sig_list)
744
+ except Exception:
745
+ pass
746
+ except Exception:
747
+ pass
748
+
749
+ # Parse embeddings into lists
750
+ paths_list = []
751
+ embeddings_list = []
752
+ if isinstance(embedding_json, dict):
753
+ for p, emb in embedding_json.items():
754
+ if isinstance(emb, dict):
755
+ maybe_emb = emb.get("embedding") or emb.get("embeddings") or emb.get("emb")
756
+ if maybe_emb is None:
757
+ continue
758
+ arr = np.asarray(maybe_emb, dtype=np.float32)
759
+ elif isinstance(emb, list):
760
+ arr = np.asarray(emb, dtype=np.float32)
761
+ else:
762
+ continue
763
+ paths_list.append(os.path.normpath(str(p)))
764
+ embeddings_list.append(arr)
765
+ elif isinstance(embedding_json, list):
766
+ for item in embedding_json:
767
+ if not isinstance(item, dict):
768
+ continue
769
+ p = item.get("path") or item.get("image_path") or item.get("file") or item.get("filename") or item.get("img_path")
770
+ emb = item.get("embeddings") or item.get("embedding") or item.get("features") or item.get("vector") or item.get("emb")
771
+ if p is None or emb is None:
772
+ continue
773
+ paths_list.append(os.path.normpath(str(p)))
774
+ embeddings_list.append(np.asarray(emb, dtype=np.float32))
775
+
776
+ if len(paths_list) == 0:
777
+ print("⚠ No reference images/embeddings found (this test harness may be running without data)")
778
+ # Return empty results gracefully
779
+ return [[] for _ in sprite_images_bytes], [[] for _ in sprite_images_bytes], []
780
+
781
+ ref_matrix = np.vstack(embeddings_list).astype(np.float32)
782
+
783
+ # Batch: Get all sprite embeddings, phash, sigs first
784
+ sprite_emb_list = []
785
+ sprite_phash_list = []
786
+ sprite_sig_list = []
787
+ per_sprite_final_indices = []
788
+ per_sprite_final_scores = []
789
+ per_sprite_rerank_debug = []
790
+ for i, sprite_bytes in enumerate(sprite_images_bytes):
791
+ sprite_pil = Image.open(sprite_bytes)
792
+ enhanced_sprite = process_image_cv2_from_pil(sprite_pil, scale=2) or sprite_pil
793
+ # sprite_emb = get_dinov2_embedding_from_pil(preprocess_for_model(enhanced_sprite)) or np.zeros(ref_matrix.shape[1])
794
+ # sprite_emb_list.append(sprite_emb)
795
+ sprite_emb = get_dinov2_embedding_from_pil(preprocess_for_model(enhanced_sprite))
796
+ sprite_emb = sprite_emb if sprite_emb is not None else np.zeros(ref_matrix.shape[1])
797
+ sprite_emb_list.append(sprite_emb)
798
+ # Perceptual hash
799
+ sprite_hash_arr = preprocess_for_hash(enhanced_sprite)
800
+ sprite_phash = None
801
+ if sprite_hash_arr is not None:
802
+ try: sprite_phash = phash.encode_image(image_array=sprite_hash_arr)
803
+ except: pass
804
+ sprite_phash_list.append(sprite_phash)
805
+ # Signature
806
+ sprite_sig = None
807
+ embedding_results, phash_results, imgmatch_results, combined_results = run_query_search_flow(
808
+ query_b64=sprite_b64_clean[i],
809
+ processed_dir=BLOCKS_DIR,
810
+ embeddings_dict=embedding_json,
811
+ hash_dict=hash_data,
812
+ signature_obj_map=sig_data,
813
+ gis=gis,
814
+ phash=phash,
815
+ MAX_PHASH_BITS=64,
816
+ k=5
817
+ )
818
+ # Call the advanced re-ranker
819
+ rerank_result = choose_top_candidates(embedding_results, phash_results, imgmatch_results,
820
+ top_k=top_k, method_weights=method_weights, verbose=True)
821
+ per_sprite_rerank_debug.append(rerank_result)
822
+
823
+ # Selection logic: prefer consensus, else weighted top-1
824
+ final = None
825
+ if len(rerank_result["consensus_topk"]) > 0:
826
+ consensus = rerank_result["consensus_topk"]
827
+ best = max(consensus, key=lambda p: rerank_result["weighted_scores_full"].get(p, 0.0))
828
+ final = best
829
+ else:
830
+ final = rerank_result["weighted_topk"][0][0] if rerank_result["weighted_topk"] else None
831
+
832
+ # Store index and score for downstream use
833
+ if final is not None and final in paths_list:
834
+ idx = paths_list.index(final)
835
+ score = rerank_result["weighted_scores_full"].get(final, 0.0)
836
+ per_sprite_final_indices.append([idx])
837
+ per_sprite_final_scores.append([score])
838
+ print(f"Sprite '{sprite_ids}' FINAL selected: {final} (index {idx}) score={score:.4f}")
839
+ else:
840
+ per_sprite_final_indices.append([])
841
+ per_sprite_final_scores.append([])
842
+
843
+ return per_sprite_final_indices, per_sprite_final_scores, paths_list#, per_sprite_rerank_debug
844
+
845
+
846
  def is_subpath(path: str, base: str) -> bool:
847
  """Return True if path is inside base (works across OSes)."""
848
  try:
 
1871
  img.save(buffer, format="PNG")
1872
  buffer.seek(0)
1873
  sprite_images_bytes.append(buffer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1874
 
1875
  # Use hybrid matching system
1876
  per_sprite_matched_indices, per_sprite_scores, paths_list = hybrid_similarity_matching(