Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| from scripts.inference import ( | |
| zero_shot_inference, | |
| few_shot_inference, | |
| base_model_inference, | |
| fine_tuned_inference | |
| ) | |
| #Lire le README.md | |
| readme_content = """# 📰 AG News Text Classification Demo | |
| Ce projet présente une application de classification de textes d’actualité basée sur le dataset **AG News**. | |
| L'objectif est de comparer plusieurs stratégies d'inférence de modèles Transformers pour la classification de texte. | |
| --- | |
| ## 🚀 Démo en ligne | |
| L’application est disponible ici : | |
| [**Lien vers la démo Hugging Face Space**](https://huggingface.co/spaces/Merwan611/classification-text) | |
| --- | |
| ## 📂 Organisation du projet | |
| - `app.py` : interface Gradio avec deux onglets (`Demo` + `Documentation`) | |
| - `scripts/inference.py` : fonctions d’inférence pour 4 types de modèles | |
| - `scripts/train.py` : script d’entraînement du modèle BERT fine-tuné sur AG News | |
| - `scripts/utils.py` : calcul des métriques d’évaluation (accuracy, F1, etc.) | |
| - `requirements.txt` : liste des dépendances Python | |
| --- | |
| ## 🧠 Description des modèles utilisés | |
| Base model | |
| Modèle BERT préentraîné textattack/bert-base-uncased-ag-news. | |
| Il est utilisé directement sans réentraînement. Le texte est tokenisé avec AutoTokenizer puis passé au modèle pour obtenir une distribution de probabilité via softmax. | |
| Zero-shot | |
| Modèle facebook/bart-large-mnli utilisé via la pipeline zero-shot-classification de Hugging Face. | |
| Le texte est comparé à une liste de labels cibles (World, Sports, Business, Sci/Tech) sans aucun entraînement préalable sur AG News. Ce modèle s’appuie sur la reconnaissance d’implications textuelles pour inférer la classe la plus probable. | |
| Few-shot | |
| Basé sur le modèle google/flan-t5-small avec la pipeline text2text-generation. | |
| Le prompt inclut quelques exemples de classification manuelle (prompt engineering). Le modèle génère ensuite une réponse textuelle correspondant à la catégorie. Les sorties sont nettoyées et validées par correspondance avec les labels autorisés. | |
| Fine-tuned model | |
| Modèle bert-base-uncased fine-tuné sur un sous-ensemble équilibré du dataset AG News (3000 exemples par classe) puis hébergé sur le Hugging Face Hub sous Merwan611/agnews-finetuned-bert. | |
| La prédiction utilise également AutoTokenizer et applique une couche softmax sur les logits du modèle. L’accès au modèle peut nécessiter un token d’authentification via une variable d’environnement CLE. | |
| --- | |
| ## 📊 Données et entraînement | |
| - **Dataset** : AG News (4 classes : World, Sports, Business, Sci/Tech) | |
| - **Préprocessing** : tokenisation avec `AutoTokenizer` BERT | |
| - **Entraînement** : 3 epochs, batch size 32, métrique optimisée : accuracy | |
| - **Sous-échantillonnage** pour accélérer l’entraînement : 3000 exemples par classe pour le train, 1000 par classe pour le test | |
| --- | |
| ## 📈 Performances | |
| | **Model** | **Accuracy** | **F1 Score** | **Precision** | **Recall** | **Loss** | | |
| | ---------------- | ------------ | ------------ | ------------- | ---------- | -------- | | |
| | **Fine-tune** | 0.92 | 0.92 | 0.92 | 0.92 | 0.28 | | |
| | **Base model** | 0.92 | 0.92 | 0.92 | 0.92 | 0.32 | | |
| | **Zero-shot** | 0.68 | 0.68 | 0.69 | 0.68 | 0.87 | | |
| | **Few-shot** | 0.87 | 0.87 | 0.87 | 0.87 | 4.74 | | |
| Le modèle fine-tuné atteint généralement une meilleure précision que le modèle de base ou les approches zero-shot. | |
| --- | |
| ## ⚙️ Lancer l’application localement | |
| 1. Cloner le repo | |
| 2. Créer un environnement virtuel Python | |
| 3. Installer les dépendances : | |
| ```bash | |
| pip install -r requirements.txt | |
| 4. Lancer python app.py""" | |
| def predict_with_model(text, model_type): | |
| """ | |
| Applique une stratégie de classification sur un texte donné. | |
| Args: | |
| text (str): Le texte d’actualité à analyser. | |
| model_type (str): Le modèle choisi ("Zero-shot", "Few-shot", etc.). | |
| Returns: | |
| tuple: | |
| - str: Catégorie prédite. | |
| - pandas.DataFrame: Score de confiance par classe. | |
| """ | |
| if model_type == "Zero-shot": | |
| prediction, scores = zero_shot_inference(text) | |
| elif model_type == "Few-shot": | |
| prediction, scores = few_shot_inference(text) | |
| elif model_type == "Fine-tuned": | |
| prediction, scores = fine_tuned_inference(text) | |
| elif model_type == "Base model": | |
| prediction, scores = base_model_inference(text) | |
| else: | |
| return "Modèle inconnu", pd.DataFrame() | |
| # Convertir le dictionnaire des scores en DataFrame pour affichage | |
| scores_df = pd.DataFrame([ | |
| {"Classe": label, "Score": score} for label, score in scores.items() | |
| ]) | |
| return prediction, scores_df | |
| # === Interface Gradio avec deux onglets === | |
| with gr.Blocks(title="Classification AG News (4 stratégies)") as app: | |
| gr.Markdown("# 📰 Classification de textes AG News") | |
| gr.Markdown("Cette application compare plusieurs approches NLP pour classer des actualités.") | |
| with gr.Tab("🧠 Inférence"): | |
| gr.Markdown("### ✍️ Entrez une phrase d'actualité à analyser") | |
| # Entrées utilisateur | |
| input_text = gr.Textbox( | |
| lines=4, | |
| placeholder="Ex: Apple lance un nouveau produit basé sur l'intelligence artificielle...", | |
| label="Texte à classifier" | |
| ) | |
| model_choice = gr.Radio( | |
| choices=["Base model", "Zero-shot", "Few-shot", "Fine-tuned"], | |
| label="Choisir le modèle", | |
| value="Base model" | |
| ) | |
| predict_button = gr.Button("🔍 Prédire") | |
| # Sorties | |
| label_output = gr.Label(label="🧾 Catégorie prédite") | |
| scores_output = gr.BarPlot( | |
| label="📊 Scores de confiance", | |
| x="Classe", y="Score", color="Classe" | |
| ) | |
| # Action sur clic bouton | |
| predict_button.click( | |
| fn=predict_with_model, | |
| inputs=[input_text, model_choice], | |
| outputs=[label_output, scores_output] | |
| ) | |
| with gr.Tab("📄 Documentation"): | |
| gr.Markdown(readme_content) | |
| # Lancer l'app | |
| if __name__ == "__main__": | |
| app.launch() | |