import os import shutil from pathlib import Path from sentence_transformers import CrossEncoder from nltk import sent_tokenize import numpy as np from huggingface_hub import try_to_load_from_cache, snapshot_download from transformers import AutoConfig class CrossEncoderSimilarity: """ Uses a cross‑encoder to compute deep semantic similarity between mark and goods. Includes robust cache clearing and multiple fallback models. """ def __init__(self, primary_model='cross-encoder/stsb-roberta-large', fallback_model='cross-encoder/stsb-distilroberta-base', second_fallback='cross-encoder/ms-marco-MiniLM-L-6-v2'): self.primary_model_name = primary_model self.fallback_model_name = fallback_model self.second_fallback_name = second_fallback self._model = None self.current_model_name = None @property def model(self): """Lazy load the cross-encoder model with automatic cache clearing and fallbacks.""" if self._model is None: # Try primary model self.current_model_name = self.primary_model_name self._model = self._load_model_with_retry(self.primary_model_name) if self._model is None: print(f"⚠️ Primary model failed. Attempting first fallback: {self.fallback_model_name}") self.current_model_name = self.fallback_model_name self._model = self._load_model_with_retry(self.fallback_model_name) if self._model is None: print(f"⚠️ First fallback failed. Attempting second fallback: {self.second_fallback_name}") self.current_model_name = self.second_fallback_name self._model = self._load_model_with_retry(self.second_fallback_name) if self._model is None: raise RuntimeError("All cross-encoder models failed to load.") return self._model def _clear_cache_for_model(self, model_name): """ Use huggingface_hub to find and remove the entire snapshot directory for a model. Returns True if something was removed, False otherwise. """ # Try to get a cached file (e.g., config.json) to locate the snapshot cached_file = try_to_load_from_cache( model_name, filename="config.json", cache_dir=os.environ.get("HF_HOME") ) if cached_file and cached_file != "_CACHED_NOFILE" and os.path.exists(cached_file): # The cached_file path is something like: # /tmp/.cache/huggingface/hub/models--org--model/snapshots/abcd1234/config.json # We want to remove the entire snapshot directory. snapshot_dir = Path(cached_file).parent if snapshot_dir.exists() and snapshot_dir.is_dir(): print(f"🗑️ Removing corrupted snapshot: {snapshot_dir}") shutil.rmtree(snapshot_dir) return True # If that didn't work, try to remove the whole model cache directory model_id = model_name.replace("/", "--") hf_home = os.environ.get("HF_HOME", "/tmp/.cache/huggingface") possible_paths = [ Path(hf_home) / "hub" / f"models--{model_id}", Path(hf_home) / "models--{model_id}", Path(hf_home) / model_name.replace("/", "--"), ] for p in possible_paths: if p.exists(): print(f"🗑️ Removing model cache directory: {p}") shutil.rmtree(p) return True return False def _load_model_with_retry(self, model_name): """Attempt to load a model, clear cache on failure, and retry with force_download.""" try: print(f"Loading cross-encoder model: {model_name}") model = CrossEncoder(model_name, num_labels=1) print(f"✅ Cross-encoder model '{model_name}' loaded.") return model except Exception as e: print(f"❌ Error loading model '{model_name}': {e}. Attempting to clear cache...") if self._clear_cache_for_model(model_name): print("Cache cleared. Retrying model load with force_download...") try: # Force a fresh download model = CrossEncoder(model_name, num_labels=1, force_download=True) print(f"✅ Cross-encoder model '{model_name}' loaded after cache clear.") return model except Exception as e2: print(f"❌ Still failed after cache clear: {e2}") return None else: print("Cache directory not found. Cannot clear.") return None def similarity(self, mark, goods, return_segments=False): if not goods: return 0.0 if not return_segments else (0.0, None) sentences = sent_tokenize(goods) if not sentences: return 0.0 if not return_segments else (0.0, None) pairs = [(mark, sent) for sent in sentences] scores = self.model.predict(pairs) # Normalize (assuming stsb model output range 0-5) scores_norm = [min(1.0, max(0.0, s / 5.0)) for s in scores] max_score = max(scores_norm) max_idx = int(np.argmax(scores_norm)) if return_segments: return max_score, sentences[max_idx] return max_score def similarity_with_explanation(self, mark, goods): max_score, best_sentence = self.similarity(mark, goods, return_segments=True) explanation = f"Highest similarity with segment: '{best_sentence}' (score: {max_score:.2f})" return max_score, explanation