Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import librosa | |
| import torch | |
| from transformers import ( | |
| Wav2Vec2ForCTC, Wav2Vec2Processor, | |
| MarianMTModel, MarianTokenizer, | |
| BertForSequenceClassification, AutoModel, AutoTokenizer,AutoModelForSequenceClassification | |
| ) | |
| # Detect device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| ### 🔹 Load Models & Tokenizers Once ### | |
| # Wav2Vec2 for Darija transcription | |
| wav2vec_model_name = "boumehdi/wav2vec2-large-xlsr-moroccan-darija" | |
| processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_name) | |
| wav2vec_model = Wav2Vec2ForCTC.from_pretrained(wav2vec_model_name).to(device) | |
| # MarianMT for translation (Arabic → English) | |
| translation_model_name = "Helsinki-NLP/opus-mt-ar-en" | |
| translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name) | |
| translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device) | |
| # AraBERT for Darija topic classification | |
| arabert_model_name = "aubmindlab/bert-base-arabert" | |
| arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name) | |
| arabert_model = BertForSequenceClassification.from_pretrained(arabert_model_name, num_labels=2).to(device) | |
| # BERT for English topic classification | |
| bert_model_name = "bert-base-uncased" | |
| bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) | |
| bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3).to(device) | |
| # Charger le modèle et le tokenizer Darija | |
| sentiment_model_name = "BenhamdaneNawfal/sentiment-analysis-darija" | |
| sentiment_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name) | |
| sentiment_model = AutoModelForSequenceClassification.from_pretrained(sentiment_model_name,num_labels=3,ignore_mismatched_sizes=True).to("cuda" if torch.cuda.is_available() else "cpu") | |
| # Labels du modèle (à modifier selon le modèle utilisé) | |
| sentiment_labels = ["Négatif", "Neutre", "Positif"] | |
| # Libellés en Darija (Arabe et Latin) | |
| darija_topic_labels = [ | |
| "مشكيل ف الشبكة (Mochkil f réseau)", # Problème de réseau | |
| "مشكيل ف الانترنت (Mochkil f internet)", # Problème d'Internet | |
| "مشكيل ف الفاتورة (Mochkil f l'factura)", # Problème de facturation et paiement | |
| "مشكيل ف التعبئة (Mochkil f l'recharge)", # Problème de recharge et forfaits | |
| "مشكيل ف التطبيق (Mochkil f l'application)", # Problème avec l’application (Orange et Moi...) | |
| "مشكيل ف بطاقة SIM (Mochkil f carte SIM)", # Problème avec la carte SIM | |
| "مساعدة تقنية (Mosa3ada technique)", # Assistance technique | |
| "العروض والتخفيضات (Offres w promotions)", # Offres et promotions | |
| "طلب معلومات (Talab l'ma3loumat)", # Demande d'information | |
| "شكاية (Chikaya)", # Réclamation | |
| "حاجة أخرى (Chi haja okhra)" # Autre | |
| ] | |
| # Libellés en Anglais | |
| english_topic_labels = [ | |
| "Network Issue", | |
| "Internet Issue", | |
| "Billing & Payment Issue", | |
| "Recharge & Plans", | |
| "App Issue", | |
| "SIM Card Issue", | |
| "Technical Support", | |
| "Offers & Promotions", | |
| "General Inquiry", | |
| "Complaint", | |
| "Other" | |
| ] | |
| # New Function to Classify Topics by Keywords | |
| def classify_topic_by_keywords(text, language='ar'): | |
| # Arabic keywords for each topic | |
| arabic_keywords = { | |
| "Customer Service": ["خدمة", "استفسار", "مساعدة", "دعم", "سؤال", "استفسار"], | |
| "résiliation Service": ["نوقف", "تجديد", "خصم", "عرض", "نحي"], | |
| "Billing Issue": ["فاتورة", "دفع", "مشكلة", "خطأ", "مبلغ"], | |
| "Other": ["شيء آخر", "غير ذلك", "أخرى"] | |
| } | |
| # English keywords for each topic | |
| english_keywords = { | |
| "Customer Service": ["service", "inquiry", "help", "support", "question", "assistance"], | |
| "résiliation Service": ["retain", "cut", "discount", "stopped", "promotion","stop"], | |
| "Billing Issue": ["bill", "payment", "problem", "error", "amount"], | |
| "Other": ["other", "none of the above", "something else"] | |
| } | |
| # Select the appropriate keywords based on the language | |
| if language == 'ar': | |
| keywords = arabic_keywords | |
| elif language == 'en': | |
| keywords = english_keywords | |
| else: | |
| raise ValueError("Invalid language specified. Use 'ar' for Arabic or 'en' for English.") | |
| # Convert text to lowercase to avoid inconsistencies | |
| text = text.lower() | |
| # Check for keywords in the text and calculate the topic scores | |
| topic_scores = {topic: 0 for topic in keywords} # Initialize topic scores | |
| for topic, words in keywords.items(): | |
| for word in words: | |
| if word in text: | |
| topic_scores[topic] += 1 # Increment score for each keyword found | |
| # Check if no keywords are found, and in that case, return "Other" | |
| if all(score == 0 for score in topic_scores.values()): | |
| return "Other" | |
| # Return the topic with the highest score | |
| best_topic = max(topic_scores, key=topic_scores.get) | |
| return best_topic | |
| def transcribe_audio(audio): | |
| """Convert audio to text, translate it, and classify topics in both Darija and English.""" | |
| try: | |
| # Load and preprocess audio | |
| audio_array, sr = librosa.load(audio, sr=16000) | |
| input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device) | |
| # Transcription (Darija) | |
| with torch.no_grad(): | |
| logits = wav2vec_model(input_values).logits | |
| tokens = torch.argmax(logits, axis=-1) | |
| transcription = processor.decode(tokens[0]) | |
| # Translate to English | |
| translation = translate_text(transcription) | |
| # Classify topics using BERT models | |
| darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels) | |
| english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels) | |
| # Classify topics using keywords-based classification | |
| darija_keyword_topic = classify_topic_by_keywords(transcription,language='ar' ) | |
| english_keyword_topic = classify_topic_by_keywords(translation,language='en' ) | |
| #english_keyword_topic = classify_topic_by_keywords(translation ) | |
| # l'analyse de sentiment | |
| sentiment = analyze_sentiment(transcription) | |
| return transcription, translation, darija_topic, english_topic, darija_keyword_topic, english_keyword_topic,sentiment | |
| except Exception as e: | |
| return f"Error processing audio: {str(e)}", "", "", "", "", "", "" | |
| def translate_text(text): | |
| """Translate Arabic text to English.""" | |
| inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device) | |
| with torch.no_grad(): | |
| translated_tokens = translation_model.generate(**inputs) | |
| return translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True) | |
| def classify_topic(text, tokenizer, model, topic_labels): | |
| """Classify topic using BERT-based models.""" | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| predicted_class = torch.argmax(outputs.logits, dim=1).item() | |
| return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other" | |
| def analyze_sentiment(text): | |
| """Classifie le sentiment du texte en Darija.""" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Tokenizer le texte | |
| inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device) | |
| # Prédiction | |
| with torch.no_grad(): | |
| outputs = sentiment_model(**inputs) | |
| predicted_class = torch.argmax(outputs.logits, dim=1).item() | |
| # Retourner la classe correspondante | |
| return sentiment_labels[predicted_class] if predicted_class < len(sentiment_labels) else "Inconnu" | |
| # 🔹 Gradio Interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification") | |
| audio_input = gr.Audio(type="filepath", label="Upload Audio or Record") | |
| submit_button = gr.Button("Process") | |
| transcription_output = gr.Textbox(label="Transcription (Darija)") | |
| translation_output = gr.Textbox(label="Translation (English)") | |
| darija_topic_output = gr.Textbox(label="Darija Topic Classification (BERT)") | |
| english_topic_output = gr.Textbox(label="English Topic Classification (BERT)") | |
| darija_keyword_topic_output = gr.Textbox(label="Darija Topic Classification (Keywords)") | |
| english_keyword_topic_output = gr.Textbox(label="English Topic Classification (Keywords)") | |
| sentiment_output = gr.Textbox(label="Sentiment (Darija)") | |
| submit_button.click(transcribe_audio, | |
| inputs=[audio_input], | |
| outputs=[transcription_output, translation_output, | |
| darija_topic_output, english_topic_output, | |
| darija_keyword_topic_output, english_keyword_topic_output, sentiment_output]) | |
| demo.launch() | |