import gradio as gr import onnxruntime as ort from transformers import AutoTokenizer import numpy as np # Initialize models context_model_file = "./bart-large-mnli.onnx" translation_model_file = "./model.onnx" # Create inference sessions for both models context_session = ort.InferenceSession(context_model_file) translation_session = ort.InferenceSession(translation_model_file) # Load tokenizers for context and translation models context_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli") translation_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr") labels = [ 'aerospace', 'agriculture', 'anatomy', 'anthropology', 'architecture', 'art', 'automotive', 'astronomy', 'aviation', 'banking', 'biotechnology', 'biology', 'blockchain', 'business', 'chemistry', 'climate change', 'communication', 'computer science', 'construction', 'consumer goods', 'cryptocurrency', 'cybersecurity', 'dance', 'diplomacy', 'ecology', 'economics', 'education', 'energy', 'engineering', 'entrepreneurship', 'entertainment', 'ethics', 'fashion', 'finance', 'film', 'fitness', 'food commerce', 'general', 'gaming', 'geography', 'geology', 'graphic design', 'healthcare', 'history', 'html', 'human resources', 'immigration', 'innovation', 'journalism', 'keywords','language','law enforcement','legal','logistics','literature', 'machine learning','management','manufacturing','mathematics','media','military', 'music','nanotechnology','nutrition','pharmaceuticals','photography', 'psychology','public health','publishing','religion','renewable energy', 'research','sales','science','social media','social work', 'space exploration','sports','statistics','supply chain', 'sustainability','telecommunications','transportation', 'urban planning','veterinary medicine','virtual reality', 'web development','writing','zoology' ] def softmax_with_temperature(logits, temperature=1.0): exp_logits = np.exp(logits / temperature) return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) def detect_context(input_text, temperature=2.0, top_n=3, score_threshold=0.05): # Tokenize input text inputs = context_tokenizer(input_text, return_tensors="np", padding=True, truncation=True, max_length=512) input_ids = inputs["input_ids"].astype(np.int64) attention_mask = inputs["attention_mask"].astype(np.int64) # Debugging: Check tokenized input print(f"Tokenized Input IDs: {input_ids}") print(f"Tokenized Attention Mask: {attention_mask}") # Run inference with the ONNX context model outputs = context_session.run(None, { "input_ids": input_ids, "attention_mask": attention_mask }) logits = outputs[0][0] # Assuming batch size 1; take the first set of logits # Debugging: Print raw logits print(f"Raw logits: {logits}") # Apply softmax with temperature scores = softmax_with_temperature(logits, temperature=temperature) # Pair labels with scores label_scores = [(label, score) for label, score in zip(labels, scores)] # Sort by scores in descending order sorted_labels = sorted(label_scores, key=lambda x: x[1], reverse=True) # Filter by threshold and return top_n contexts filtered_labels = [label for label, score in sorted_labels if score > score_threshold] top_contexts = filtered_labels[:top_n] print(f"All scores: {label_scores}") # Debugging: Print all scores print(f"Selected contexts: {top_contexts}") # Debugging: Print selected contexts return top_contexts if top_contexts else ["general"] def translate_text(input_text): tokenized_input = translation_tokenizer( input_text, return_tensors="np", padding=True, truncation=True, max_length=512 ) input_ids = tokenized_input["input_ids"].astype(np.int64) attention_mask = tokenized_input["attention_mask"].astype(np.int64) decoder_start_token_id = translation_tokenizer.cls_token_id or translation_tokenizer.pad_token_id decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64) for _ in range(512): outputs = translation_session.run( None, { "input_ids": input_ids, "attention_mask": attention_mask, "decoder_input_ids": decoder_input_ids, } ) logits = outputs[0] next_token_id = np.argmax(logits[:, -1, :], axis=-1).item() decoder_input_ids = np.concatenate( [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1 ) if next_token_id == translation_tokenizer.eos_token_id: break return translation_tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True) def process_request(input_text): context = detect_context(input_text) translation = translate_text(input_text) # Translate without needing to pass context explicitly return translation gr.Interface( fn=process_request, inputs="text", outputs="text", live=True ).launch()