Frenchizer commited on
Commit
96a8ebf
·
verified ·
1 Parent(s): 7a24f06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -25
app.py CHANGED
@@ -1,37 +1,85 @@
1
- import os
2
- import numpy as np
3
  import onnxruntime as ort
4
  from transformers import AutoTokenizer
5
- import gradio as gr
 
6
 
7
- # Load the ONNX model and tokenizer
8
- model_path = "model.onnx"
9
- translation_session = ort.InferenceSession(model_path)
10
- translation_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
11
 
12
- def translate_text(input_text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # Tokenize input text
14
- tokenized_input = translation_tokenizer(
15
- input_text, return_tensors="np", padding=True, truncation=True, max_length=512
16
- )
17
 
18
- input_ids = tokenized_input["input_ids"].astype(np.int64)
19
- attention_mask = tokenized_input["attention_mask"].astype(np.int64)
20
 
21
- # Run inference with the ONNX model
22
- outputs = translation_session.run(
23
- None,
24
- {
25
- "input_ids": input_ids,
26
- "attention_mask": attention_mask,
27
- }
28
- )
29
 
30
- # Decode the output tokens
31
- translated_tokens = np.argmax(outputs[0], axis=-1)
32
- translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
33
 
34
- return translated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # Create a Gradio interface
 
 
 
 
 
 
 
 
 
37
  interface.launch()
 
1
+ import gradio as gr
 
2
  import onnxruntime as ort
3
  from transformers import AutoTokenizer
4
+ import numpy as np
5
+ from gradio_client import Client
6
 
7
+ # Initialize the context model
8
+ context_model_file = "./bart-large-mnli.onnx"
9
+ context_session = ort.InferenceSession(context_model_file)
10
+ context_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
11
 
12
+ # Initialize the Gradio client for the translation model
13
+ translation_client = Client("Frenchizer/Frenchizer-Translation-Model") # Replace with your Space name
14
+
15
+ labels = [
16
+ "aerospace", "anatomy", "anthropology", "art",
17
+ "automotive", "blockchain", "biology", "chemistry",
18
+ "cryptocurrency", "data science", "design", "e-commerce",
19
+ "education", "engineering", "entertainment", "environment",
20
+ "fashion", "finance", "food commerce", "general",
21
+ "gaming", "healthcare", "history", "html",
22
+ "information technology", "IT", "keywords", "legal",
23
+ "literature", "machine learning", "marketing", "medicine",
24
+ "music", "personal development", "philosophy", "physics",
25
+ "politics", "poetry", "programming", "real estate", "retail",
26
+ "robotics", "slang", "social media", "speech", "sports",
27
+ "sustained", "technical", "theater", "tourism", "travel"
28
+ ]
29
+
30
+ def softmax_with_temperature(logits, temperature=1.0):
31
+ exp_logits = np.exp(logits / temperature)
32
+ return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
33
+
34
+ def detect_context(input_text, top_n=3, score_threshold=0.05):
35
  # Tokenize input text
36
+ inputs = context_tokenizer(input_text, return_tensors="np", padding=True, truncation=True, max_length=512)
 
 
37
 
38
+ input_ids = inputs["input_ids"].astype(np.int64)
39
+ attention_mask = inputs["attention_mask"].astype(np.int64)
40
 
41
+ # Run inference with the ONNX context model
42
+ outputs = context_session.run(None, {
43
+ "input_ids": input_ids,
44
+ "attention_mask": attention_mask
45
+ })
 
 
 
46
 
47
+ scores = outputs[0][0] # Assuming batch size 1; take the first set of logits
 
 
48
 
49
+ # Pair labels with scores
50
+ label_scores = [(label, score) for label, score in zip(labels, scores)]
51
+
52
+ # Sort by scores in descending order
53
+ sorted_labels = sorted(label_scores, key=lambda x: x[1], reverse=True)
54
+
55
+ # Filter by threshold and return top_n contexts
56
+ filtered_labels = [label for label, score in sorted_labels if score > score_threshold]
57
+ top_contexts = filtered_labels[:top_n]
58
+
59
+ return top_contexts if top_contexts else ["general"]
60
+
61
+ def translate_text(input_text):
62
+ # Call the translation model via the Gradio client
63
+ result = translation_client.predict(input_text)
64
+ return result
65
+
66
+ def process_request(input_text):
67
+ # Detect context
68
+ context = detect_context(input_text)
69
+ print(f"Detected context: {context}")
70
+
71
+ # Translate text
72
+ translation = translate_text(input_text)
73
+ return translation
74
 
75
  # Create a Gradio interface
76
+ interface = gr.Interface(
77
+ fn=process_request,
78
+ inputs="text",
79
+ outputs="text",
80
+ title="Frenchizer",
81
+ description="Translate text from English to French with context detection."
82
+ )
83
+
84
+ # Launch the Gradio app
85
  interface.launch()