Frenchizer commited on
Commit
6b27907
·
verified ·
1 Parent(s): 65f8dcc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -22
app.py CHANGED
@@ -4,13 +4,12 @@ from sklearn.metrics.pairwise import cosine_similarity
4
  import torch
5
  import numpy as np
6
  from gradio_client import Client
7
-
8
- # Cache the model and tokenizer using lru_cache
9
  from functools import lru_cache
10
 
 
11
  @lru_cache(maxsize=1)
12
  def load_model_and_tokenizer():
13
- model_name = "./all-MiniLM-L6-v2"
14
  tokenizer = AutoTokenizer.from_pretrained(model_name)
15
  model = AutoModel.from_pretrained(model_name)
16
  return tokenizer, model
@@ -36,33 +35,31 @@ labels = [
36
 
37
  @lru_cache(maxsize=1)
38
  def precompute_label_embeddings():
39
- def encode_text(texts):
40
- inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
41
- with torch.no_grad():
42
- outputs = model(**inputs)
43
- return outputs.last_hidden_state.mean(dim=1).numpy() # Use mean pooling for embeddings
44
- return encode_text(labels)
45
 
46
  label_embeddings = precompute_label_embeddings()
47
 
48
- # Function to detect context
49
  def detect_context(input_text, high_confidence_threshold=0.9, fallback_threshold=0.8, max_results=3):
50
- def encode_text(texts):
51
- inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
52
- with torch.no_grad():
53
- outputs = model(**inputs)
54
- return outputs.last_hidden_state.mean(dim=1).numpy() # Use mean pooling for embeddings
55
 
56
- input_embedding = encode_text([input_text])
57
  similarities = cosine_similarity(input_embedding, label_embeddings)[0]
58
 
59
- for label, score in zip(labels, similarities):
60
- if score >= high_confidence_threshold:
61
- return [label]
62
 
63
- label_scores = [(label, score) for label, score in zip(labels, similarities) if score >= fallback_threshold]
64
- sorted_labels = sorted(label_scores, key=lambda x: x[1], reverse=True)[:max_results]
65
- return [label for label, score in sorted_labels] if sorted_labels else ["general"]
66
 
67
  # Translation client
68
  translation_client = Client("Frenchizer/space_3")
 
4
  import torch
5
  import numpy as np
6
  from gradio_client import Client
 
 
7
  from functools import lru_cache
8
 
9
+ # Cache the model and tokenizer using lru_cache
10
  @lru_cache(maxsize=1)
11
  def load_model_and_tokenizer():
12
+ model_name = "./all-MiniLM-L6-v2" # Replace with your Space and model path
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
  model = AutoModel.from_pretrained(model_name)
15
  return tokenizer, model
 
35
 
36
  @lru_cache(maxsize=1)
37
  def precompute_label_embeddings():
38
+ inputs = tokenizer(labels, padding=True, truncation=True, return_tensors="pt")
39
+ with torch.no_grad():
40
+ outputs = model(**inputs)
41
+ return outputs.last_hidden_state.mean(dim=1).numpy() # Mean pooling for embeddings
 
 
42
 
43
  label_embeddings = precompute_label_embeddings()
44
 
45
+ # Function to detect context (optimized)
46
  def detect_context(input_text, high_confidence_threshold=0.9, fallback_threshold=0.8, max_results=3):
47
+ # Encode the input text
48
+ inputs = tokenizer([input_text], padding=True, truncation=True, return_tensors="pt")
49
+ with torch.no_grad():
50
+ outputs = model(**inputs)
51
+ input_embedding = outputs.last_hidden_state.mean(dim=1).numpy() # Mean pooling for embedding
52
 
53
+ # Compute cosine similarities (optimized)
54
  similarities = cosine_similarity(input_embedding, label_embeddings)[0]
55
 
56
+ # Find top-N labels based on thresholds
57
+ top_indices = np.argsort(similarities)[-max_results:][::-1]
58
+ top_labels = [labels[i] for i in top_indices if similarities[i] >= fallback_threshold]
59
 
60
+ # Return high-confidence labels if any, otherwise fallback labels
61
+ high_conf_labels = [label for label in top_labels if similarities[labels.index(label)] >= high_confidence_threshold]
62
+ return high_conf_labels if high_conf_labels else top_labels[:max_results]
63
 
64
  # Translation client
65
  translation_client = Client("Frenchizer/space_3")