Frenchizer commited on
Commit
c905911
·
verified ·
1 Parent(s): 581a0fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -21
app.py CHANGED
@@ -4,16 +4,16 @@ from transformers import AutoTokenizer
4
  import numpy as np
5
 
6
  # Initialize models
7
- context_model_file = "./bart-large-mnli.onnx"
8
  translation_model_file = "./model.onnx"
9
 
10
  # Create inference sessions for both models
11
  context_session = ort.InferenceSession(context_model_file)
12
  translation_session = ort.InferenceSession(translation_model_file)
13
 
14
- # Load tokenizer for translation model and for context model
15
- tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
16
- context_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
17
 
18
  labels = [
19
  "general", "pharma", "legal", "technical", "UI", "user interface", "medicine",
@@ -23,51 +23,71 @@ labels = [
23
  "finance", "sports", "education", "politics", "economics", "art", "history",
24
  "music", "gaming", "aerospace", "engineering", "robotics", "travel", "tourism",
25
  "healthcare", "psychology", "environment", "fashion", "design", "real estate",
26
- "retail", "news", "entertainment", "social media","automotive", "machine learning",
27
- "cryptocurrency","blockchain","philosophy","anthropology","archaeology","data science"
28
  ]
29
 
30
- def detect_context(input_text, top_n=3, score_threshold=0.05):
 
 
 
 
31
  # Tokenize input text
32
  inputs = context_tokenizer(input_text, return_tensors="np", padding=True, truncation=True, max_length=512)
33
-
34
  input_ids = inputs["input_ids"].astype(np.int64)
35
  attention_mask = inputs["attention_mask"].astype(np.int64)
36
 
 
 
 
 
37
  # Run inference with the ONNX context model
38
  outputs = context_session.run(None, {
39
  "input_ids": input_ids,
40
  "attention_mask": attention_mask
41
  })
42
 
43
- scores = outputs[0][0] # Assuming batch size 1; take the first set of logits
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # Pair labels with scores
46
  label_scores = [(label, score) for label, score in zip(labels, scores)]
47
 
 
 
 
48
  # Sort by scores in descending order
49
  sorted_labels = sorted(label_scores, key=lambda x: x[1], reverse=True)
50
 
51
- # Filter by threshold and return top_n contexts
52
  filtered_labels = [label for label, score in sorted_labels if score > score_threshold]
53
- top_contexts = filtered_labels[:top_n]
54
-
55
- print(f"All scores: {label_scores}") # Debugging: Print all scores
56
- print(f"Selected contexts: {top_contexts}") # Debugging: Print selected contexts
57
 
58
- return top_contexts if top_contexts else ["general"]
 
59
 
 
 
60
 
61
  def translate_text(input_text):
62
- tokenized_input = tokenizer(
63
  input_text, return_tensors="np",
64
  padding=True, truncation=True, max_length=512
65
  )
66
-
67
  input_ids = tokenized_input["input_ids"].astype(np.int64)
68
  attention_mask = tokenized_input["attention_mask"].astype(np.int64)
69
 
70
- decoder_start_token_id = tokenizer.cls_token_id or tokenizer.pad_token_id
71
  decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
72
 
73
  for _ in range(512):
@@ -86,10 +106,10 @@ def translate_text(input_text):
86
  [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
87
  )
88
 
89
- if next_token_id == tokenizer.eos_token_id:
90
  break
91
 
92
- return tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
93
 
94
  def process_request(input_text):
95
  context = detect_context(input_text)
@@ -101,4 +121,4 @@ gr.Interface(
101
  inputs="text",
102
  outputs="text",
103
  live=True
104
- ).launch()
 
4
  import numpy as np
5
 
6
  # Initialize models
7
+ context_model_file = "./bart-base-mnli.onnx" # Using bart-base-mnli for faster inference
8
  translation_model_file = "./model.onnx"
9
 
10
  # Create inference sessions for both models
11
  context_session = ort.InferenceSession(context_model_file)
12
  translation_session = ort.InferenceSession(translation_model_file)
13
 
14
+ # Load tokenizers for context and translation models
15
+ context_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base-mnli")
16
+ translation_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
17
 
18
  labels = [
19
  "general", "pharma", "legal", "technical", "UI", "user interface", "medicine",
 
23
  "finance", "sports", "education", "politics", "economics", "art", "history",
24
  "music", "gaming", "aerospace", "engineering", "robotics", "travel", "tourism",
25
  "healthcare", "psychology", "environment", "fashion", "design", "real estate",
26
+ "retail", "news", "entertainment", "social media", "automotive", "machine learning",
27
+ "cryptocurrency", "blockchain", "philosophy", "anthropology", "archaeology", "data science"
28
  ]
29
 
30
+ def softmax_with_temperature(logits, temperature=1.0):
31
+ exp_logits = np.exp(logits / temperature)
32
+ return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
33
+
34
+ def detect_context(input_text, temperature=2.0, score_threshold=0.05):
35
  # Tokenize input text
36
  inputs = context_tokenizer(input_text, return_tensors="np", padding=True, truncation=True, max_length=512)
 
37
  input_ids = inputs["input_ids"].astype(np.int64)
38
  attention_mask = inputs["attention_mask"].astype(np.int64)
39
 
40
+ # Debugging: Check tokenized input
41
+ print(f"Tokenized Input IDs: {input_ids}")
42
+ print(f"Tokenized Attention Mask: {attention_mask}")
43
+
44
  # Run inference with the ONNX context model
45
  outputs = context_session.run(None, {
46
  "input_ids": input_ids,
47
  "attention_mask": attention_mask
48
  })
49
 
50
+ # Debugging: Check output shape
51
+ print(f"Logits shape: {outputs[0].shape}") # Expected: (batch_size, num_labels)
52
+
53
+ logits = outputs[0][0] # Assuming batch size 1; take the first set of logits
54
+
55
+ # Debugging: Print raw logits
56
+ print(f"Raw logits: {logits}")
57
+
58
+ # Apply softmax with temperature
59
+ scores = softmax_with_temperature(logits, temperature=temperature)
60
+
61
+ # Debugging: Print scores
62
+ print(f"Scores with softmax: {scores}")
63
 
64
  # Pair labels with scores
65
  label_scores = [(label, score) for label, score in zip(labels, scores)]
66
 
67
+ # Debugging: Print all label scores
68
+ print(f"All label scores: {label_scores}")
69
+
70
  # Sort by scores in descending order
71
  sorted_labels = sorted(label_scores, key=lambda x: x[1], reverse=True)
72
 
73
+ # Filter by threshold
74
  filtered_labels = [label for label, score in sorted_labels if score > score_threshold]
 
 
 
 
75
 
76
+ # Debugging: Print filtered labels
77
+ print(f"Filtered labels: {filtered_labels}")
78
 
79
+ # Default to "general" if no valid context is found
80
+ return filtered_labels if filtered_labels else ["general"]
81
 
82
  def translate_text(input_text):
83
+ tokenized_input = translation_tokenizer(
84
  input_text, return_tensors="np",
85
  padding=True, truncation=True, max_length=512
86
  )
 
87
  input_ids = tokenized_input["input_ids"].astype(np.int64)
88
  attention_mask = tokenized_input["attention_mask"].astype(np.int64)
89
 
90
+ decoder_start_token_id = translation_tokenizer.cls_token_id or translation_tokenizer.pad_token_id
91
  decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
92
 
93
  for _ in range(512):
 
106
  [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
107
  )
108
 
109
+ if next_token_id == translation_tokenizer.eos_token_id:
110
  break
111
 
112
+ return translation_tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
113
 
114
  def process_request(input_text):
115
  context = detect_context(input_text)
 
121
  inputs="text",
122
  outputs="text",
123
  live=True
124
+ ).launch()