Spaces:
Build error
Build error
| import gradio as gr | |
| import onnxruntime as ort | |
| from transformers import AutoTokenizer | |
| import numpy as np | |
| # Initialize models | |
| context_model_file = "./bart-large-mnli.onnx" | |
| translation_model_file = "./model.onnx" | |
| # Create inference sessions for both models | |
| context_session = ort.InferenceSession(context_model_file) | |
| translation_session = ort.InferenceSession(translation_model_file) | |
| # Load tokenizers for context and translation models | |
| context_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli") | |
| translation_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr") | |
| labels = [ | |
| 'aerospace', 'agriculture', 'anatomy', 'anthropology', 'architecture', | |
| 'art', 'automotive', 'astronomy', 'aviation', 'banking', | |
| 'biotechnology', 'biology', 'blockchain', 'business', | |
| 'chemistry', 'climate change', 'communication', 'computer science', | |
| 'construction', 'consumer goods', 'cryptocurrency', 'cybersecurity', | |
| 'dance', 'diplomacy', 'ecology', 'economics', | |
| 'education', 'energy', 'engineering', 'entrepreneurship', | |
| 'entertainment', 'ethics', 'fashion', 'finance', | |
| 'film', 'fitness', 'food commerce', 'general', | |
| 'gaming', 'geography', 'geology', 'graphic design', | |
| 'healthcare', 'history', 'html', 'human resources', | |
| 'immigration', 'innovation', 'journalism', | |
| 'keywords','language','law enforcement','legal','logistics','literature', | |
| 'machine learning','management','manufacturing','mathematics','media','military', | |
| 'music','nanotechnology','nutrition','pharmaceuticals','photography', | |
| 'psychology','public health','publishing','religion','renewable energy', | |
| 'research','sales','science','social media','social work', | |
| 'space exploration','sports','statistics','supply chain', | |
| 'sustainability','telecommunications','transportation', | |
| 'urban planning','veterinary medicine','virtual reality', | |
| 'web development','writing','zoology' | |
| ] | |
| def softmax_with_temperature(logits, temperature=1.0): | |
| exp_logits = np.exp(logits / temperature) | |
| return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) | |
| def detect_context(input_text, temperature=2.0, top_n=3, score_threshold=0.05): | |
| # Tokenize input text | |
| inputs = context_tokenizer(input_text, return_tensors="np", padding=True, truncation=True, max_length=512) | |
| input_ids = inputs["input_ids"].astype(np.int64) | |
| attention_mask = inputs["attention_mask"].astype(np.int64) | |
| # Debugging: Check tokenized input | |
| print(f"Tokenized Input IDs: {input_ids}") | |
| print(f"Tokenized Attention Mask: {attention_mask}") | |
| # Run inference with the ONNX context model | |
| outputs = context_session.run(None, { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask | |
| }) | |
| logits = outputs[0][0] # Assuming batch size 1; take the first set of logits | |
| # Debugging: Print raw logits | |
| print(f"Raw logits: {logits}") | |
| # Apply softmax with temperature | |
| scores = softmax_with_temperature(logits, temperature=temperature) | |
| # Pair labels with scores | |
| label_scores = [(label, score) for label, score in zip(labels, scores)] | |
| # Sort by scores in descending order | |
| sorted_labels = sorted(label_scores, key=lambda x: x[1], reverse=True) | |
| # Filter by threshold and return top_n contexts | |
| filtered_labels = [label for label, score in sorted_labels if score > score_threshold] | |
| top_contexts = filtered_labels[:top_n] | |
| print(f"All scores: {label_scores}") # Debugging: Print all scores | |
| print(f"Selected contexts: {top_contexts}") # Debugging: Print selected contexts | |
| return top_contexts if top_contexts else ["general"] | |
| def translate_text(input_text): | |
| tokenized_input = translation_tokenizer( | |
| input_text, return_tensors="np", | |
| padding=True, truncation=True, max_length=512 | |
| ) | |
| input_ids = tokenized_input["input_ids"].astype(np.int64) | |
| attention_mask = tokenized_input["attention_mask"].astype(np.int64) | |
| decoder_start_token_id = translation_tokenizer.cls_token_id or translation_tokenizer.pad_token_id | |
| decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64) | |
| for _ in range(512): | |
| outputs = translation_session.run( | |
| None, | |
| { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "decoder_input_ids": decoder_input_ids, | |
| } | |
| ) | |
| logits = outputs[0] | |
| next_token_id = np.argmax(logits[:, -1, :], axis=-1).item() | |
| decoder_input_ids = np.concatenate( | |
| [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1 | |
| ) | |
| if next_token_id == translation_tokenizer.eos_token_id: | |
| break | |
| return translation_tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True) | |
| def process_request(input_text): | |
| context = detect_context(input_text) | |
| translation = translate_text(input_text) # Translate without needing to pass context explicitly | |
| return translation | |
| gr.Interface( | |
| fn=process_request, | |
| inputs="text", | |
| outputs="text", | |
| live=True | |
| ).launch() | |