Frenchizer commited on
Commit
1f69442
·
1 Parent(s): 848fa80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -17
app.py CHANGED
@@ -27,31 +27,36 @@ labels = [
27
  "cryptocurrency","blockchain","philosophy","anthropology","archaeology","data science"
28
  ]
29
 
30
- def detect_context(input_text):
31
  # Tokenize input text
32
  inputs = context_tokenizer(input_text, return_tensors="np", padding=True, truncation=True, max_length=512)
33
 
34
- # Prepare input for ONNX model
35
  input_ids = inputs["input_ids"].astype(np.int64)
36
  attention_mask = inputs["attention_mask"].astype(np.int64)
37
-
38
- # Run inference with context model
39
  outputs = context_session.run(None, {
40
  "input_ids": input_ids,
41
  "attention_mask": attention_mask
42
- })
43
-
44
- # Assuming the output is logits for each label
45
- scores = outputs[0]
46
-
47
- # Get the top label based on scores
48
- top_label_index = np.argmax(scores, axis=1)[0]
49
-
50
- # Map index to label
51
- detected_context = labels[top_label_index]
52
-
53
- print(detected_context)
54
- return detected_context
 
 
 
 
 
 
55
 
56
  def translate_text(input_text):
57
  tokenized_input = tokenizer(
 
27
  "cryptocurrency","blockchain","philosophy","anthropology","archaeology","data science"
28
  ]
29
 
30
+ def detect_context(input_text, top_n=3, score_threshold=0.05):
31
  # Tokenize input text
32
  inputs = context_tokenizer(input_text, return_tensors="np", padding=True, truncation=True, max_length=512)
33
 
 
34
  input_ids = inputs["input_ids"].astype(np.int64)
35
  attention_mask = inputs["attention_mask"].astype(np.int64)
36
+
37
+ # Run inference with the ONNX context model
38
  outputs = context_session.run(None, {
39
  "input_ids": input_ids,
40
  "attention_mask": attention_mask
41
+ })
42
+
43
+ scores = outputs[0][0] # Assuming batch size 1; take the first set of logits
44
+
45
+ # Pair labels with scores
46
+ label_scores = [(label, score) for label, score in zip(labels, scores)]
47
+
48
+ # Sort by scores in descending order
49
+ sorted_labels = sorted(label_scores, key=lambda x: x[1], reverse=True)
50
+
51
+ # Filter by threshold and return top_n contexts
52
+ filtered_labels = [label for label, score in sorted_labels if score > score_threshold]
53
+ top_contexts = filtered_labels[:top_n]
54
+
55
+ print(f"All scores: {label_scores}") # Debugging: Print all scores
56
+ print(f"Selected contexts: {top_contexts}") # Debugging: Print selected contexts
57
+
58
+ return top_contexts if top_contexts else ["general"]
59
+
60
 
61
  def translate_text(input_text):
62
  tokenized_input = tokenizer(