Spaces:
Build error
Build error
File size: 5,110 Bytes
201752f b55eab0 587a53a b55eab0 587a53a f285d15 587a53a c905911 6754320 c905911 b55eab0 f285d15 99b3521 f285d15 c905911 5ad0807 587a53a 848fa80 587a53a dc20849 1f69442 c905911 1f69442 dc20849 cd1f49d dc20849 1f69442 c905911 f285d15 c9d02af 5ad0807 36a1938 5ad0807 36a1938 5ad0807 36a1938 5ad0807 36a1938 5ad0807 7e40604 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import gradio as gr
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np
# Initialize models
context_model_file = "./bart-large-mnli.onnx"
translation_model_file = "./model.onnx"
# Create inference sessions for both models
context_session = ort.InferenceSession(context_model_file)
translation_session = ort.InferenceSession(translation_model_file)
# Load tokenizers for context and translation models
context_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
translation_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
labels = [
'aerospace', 'agriculture', 'anatomy', 'anthropology', 'architecture',
'art', 'automotive', 'astronomy', 'aviation', 'banking',
'biotechnology', 'biology', 'blockchain', 'business',
'chemistry', 'climate change', 'communication', 'computer science',
'construction', 'consumer goods', 'cryptocurrency', 'cybersecurity',
'dance', 'diplomacy', 'ecology', 'economics',
'education', 'energy', 'engineering', 'entrepreneurship',
'entertainment', 'ethics', 'fashion', 'finance',
'film', 'fitness', 'food commerce', 'general',
'gaming', 'geography', 'geology', 'graphic design',
'healthcare', 'history', 'html', 'human resources',
'immigration', 'innovation', 'journalism',
'keywords','language','law enforcement','legal','logistics','literature',
'machine learning','management','manufacturing','mathematics','media','military',
'music','nanotechnology','nutrition','pharmaceuticals','photography',
'psychology','public health','publishing','religion','renewable energy',
'research','sales','science','social media','social work',
'space exploration','sports','statistics','supply chain',
'sustainability','telecommunications','transportation',
'urban planning','veterinary medicine','virtual reality',
'web development','writing','zoology'
]
def softmax_with_temperature(logits, temperature=1.0):
exp_logits = np.exp(logits / temperature)
return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
def detect_context(input_text, temperature=2.0, top_n=3, score_threshold=0.05):
# Tokenize input text
inputs = context_tokenizer(input_text, return_tensors="np", padding=True, truncation=True, max_length=512)
input_ids = inputs["input_ids"].astype(np.int64)
attention_mask = inputs["attention_mask"].astype(np.int64)
# Debugging: Check tokenized input
print(f"Tokenized Input IDs: {input_ids}")
print(f"Tokenized Attention Mask: {attention_mask}")
# Run inference with the ONNX context model
outputs = context_session.run(None, {
"input_ids": input_ids,
"attention_mask": attention_mask
})
logits = outputs[0][0] # Assuming batch size 1; take the first set of logits
# Debugging: Print raw logits
print(f"Raw logits: {logits}")
# Apply softmax with temperature
scores = softmax_with_temperature(logits, temperature=temperature)
# Pair labels with scores
label_scores = [(label, score) for label, score in zip(labels, scores)]
# Sort by scores in descending order
sorted_labels = sorted(label_scores, key=lambda x: x[1], reverse=True)
# Filter by threshold and return top_n contexts
filtered_labels = [label for label, score in sorted_labels if score > score_threshold]
top_contexts = filtered_labels[:top_n]
print(f"All scores: {label_scores}") # Debugging: Print all scores
print(f"Selected contexts: {top_contexts}") # Debugging: Print selected contexts
return top_contexts if top_contexts else ["general"]
def translate_text(input_text):
tokenized_input = translation_tokenizer(
input_text, return_tensors="np",
padding=True, truncation=True, max_length=512
)
input_ids = tokenized_input["input_ids"].astype(np.int64)
attention_mask = tokenized_input["attention_mask"].astype(np.int64)
decoder_start_token_id = translation_tokenizer.cls_token_id or translation_tokenizer.pad_token_id
decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
for _ in range(512):
outputs = translation_session.run(
None,
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
}
)
logits = outputs[0]
next_token_id = np.argmax(logits[:, -1, :], axis=-1).item()
decoder_input_ids = np.concatenate(
[decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
)
if next_token_id == translation_tokenizer.eos_token_id:
break
return translation_tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
def process_request(input_text):
context = detect_context(input_text)
translation = translate_text(input_text) # Translate without needing to pass context explicitly
return translation
gr.Interface(
fn=process_request,
inputs="text",
outputs="text",
live=True
).launch()
|