Frenchizer commited on
Commit
201752f
·
1 Parent(s): b950319

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -73
app.py CHANGED
@@ -1,74 +1,75 @@
1
- import gradio as gr
2
- from transformers import AutoTokenizer, pipeline
3
- import numpy as np
4
-
5
- MODEL_FILE = "./model.onnx"
6
- session = ort.InferenceSession(MODEL_FILE)
7
- tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
8
-
9
- # Predefined labels for context detection
10
- labels = [
11
- "general", "pharma", "legal", "technical", "UI", "medicine", "it", "marketing",
12
- "e-commerce", "programming", "website", "html", "keywords", "food commerce",
13
- "personal development", "literature", "poetry", "physics", "chemistry", "biology",
14
- "theater", "finance", "sports", "education", "politics", "economics", "art",
15
- "history", "music", "gaming", "aerospace", "engineering", "robotics", "travel",
16
- "tourism", "healthcare", "psychology", "environment", "fashion", "design",
17
- "real estate", "retail", "news", "entertainment", "social media", "automotive",
18
- "machine learning", "cryptocurrency", "blockchain", "philosophy", "anthropology",
19
- "archaeology", "data science"
20
- ]
21
-
22
- # Context detection pipeline
23
- context_pipeline = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
24
-
25
- def detect_context(input_text):
26
- result = context_pipeline(input_text, candidate_labels=labels)
27
- return result["labels"][0]
28
-
29
- def gradio_predict(input_text):
30
- try:
31
- tokenized_input = tokenizer(
32
- input_text, return_tensors="np", padding=True, truncation=True, max_length=512
33
- )
34
- input_ids = tokenized_input["input_ids"].astype(np.int64)
35
- attention_mask = tokenized_input["attention_mask"].astype(np.int64)
36
-
37
- decoder_start_token_id = tokenizer.cls_token_id or tokenizer.pad_token_id
38
- decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
39
-
40
- for _ in range(512):
41
- outputs = session.run(
42
- None,
43
- {
44
- "input_ids": input_ids,
45
- "attention_mask": attention_mask,
46
- "decoder_input_ids": decoder_input_ids,
47
- }
48
- )
49
-
50
- logits = outputs[0]
51
- next_token_id = np.argmax(logits[:, -1, :], axis=-1).item()
52
- decoder_input_ids = np.concatenate(
53
- [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
54
- )
55
-
56
- if next_token_id == tokenizer.eos_token_id:
57
- break
58
-
59
- translated_text = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
60
-
61
- return {
62
- "translation": translated_text,
63
- "context": detect_context(input_text)
64
- }
65
-
66
- except Exception as e:
67
- return f"Error: {str(e)}"
68
-
69
- gr.Interface(
70
- fn=gradio_predict,
71
- inputs="text",
72
- outputs=["text", "text"],
73
- live=True
 
74
  ).launch()
 
1
+ import gradio as gr
2
+ import onnxruntime as ort
3
+ from transformers import AutoTokenizer, pipeline
4
+ import numpy as np
5
+
6
+ MODEL_FILE = "./model.onnx"
7
+ session = ort.InferenceSession(MODEL_FILE)
8
+ tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
9
+
10
+ # Predefined labels for context detection
11
+ labels = [
12
+ "general", "pharma", "legal", "technical", "UI", "medicine", "it", "marketing",
13
+ "e-commerce", "programming", "website", "html", "keywords", "food commerce",
14
+ "personal development", "literature", "poetry", "physics", "chemistry", "biology",
15
+ "theater", "finance", "sports", "education", "politics", "economics", "art",
16
+ "history", "music", "gaming", "aerospace", "engineering", "robotics", "travel",
17
+ "tourism", "healthcare", "psychology", "environment", "fashion", "design",
18
+ "real estate", "retail", "news", "entertainment", "social media", "automotive",
19
+ "machine learning", "cryptocurrency", "blockchain", "philosophy", "anthropology",
20
+ "archaeology", "data science"
21
+ ]
22
+
23
+ # Context detection pipeline
24
+ context_pipeline = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
25
+
26
+ def detect_context(input_text):
27
+ result = context_pipeline(input_text, candidate_labels=labels)
28
+ return result["labels"][0]
29
+
30
+ def gradio_predict(input_text):
31
+ try:
32
+ tokenized_input = tokenizer(
33
+ input_text, return_tensors="np", padding=True, truncation=True, max_length=512
34
+ )
35
+ input_ids = tokenized_input["input_ids"].astype(np.int64)
36
+ attention_mask = tokenized_input["attention_mask"].astype(np.int64)
37
+
38
+ decoder_start_token_id = tokenizer.cls_token_id or tokenizer.pad_token_id
39
+ decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
40
+
41
+ for _ in range(512):
42
+ outputs = session.run(
43
+ None,
44
+ {
45
+ "input_ids": input_ids,
46
+ "attention_mask": attention_mask,
47
+ "decoder_input_ids": decoder_input_ids,
48
+ }
49
+ )
50
+
51
+ logits = outputs[0]
52
+ next_token_id = np.argmax(logits[:, -1, :], axis=-1).item()
53
+ decoder_input_ids = np.concatenate(
54
+ [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
55
+ )
56
+
57
+ if next_token_id == tokenizer.eos_token_id:
58
+ break
59
+
60
+ translated_text = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
61
+
62
+ return {
63
+ "translation": translated_text,
64
+ "context": detect_context(input_text)
65
+ }
66
+
67
+ except Exception as e:
68
+ return f"Error: {str(e)}"
69
+
70
+ gr.Interface(
71
+ fn=gradio_predict,
72
+ inputs="text",
73
+ outputs=["text", "text"],
74
+ live=True
75
  ).launch()