Spaces:
Sleeping
Sleeping
| import os | |
| import pickle | |
| import torch | |
| import open_clip | |
| import faiss | |
| import logging | |
| from PIL import Image | |
| from collections import defaultdict | |
| logger = logging.getLogger("ai_green_assistant.plentclef") | |
| def _load_plantclef_cache(cache_path): | |
| """Load PlantCLEF cache across PyTorch versions. | |
| PyTorch >=2.6 defaults torch.load(..., weights_only=True), which may fail | |
| for legacy cache files serialized with pickle opcodes. | |
| """ | |
| try: | |
| return torch.load(cache_path, map_location="cpu", weights_only=True) | |
| except TypeError: | |
| # Older PyTorch versions don't support weights_only argument. | |
| return torch.load(cache_path, map_location="cpu") | |
| except Exception: | |
| try: | |
| # Legacy trusted cache fallback for PyTorch >=2.6 behavior change. | |
| return torch.load(cache_path, map_location="cpu", weights_only=False) | |
| except Exception: | |
| # Some historical cache files are plain pickle dictionaries. | |
| with open(cache_path, "rb") as f: | |
| return pickle.load(f) | |
| def _rrf_merge(results_list: list[list], k: int = 60, debug_log: bool = False) -> list[tuple]: | |
| """Reciprocal Rank Fusion across multiple ranked result lists. | |
| Each element of *results_list* is a list of (species, score, image_paths) | |
| tuples already sorted by descending score. Returns a merged list sorted | |
| by descending RRF score, same tuple format (image_paths will be empty). | |
| If debug_log=True, logs detailed RRF calculations for each species. | |
| """ | |
| combined: dict[str, float] = defaultdict(float) | |
| species_sources: dict[str, list] = defaultdict(list) # Track which index found each species | |
| for index_idx, ranked in enumerate(results_list): | |
| index_name = ["PlantCLEF", "LeafSnap"][index_idx] if index_idx < 2 else f"Index_{index_idx}" | |
| if debug_log: | |
| logger.info(f" {index_name} results (top {len(ranked)}):") | |
| for rank, (species, _score, _paths) in enumerate(ranked): | |
| rrf_contribution = 1.0 / (k + rank + 1) | |
| combined[species] += rrf_contribution | |
| species_sources[species].append({ | |
| "index": index_name, | |
| "rank": rank, | |
| "score": _score, | |
| "rrf_contribution": rrf_contribution | |
| }) | |
| if debug_log and rank < 10: # Log top 10 from each index | |
| logger.info(f" Rank {rank:2d}: {species:40s} | raw_score: {_score:7.4f} | RRF += {rrf_contribution:.6f}") | |
| final_results = sorted( | |
| [(species, rrf_score, []) for species, rrf_score in combined.items()], | |
| key=lambda x: x[1], | |
| reverse=True, | |
| ) | |
| if debug_log: | |
| logger.info(f"\n RRF Final Results (merged):") | |
| for final_rank, (species, rrf_score, _) in enumerate(final_results[:20]): | |
| sources_str = ", ".join( | |
| f"{s['index']}(rank={s['rank']}, contrib={s['rrf_contribution']:.6f})" | |
| for s in species_sources.get(species, []) | |
| ) | |
| logger.info(f" Rank {final_rank:2d}: {species:40s} | RRF_score: {rrf_score:.6f} | sources: {sources_str}") | |
| return final_results | |
| class PlentClefIndex(): | |
| def __init__(self, model_name, index_path, index_cache, | |
| leafsnap_index_path=None, leafsnap_cache_path=None, | |
| leafsnap_aliases: dict | None = None): | |
| self.model, self.preprocess, self.tokenizer = open_clip.create_model_and_transforms( | |
| model_name=model_name, | |
| pretrained="laion2b_s34b_b79k" | |
| ) | |
| self.index = faiss.read_index(index_path) | |
| self.model.eval() | |
| data = _load_plantclef_cache(index_cache) | |
| if not isinstance(data, dict): | |
| raise TypeError("PlantCLEF cache must be a dict-like object") | |
| self.plantclef_image_embeddings = data.get("embeddings") | |
| if "labels" not in data: | |
| raise KeyError("Missing 'labels' in PlantCLEF cache") | |
| self.plantclef_labels = data["labels"] | |
| logger.info(f"PlantCLEF loaded: {len(self.plantclef_labels)} species") | |
| # Optional LeafSnap index (leaf-only images, same embedding space) | |
| self.leafsnap_index = None | |
| self.leafsnap_labels: list = [] | |
| logger.info(f"Attempting to load LeafSnap from: {leafsnap_index_path}") | |
| if leafsnap_index_path and os.path.exists(leafsnap_index_path): | |
| logger.info(f" β LeafSnap index found at {leafsnap_index_path}") | |
| try: | |
| self.leafsnap_index = faiss.read_index(leafsnap_index_path) | |
| logger.info(f" β LeafSnap index loaded successfully") | |
| except Exception as e: | |
| logger.warning(f" β Failed to load LeafSnap index: {e}") | |
| self.leafsnap_index = None | |
| if leafsnap_cache_path and os.path.exists(leafsnap_cache_path): | |
| logger.info(f" β LeafSnap cache found at {leafsnap_cache_path}") | |
| try: | |
| ls_data = _load_plantclef_cache(leafsnap_cache_path) | |
| if not isinstance(ls_data, dict) or "labels" not in ls_data: | |
| raise KeyError("Missing 'labels' in LeafSnap cache") | |
| self.leafsnap_labels = ls_data["labels"] | |
| logger.info(f" β LeafSnap cache loaded: {len(self.leafsnap_labels)} species") | |
| except Exception as e: | |
| logger.warning(f" β Failed to load LeafSnap cache: {e}") | |
| self.leafsnap_index = None | |
| self.leafsnap_labels = [] | |
| else: | |
| logger.warning(f" β LeafSnap cache not found at {leafsnap_cache_path}") | |
| else: | |
| logger.info(f" β LeafSnap index not found at {leafsnap_index_path}") | |
| if self.leafsnap_index and self.leafsnap_labels: | |
| logger.info(f"β LeafSnap enabled for search") | |
| else: | |
| logger.info(f"β LeafSnap disabled - search will use PlantCLEF only") | |
| # Dict mapping LeafSnap label -> canonical DB species name | |
| self.leafsnap_aliases: dict[str, str] = leafsnap_aliases or {} | |
| def embed_image(self, path): | |
| img = self.preprocess(Image.open(path).convert("RGB")).unsqueeze(0) | |
| with torch.no_grad(): | |
| e = self.model.encode_image(img) | |
| e = e / e.norm(dim=-1, keepdim=True) | |
| return e.cpu().numpy().astype("float32") | |
| def _search_index(self, q, index, labels, k, index_name="Index"): | |
| """Search a single FAISS index and aggregate scores by species.""" | |
| sims, idxs = index.search(q, k) | |
| aggregated: dict = defaultdict(lambda: {'score_sum': 0.0, 'count': 0, 'image_paths': []}) | |
| for score, idx in zip(sims[0], idxs[0]): | |
| species_label = labels[idx] | |
| species_label = self.leafsnap_aliases.get(species_label, species_label) | |
| aggregated[species_label]['score_sum'] += score | |
| aggregated[species_label]['count'] += 1 | |
| results = [ | |
| (cat, d['score_sum'] / d['count'], d['image_paths']) | |
| for cat, d in aggregated.items() | |
| ] | |
| results.sort(key=lambda x: x[1], reverse=True) | |
| return results | |
| def search(self, path, labels, k=5, debug=False, search_k=None, return_scores=False): | |
| q = self.embed_image(path) | |
| effective_k = max(int(search_k or k), int(k)) | |
| planclef_results = self._search_index(q, self.index, labels, effective_k, "PlantCLEF") | |
| top_planclef_score = planclef_results[0][1] if planclef_results else 0.0 | |
| if self.leafsnap_index is not None and self.leafsnap_labels: | |
| leafsnap_results = self._search_index(q, self.leafsnap_index, self.leafsnap_labels, effective_k, "LeafSnap") | |
| merged_all = _rrf_merge([planclef_results, leafsnap_results], debug_log=debug) | |
| merged = merged_all[:k] | |
| if debug: | |
| logger.info(f"\n=== SEARCH SUMMARY (RRF Merged) ===") | |
| logger.info(f"PlantCLEF found {len(planclef_results)} species, top 10:") | |
| for rank, (sp, score, _) in enumerate(planclef_results[:10]): | |
| logger.info(f" {rank}: {sp} ({score:.4f})") | |
| logger.info(f"LeafSnap found {len(leafsnap_results)} species, top 10:") | |
| for rank, (sp, score, _) in enumerate(leafsnap_results[:10]): | |
| logger.info(f" {rank}: {sp} ({score:.4f})") | |
| logger.info(f"\nFinal RRF merged ranking (top 50, requested k={k}, search_k={effective_k}):") | |
| for rank, (sp, score, _) in enumerate(merged_all[:50]): | |
| logger.info(f" {rank}: {sp} (RRF: {score:.6f})") | |
| if return_scores: | |
| return merged, top_planclef_score | |
| return merged | |
| if debug: | |
| logger.info(f"\n=== SEARCH SUMMARY (PlantCLEF only) ===") | |
| logger.info(f"Found {len(planclef_results)} species, showing top 50 (requested k={k}, search_k={effective_k}):") | |
| for rank, (sp, score, _) in enumerate(planclef_results[:50]): | |
| logger.info(f" {rank}: {sp} ({score:.6f})") | |
| if return_scores: | |
| return planclef_results[:k], top_planclef_score | |
| return planclef_results[:k] | |