DIVYA-NSHU99 commited on
Commit
4c5dfd9
·
verified ·
1 Parent(s): 80de863

Update app/src/cross_encoder.py

Browse files
Files changed (1) hide show
  1. app/src/cross_encoder.py +29 -15
app/src/cross_encoder.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from sentence_transformers import CrossEncoder
2
  from nltk import sent_tokenize
3
  import numpy as np
@@ -5,19 +7,37 @@ import numpy as np
5
  class CrossEncoderSimilarity:
6
  """
7
  Uses a cross‑encoder to compute deep semantic similarity between mark and goods.
8
- Supports sentence‑level segmentation and returns attention weights for explainability.
9
  """
10
 
11
  def __init__(self, model_name='cross-encoder/stsb-roberta-large'):
12
- self.model = CrossEncoder(model_name, num_labels=1) # regression output
13
- # We'll store the last attention scores if needed (for explainability)
14
- self.last_attention = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def similarity(self, mark, goods, return_segments=False):
17
- """
18
- Returns a score between 0 and 1. If return_segments=True, also returns
19
- the maximum segment score and the segment text.
20
- """
21
  if not goods:
22
  return 0.0 if not return_segments else (0.0, None)
23
  sentences = sent_tokenize(goods)
@@ -26,8 +46,7 @@ class CrossEncoderSimilarity:
26
 
27
  pairs = [(mark, sent) for sent in sentences]
28
  scores = self.model.predict(pairs)
29
- # Normalize: assume model output range roughly 0-5 (for stsb models)
30
- # If using a different model, adjust normalization accordingly.
31
  scores_norm = [min(1.0, max(0.0, s / 5.0)) for s in scores]
32
  max_score = max(scores_norm)
33
  max_idx = int(np.argmax(scores_norm))
@@ -37,11 +56,6 @@ class CrossEncoderSimilarity:
37
  return max_score
38
 
39
  def similarity_with_explanation(self, mark, goods):
40
- """
41
- Returns score and the most relevant sentence from goods, plus optionally attention.
42
- For attention, we'd need a model that returns cross‑attention; not all do.
43
- This method provides a simple explanation.
44
- """
45
  max_score, best_sentence = self.similarity(mark, goods, return_segments=True)
46
  explanation = f"Highest similarity with segment: '{best_sentence}' (score: {max_score:.2f})"
47
  return max_score, explanation
 
1
+ import os
2
+ import shutil
3
  from sentence_transformers import CrossEncoder
4
  from nltk import sent_tokenize
5
  import numpy as np
 
7
  class CrossEncoderSimilarity:
8
  """
9
  Uses a cross‑encoder to compute deep semantic similarity between mark and goods.
10
+ Supports sentence‑level segmentation and lazy model loading with auto cache clearing.
11
  """
12
 
13
  def __init__(self, model_name='cross-encoder/stsb-roberta-large'):
14
+ self.model_name = model_name
15
+ self._model = None
16
+
17
+ @property
18
+ def model(self):
19
+ """Lazy load the cross-encoder model, with retry and cache clearing on failure."""
20
+ if self._model is None:
21
+ try:
22
+ print(f"Loading cross-encoder model: {self.model_name}")
23
+ self._model = CrossEncoder(self.model_name, num_labels=1)
24
+ except Exception as e:
25
+ print(f"❌ Error loading model: {e}. Attempting to clear cache and retry...")
26
+ # Determine cache directory for this model
27
+ cache_dir = os.path.join(
28
+ os.environ.get("HF_HOME", "/tmp/.cache/huggingface"),
29
+ "models",
30
+ self.model_name.replace("/", "--")
31
+ )
32
+ if os.path.exists(cache_dir):
33
+ print(f"Removing corrupted cache: {cache_dir}")
34
+ shutil.rmtree(cache_dir)
35
+ print("Retrying model load...")
36
+ self._model = CrossEncoder(self.model_name, num_labels=1)
37
+ print("✅ Cross-encoder model loaded successfully after cache clear.")
38
+ return self._model
39
 
40
  def similarity(self, mark, goods, return_segments=False):
 
 
 
 
41
  if not goods:
42
  return 0.0 if not return_segments else (0.0, None)
43
  sentences = sent_tokenize(goods)
 
46
 
47
  pairs = [(mark, sent) for sent in sentences]
48
  scores = self.model.predict(pairs)
49
+ # Normalize (assuming stsb model output range 0-5)
 
50
  scores_norm = [min(1.0, max(0.0, s / 5.0)) for s in scores]
51
  max_score = max(scores_norm)
52
  max_idx = int(np.argmax(scores_norm))
 
56
  return max_score
57
 
58
  def similarity_with_explanation(self, mark, goods):
 
 
 
 
 
59
  max_score, best_sentence = self.similarity(mark, goods, return_segments=True)
60
  explanation = f"Highest similarity with segment: '{best_sentence}' (score: {max_score:.2f})"
61
  return max_score, explanation