Frenchizer commited on
Commit
6f6ca80
·
1 Parent(s): 64248a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -9
app.py CHANGED
@@ -1,13 +1,4 @@
1
  import gradio as gr
2
- import onnxruntime as ort
3
- from transformers import AutoTokenizer, pipeline
4
- import numpy as np
5
-
6
- MODEL_FILE = "./model.onnx"
7
- session = ort.InferenceSession(MODEL_FILE)
8
- tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
9
-
10
- labels = [
11
  "general", "pharma", "legal", "technical", "UI", "medicine", "it", "marketing",
12
  "e-commerce", "programming", "website", "html", "keywords", "food commerce",
13
  "personal development", "literature", "poetry", "physics", "chemistry", "biology",
@@ -36,4 +27,35 @@ def translate_text(input_text):
36
  decoder_start_token_id = tokenizer.cls_token_id or tokenizer.pad_token_id
37
  decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  ).launch()
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
2
  "general", "pharma", "legal", "technical", "UI", "medicine", "it", "marketing",
3
  "e-commerce", "programming", "website", "html", "keywords", "food commerce",
4
  "personal development", "literature", "poetry", "physics", "chemistry", "biology",
 
27
  decoder_start_token_id = tokenizer.cls_token_id or tokenizer.pad_token_id
28
  decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
29
 
30
+ for _ in range(512):
31
+ outputs = session.run(
32
+ None,
33
+ {
34
+ "input_ids": input_ids,
35
+ "attention_mask": attention_mask,
36
+ "decoder_input_ids": decoder_input_ids,
37
+ }
38
+ )
39
+
40
+ logits = outputs[0]
41
+ next_token_id = np.argmax(logits[:, -1, :], axis=-1).item()
42
+ decoder_input_ids = np.concatenate(
43
+ [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
44
+ )
45
+
46
+ if next_token_id == tokenizer.eos_token_id:
47
+ break
48
+
49
+ return tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
50
+
51
+ def process_request(input_text):
52
+ context = detect_context(input_text)
53
+ translation = translate_text(input_text)
54
+ return translation, context
55
+
56
+ gr.Interface(
57
+ fn=process_request,
58
+ inputs="text",
59
+ outputs="text",
60
+ live=True
61
  ).launch()