Spaces:
Running
Running
| import os | |
| import shutil | |
| from pathlib import Path | |
| from sentence_transformers import CrossEncoder | |
| from nltk import sent_tokenize | |
| import numpy as np | |
| from huggingface_hub import try_to_load_from_cache, snapshot_download | |
| from transformers import AutoConfig | |
| class CrossEncoderSimilarity: | |
| """ | |
| Uses a cross‑encoder to compute deep semantic similarity between mark and goods. | |
| Includes robust cache clearing and multiple fallback models. | |
| """ | |
| def __init__(self, | |
| primary_model='cross-encoder/stsb-roberta-large', | |
| fallback_model='cross-encoder/stsb-distilroberta-base', | |
| second_fallback='cross-encoder/ms-marco-MiniLM-L-6-v2'): | |
| self.primary_model_name = primary_model | |
| self.fallback_model_name = fallback_model | |
| self.second_fallback_name = second_fallback | |
| self._model = None | |
| self.current_model_name = None | |
| def model(self): | |
| """Lazy load the cross-encoder model with automatic cache clearing and fallbacks.""" | |
| if self._model is None: | |
| # Try primary model | |
| self.current_model_name = self.primary_model_name | |
| self._model = self._load_model_with_retry(self.primary_model_name) | |
| if self._model is None: | |
| print(f"⚠️ Primary model failed. Attempting first fallback: {self.fallback_model_name}") | |
| self.current_model_name = self.fallback_model_name | |
| self._model = self._load_model_with_retry(self.fallback_model_name) | |
| if self._model is None: | |
| print(f"⚠️ First fallback failed. Attempting second fallback: {self.second_fallback_name}") | |
| self.current_model_name = self.second_fallback_name | |
| self._model = self._load_model_with_retry(self.second_fallback_name) | |
| if self._model is None: | |
| raise RuntimeError("All cross-encoder models failed to load.") | |
| return self._model | |
| def _clear_cache_for_model(self, model_name): | |
| """ | |
| Use huggingface_hub to find and remove the entire snapshot directory for a model. | |
| Returns True if something was removed, False otherwise. | |
| """ | |
| # Try to get a cached file (e.g., config.json) to locate the snapshot | |
| cached_file = try_to_load_from_cache( | |
| model_name, | |
| filename="config.json", | |
| cache_dir=os.environ.get("HF_HOME") | |
| ) | |
| if cached_file and cached_file != "_CACHED_NOFILE" and os.path.exists(cached_file): | |
| # The cached_file path is something like: | |
| # /tmp/.cache/huggingface/hub/models--org--model/snapshots/abcd1234/config.json | |
| # We want to remove the entire snapshot directory. | |
| snapshot_dir = Path(cached_file).parent | |
| if snapshot_dir.exists() and snapshot_dir.is_dir(): | |
| print(f"🗑️ Removing corrupted snapshot: {snapshot_dir}") | |
| shutil.rmtree(snapshot_dir) | |
| return True | |
| # If that didn't work, try to remove the whole model cache directory | |
| model_id = model_name.replace("/", "--") | |
| hf_home = os.environ.get("HF_HOME", "/tmp/.cache/huggingface") | |
| possible_paths = [ | |
| Path(hf_home) / "hub" / f"models--{model_id}", | |
| Path(hf_home) / "models--{model_id}", | |
| Path(hf_home) / model_name.replace("/", "--"), | |
| ] | |
| for p in possible_paths: | |
| if p.exists(): | |
| print(f"🗑️ Removing model cache directory: {p}") | |
| shutil.rmtree(p) | |
| return True | |
| return False | |
| def _load_model_with_retry(self, model_name): | |
| """Attempt to load a model, clear cache on failure, and retry with force_download.""" | |
| try: | |
| print(f"Loading cross-encoder model: {model_name}") | |
| model = CrossEncoder(model_name, num_labels=1) | |
| print(f"✅ Cross-encoder model '{model_name}' loaded.") | |
| return model | |
| except Exception as e: | |
| print(f"❌ Error loading model '{model_name}': {e}. Attempting to clear cache...") | |
| if self._clear_cache_for_model(model_name): | |
| print("Cache cleared. Retrying model load with force_download...") | |
| try: | |
| # Force a fresh download | |
| model = CrossEncoder(model_name, num_labels=1, force_download=True) | |
| print(f"✅ Cross-encoder model '{model_name}' loaded after cache clear.") | |
| return model | |
| except Exception as e2: | |
| print(f"❌ Still failed after cache clear: {e2}") | |
| return None | |
| else: | |
| print("Cache directory not found. Cannot clear.") | |
| return None | |
| def similarity(self, mark, goods, return_segments=False): | |
| if not goods: | |
| return 0.0 if not return_segments else (0.0, None) | |
| sentences = sent_tokenize(goods) | |
| if not sentences: | |
| return 0.0 if not return_segments else (0.0, None) | |
| pairs = [(mark, sent) for sent in sentences] | |
| scores = self.model.predict(pairs) | |
| # Normalize (assuming stsb model output range 0-5) | |
| scores_norm = [min(1.0, max(0.0, s / 5.0)) for s in scores] | |
| max_score = max(scores_norm) | |
| max_idx = int(np.argmax(scores_norm)) | |
| if return_segments: | |
| return max_score, sentences[max_idx] | |
| return max_score | |
| def similarity_with_explanation(self, mark, goods): | |
| max_score, best_sentence = self.similarity(mark, goods, return_segments=True) | |
| explanation = f"Highest similarity with segment: '{best_sentence}' (score: {max_score:.2f})" | |
| return max_score, explanation |