disk / app /src /cross_encoder.py
DIVYA-NSHU99's picture
Update app/src/cross_encoder.py
dd15286 verified
import os
import shutil
from pathlib import Path
from sentence_transformers import CrossEncoder
from nltk import sent_tokenize
import numpy as np
from huggingface_hub import try_to_load_from_cache, snapshot_download
from transformers import AutoConfig
class CrossEncoderSimilarity:
"""
Uses a cross‑encoder to compute deep semantic similarity between mark and goods.
Includes robust cache clearing and multiple fallback models.
"""
def __init__(self,
primary_model='cross-encoder/stsb-roberta-large',
fallback_model='cross-encoder/stsb-distilroberta-base',
second_fallback='cross-encoder/ms-marco-MiniLM-L-6-v2'):
self.primary_model_name = primary_model
self.fallback_model_name = fallback_model
self.second_fallback_name = second_fallback
self._model = None
self.current_model_name = None
@property
def model(self):
"""Lazy load the cross-encoder model with automatic cache clearing and fallbacks."""
if self._model is None:
# Try primary model
self.current_model_name = self.primary_model_name
self._model = self._load_model_with_retry(self.primary_model_name)
if self._model is None:
print(f"⚠️ Primary model failed. Attempting first fallback: {self.fallback_model_name}")
self.current_model_name = self.fallback_model_name
self._model = self._load_model_with_retry(self.fallback_model_name)
if self._model is None:
print(f"⚠️ First fallback failed. Attempting second fallback: {self.second_fallback_name}")
self.current_model_name = self.second_fallback_name
self._model = self._load_model_with_retry(self.second_fallback_name)
if self._model is None:
raise RuntimeError("All cross-encoder models failed to load.")
return self._model
def _clear_cache_for_model(self, model_name):
"""
Use huggingface_hub to find and remove the entire snapshot directory for a model.
Returns True if something was removed, False otherwise.
"""
# Try to get a cached file (e.g., config.json) to locate the snapshot
cached_file = try_to_load_from_cache(
model_name,
filename="config.json",
cache_dir=os.environ.get("HF_HOME")
)
if cached_file and cached_file != "_CACHED_NOFILE" and os.path.exists(cached_file):
# The cached_file path is something like:
# /tmp/.cache/huggingface/hub/models--org--model/snapshots/abcd1234/config.json
# We want to remove the entire snapshot directory.
snapshot_dir = Path(cached_file).parent
if snapshot_dir.exists() and snapshot_dir.is_dir():
print(f"🗑️ Removing corrupted snapshot: {snapshot_dir}")
shutil.rmtree(snapshot_dir)
return True
# If that didn't work, try to remove the whole model cache directory
model_id = model_name.replace("/", "--")
hf_home = os.environ.get("HF_HOME", "/tmp/.cache/huggingface")
possible_paths = [
Path(hf_home) / "hub" / f"models--{model_id}",
Path(hf_home) / "models--{model_id}",
Path(hf_home) / model_name.replace("/", "--"),
]
for p in possible_paths:
if p.exists():
print(f"🗑️ Removing model cache directory: {p}")
shutil.rmtree(p)
return True
return False
def _load_model_with_retry(self, model_name):
"""Attempt to load a model, clear cache on failure, and retry with force_download."""
try:
print(f"Loading cross-encoder model: {model_name}")
model = CrossEncoder(model_name, num_labels=1)
print(f"✅ Cross-encoder model '{model_name}' loaded.")
return model
except Exception as e:
print(f"❌ Error loading model '{model_name}': {e}. Attempting to clear cache...")
if self._clear_cache_for_model(model_name):
print("Cache cleared. Retrying model load with force_download...")
try:
# Force a fresh download
model = CrossEncoder(model_name, num_labels=1, force_download=True)
print(f"✅ Cross-encoder model '{model_name}' loaded after cache clear.")
return model
except Exception as e2:
print(f"❌ Still failed after cache clear: {e2}")
return None
else:
print("Cache directory not found. Cannot clear.")
return None
def similarity(self, mark, goods, return_segments=False):
if not goods:
return 0.0 if not return_segments else (0.0, None)
sentences = sent_tokenize(goods)
if not sentences:
return 0.0 if not return_segments else (0.0, None)
pairs = [(mark, sent) for sent in sentences]
scores = self.model.predict(pairs)
# Normalize (assuming stsb model output range 0-5)
scores_norm = [min(1.0, max(0.0, s / 5.0)) for s in scores]
max_score = max(scores_norm)
max_idx = int(np.argmax(scores_norm))
if return_segments:
return max_score, sentences[max_idx]
return max_score
def similarity_with_explanation(self, mark, goods):
max_score, best_sentence = self.similarity(mark, goods, return_segments=True)
explanation = f"Highest similarity with segment: '{best_sentence}' (score: {max_score:.2f})"
return max_score, explanation