Frenchizer commited on
Commit
587a53a
·
1 Parent(s): f2fd329

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -28
app.py CHANGED
@@ -1,38 +1,59 @@
1
  import gradio as gr
2
  import onnxruntime as ort
3
- from transformers import AutoTokenizer, pipeline
4
  import numpy as np
5
 
6
- MODEL_FILE = "./model.onnx"
7
- session = ort.InferenceSession(MODEL_FILE)
 
 
 
 
 
 
 
8
  tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
9
 
10
  labels = [
11
- "general", "pharma", "legal", "technical", "UI", "user interface", "medicine", "information technology", "IT", "marketing",
12
- "e-commerce", "programming", "website", "html", "keywords", "food commerce",
13
- "personal development", "literature", "poetry", "physics", "chemistry", "biology",
14
- "theater", "finance", "sports", "education", "politics", "economics", "art",
15
- "history", "music", "gaming", "aerospace", "engineering", "robotics", "travel",
16
- "tourism", "healthcare", "psychology", "environment", "fashion", "design",
17
- "real estate", "retail", "news", "entertainment", "social media", "automotive",
18
- "machine learning", "cryptocurrency", "blockchain", "philosophy", "anthropology",
19
- "archaeology", "data science"
20
  ]
21
 
22
- #context_pipeline = pipeline("zero-shot-classification", model="MoritzLaurer/deberta-v3-base-mnli-fever-anli")
23
- context_pipeline = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
24
-
25
  def detect_context(input_text):
26
- result = context_pipeline(input_text, candidate_labels=labels)
27
- contexts = [label for label, score in zip(result["labels"], result["scores"]) if label in labels and score > 0.1]
28
- print(contexts)
29
- return contexts or ["general"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
-
32
- def translate_text(input_text, context):
33
  tokenized_input = tokenizer(
34
- input_text, return_tensors="np", padding=True, truncation=True, max_length=512
 
35
  )
 
36
  input_ids = tokenized_input["input_ids"].astype(np.int64)
37
  attention_mask = tokenized_input["attention_mask"].astype(np.int64)
38
 
@@ -40,7 +61,7 @@ def translate_text(input_text, context):
40
  decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
41
 
42
  for _ in range(512):
43
- outputs = session.run(
44
  None,
45
  {
46
  "input_ids": input_ids,
@@ -60,16 +81,14 @@ def translate_text(input_text, context):
60
 
61
  return tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
62
 
63
-
64
  def process_request(input_text):
65
- contexts = detect_context(input_text)
66
- translation = translate_text(input_text, contexts[0]) # Use the first detected context
67
  return translation
68
 
69
-
70
  gr.Interface(
71
  fn=process_request,
72
  inputs="text",
73
  outputs="text",
74
  live=True
75
- ).launch()
 
1
  import gradio as gr
2
  import onnxruntime as ort
3
+ from transformers import AutoTokenizer
4
  import numpy as np
5
 
6
+ # Initialize models
7
+ context_model_file = "./bart-base-nmli.onnx"
8
+ translation_model_file = "./model.onnx"
9
+
10
+ # Create inference sessions for both models
11
+ context_session = ort.InferenceSession(context_model_file)
12
+ translation_session = ort.InferenceSession(translation_model_file)
13
+
14
+ # Load tokenizer for translation model
15
  tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
16
 
17
  labels = [
18
+ "general", "pharma", "legal", "technical", "UI", "user interface", "medicine",
19
+ "information technology", "IT", "marketing", "e-commerce", "programming",
20
+ "website", "html", "keywords", "food commerce", "personal development",
21
+ "literature", "poetry", "physics", "chemistry", "biology", "theater",
22
+ "finance", "sports", "education", "politics", "economics", "art", "history",
23
+ "music", "gaming", "aerospace", "engineering", "robotics", "travel", "tourism",
24
+ "healthcare", "psychology", "environment", "fashion", "design", "real estate",
25
+ "retail", "news", "entertainment", "social media","automotive", "machine learning",
26
+ "cryptocurrency","blockchain","philosophy","anthropology","archaeology","data science"
27
  ]
28
 
 
 
 
29
  def detect_context(input_text):
30
+ # Tokenize input text
31
+ inputs = tokenizer(input_text, return_tensors="np")
32
+
33
+ # Prepare input for ONNX model
34
+ input_ids = inputs["input_ids"].astype(np.int64)
35
+
36
+ # Run inference with context model
37
+ outputs = context_session.run(None, {"input_ids": input_ids})
38
+
39
+ # Assuming the output is logits for each label
40
+ scores = outputs[0]
41
+
42
+ # Get the top label based on scores
43
+ top_label_index = np.argmax(scores, axis=1)[0]
44
+
45
+ # Map index to label
46
+ detected_context = labels[top_label_index]
47
+
48
+ print(detected_context)
49
+ return detected_context
50
 
51
+ def translate_text(input_text):
 
52
  tokenized_input = tokenizer(
53
+ input_text, return_tensors="np",
54
+ padding=True, truncation=True, max_length=512
55
  )
56
+
57
  input_ids = tokenized_input["input_ids"].astype(np.int64)
58
  attention_mask = tokenized_input["attention_mask"].astype(np.int64)
59
 
 
61
  decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
62
 
63
  for _ in range(512):
64
+ outputs = translation_session.run(
65
  None,
66
  {
67
  "input_ids": input_ids,
 
81
 
82
  return tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
83
 
 
84
  def process_request(input_text):
85
+ context = detect_context(input_text)
86
+ translation = translate_text(input_text) # Translate without needing to pass context explicitly
87
  return translation
88
 
 
89
  gr.Interface(
90
  fn=process_request,
91
  inputs="text",
92
  outputs="text",
93
  live=True
94
+ ).launch()