File size: 5,110 Bytes
201752f
b55eab0
587a53a
b55eab0
 
587a53a
f285d15
587a53a
 
 
 
 
 
c905911
6754320
c905911
b55eab0
f285d15
99b3521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f285d15
 
c905911
 
 
 
5ad0807
587a53a
848fa80
587a53a
dc20849
1f69442
c905911
 
 
 
1f69442
dc20849
cd1f49d
dc20849
1f69442
 
c905911
 
 
 
 
 
 
 
f285d15
 
c9d02af
5ad0807
 
 
 
 
 
 
 
 
 
 
 
 
36a1938
5ad0807
 
 
 
 
 
 
36a1938
5ad0807
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36a1938
5ad0807
 
36a1938
5ad0807
 
 
 
 
 
7e40604
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np

# Initialize models
context_model_file = "./bart-large-mnli.onnx"
translation_model_file = "./model.onnx"

# Create inference sessions for both models
context_session = ort.InferenceSession(context_model_file)
translation_session = ort.InferenceSession(translation_model_file)

# Load tokenizers for context and translation models
context_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
translation_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")

labels = [
    'aerospace', 'agriculture', 'anatomy', 'anthropology', 'architecture',
    'art', 'automotive', 'astronomy', 'aviation', 'banking',
    'biotechnology', 'biology', 'blockchain', 'business',
    'chemistry', 'climate change', 'communication', 'computer science',
    'construction', 'consumer goods', 'cryptocurrency', 'cybersecurity',
    'dance', 'diplomacy', 'ecology', 'economics',
    'education', 'energy', 'engineering', 'entrepreneurship',
    'entertainment', 'ethics', 'fashion', 'finance',
    'film', 'fitness', 'food commerce', 'general',
    'gaming', 'geography', 'geology', 'graphic design',
    'healthcare', 'history', 'html', 'human resources',
    'immigration', 'innovation', 'journalism', 
    'keywords','language','law enforcement','legal','logistics','literature',
  	'machine learning','management','manufacturing','mathematics','media','military',
  	'music','nanotechnology','nutrition','pharmaceuticals','photography',
  	'psychology','public health','publishing','religion','renewable energy',
  	'research','sales','science','social media','social work',
  	'space exploration','sports','statistics','supply chain',
  	'sustainability','telecommunications','transportation',
  	'urban planning','veterinary medicine','virtual reality',
  	'web development','writing','zoology'
]

def softmax_with_temperature(logits, temperature=1.0):
    exp_logits = np.exp(logits / temperature)
    return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)

def detect_context(input_text, temperature=2.0, top_n=3, score_threshold=0.05):
    # Tokenize input text
    inputs = context_tokenizer(input_text, return_tensors="np", padding=True, truncation=True, max_length=512)
    input_ids = inputs["input_ids"].astype(np.int64)
    attention_mask = inputs["attention_mask"].astype(np.int64)

    # Debugging: Check tokenized input
    print(f"Tokenized Input IDs: {input_ids}")
    print(f"Tokenized Attention Mask: {attention_mask}")

    # Run inference with the ONNX context model
    outputs = context_session.run(None, {
        "input_ids": input_ids,
        "attention_mask": attention_mask
    })

    logits = outputs[0][0]  # Assuming batch size 1; take the first set of logits

    # Debugging: Print raw logits
    print(f"Raw logits: {logits}")

    # Apply softmax with temperature
    scores = softmax_with_temperature(logits, temperature=temperature)

    # Pair labels with scores
    label_scores = [(label, score) for label, score in zip(labels, scores)]

    # Sort by scores in descending order
    sorted_labels = sorted(label_scores, key=lambda x: x[1], reverse=True)

    # Filter by threshold and return top_n contexts
    filtered_labels = [label for label, score in sorted_labels if score > score_threshold]
    top_contexts = filtered_labels[:top_n]

    print(f"All scores: {label_scores}")  # Debugging: Print all scores
    print(f"Selected contexts: {top_contexts}")  # Debugging: Print selected contexts

    return top_contexts if top_contexts else ["general"]

def translate_text(input_text):
    tokenized_input = translation_tokenizer(
        input_text, return_tensors="np",
        padding=True, truncation=True, max_length=512
    )
    
    input_ids = tokenized_input["input_ids"].astype(np.int64)
    attention_mask = tokenized_input["attention_mask"].astype(np.int64)

    decoder_start_token_id = translation_tokenizer.cls_token_id or translation_tokenizer.pad_token_id
    decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)

    for _ in range(512):
        outputs = translation_session.run(
            None,
            {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "decoder_input_ids": decoder_input_ids,
            }
        )

        logits = outputs[0]
        next_token_id = np.argmax(logits[:, -1, :], axis=-1).item()
        decoder_input_ids = np.concatenate(
            [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
        )

        if next_token_id == translation_tokenizer.eos_token_id:
            break

    return translation_tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)

def process_request(input_text):
    context = detect_context(input_text)
    translation = translate_text(input_text)  # Translate without needing to pass context explicitly
    return translation

gr.Interface(
    fn=process_request,
    inputs="text",
    outputs="text",
    live=True
).launch()